You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
226 lines
8.7 KiB
226 lines
8.7 KiB
5 months ago
|
#-*- coding:utf-8 -*-
|
||
|
# from pycorrector import MacBertCorrector
|
||
|
# m = MacBertCorrector("shibing624/macbert4csc-base-chinese")
|
||
|
from qwen_agent.agents import Assistant
|
||
|
from docx import Document
|
||
|
from pprint import pprint
|
||
|
import re
|
||
|
from paddlenlp import Taskflow
|
||
|
import json
|
||
|
import time
|
||
|
import json_repair
|
||
|
print(json_repair.loads('{"name":""aaaa"}'))
|
||
|
start_time = time.time()
|
||
|
corrector = Taskflow("text_correction")
|
||
|
llm_cfg = {
|
||
|
#'model': 'qwen1.5-72b-chat',
|
||
|
'model':"qwen2-72b",
|
||
|
'model_server': 'http://127.0.0.1:1025/v1', # base_url, also known as api_base
|
||
|
# 'api_key': 'sk-ea89cf04431645b185990b8af8c9bb13',
|
||
|
}
|
||
|
bot = Assistant(llm=llm_cfg,
|
||
|
name='Assistant',
|
||
|
# description='使用RAG检索并回答,支持文件类型:PDF/Word/PPT/TXT/HTML。'
|
||
|
|
||
|
)
|
||
|
# prompt='''
|
||
|
# 是否存在错别字,若存在请指出,不做其他方面的校验,你只能在[存在,不存在,未知]选项中选择答案,
|
||
|
# 回答格式[{“placeName”:“原文”,"改正后":"改正的内容","回答":"答案"},{“placeName”:“原文”,"改正后":"改正的内容","回答":"答案"}],不做过多的解释,严格按回答格式作答;
|
||
|
# '''
|
||
|
prompt='''
|
||
|
请回答以上问题,[是,否]选项中选择答案,原文内容,标点符号保持不变,如果有错请给出解析,没有错则不用给解析
|
||
|
回答格式请按照以下json格式[{"placeName":"序号","回答":"答案","jianyi","解析"},{"placeName":"序号","回答":"答案","jianyi","解析"}],不做过多的解释,严格按回答格式作答;
|
||
|
'''
|
||
|
def getDocxToTextAll(name):
|
||
|
docxPath=name
|
||
|
document = Document(docxPath)
|
||
|
# 逐段读取docx文档的内容
|
||
|
levelList=[]
|
||
|
words=[]
|
||
|
addStart = False
|
||
|
levelText=""
|
||
|
i = 0
|
||
|
for paragraph in document.paragraphs:
|
||
|
# 判断该段落的标题级别
|
||
|
# 这里用isTitle()临时代表,具体见下文介绍的方法
|
||
|
text = paragraph.text
|
||
|
if text.strip():#非空判断
|
||
|
# print("非空")
|
||
|
words.append(text)
|
||
|
# 将所有段落文本拼接成一个字符串,并用换行符分隔
|
||
|
print("checkDocumentError",len(words))
|
||
|
text = '\n'.join(words)
|
||
|
|
||
|
# 将文本写入txt文件
|
||
|
with open("checkDocumentError.txt", 'w', encoding='utf-8') as txt_file:
|
||
|
txt_file.write(text)
|
||
|
def getDocumentError(filename):
|
||
|
getDocxToTextAll(filename)
|
||
|
error_places = []
|
||
|
# # 打开文件
|
||
|
for batch in read_file_in_batches('checkDocumentError.txt'):
|
||
|
res=process_batch(batch)
|
||
|
if(len(res)>0):
|
||
|
error_places.extend(res)
|
||
|
|
||
|
pprint(error_places)
|
||
|
end_time = time.time()
|
||
|
# 计算执行时间
|
||
|
elapsed_time = end_time - start_time
|
||
|
print(f"checkDocumentError程序执行时间: {elapsed_time} 秒")
|
||
|
return error_places
|
||
|
#
|
||
|
# 过滤掉填充的None(如果有的话)
|
||
|
# chunk = [line for line in chunk if line is not None]
|
||
|
# res = m.correct_batch(sentences)
|
||
|
# print("DocumentError",res)
|
||
|
# lines_with_greeting = [place for place in res if len( place['errors'])>0]
|
||
|
# error_places.extend(lines_with_greeting)
|
||
|
# pprint(error_places)
|
||
|
# if len(lines_with_greeting)>0:
|
||
|
# for t in error_places:
|
||
|
# keyword= t['source']
|
||
|
#
|
||
|
# errorWord=t["errors"]
|
||
|
# # 查找包含关键字的段落
|
||
|
# paragraphs = re.findall(r'.*?' + re.escape(keyword) + r'.*?\n', gettext)
|
||
|
# t["yuanwen"]=paragraphs[0]
|
||
|
# return error_places
|
||
|
# else:
|
||
|
# return error_places
|
||
|
# return lines_with_greeting
|
||
|
def read_file_in_batches(file_path, batch_size=5000):
|
||
|
"""
|
||
|
分批读取文本文件
|
||
|
:param file_path: 文件路径
|
||
|
:param batch_size: 每批处理的字符数
|
||
|
:return: 生成器,每次返回一批文本
|
||
|
"""
|
||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||
|
batch = []
|
||
|
char_count = 0
|
||
|
for line in file:
|
||
|
batch.append(line)
|
||
|
char_count += len(line)
|
||
|
if char_count >= batch_size:
|
||
|
yield ''.join(batch)
|
||
|
batch = []
|
||
|
char_count = 0
|
||
|
if batch:
|
||
|
yield ''.join(batch)
|
||
|
|
||
|
def process_batch(batch):
|
||
|
"""
|
||
|
处理一批文本
|
||
|
:param batch: 一批文本
|
||
|
"""
|
||
|
# 在这里添加你的处理逻辑
|
||
|
# error_places=[]
|
||
|
sentences = re.split(r'[。\n]', batch)
|
||
|
# 去掉空字符串
|
||
|
sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
|
||
|
res = corrector(sentences)
|
||
|
lines_with_greeting = [place for place in res if len(place['errors']) > 0]
|
||
|
# error_places.extend(lines_with_greeting)
|
||
|
# pprint(error_places)
|
||
|
words=''
|
||
|
err=[]
|
||
|
if len(lines_with_greeting) > 0:
|
||
|
num=0
|
||
|
wenti=[]#记录问题的数组
|
||
|
keyword_list = []#记录问题
|
||
|
for t in lines_with_greeting:
|
||
|
temp_errorWords = []
|
||
|
keyword = t['source']
|
||
|
keyword_list.append(keyword)
|
||
|
for item in t["errors"]:
|
||
|
for key, value in item['correction'].items():
|
||
|
temp_errorWords.append(key)
|
||
|
wenti.append("{}、原文:{}。问题:【{}】这些字是否为当前原文的错别字".format(num,keyword,",".join(temp_errorWords)))
|
||
|
num+=1
|
||
|
words ="\n".join(wenti)
|
||
|
|
||
|
messages = [{'role': 'user', 'content': [{'text': words+ prompt}]}]
|
||
|
runList = []
|
||
|
print(words+ prompt)
|
||
|
for rsp in bot.run(messages):
|
||
|
runList.append(rsp)
|
||
|
data = runList[len(runList) - 1][0]["content"]
|
||
|
pprint(data)
|
||
|
parsed_data = json_repair.loads(data.replace("\\","").replace('`', ''))
|
||
|
err = [
|
||
|
{**place, "placeName": keyword_list[int(place["placeName"])],"jianyi":place["解析"]}
|
||
|
for place in parsed_data
|
||
|
if place['回答'] == '是'
|
||
|
]
|
||
|
pprint(err)
|
||
|
# err = [place["placeName"]=keyword_list[int(place["placeName"])] for place in parsed_data if place['回答'] == '是']
|
||
|
# if len(err) > 0:
|
||
|
# # for t in error_places:
|
||
|
# # keyword = t['placeName']
|
||
|
# # # 查找包含关键字的段落
|
||
|
# # paragraphs = re.findall(r'.*?' + re.escape(keyword) + r'.*?\n', gettext)
|
||
|
# # t["yuanwen"] = paragraphs[0]
|
||
|
# return err
|
||
|
# else:
|
||
|
return err
|
||
|
|
||
|
# from flask import Flask, request, jsonify
|
||
|
# import os
|
||
|
# # from checkPlaceName import checkPlaceName
|
||
|
# # from checkRepeatText import checkRepeatText
|
||
|
# # from checkCompanyName import checkCompanyName
|
||
|
# # from documentError import getDocumentError
|
||
|
# app = Flask(__name__)
|
||
|
# UPLOAD_FOLDER = 'uploads'
|
||
|
# if not os.path.exists(UPLOAD_FOLDER):
|
||
|
# os.makedirs(UPLOAD_FOLDER)
|
||
|
# @app.route('/upload', methods=['POST'])
|
||
|
# def upload_file():
|
||
|
# if 'file' not in request.files:
|
||
|
# return jsonify({"error": "No file part"}), 400
|
||
|
# file = request.files['file']
|
||
|
# if file.filename == '':
|
||
|
# return jsonify({"error": "No selected file"}), 400
|
||
|
# if file:
|
||
|
# filename = file.filename
|
||
|
# file.save(os.path.join(UPLOAD_FOLDER,filename))
|
||
|
# return jsonify({"message": "File uploaded successfully"}), 200
|
||
|
# # @app.route('/checkPlaceName/<filename>', methods=['GET'])
|
||
|
# # def checkPlaceNameWeb(filename):
|
||
|
# # return checkPlaceName(filename)
|
||
|
# # @app.route('/checkRepeatText/<filename>', methods=['GET'])
|
||
|
# # def checkRepeatTextWeb(filename):
|
||
|
# # return checkRepeatText(filename)
|
||
|
# # @app.route('/checkCompanyName/<filename>', methods=['GET'])
|
||
|
# # def checkCompanyNameWeb(filename):
|
||
|
# # return checkCompanyName(filename)
|
||
|
# # @app.route('/checkDocumentErrorWeb/<filename>', methods=['GET'])
|
||
|
# # def checkDocumentErrorWeb(filename):
|
||
|
# # return getDocumentError(filename)
|
||
|
# if __name__ == '__main__':
|
||
|
# app.run(host='0.0.0.0',port=80)
|
||
|
# from transformers import AutoTokenizer, AutoModel, GenerationConfig,AutoModelForCausalLM
|
||
|
# import os
|
||
|
# os.environ['NPU_VISIBLE_DEVICES']='0,1,2,3,4,5,6,7'
|
||
|
# os.environ['ASCEND_RT_VISIBLE_DEVICES']='0,1,2,3,4,5,6,7'
|
||
|
# import torch
|
||
|
# import torch_npu
|
||
|
# from torch_npu.contrib import transfer_to_npu
|
||
|
|
||
|
# from accelerate import Accelerator
|
||
|
|
||
|
# # device = 'cpu'
|
||
|
# accelerator = Accelerator()
|
||
|
# # torch_device = "npu" # 0~7
|
||
|
# # torch.npu.set_device(torch.device(torch_device))
|
||
|
# devices = []
|
||
|
# for i in range(8):
|
||
|
# devices.append(f"npu:{i}")
|
||
|
# print(devices)
|
||
|
# torch.npu.set_device(devices)
|
||
|
# torch.npu.set_compile_mode(jit_compile=False)
|
||
|
# model_name_or_path = '/mnt/sdc/qwen/Qwen2-72B-Instruct'
|
||
|
# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
|
||
|
# # model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, device_map="auto",torch_dtype=torch.float16)
|
||
|
# model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, device_map=accelerator,torch_dtype=torch.float16).npu().eval()
|