p-tuning算法介绍及其pytorch代码实现

P-tuning介绍

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码实现

import torch
from transformers import BertTokenizer, BertForSequenceClassification
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2)


def train(tokenize, model, prompt_lenght, prompt, data):
    # 冻结Bert参数
    for param in model.bert.parameters():
        param.requires_grad = False

    # 优化器
    optimizer = torch.optim.Adam([prompt], lr=1e-3)

    # 训练循环
    num_epochs = 8
    losses = []
    for epoch in range(num_epochs):
        total_loss = 0.0
        for text, label in data:
            # 处理输入和标签
            inputs = tokenizer(text, return_tensors='pt')
            labels = torch.tensor([label])  # 标签,形状为 [batch_size]

            # 访问 BERT 的嵌入层
            bert_model = model.bert
            input_ids = inputs['input_ids']

            # 获取输入标记的嵌入表示
            with torch.no_grad():
                input_embeddings = bert_model.embeddings(input_ids)

            # 扩展和拼接提示向量和输入嵌入表示
            prompt_embeddings = prompt.unsqueeze(0).expand(input_ids.size(0), -1, -1)       # unsqueeze(0):新增第一个维度。expand(input_ids.size(0), -1,- 1):对第一个维度按照input_ids[0]的大小进行扩展,-1表示自动计算维度大小。
            prompted_input = torch.cat((prompt_embeddings, input_embeddings), dim=1)

            # 前向传播
            attention_mask = torch.cat((torch.ones(prompt_embeddings.size()[:2], dtype=torch.long), inputs['attention_mask']), dim=1)
            outputs = bert_model(inputs_embeds=prompted_input, attention_mask=attention_mask)
            sequence_output = outputs.last_hidden_state

            # 分类头
            logits = model.classifier(sequence_output[:, prompt_length:, :])  # 跳过提示向量部分

            # 确保logits的形状与labels匹配
            logits = logits[:, 0, :]  # 只取第一个token的logits(即[CLS] token)

            # 计算损失
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))  # 确保 logits 和 labels 的形状匹配

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        losses.append(total_loss)    
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss/len(data)}')
    torch.save(prompt, 'path_to_trained_prompt.pt')    
    return losses

def plot_loss(losses):
    plt.figure()
    plt.plot(losses)

def predict_classify(tokenize, model, prompt_length, trained_prompt, data):
    predict_list = []
    for input_text in data:
        inputs = tokenizer(input_text, return_tensors='pt')

        # 访问 BERT 的嵌入层
        bert_model = model.bert
        input_ids = inputs['input_ids']

        # 获取输入标记的嵌入表示
        with torch.no_grad():
            input_embeddings = bert_model.embeddings(input_ids)

        # 扩展和拼接提示向量和输入嵌入表示
        prompt_embeddings = trained_prompt.unsqueeze(0).expand(input_ids.size(0), -1, -1)
        prompted_input = torch.cat((prompt_embeddings, input_embeddings), dim=1)

        # 构建新的注意力掩码
        attention_mask = torch.cat((torch.ones(prompt_embeddings.size()[:2], dtype=torch.long), inputs['attention_mask']), dim=1)

        # 前向传播进行推理
        with torch.no_grad():
            outputs = bert_model(inputs_embeds=prompted_input, attention_mask=attention_mask)
            sequence_output = outputs.last_hidden_state

            # 分类头
            logits = model.classifier(sequence_output[:, prompt_length:, :])  # 跳过提示向量部分
            logits = logits[:, 0, :]  # 只取第一个token的logits(即[CLS] token)

        # 获取预测结果
        predicted_label = torch.argmax(logits, dim=-1).item()
        print(f"Input data: {input_text}, Predicted label: {predicted_label}")
        predict_list.append(predicted_label)
    return predict_list

# p-tuning训练
# 定义可学习的提示向量
prompt_length = 5
prompt = torch.nn.Parameter(torch.randn(prompt_length, model.config.hidden_size))
# 训练集
data = [("This movie is great", 1), ("This movie is bad", 0)]
# 训练
losses = train(tokenizer, model, prompt_length, prompt, data)
# 绘制Loss曲线
plot_loss(losses)

# p-tuning预测
prompt_length = 5
trained_prompt = torch.load('path_to_trained_prompt.pt')  # 加载训练好的提示嵌入
input_text = ["This movie is good", "This movie is bad", "This movie is not good"]
predict_list = predict_classify(tokenizer, model, prompt_length, trained_prompt, input_text)

拓展文章:第7章 大模型之Adaptation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/775594.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Games101学习笔记 Lecture 15: Ray Tracing 3 (Light Transport Global Illumination)

Lecture 15: Ray Tracing 3 (Light Transport & Global Illumination 一、BRDF 双向反射分布函数定义 二、反射方程 Reflection Equation三、渲染方程1.重写反射方程2.当其他的点反射的radiance作为入射 一、BRDF 双向反射分布函数 定义 计算不同的反射方向上会分布多少能…

竹云实力入选《现代企业零信任网络建设应用指南报告》代表性厂商

2024年7月3日,国内网络安全媒体安全牛正式发布《现代企业零信任网络建设应用指南报告(2024版)》。竹云凭借在零信任领域创新性的产品方案和优异的市场表现,实力入选代表性厂商。 伴随着云计算、AI、大数据等技术的发展,远程办公、业务协同、…

遗漏知识点

什么是RAII? RAII是Resource Acquisition Is Initialization(wiki上面翻译成 “资源获取就是初始化”)的简称,是C语言的一种管理资源、避免泄漏的惯用法。利用的就是C构造的对象最终会被销毁的原则。RAII的做法是使用一个对象&am…

西门子PLC1200--与电脑S7通讯

硬件构成 PLC为西门子1211DCDCDC 电脑上位机用PYTHON编写 二者通讯用网线,通讯协议用S7 PLC上的数据 PLC上的数据是2个uint,在DB1,地址偏移分别是0和2 需要注意的是DB块要关闭优化的块访问,否则是没有偏移地址的 PLC中的数据内…

VCS+Vivado联合仿真BUG

场景: 在vcsvivado联合仿真过程中,对vivado导出的shell脚本修改,修改某些source文件路径,vcs编译时会报Permission Denied。 问题描述 对shell脚本修改如下: 修改仅为注释掉某一行,下面变为source文件新…

【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【20】认证服务04—SSO单点登录

持续学习&持续更新中… 守破离 【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【20】认证服务04—SSO单点登录 xxl-sso多系统-单点登录单点登录流程原理图单点登录流程简单实现参考 xxl-sso https://gitee.com/xuxueli0323/xxl-sso xxl-sso是开源的一个单点登录框架 …

交换机基本原理

文章目录 前言一、交换机的转发行为二、交换机的转发原理1.MAC地址表2.交换机初始状态3.学习MAC地址4.ARP协议5.交换机转发数据帧6.目标主机回复 三、华为交换机基本命令1.VRP视图分层2.命令行补全3.命令行帮助4.配置设备名称5.命令等级6.用户界面7.配置console认证8.配置用户界…

Ubuntu系统复制文件到共享文件夹出错

1、问题描述 Ubuntu系统复制文件到共享文件夹时,出现拼接文件时出错:输入/输出错误。 使用cp命令: cp -Rf XXX YYY 也是出错: cp: 写入 xxx 出错: 输入/输出错误 2、查看磁盘空间 查看磁盘空间,显示空间还有剩余…

Day05-组织架构-角色管理

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1.组织架构-编辑部门-弹出层获取数据2.组织架构-编辑部门-编辑表单校验3.组织架构-编辑部门-确认取消4.组织架构-删除部门5.角色管理-搭建页面结构6.角色管理-获取数…

第一次的pentest show总结

第一次的pentest show总结 前言 开始之前,我特别感谢TryHackMe(英)、HackTheBox(美)、zero-point security(英)、offsec(美)等平台,使我们能够通过网络以线上的方式学习与练习,打破传统线下各地区教育资源差异大的限制,对网络教…

探索Sui的面向对象模型和Move编程语言

Sui区块链作为一种新兴的一层协议(L1),采用先进技术来解决常见的一层协议权衡问题。Cointelegraph Research详细剖析了这一区块链新秀。 Sui使用Move编程语言,该语言专注于资产表示和访问控制。本文探讨了Sui的对象中心数据存储模…

Python从0到100(三十七):数据提取的概念和数据分类

1. 爬虫中数据的分类 在爬虫开发过程中,我们会遇到多种类型的数据。了解这些数据的类型对于有效地提取和解析信息至关重要。 结构化数据 结构化数据是指具有固定格式和模式的数据,常见的结构化数据格式包括JSON和XML。 处理方式:可以直接转换为Python的字典或列表等数据类…

【UML用户指南】-27-对体系结构建模-制品

目录 1、组成结构 2、制品的种类 2.1、部署制品 (deployment artifact) 2.2、工作产品制品 (work product artifact) 2.3、执行制品 (execution artifact) 3、标准元素 4、常用建模技术 4.1、对可执…

Redis 7.x 系列【17】四种持久化策略

有道无术,术尚可求,有术无道,止于术。 本系列Redis 版本 7.2.5 源码地址:https://gitee.com/pearl-organization/study-redis-demo 文章目录 1. 概述2. 案例演示2.1 无持久化2.2 RDB2.3 AOF2.4 混合模式2.4.1 方式一:…

LLM - 神经网络的组成

1. 一个神经元的结构:即接受多个输入X向量,在一个权重向量W和一个偏执标量b的作用下,经过激活函数后,产生一个输出。 2. 一层神经网络的结构:该层网络里的每个神经元并行计算,得到各自的输出;计算方式是输入…

CISAW证书考完有什么用?值得投资吗?

CISAW证书,在信息安全领域内被公认为具有高价值的一种职业资格认证,它象征着持有者在该领域的专业技能和知识水平。 因此,CISAW证书不仅具有实质性的价值,还能为持有者带来诸多益处。 首先,拥有CISAW证书的专业人士更…

简过网:教师编制报考要求和条件,都给你汇总好了!

如果你想要考教师编,那么在考试之前你先要明白这些知识! ​ 一、什么是教师编? 在编教师拥有的编制为事业编,即在编老师为事业单位工作人员 二、考教师编需要什么条件? 1、普通话 语文学科普通话要求达到二级甲等及…

5.基于SpringBoot的SSMP整合案例-数据层开发

目录 1.新建项目 2.实体类开发: 2.1在pom.xml中增加Lombok坐标: 2.2添加Book实体类 3.数据层开发: 3.1 配置MyBatisPlus与Druid 3.2创建数据层接口 3.3写测试类 3.4点击运行: 4.数据层快速开发: 4.1配置MyB…

Camera link(学习笔记)

Camera Link协议是一种专门针对机器视觉应用领域的串行通信协议,它使用低压差分信号(LVDS)进行数据的传输和通信。Camera Link标准是在ChannelLink标准的基础上多加了6对差分信号线,其中4对用于并行传输相机控制信号,另外2对用于相机和图像采…

植物学(书籍学习资料)

包含观赏植物学、植物学、植物学百科图鉴等多本植物学方面的书籍学习资料。 图2、3为观赏植物学截图; 图4、5为植物学百科图鉴截图; 图6、7为植物学学习指南截图。