🌑

Jenqyang

Chatglm2-6B医疗领域微调实战

10月中旬,赶着天池的一场中文医疗大模型比赛的尾巴,使用Chatglm2-6B进行了医疗领域的微调尝试,这里就稍稍讲解一下这场比赛以及对Chatglm2-6B使用PT方法进行微调的感悟。

比赛介绍

为推动LLM在医疗领域的发展和落地,将CBLUE基准进行二次开发,将18种不同的医疗场景NLP任务全部转化为基于提示的语言生成任务,形成首个中文医疗场景的LLM评测基准——PromptCBLUE。为了考察大模型领域的不同技术,对PromptCBLUE评测开放两个榜单:不微调和参数高效微调赛道(以下内容都是着眼于参数高效微调赛道)。

比赛规则

  1. 选手可以在开源的大模型主干之上添加参数高效微调模块(添加的额外参数量不得高于大模型主干参数量的1%), 且只能微调参数高效微调模块,不得微调大模型主干,且只能使用一组参数高效微调模块完成所有的任务。
  2. 选手可以通过扩展提示内容或者从训练集中选择示例等方式来增强模型表现,但不能更改或者删减指令的内容。
  3. 选手不能将LLM多次生成的结果进行集成(类似于self-consistency的方法),提交的每个测试样本预测结果必须是LLM单次回复生成的。

数据集介绍

{
        "input":  str,
        "target":  str,
        "type":  str,
        "answer_choices":  str,
        "sample_id":  str,
}

input字段字符串是LLM模型的输入,target字段也是一个字符串,则是LLM模型需要生成的文本序列。其他附加信息有: type是原任务类型(不能作为模型输入),answer_choices字段是选项,只有分类、术语标准化、推理类任务上该字段才会有实际取值,sample_id是样本编号。这些附加信息是不作为LLM的输入。

P-Tuning v2 微调

P-tuning v2微调方法是一种基于参数调整的微调方法,其主要思想是通过调整预训练模型的参数来提高模型的表现。这种方法将预训练模型的参数分成两个部分:固定参数和可调整参数。在微调过程中,可调整参数会根据任务数据进行调整,而固定参数则保持不变。这种方法能够提高模型的泛化性能,同时避免了过拟合问题。

在该项任务中进行如下的参数设置:

PRE_SEQ_LEN=128 # soft prompt 长度
LR=1e-2 # 学习率
NUM_GPUS=1 # GPU个数,因为我是单卡所以设置1

torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
    --do_train \
    --train_file PromptCBLUE/train.json \  # 训练集json文件路径,json-line格式
    --validation_file PromptCBLUE/dev.json \  # 验证集json文件路径,json-line格式
    --preprocessing_num_workers 10 \
    --prompt_column input \     # json中作为LLM输入的key
    --response_column target \   # json中作为LLM输出的key
    --overwrite_cache \
    --model_name_or_path /root/autodl-tmp/chatglm2-6b \   # 预训练模型名称或路径
    --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \  # 输出目录,用于保存模型和日志等文件
    --overwrite_output_dir \
    --max_source_length 64 \   # 输入文本的最大长度,单位是token
    --max_target_length 128 \  # 输出文本的最大长度,单位是token
    --per_device_train_batch_size 32 \  # 每个设备的训练批处理大小
    --per_device_eval_batch_size 1 \  # 每个设备的评估批处理大小
    --gradient_accumulation_steps 1 \  # 梯度累积次数
    --predict_with_generate \  # 是否使用生成模式进行预测
    --max_steps 1000 \  # 表示最大训练步数,即模型在训练集上的迭代次数,可以理解为训练多少轮
    --logging_steps 10 \  # 日志记录间隔(命令行loss、lr、epoch信息输出间隔)
    --save_steps 200 \  # 保存模型的步数间隔,即训练多少轮保存一次训练结果(checkpoint)
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    # --quantization_bit # 量化等级,用于降低显存需求,由于使用32G的V100所以无需量化

参数讲解:

  1. max_source_lengthmax_target_length的最大长度是影响模型处理能力的重要参数,一般建议根据任务和数据集的特点选择合适的值,但不要超过ChatGLM-6B的序列长度限制(2048)。
  2. per_device_train_batch_size是每个设备上的训练批处理大小,即每个GPU/TPU内核/CPU上用于训练的数据量。这个参数会影响显存的占用和训练的速度。一般来说,批次大小越大,显存占用越高,训练速度越快。
  3. gradient_accumulation_steps是指梯度累积的次数,即在进行一次梯度更新之前,需要累积多少个小批量的梯度。这个参数可以用来解决显存不足的问题。通过减小每个小批量的大小,增加梯度累积的次数,来保持总的有效批量大小不变。
  4. gradient_accumulation_stepsper_device_train_batch_size的关系是:它们都可以影响训练过程中的有效批次大小。有效批次大小是指每次梯度更新时使用的数据量,等于$per_device_train_batch_size\times gradient_accumulation_steps\times NUM_GPUS$。有效批次大小会影响模型的训练速度和效果,一般来说,有效批次大小越大,训练速度越快,但是也可能导致梯度爆炸或者欠拟合。

模型表现

命令行执行bash train.sh,显存占用大约28GB(V100-32GB)。

例子1:

>>> display_answer(model, "下面的搜索词和页面标签的意思有多相同?\n搜索词:冬季上火了吃什么降火\n页面标签:宝宝上火吃什么药降火\n选项:完全不匹配或者没有参考价值,很少匹配有一些参考价值,部分匹配,完全匹配")

# 微调前
搜索词和页面标签的意思存在一定的相似性,都涉及到宝宝上火的问题,但表达方式和侧重点略有不同。
搜索词“冬季上火了吃什么降火”主要关注冬季宝宝上火时应该吃哪些食物来降火,强调的是在冬季这个特定的季节背景下,提供相应的食物建议。
页面标签“宝宝上火吃什么药降火”则更侧重于提供宝宝上火的药物推荐,帮助家长采取相应的药物治疗措施。
虽然这两个搜索词和页面标签都涉及到宝宝上火的问题,但表达方式和侧重点不同,搜索词更注重食物推荐,而页面标签更注重药物推荐。

# 微调后
完全匹配

例子2:

>>> display_answer(model, "请问说话人在最后一句话中表达的意图是什么?\n问诊对话历史:\n医生:n你为什么觉得和你的吻呢,这些多重要啊,\n患者:听不清啊\n医生:小孩现在十娘说啥,你是为什么感觉他和露露吧,\n患者:就是听着感觉像什么东西堵着喉咙\n可选意图:关于病因的询问,关于症状的回答,关于病因的回答,关于注意事项的提问,关于用药建议的提问,关于已有检查和治疗的提问,关于个人基本信息的询问,关于用药建议的解答")

# 微调前
关于病因的询问。

# 微调后
关于病因的询问

例子3:

>>> display_answer(model, "现在假定你是一名医生。有如下对话历史:\n患者:肠胃不好,胃有点胀,经常放屁。(男,22岁)\n医生:你好,这种情况多久了\n根据上述对话历史,给出你的回复")

# 微调前
医生:你好,你最近感到肠胃不适,包括胃胀和放屁吗?

# 微调后
这种情况持续了一段时间。一般肠胃不好引起的胀气,建议你注意饮食,多喝水,多吃蔬菜水果,不要吃辛辣油腻的食物。

例子4:

>>> display_answer(model, "根据医患对话内容判断临床发现实体的阴阳性情况:\n患者:拍了甲状腺的彩超,请医生帮忙看看\n医生:你好,能把彩超报告发过来吗?\n患者:可以的谢谢医生\n患者:图片因隐私问题无法显示\n患者:能看清楚吗\n医生:彩超结果问题不大,查过甲功吗\n患者:结节严重不\n患者:甲功查过了。也有检验单,貌似不好。\n医生:把甲功的单子发过来\n患者:稍等\n医生:好的\n患者:图片因隐私问题无法显示\n患者:9月底做的\n医生:用药了吗\n患者:这个单子可以看出我是什么问题吗\n患者:用药了。这个月了吧。\n患者:但是中间有间断过\n医生:有问题\n医生:用的什么药\n患者:是什么问题\n患者:我想全面的找您了解下,多看几个医生这样确诊些\n医生:甲亢\n医生:需要规范治疗\n患者:我开始吃的是甲硫,有起疹子,换成了丙硫\n医生:可以,继续吃吧,定期复查甲功\n患者:我想问下,甲亢治疗只有这两种药吗?还有其他的吗\n医生:这药效果好\n患者:丙硫氧嘧定可以是吧\n医生:必须在内分泌科医生知道下用药\n患者:我现在复查的结果是t3t4正常了,tsh还是低\n医生:可以\n医生:低问题不大\n医生:高了不行\n患者:也就是说我现在甲亢也只能吃这个药,不管去哪个医院医生基本都是开这个药对吧\n医生:是的\n患者:图片因隐私问题无法显示\n患者:这个是本月复查的检验单\n患者:您帮忙看看\n医生:这个结果还行\n患者:还行?\n患者:可是有一项值低啊,这个有影响吗\n医生:tsh低问题问题不大\n患者:一般要服药多久能正常呢\n患者:我两次复查的结果是第二次的t3t4值比第一次下降了一些,但还在正常范围内,有问题吗\n医生:就是有效果,控制住啦,逐渐减量\n该对话中涉及的临床发现实体包括:起疹子,甲亢\n请在以下阴阳性选项中做出判定:阳性,阴性,其他,不标注\n解释:临床发现实体用于描述患者的医学表现。对阴阳性标签说明如下:“阳性”是指已有症状疾病或者假设未来可能发生的疾病等;“阴性”是指未患有症状疾病,“不标注”是指无实际意义的不标注或者和病人当前的状态独立不标注,“其他”指没有回答、不知道、回答不明确或者模棱两可不好推断")

# 微调前
根据医患对话内容,可以判断出该对话中涉及的临床发现实体包括起疹子和甲亢,因此选项应为阳性。

# 微调后
阳性

例子5:

>>> display_answer(model, "在给定的ICD-10诊断标准词中,找出与原诊断描述匹配的归一词:\n原诊断描述:浸润性低分化鳞癌\n可供选择的ICD-10诊断标准词:非浸润性粉刺癌,假腺样鳞状细胞癌,浸润性小管性癌,浸润性导管癌和管状癌,浸润性小叶癌,非浸润性导管内乳头状腺癌,硬化性浸润性基底细胞癌,浸润性导管癌和粘液癌,浸润性导管腺癌,多形性低度腺癌,浸润性乳头状腺癌,弥漫性低分化淋巴细胞性恶性淋巴瘤,高分化胸腺癌,非浸润性小叶癌,浸润性微乳头状癌,硬化性肝癌,鳞状细胞癌,鳞状细胞原位癌,基底样鳞状细胞癌,鳞状毛囊角化病,浸润性导管癌合并其他型癌,非浸润性囊内癌,乳头状鳞状细胞癌\n请填写与原诊断描述匹配的归一词:")

# 微调前
非浸润性粉刺癌(非浸润性鳞癌)

# 微调后
鳞状细胞癌

例子6:

>>> display_answer(model, "句子A:标题,如:乙肝小三阳吃什么药\n句子B:治疗乙肝小三阳吃什么药\n这两句话表达的医学意思相同吗?\n选项:相同,不同\n答:")

# 微调前
这两句话表达的医学意思不同。
句子A是一个疑问句,询问乙肝小三阳患者应该吃什么药。句子B是一个陈述句,指出乙肝小三阳患者应该吃治疗乙肝小三阳的药物。

# 微调后
相同

从以上例子可以看出,微调后的模型回答更加简洁(例子1、4),前后文联系也更有逻辑性(例子3),同时在医疗领域的问答更加合理,有明显的准确率提升(例子5、6)。总之,使用P-Tuning方法微调ChatGLM能够提高在垂直领域的表现,不容易出现模型变笨的现象(因为模型整体参数变化很小),既能够保持原模型的“智力”,又能够提高在某个领域的专业能力。

— Oct 22, 2023

Made with ❤ and Hexo.js at Earth.