1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > 中医药天池大数据竞赛——中医文献问题生成挑战(三)

中医药天池大数据竞赛——中医文献问题生成挑战(三)

时间:2019-06-30 04:40:13

相关推荐

中医药天池大数据竞赛——中医文献问题生成挑战(三)

前两篇主要写了数据预处理(/jasmine0244/article/details/108888236)和模型搭建(/jasmine0244/article/details/108902127),接下来就是K折模型验证评估。

官方的评估标准是以ROUGE-L(/anthology/W04-1013.pdf)为准。我在github上找到了一个免安装的版本,rouge4chinese(/hpzhao/nlp-metrics)。

首先我们将K折的验证集也输出来,方便和验证集的预测结果进行评估

from tqdm import tqdmk_folds = 5for mode in range(k_folds): valid_data = [data[j] for i, j in enumerate(random_order) if i % k_folds == mode]print(len(valid_data))with open("ref_{0}".format(mode), 'w', encoding='utf-8') as f:for d in tqdm(iter(valid_data), desc=u'正在输出(共%s条样本)' % len(valid_data)):s = '%s\t%s\t%s\n' % (d[1], d[2], d[0])f.write(s)f.flush()

又因为rouge4chinese用来评估的文件文本是分词后的结果,因此对输出文本均做分词处理,我这里用的是jieba分词器

import jiebadef cut(old_file, new_file):with open(old_file,"r") as f:lines = f.readlines()with open(new_file, 'w', encoding='utf-8') as f:for line in lines:s = " ".join(list(jieba.cut(line))) + "\n"f.write(s)f.flush()# 验证集预测输出分词for mode in range(k_folds):cut('qa_{0}.csv'.format(mode),'qa_{0}.txt'.format(mode))# 验证集原始分词for mode in range(k_folds):cut('ref_{0}'.format(mode),'ref_{0}.txt'.format(mode))

然后修改rouge4chinese里的run.sh脚本 中GEN和REF的路径即可

生成了5个sh文件,分别表示K=0-4这5种情况的评估结果,看下面记录,第0份结果比较好(第1份效果最差,已经删除了,没记录)

最后就是采用最优模型来预测:

def submit(qag,submit_file):with open(test_file,"r") as f:text = f.read()test_json = json.loads(text)print("test_json's size: {0}".format(len(test_json)))pre_data = []for line in tqdm(iter(test_json), desc=u'正在预测(共%s条样本)' % len(test_json)):one_json = dict()one_json['id'] = line['id']one_json['text'] = line['text']one_json['annotations'] = []for p in line['annotations']:if p['A']:ans = p['A']para = Nonefor t in text_segmentate(line['text'], max_p_len - 2, seps, strips):if ans in t:para = tbreakif para:q = qag.generate(para, ans)else:q = qag.generate(line['text'], ans)one_json['annotations'].append({'Q': q, 'A': p['A']})else:one_json['annotations'].append(p)pre_data.append(one_json)print("pre_data's size: {0}".format(len(pre_data)))with open(submit_file, 'w', encoding='utf-8') as f1:f1.write(json.dumps(pre_data, indent=4, ensure_ascii=False))

if __name__ == "__main__":#k_folds_train(k_folds)best_model='../best_model.weights_0'model = build_model()model.load_weights(best_model) qag = QuestionAnswerGeneration(start_id=None, end_id=tokenizer._token_end_id, maxlen=max_q_len)submit_file = "../prediction_result/result.json"submit(qag,submit_file)

最后输出的结果我没提交到天池平台,因为忘记在天池平台上点击支付宝实名认证被淘汰了。但是我第一次提交的成绩是

这个效果应该会比我第一次好,接下去我也没再做改进了,算法上也没有什么贡献,就止步了。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。