You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
71 lines
2.7 KiB
71 lines
2.7 KiB
# from paddlenlp import Taskflow
|
|
# similarity1 = Taskflow("text_similarity",device_id=3,precision='fp16')##checkRepeatText
|
|
# from flask import Flask, request, jsonify
|
|
# import threading
|
|
# app = Flask(__name__)
|
|
|
|
# # 创建一个锁对象
|
|
# lock = threading.Lock()
|
|
# @app.route('/taskflow/checkRepeatText', methods=['POST'])
|
|
# def process_request():
|
|
# with lock:
|
|
# data = request.get_json()
|
|
# # print("data",data)
|
|
# # 提取文本数据
|
|
# text_data = data.get('data', {}).get('text')
|
|
# # 处理文本数据,例如检查错误等
|
|
# # 这里您可以添加实际的逻辑来检查文档错误
|
|
# res =similarity1(text_data)
|
|
# # 示例:简单打印接收到的文本
|
|
# # # 返回响应
|
|
# return jsonify({"status": "success", "data": res}), 200
|
|
|
|
# if __name__ == '__main__':
|
|
# app.run(threaded=True,port=8192)
|
|
from sentence_transformers import SentenceTransformer, util
|
|
import itertools
|
|
from fastapi import FastAPI, Request
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
|
import uvicorn
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
import torch
|
|
app = FastAPI()
|
|
model = SentenceTransformer("shibing624/text2vec-base-chinese",device="npu:5")
|
|
|
|
class RequestData(BaseModel):
|
|
data: dict
|
|
@app.post("/taskflow/checkRepeatText")
|
|
async def process_request(request: RequestData):
|
|
global model
|
|
# 提取文本数据
|
|
text_data = request.data.get('text')
|
|
a=text_data[0][0]
|
|
b=text_data[0][1]
|
|
emb_a = model.encode(a)
|
|
emb_b = model.encode(b)
|
|
cos_sim = util.cos_sim(emb_a, emb_b)
|
|
results = []
|
|
results.append({"text1":a,"text2":b,"similarity":cos_sim.item()})
|
|
# 返回响应
|
|
return JSONResponse(content={"status": "success", "data": results}, status_code=200)
|
|
@app.post("/taskflow/getRepeatText")
|
|
async def process_request(request: RequestData):
|
|
global model
|
|
# 提取文本数据
|
|
text_data = request.data.get('text')
|
|
allcorpus =text_data[0] #全部文档信息
|
|
query=text_data[1] #要查询的文档信息
|
|
corpus_embeddings = model.encode(allcorpus, convert_to_tensor=True)
|
|
top_k = min(4, len(allcorpus))
|
|
query_embedding = model.encode(query, convert_to_tensor=True)
|
|
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
|
|
top_results = torch.topk(cos_scores, k=top_k)
|
|
results = []
|
|
for score, idx in zip(top_results[0], top_results[1]):
|
|
print(allcorpus[idx], "(Score: {:.4f})".format(score.item()))
|
|
results.append({"text1":allcorpus[idx],"text2":query,"similarity":score.item()})
|
|
# 返回响应
|
|
return JSONResponse(content={"status": "success", "data": results}, status_code=200)
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=8192)
|
|
|