You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
134 lines
4.8 KiB
134 lines
4.8 KiB
5 months ago
|
# -*- coding:utf-8 -*-
|
||
|
import time
|
||
|
from docx import Document
|
||
|
from paddlenlp import Taskflow
|
||
|
from qwen_agent.agents import Assistant
|
||
|
import re
|
||
|
import json_repair
|
||
|
wordtag = Taskflow("knowledge_mining")
|
||
|
|
||
|
prompt = '''
|
||
|
.根据上述文本判断,是否为具体的公司或组织名称,你可以使用工具利用互联网查询,
|
||
|
你只能在[具体的公司或组织名称,公益组织,简称,统称,泛化组织,政府单位,机关单位,学校,行业类型,其他]选项中选择答案,
|
||
|
回答格式[{“companyName”:“名称”,"回答":"答案"},{“companyName”:“名称”,"回答":"答案"}],不做过多的解释,严格按回答格式作答;
|
||
|
'''
|
||
|
llm_cfg = {
|
||
|
#'model': 'qwen1.5-72b-chat',
|
||
|
'model':"qwen2-72b",
|
||
|
'model_server': 'http://127.0.0.1:1025/v1', # base_url, also known as api_base
|
||
|
# 'api_key': 'sk-ea89cf04431645b185990b8af8c9bb13',
|
||
|
}
|
||
|
bot = Assistant(llm=llm_cfg,
|
||
|
name='Assistant',
|
||
|
# system_message="你是一个地理专家,可以准确的判断地理位置,如果你不确定,可以使用工具"
|
||
|
)
|
||
|
|
||
|
def getDocxToTextAll(name):
|
||
|
docxPath=name
|
||
|
document = Document(docxPath)
|
||
|
# 逐段读取docx文档的内容
|
||
|
levelList=[]
|
||
|
words=[]
|
||
|
addStart = False
|
||
|
levelText=""
|
||
|
i = 0
|
||
|
for paragraph in document.paragraphs:
|
||
|
# 判断该段落的标题级别
|
||
|
# 这里用isTitle()临时代表,具体见下文介绍的方法
|
||
|
text = paragraph.text
|
||
|
if text.strip():#非空判断
|
||
|
# print("非空")
|
||
|
words.append(text)
|
||
|
# 将所有段落文本拼接成一个字符串,并用换行符分隔
|
||
|
print("checkCompanyName",len(words))
|
||
|
text = '\n'.join(words)
|
||
|
|
||
|
# 将文本写入txt文件
|
||
|
with open("checkCompanyName.txt", 'w', encoding='utf-8') as txt_file:
|
||
|
txt_file.write(text)
|
||
|
def checkCompanyName(filename):
|
||
|
getDocxToTextAll(filename)
|
||
|
start_time=time.time()
|
||
|
error_places = []
|
||
|
for batch in read_file_in_batches('checkCompanyName.txt'):
|
||
|
res=process_batch(batch)
|
||
|
if(len(res)>0):
|
||
|
error_places.extend(res)
|
||
|
|
||
|
print(error_places)
|
||
|
end_time = time.time()
|
||
|
# 计算执行时间
|
||
|
elapsed_time = end_time - start_time
|
||
|
print(f"checkCompanyName程序执行时间: {elapsed_time} 秒")
|
||
|
return error_places
|
||
|
|
||
|
def read_file_in_batches(file_path, batch_size=5000):
|
||
|
"""
|
||
|
分批读取文本文件
|
||
|
:param file_path: 文件路径
|
||
|
:param batch_size: 每批处理的字符数
|
||
|
:return: 生成器,每次返回一批文本
|
||
|
"""
|
||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||
|
batch = []
|
||
|
char_count = 0
|
||
|
for line in file:
|
||
|
batch.append(line)
|
||
|
char_count += len(line)
|
||
|
if char_count >= batch_size:
|
||
|
yield ''.join(batch)
|
||
|
batch = []
|
||
|
char_count = 0
|
||
|
if batch:
|
||
|
yield ''.join(batch)
|
||
|
|
||
|
def process_batch(batch):
|
||
|
"""
|
||
|
处理一批文本
|
||
|
:param batch: 一批文本
|
||
|
"""
|
||
|
# 在这里添加你的处理逻辑
|
||
|
|
||
|
# sentences = re.split(r'[。\n]', batch)
|
||
|
# sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
|
||
|
res=wordtag(batch)
|
||
|
placeList = []
|
||
|
isplace = False
|
||
|
for zuhe in res[0]['items']:
|
||
|
# 上一个的地名,这一个还是地名,就和上一个相加代替这个
|
||
|
zhi = zuhe.get("wordtag_label")
|
||
|
if isplace:
|
||
|
name = placeList[len(placeList) - 1]
|
||
|
if zhi.find("组织机构类")>=0 : # or zuhe[1] == "ns"
|
||
|
isplace = True
|
||
|
new_text = zuhe['item'].replace("\n", "")
|
||
|
placeList[len(placeList) - 1] = name + new_text
|
||
|
continue
|
||
|
if zhi.find("组织机构类")>=0 :
|
||
|
isplace = True
|
||
|
new_text = zuhe['item'].replace("\n", "")
|
||
|
placeList.append(new_text)
|
||
|
else:
|
||
|
isplace = False
|
||
|
placeList=list(dict.fromkeys(placeList))
|
||
|
placeStr = ",".join(placeList)
|
||
|
messages = [{'role': 'user', 'content': [{'text': placeStr+prompt}]}]
|
||
|
print("checkCompanyName",placeStr+prompt)
|
||
|
runList = []
|
||
|
for rsp in bot.run(messages):
|
||
|
runList.append(rsp)
|
||
|
data = runList[len(runList) - 1][0]["content"]
|
||
|
print("checkCompanyName",data)
|
||
|
parsed_data = json_repair.loads(data.replace('`', ''))
|
||
|
error_places = [place for place in parsed_data if place['回答'] == '具体的公司或组织名称']
|
||
|
print("checkCompanyName",error_places)
|
||
|
if len(error_places)>0:
|
||
|
for t in error_places:
|
||
|
keyword= t['companyName']
|
||
|
# 查找包含关键字的段落
|
||
|
paragraphs = re.findall(r'.*?' + re.escape(keyword) + r'.*?\n', batch)
|
||
|
t["yuanwen"]=paragraphs[0]
|
||
|
return error_places
|
||
|
else:
|
||
|
return error_places
|