You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

223 lines
8.5 KiB

import threading
# from checkPlaceName import checkPlaceName
# from checkRepeatText import checkRepeatText
# from checkCompanyName import checkCompanyName
# from checkDocumentError import checkDocumentError
# from checkTitleName import checkTitleName
# from myLogger import outLog
# import time
# def run_check_company_name(filename,user_id):
# for i in checkCompanyName(filename,user_id):
# pass
# def run_get_document_error(filename,user_id):
# for i in checkDocumentError(filename,user_id):
# pass
# def runcheckTitleName(filename,user_id):
# for i in checkTitleName(filename,user_id):
# pass
# def runcheckRepeatText(filename,user_id):
# for i in checkRepeatText(filename,user_id):
# pass
# def runcheckPlaceName(filename,user_id):
# for i in checkPlaceName(filename,user_id):
# pass
# def get(user_id):
# time.sleep(5)
# while True:
# if outLog.is_done(user_id):
# break
# q = outLog.get_queueData(user_id)
# if q:
# text = q.pop(0)
# print(text)
# print("打印结束")
# filename = "17.docx"
# # 创建线程getapp
# thread1 = threading.Thread(target=run_check_company_name, args=(filename,"1"))
# thread2 = threading.Thread(target=run_get_document_error, args=(filename,"1"))
# thread3 = threading.Thread(target=runcheckTitleName, args=(filename,"1"))
# thread4 = threading.Thread(target=runcheckRepeatText, args=(filename,"1"))
# thread5 = threading.Thread(target=runcheckPlaceName, args=(filename,"1"))
# thread6 = threading.Thread(target=get, args=("1",))
# thread1 = threading.Thread(target=getapp, args=(filename,))
# thread2 = threading.Thread(target=getapp, args=(filename,))
# thread3 = threading.Thread(target=getapp, args=(filename,))
# thread4 = threading.Thread(target=getapp, args=(filename,))
# thread5 = threading.Thread(target=getapp, args=(filename,))
# thread6 = threading.Thread(target=getapp, args=("1",))
# # 启动线程
# thread1.start()
# thread2.start()
# thread3.start()
# thread4.start()
# thread5.start()
# thread6.start()
# # 等待线程完成
# thread1.join()
# thread2.join()
# thread3.join()
# thread4.join()
# thread5.join()
# thread6.join()
# print("Both tasks completed.")
# from pycorrector.macbert.macbert_corrector import MacBertCorrector
# m = MacBertCorrector("models")
# for i in range(10):
# i = m.correct("行政捡查是行政机关覆行政府职能、管理经济社会事务的重要方式,开展计划统筹是行政检查控总量、提质效的重要措施和手段,直接影响改革或得感和社会满意度")
# print(i)
# import re
# import json
# import json_repair
# import math
# import os
# import platform
# import torch
# import torch_npu
# import operator
# from torch_npu.contrib import transfer_to_npu
# torch_device = "npu:4" # 0~7
# torch.npu.set_device(torch.device(torch_device))
# torch.npu.set_compile_mode(jit_compile=False)
# from transformers import BertTokenizerFast,BertForMaskedLM
# # option = {}
# # option["NPU_FUZZY_COMPILE_BLACKLIST"] = "Tril"
# # torch.npu.set_option(option)
# print("torch && torch_npu import successfully")
# DEFAULT_CKPT_PATH = 'macbert4csc'
# #models=----macbert4csc-base-chinese
# model = BertForMaskedLM.from_pretrained(
# DEFAULT_CKPT_PATH,
# torch_dtype=torch.float16,
# device_map=torch_device
# ).npu().eval()
# tokenizer = BertTokenizerFast.from_pretrained(DEFAULT_CKPT_PATH)
# def get_errors(corrected_text, origin_text):
# sub_details = []
# for i, ori_char in enumerate(origin_text):
# if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
# # add unk word
# corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
# continue
# if i >= len(corrected_text):
# continue
# if ori_char != corrected_text[i]:
# if ori_char.lower() == corrected_text[i]:
# # pass english upper char
# corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
# continue
# sub_details.append((ori_char, corrected_text[i], i, i + 1))
# sub_details = sorted(sub_details, key=operator.itemgetter(2))
# return corrected_text, sub_details
# result = []
# def getapp(gettext):
# result = []
# batchNum = 20
# sentences = re.split(r'[。\n]', gettext)
# # 去掉空字符
# sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
# # 计算总字符数
# total_chars = len(sentences)
# # 计算有多少份
# num_chunks = math.ceil(total_chars / batchNum)
# # 按batchNum字为一份进行处理
# chunks = [sentences[i:i + batchNum] for i in range(0, total_chars, batchNum)]
# # 打印每一份的内容
# err = []
# for i, chunk in enumerate(chunks):
# inputs = tokenizer(chunk, padding=True, return_tensors='pt').to(torch_device)
# with torch.no_grad():
# outputs = model(**inputs)
# for id, (logit_tensor, sentence) in enumerate(zip(outputs.logits, chunk)):
# decode_tokens_new = tokenizer.decode(
# torch.argmax(logit_tensor, dim=-1), skip_special_tokens=True).split(' ')
# decode_tokens_new = decode_tokens_new[:len(sentence)]
# if len(decode_tokens_new) == len(sentence):
# probs = torch.max(torch.softmax(logit_tensor, dim=-1), dim=-1)[0].cpu().numpy()
# decode_str = ''
# for i in range(len(sentence)):
# if probs[i + 1] >= 0.7:
# decode_str += decode_tokens_new[i]
# else:
# decode_str += sentence[i]
# corrected_text = decode_str
# else:
# corrected_text = sentence
# print(corrected_text)
#outputs = model(**tokenizer(chunk, padding=True, return_tensors='pt').to(torch_device))
# for ids, text in zip(outputs.logits, chunk):
# _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
# corrected_text = _text[:len(text)]
# corrected_text, details = get_errors(corrected_text, text)
# print(text, ' => ', corrected_text, details)
# result.append((corrected_text, details))
# for i, sent in enumerate(chunk):
# decode_tokens = tokenizer.decode(outputs[i], skip_special_tokens=True).replace(' ', '')
# corrected_sent = decode_tokens[:len(sent)]
# print(corrected_sent)
# corrected_sents.append(corrected_sent)
# from flask import Flask, request, jsonify
# import threading
# import time
# import re
# import math
# from macbert_corrector import MacBertCorrector
# m = MacBertCorrector("macbert4csc")
# app = Flask(__name__)
# # 创建一个锁对象
# lock = threading.Lock()
# #多线程但是每次只处理一个请求,多余的请求需要排队
# @app.route('/taskflow/checkDocumentError', methods=['POST'])
# def process_request():
# with lock:
# data = request.get_json()
# # print("data",data)
# # 提取文本数据
# text_data = data.get('data', {}).get('text', [])
# # print(text_data)
# # 处理文本数据,例如检查错误等
# # 这里您可以添加实际的逻辑来检查文档错误
# res = m.correct_batch(text_data)
# # 示例:简单打印接收到的文本
# # # 返回响应
# return jsonify({"status": "success", "data": res}), 200
# if __name__ == '__main__':
# app.run(threaded=True,port=5001)
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import uvicorn
from fastapi.responses import JSONResponse
from pydantic import BaseModel
app = FastAPI()
from macbert_corrector import MacBertCorrector
m = MacBertCorrector("macbert4csc")
class RequestData(BaseModel):
data: dict
@app.post("/taskflow/checkDocumentError")
async def process_request(request: RequestData):
global m
# 提取文本数据
text_data = request.data.get('text')
# 处理文本数据,例如检查错误等
# 这里您可以添加实际的逻辑来检查文档错误
# print(text_data)
# 处理文本数据,例如检查错误等
# 这里您可以添加实际的逻辑来检查文档错误
res = m.correct_batch(text_data)
# 返回响应
return JSONResponse(content={"status": "success", "data": res}, status_code=200)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=5001)