苏州园区网站制作公司,网站建设mrd文档模板,网站更换空间教程,电影网站制作教程及步骤文章目录MOELoRA 的核心组件MOE 在多任务学习中的作用LoRA 在参数高效微调中的贡献MOELoRA 的协同工作机制https://arxiv.org/pdf/2310.18339 When MOE Meets LLMs: Parameter Efficient Fine-tuning for Multi-task Medical Applications MOELoRA 的核心组件
MOELoRA 的核心思…文章目录MOELoRA 的核心组件MOE 在多任务学习中的作用LoRA 在参数高效微调中的贡献MOELoRA 的协同工作机制https://arxiv.org/pdf/2310.18339When MOE Meets LLMs: Parameter Efficient Fine-tuning for Multi-task Medical ApplicationsMOELoRA 的核心组件MOELoRA 的核心思想建立在两个关键技术上混合专家系统MOE和低秩自适应LoRA。MOE 负责处理多任务学习中的任务分配和专家协作LoRA 则专注于参数高效的模型微调。MOE 在多任务学习中的作用MOE 结构通过动态路由机制将输入数据分配给不同的专家模块每个专家专注于特定任务或数据子集。这种设计允许模型在不显著增加参数量的情况下灵活处理多任务场景。MOE 的优势在于其能够根据任务复杂度自动调整专家资源的分配提升模型在有限数据和计算资源下的表现。LoRA 在参数高效微调中的贡献LoRA 通过低秩矩阵分解技术在预训练模型的基础上引入少量可训练参数大幅降低微调阶段的资源消耗。具体实现中LoRA 将权重更新 ΔW 分解为两个低秩矩阵的乘积例如 ΔW BA其中 B 和 A 的秩远小于原权重矩阵。这种方法既保留了预训练模型的知识又实现了高效的任务适配。MOELoRA 的协同工作机制MOELoRA 将 MOE 的任务分配能力与 LoRA 的参数效率结合形成分层优化结构。MOE 层负责识别任务类型并激活对应的专家模块每个专家内部采用 LoRA 进行微调。这种设计既避免了多任务间的干扰又通过共享基础模型参数减少了冗余。https://github.com/liuqidong07/MOELoRA-peft/blob/master/src/MLoRA/peft/tuners/mmoelora.pyclassMMOELoraLayer(LoraLayer):def__init__(self,in_features:int,out_features:int,expert_num:int):super().__init__(in_features,out_features)self.expert_numexpert_numdefupdate_layer(self,adapter_name,r,lora_alpha,lora_dropout,init_lora_weights):self.r[adapter_name]r self.lora_alpha[adapter_name]lora_alphaiflora_dropout0.0:lora_dropout_layernn.Dropout(plora_dropout)else:lora_dropout_layernn.Identity()self.lora_dropout.update(nn.ModuleDict({adapter_name:lora_dropout_layer}))# Actual trainable parametersifr0:self.lora_A.update(nn.ModuleDict({adapter_name:MMOELinearA(self.in_features,r,self.expert_num)}))self.lora_B.update(nn.ModuleDict({adapter_name:MMOELinearB(r,self.out_features,self.expert_num)}))self.scaling[adapter_name]lora_alpha/rifinit_lora_weights:self.reset_lora_parameters(adapter_name)self.to(self.weight.device)defreset_lora_parameters(self,adapter_name):ifadapter_nameinself.lora_A.keys():# initialize A the same way as the default for nn.Linear and B to zeroforiinrange(self.expert_num):nn.init.normal_(self.lora_A[adapter_name].loraA[i].mlp.weight,mean0.0,std0.01)nn.init.zeros_(self.lora_B[adapter_name].loraB[i].mlp.weight)classMMOELoraLinear(nn.Linear,MMOELoraLayer):# Lora implemented in a dense layer# nn.Linear is the pretrained weights in LLM, MMOELoraLayer is the designed trainable Loradef__init__(self,adapter_name:str,in_features:int,out_features:int,r:int0,lora_alpha:int1,lora_dropout:float0.0,fan_in_fan_out:boolFalse,# Set this to True if the layer to replace stores weight like (fan_in, fan_out)**kwargs,):init_lora_weightskwargs.pop(init_lora_weights,True)self.expert_numkwargs.pop(expert_num,True)self.task_numkwargs.pop(task_num,True)self.te_dimkwargs.pop(task_embedding_dim,True)nn.Linear.__init__(self,in_features,out_features,**kwargs)MMOELoraLayer.__init__(self,in_featuresin_features,out_featuresout_features,expert_numself.expert_num)# init the Gate networkself.lora_task_embeddingnn.ModuleDict({})self.lora_gatenn.ModuleDict({})self.lora_task_embedding.update(nn.ModuleDict({adapter_name:nn.Embedding(self.task_num1,self.te_dim)}))self.lora_gate.update(nn.ModuleDict({adapter_name:Gate(self.te_dim,self.expert_num)}))# Freezing the pre-trained weight matrixself.weight.requires_gradFalseself.fan_in_fan_outfan_in_fan_outiffan_in_fan_out:self.weight.dataself.weight.data.T nn.Linear.reset_parameters(self)self.update_layer(adapter_name,r,lora_alpha,lora_dropout,init_lora_weights)self.active_adapteradapter_namedefmerge(self,task_id):ifself.active_adapternotinself.lora_A.keys():returnifself.merged:warnings.warn(Already merged. Nothing to do.)returnifself.r[self.active_adapter]0:expert_weightself.lora_gate[self.active_adapter](self.lora_task_embedding[self.active_adapter](task_id))foriinrange(self.expert_num):lora_A_weightsself.lora_A[self.active_adapter].loraA[i].mlp.weight lora_B_weightsself.lora_B[self.active_adapter].loraB[i].mlp.weight self.weight.data(transpose(lora_B_weights lora_A_weights,self.fan_in_fan_out,)*self.scaling[self.active_adapter]*expert_weight[...,i])self.mergedTruedefunmerge(self,task_id):ifself.active_adapternotinself.lora_A.keys():returnifnotself.merged:warnings.warn(Already unmerged. Nothing to do.)returnifself.r[self.active_adapter]0:expert_weightself.lora_gate[self.active_adapter](self.lora_task_embedding[self.active_adapter](task_id))foriinrange(self.expert_num):lora_A_weightsself.lora_A[self.active_adapter].loraA[i].mlp.weight lora_B_weightsself.lora_B[self.active_adapter].loraB[i].mlp.weight self.weight.data-(transpose(lora_B_weights lora_A_weights,self.fan_in_fan_out,)*self.scaling[self.active_adapter]*expert_weight[...,i])self.mergedFalsedefforward(self,x:torch.Tensor,**kwargs):task_idkwargs[task_id]previous_dtypex.dtypeifself.active_adapternotinself.lora_A.keys():# No adapter, directly use linearreturnF.linear(x,transpose(self.weight,self.fan_in_fan_out),biasself.bias)ifself.disable_adapters:# No adapterifself.r[self.active_adapter]0andself.merged:# merge the adapter to linearself.unmerge(task_id)resultF.linear(x,transpose(self.weight,self.fan_in_fan_out),biasself.bias)elifself.r[self.active_adapter]0andnotself.merged:# general lora processresultF.linear(x,transpose(self.weight,self.fan_in_fan_out),biasself.bias)xx.to(self.lora_A[self.active_adapter].loraA[0].weight.dtype)expert_weightself.lora_gate[self.active_adapter](self.lora_task_embedding[self.active_adapter](task_id))foriinrange(self.expert_num):result(# lora processself.lora_B[self.active_adapter].loraB[i](self.lora_A[self.active_adapter].loraA[i](self.lora_dropout[self.active_adapter](x)),)*self.scaling[self.active_adapter]*expert_weight[...,i].unsqueeze(-1).unsqueeze(0))else:resultF.linear(x,transpose(self.weight,self.fan_in_fan_out),biasself.bias)resultresult.to(previous_dtype)returnresultclassMMOELinearA(nn.Module):MMOE based LoRA blockdef__init__(self,in_features,out_features,expert_num)-None:super().__init__()self.expert_numexpert_num self.in_features,self.out_featuresin_features,out_features self.loraAnn.ModuleList([])assertself.out_features%self.expert_num0# lora rank should be divided by expert numberself.rself.out_features//self.expert_numfor_inrange(self.expert_num):self.loraA.append(Expert(self.in_features,self.r))defforward(self,x):input x is a vector, return output is a listoutputs[]foriinrange(self.expert_num):outputs.append(self.loraA[i](x))returnoutputsclassMMOELinearB(nn.Module):MMOE based LoRA blockdef__init__(self,in_features,out_features,expert_num)-None:super().__init__()self.expert_numexpert_num self.in_features,self.out_featuresin_features,out_features self.loraBnn.ModuleList([])assertself.in_features%self.expert_num0self.rself.in_features//self.expert_numfor_inrange(self.expert_num):self.loraB.append(Expert(self.r,self.out_features))defforward(self,x):input x is a list, return output is also a listoutputs[]foriinrange(self.expert_num):outputs.append(self.loraB[i](x[i]))returnoutputsclassExpert(nn.Module):def__init__(self,in_features,out_features):super().__init__()self.in_features,self.out_featuresin_features,out_features self.mlpnn.Linear(self.in_features,self.out_features,biasFalse)self.weightself.mlp.weightdefforward(self,x):# LoRA A or B blockyself.mlp(x)returnyclassGate(nn.Module):def__init__(self,input_size,expert_num):super().__init__()# 使用embedding来代替线性层self.GateLnn.Linear(input_size,expert_num,biasFalse)self.actnn.Softmax(dim1)# 第0维为batch sizedefforward(self,x):yself.GateL(x)yself.act(y)returny