import threading # from checkPlaceName import checkPlaceName # from checkRepeatText import checkRepeatText # from checkCompanyName import checkCompanyName # from checkDocumentError import checkDocumentError # from checkTitleName import checkTitleName # from myLogger import outLog # import time # def run_check_company_name(filename,user_id): # for i in checkCompanyName(filename,user_id): # pass # def run_get_document_error(filename,user_id): # for i in checkDocumentError(filename,user_id): # pass # def runcheckTitleName(filename,user_id): # for i in checkTitleName(filename,user_id): # pass # def runcheckRepeatText(filename,user_id): # for i in checkRepeatText(filename,user_id): # pass # def runcheckPlaceName(filename,user_id): # for i in checkPlaceName(filename,user_id): # pass # def get(user_id): # time.sleep(5) # while True: # if outLog.is_done(user_id): # break # q = outLog.get_queueData(user_id) # if q: # text = q.pop(0) # print(text) # print("打印结束") # filename = "17.docx" # # 创建线程getapp # thread1 = threading.Thread(target=run_check_company_name, args=(filename,"1")) # thread2 = threading.Thread(target=run_get_document_error, args=(filename,"1")) # thread3 = threading.Thread(target=runcheckTitleName, args=(filename,"1")) # thread4 = threading.Thread(target=runcheckRepeatText, args=(filename,"1")) # thread5 = threading.Thread(target=runcheckPlaceName, args=(filename,"1")) # thread6 = threading.Thread(target=get, args=("1",)) # thread1 = threading.Thread(target=getapp, args=(filename,)) # thread2 = threading.Thread(target=getapp, args=(filename,)) # thread3 = threading.Thread(target=getapp, args=(filename,)) # thread4 = threading.Thread(target=getapp, args=(filename,)) # thread5 = threading.Thread(target=getapp, args=(filename,)) # thread6 = threading.Thread(target=getapp, args=("1",)) # # 启动线程 # thread1.start() # thread2.start() # thread3.start() # thread4.start() # thread5.start() # thread6.start() # # 等待线程完成 # thread1.join() # thread2.join() # thread3.join() # thread4.join() # thread5.join() # thread6.join() # print("Both tasks completed.") # from pycorrector.macbert.macbert_corrector import MacBertCorrector # m = MacBertCorrector("models") # for i in range(10): # i = m.correct("行政捡查是行政机关覆行政府职能、管理经济社会事务的重要方式,开展计划统筹是行政检查控总量、提质效的重要措施和手段,直接影响改革或得感和社会满意度") # print(i) # import re # import json # import json_repair # import math # import os # import platform # import torch # import torch_npu # import operator # from torch_npu.contrib import transfer_to_npu # torch_device = "npu:4" # 0~7 # torch.npu.set_device(torch.device(torch_device)) # torch.npu.set_compile_mode(jit_compile=False) # from transformers import BertTokenizerFast,BertForMaskedLM # # option = {} # # option["NPU_FUZZY_COMPILE_BLACKLIST"] = "Tril" # # torch.npu.set_option(option) # print("torch && torch_npu import successfully") # DEFAULT_CKPT_PATH = 'macbert4csc' # #models=----macbert4csc-base-chinese # model = BertForMaskedLM.from_pretrained( # DEFAULT_CKPT_PATH, # torch_dtype=torch.float16, # device_map=torch_device # ).npu().eval() # tokenizer = BertTokenizerFast.from_pretrained(DEFAULT_CKPT_PATH) # def get_errors(corrected_text, origin_text): # sub_details = [] # for i, ori_char in enumerate(origin_text): # if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']: # # add unk word # corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] # continue # if i >= len(corrected_text): # continue # if ori_char != corrected_text[i]: # if ori_char.lower() == corrected_text[i]: # # pass english upper char # corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] # continue # sub_details.append((ori_char, corrected_text[i], i, i + 1)) # sub_details = sorted(sub_details, key=operator.itemgetter(2)) # return corrected_text, sub_details # result = [] # def getapp(gettext): # result = [] # batchNum = 20 # sentences = re.split(r'[。\n]', gettext) # # 去掉空字符 # sentences = [sentence.strip() for sentence in sentences if sentence.strip()] # # 计算总字符数 # total_chars = len(sentences) # # 计算有多少份 # num_chunks = math.ceil(total_chars / batchNum) # # 按batchNum字为一份进行处理 # chunks = [sentences[i:i + batchNum] for i in range(0, total_chars, batchNum)] # # 打印每一份的内容 # err = [] # for i, chunk in enumerate(chunks): # inputs = tokenizer(chunk, padding=True, return_tensors='pt').to(torch_device) # with torch.no_grad(): # outputs = model(**inputs) # for id, (logit_tensor, sentence) in enumerate(zip(outputs.logits, chunk)): # decode_tokens_new = tokenizer.decode( # torch.argmax(logit_tensor, dim=-1), skip_special_tokens=True).split(' ') # decode_tokens_new = decode_tokens_new[:len(sentence)] # if len(decode_tokens_new) == len(sentence): # probs = torch.max(torch.softmax(logit_tensor, dim=-1), dim=-1)[0].cpu().numpy() # decode_str = '' # for i in range(len(sentence)): # if probs[i + 1] >= 0.7: # decode_str += decode_tokens_new[i] # else: # decode_str += sentence[i] # corrected_text = decode_str # else: # corrected_text = sentence # print(corrected_text) #outputs = model(**tokenizer(chunk, padding=True, return_tensors='pt').to(torch_device)) # for ids, text in zip(outputs.logits, chunk): # _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '') # corrected_text = _text[:len(text)] # corrected_text, details = get_errors(corrected_text, text) # print(text, ' => ', corrected_text, details) # result.append((corrected_text, details)) # for i, sent in enumerate(chunk): # decode_tokens = tokenizer.decode(outputs[i], skip_special_tokens=True).replace(' ', '') # corrected_sent = decode_tokens[:len(sent)] # print(corrected_sent) # corrected_sents.append(corrected_sent) # from flask import Flask, request, jsonify # import threading # import time # import re # import math # from macbert_corrector import MacBertCorrector # m = MacBertCorrector("macbert4csc") # app = Flask(__name__) # # 创建一个锁对象 # lock = threading.Lock() # #多线程但是每次只处理一个请求,多余的请求需要排队 # @app.route('/taskflow/checkDocumentError', methods=['POST']) # def process_request(): # with lock: # data = request.get_json() # # print("data",data) # # 提取文本数据 # text_data = data.get('data', {}).get('text', []) # # print(text_data) # # 处理文本数据,例如检查错误等 # # 这里您可以添加实际的逻辑来检查文档错误 # res = m.correct_batch(text_data) # # 示例:简单打印接收到的文本 # # # 返回响应 # return jsonify({"status": "success", "data": res}), 200 # if __name__ == '__main__': # app.run(threaded=True,port=5001) from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig import uvicorn from fastapi.responses import JSONResponse from pydantic import BaseModel app = FastAPI() from macbert_corrector import MacBertCorrector m = MacBertCorrector("macbert4csc") class RequestData(BaseModel): data: dict @app.post("/taskflow/checkDocumentError") async def process_request(request: RequestData): global m # 提取文本数据 text_data = request.data.get('text') # 处理文本数据,例如检查错误等 # 这里您可以添加实际的逻辑来检查文档错误 # print(text_data) # 处理文本数据,例如检查错误等 # 这里您可以添加实际的逻辑来检查文档错误 res = m.correct_batch(text_data) # 返回响应 return JSONResponse(content={"status": "success", "data": res}, status_code=200) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=5001)