You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
223 lines
8.5 KiB
223 lines
8.5 KiB
import threading
|
|
# from checkPlaceName import checkPlaceName
|
|
# from checkRepeatText import checkRepeatText
|
|
# from checkCompanyName import checkCompanyName
|
|
# from checkDocumentError import checkDocumentError
|
|
# from checkTitleName import checkTitleName
|
|
# from myLogger import outLog
|
|
# import time
|
|
# def run_check_company_name(filename,user_id):
|
|
# for i in checkCompanyName(filename,user_id):
|
|
# pass
|
|
|
|
# def run_get_document_error(filename,user_id):
|
|
# for i in checkDocumentError(filename,user_id):
|
|
# pass
|
|
# def runcheckTitleName(filename,user_id):
|
|
# for i in checkTitleName(filename,user_id):
|
|
# pass
|
|
# def runcheckRepeatText(filename,user_id):
|
|
# for i in checkRepeatText(filename,user_id):
|
|
# pass
|
|
# def runcheckPlaceName(filename,user_id):
|
|
# for i in checkPlaceName(filename,user_id):
|
|
# pass
|
|
|
|
# def get(user_id):
|
|
# time.sleep(5)
|
|
# while True:
|
|
# if outLog.is_done(user_id):
|
|
# break
|
|
# q = outLog.get_queueData(user_id)
|
|
# if q:
|
|
# text = q.pop(0)
|
|
# print(text)
|
|
# print("打印结束")
|
|
|
|
# filename = "17.docx"
|
|
|
|
# # 创建线程getapp
|
|
# thread1 = threading.Thread(target=run_check_company_name, args=(filename,"1"))
|
|
# thread2 = threading.Thread(target=run_get_document_error, args=(filename,"1"))
|
|
# thread3 = threading.Thread(target=runcheckTitleName, args=(filename,"1"))
|
|
# thread4 = threading.Thread(target=runcheckRepeatText, args=(filename,"1"))
|
|
# thread5 = threading.Thread(target=runcheckPlaceName, args=(filename,"1"))
|
|
# thread6 = threading.Thread(target=get, args=("1",))
|
|
# thread1 = threading.Thread(target=getapp, args=(filename,))
|
|
# thread2 = threading.Thread(target=getapp, args=(filename,))
|
|
# thread3 = threading.Thread(target=getapp, args=(filename,))
|
|
# thread4 = threading.Thread(target=getapp, args=(filename,))
|
|
# thread5 = threading.Thread(target=getapp, args=(filename,))
|
|
# thread6 = threading.Thread(target=getapp, args=("1",))
|
|
# # 启动线程
|
|
# thread1.start()
|
|
# thread2.start()
|
|
# thread3.start()
|
|
# thread4.start()
|
|
# thread5.start()
|
|
# thread6.start()
|
|
# # 等待线程完成
|
|
# thread1.join()
|
|
# thread2.join()
|
|
# thread3.join()
|
|
# thread4.join()
|
|
# thread5.join()
|
|
# thread6.join()
|
|
# print("Both tasks completed.")
|
|
# from pycorrector.macbert.macbert_corrector import MacBertCorrector
|
|
|
|
# m = MacBertCorrector("models")
|
|
# for i in range(10):
|
|
# i = m.correct("行政捡查是行政机关覆行政府职能、管理经济社会事务的重要方式,开展计划统筹是行政检查控总量、提质效的重要措施和手段,直接影响改革或得感和社会满意度")
|
|
# print(i)
|
|
# import re
|
|
# import json
|
|
# import json_repair
|
|
# import math
|
|
# import os
|
|
# import platform
|
|
# import torch
|
|
# import torch_npu
|
|
# import operator
|
|
# from torch_npu.contrib import transfer_to_npu
|
|
# torch_device = "npu:4" # 0~7
|
|
# torch.npu.set_device(torch.device(torch_device))
|
|
# torch.npu.set_compile_mode(jit_compile=False)
|
|
# from transformers import BertTokenizerFast,BertForMaskedLM
|
|
# # option = {}
|
|
# # option["NPU_FUZZY_COMPILE_BLACKLIST"] = "Tril"
|
|
# # torch.npu.set_option(option)
|
|
# print("torch && torch_npu import successfully")
|
|
|
|
# DEFAULT_CKPT_PATH = 'macbert4csc'
|
|
# #models=----macbert4csc-base-chinese
|
|
# model = BertForMaskedLM.from_pretrained(
|
|
# DEFAULT_CKPT_PATH,
|
|
# torch_dtype=torch.float16,
|
|
# device_map=torch_device
|
|
# ).npu().eval()
|
|
# tokenizer = BertTokenizerFast.from_pretrained(DEFAULT_CKPT_PATH)
|
|
# def get_errors(corrected_text, origin_text):
|
|
# sub_details = []
|
|
# for i, ori_char in enumerate(origin_text):
|
|
# if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
|
|
# # add unk word
|
|
# corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
|
|
# continue
|
|
# if i >= len(corrected_text):
|
|
# continue
|
|
# if ori_char != corrected_text[i]:
|
|
# if ori_char.lower() == corrected_text[i]:
|
|
# # pass english upper char
|
|
# corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
|
|
# continue
|
|
# sub_details.append((ori_char, corrected_text[i], i, i + 1))
|
|
# sub_details = sorted(sub_details, key=operator.itemgetter(2))
|
|
# return corrected_text, sub_details
|
|
# result = []
|
|
# def getapp(gettext):
|
|
# result = []
|
|
# batchNum = 20
|
|
# sentences = re.split(r'[。\n]', gettext)
|
|
# # 去掉空字符
|
|
# sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
|
|
# # 计算总字符数
|
|
# total_chars = len(sentences)
|
|
|
|
# # 计算有多少份
|
|
# num_chunks = math.ceil(total_chars / batchNum)
|
|
|
|
# # 按batchNum字为一份进行处理
|
|
# chunks = [sentences[i:i + batchNum] for i in range(0, total_chars, batchNum)]
|
|
# # 打印每一份的内容
|
|
# err = []
|
|
# for i, chunk in enumerate(chunks):
|
|
# inputs = tokenizer(chunk, padding=True, return_tensors='pt').to(torch_device)
|
|
# with torch.no_grad():
|
|
# outputs = model(**inputs)
|
|
# for id, (logit_tensor, sentence) in enumerate(zip(outputs.logits, chunk)):
|
|
# decode_tokens_new = tokenizer.decode(
|
|
# torch.argmax(logit_tensor, dim=-1), skip_special_tokens=True).split(' ')
|
|
# decode_tokens_new = decode_tokens_new[:len(sentence)]
|
|
# if len(decode_tokens_new) == len(sentence):
|
|
# probs = torch.max(torch.softmax(logit_tensor, dim=-1), dim=-1)[0].cpu().numpy()
|
|
# decode_str = ''
|
|
# for i in range(len(sentence)):
|
|
# if probs[i + 1] >= 0.7:
|
|
# decode_str += decode_tokens_new[i]
|
|
# else:
|
|
# decode_str += sentence[i]
|
|
# corrected_text = decode_str
|
|
# else:
|
|
# corrected_text = sentence
|
|
# print(corrected_text)
|
|
#outputs = model(**tokenizer(chunk, padding=True, return_tensors='pt').to(torch_device))
|
|
# for ids, text in zip(outputs.logits, chunk):
|
|
# _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
|
|
# corrected_text = _text[:len(text)]
|
|
# corrected_text, details = get_errors(corrected_text, text)
|
|
# print(text, ' => ', corrected_text, details)
|
|
# result.append((corrected_text, details))
|
|
# for i, sent in enumerate(chunk):
|
|
# decode_tokens = tokenizer.decode(outputs[i], skip_special_tokens=True).replace(' ', '')
|
|
# corrected_sent = decode_tokens[:len(sent)]
|
|
# print(corrected_sent)
|
|
# corrected_sents.append(corrected_sent)
|
|
# from flask import Flask, request, jsonify
|
|
# import threading
|
|
# import time
|
|
# import re
|
|
# import math
|
|
# from macbert_corrector import MacBertCorrector
|
|
# m = MacBertCorrector("macbert4csc")
|
|
# app = Flask(__name__)
|
|
|
|
# # 创建一个锁对象
|
|
# lock = threading.Lock()
|
|
# #多线程但是每次只处理一个请求,多余的请求需要排队
|
|
# @app.route('/taskflow/checkDocumentError', methods=['POST'])
|
|
# def process_request():
|
|
# with lock:
|
|
|
|
# data = request.get_json()
|
|
# # print("data",data)
|
|
# # 提取文本数据
|
|
# text_data = data.get('data', {}).get('text', [])
|
|
# # print(text_data)
|
|
# # 处理文本数据,例如检查错误等
|
|
# # 这里您可以添加实际的逻辑来检查文档错误
|
|
# res = m.correct_batch(text_data)
|
|
# # 示例:简单打印接收到的文本
|
|
# # # 返回响应
|
|
# return jsonify({"status": "success", "data": res}), 200
|
|
|
|
# if __name__ == '__main__':
|
|
# app.run(threaded=True,port=5001)
|
|
from fastapi import FastAPI, Request
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
|
import uvicorn
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
app = FastAPI()
|
|
from macbert_corrector import MacBertCorrector
|
|
m = MacBertCorrector("macbert4csc")
|
|
|
|
class RequestData(BaseModel):
|
|
data: dict
|
|
@app.post("/taskflow/checkDocumentError")
|
|
async def process_request(request: RequestData):
|
|
global m
|
|
# 提取文本数据
|
|
text_data = request.data.get('text')
|
|
# 处理文本数据,例如检查错误等
|
|
# 这里您可以添加实际的逻辑来检查文档错误
|
|
|
|
# print(text_data)
|
|
# 处理文本数据,例如检查错误等
|
|
# 这里您可以添加实际的逻辑来检查文档错误
|
|
res = m.correct_batch(text_data)
|
|
|
|
# 返回响应
|
|
return JSONResponse(content={"status": "success", "data": res}, status_code=200)
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=5001)
|