clearwind

clearwind

首页
分类
登录 →
clearwind

clearwind

首页 分类
登录
  1. 首页
  2. 🚀AI
  3. 🔧LLM
  4. 数据集分类均衡问题及其解决方案

数据集分类均衡问题及其解决方案

0
  • 🔧LLM
  • 发布于 2025-01-03
  • 14 次阅读
clearwind
clearwind

数据集的类别均衡性对模型的性能有着至关重要的影响。当数据集中某些类别的样本数量远多于其他类别时,就会出现数据不均衡问题。这种不平衡可能导致模型在训练过程中偏向多数类,从而影响对少数类的预测性能。

问题描述

分类数据集统计

import pandas as pd

# 读取CSV文件
csv_file_path = "/Volumes/Date/huggingface/dataset/news/train.csv"
df = pd.read_csv(csv_file_path)

# 统计每个类别的数据量
category_counts = df["label"].value_counts()

# 计算每个类别的比值
total_data = len(df)
category_ratios = (category_counts / total_data) * 100

# 打印每个类别的数据量
print("每个类别的数据量:")
print(category_counts.to_string())

# 打印每个类别的比值
print("\n每个类别的比值 (%):")
print(category_ratios.to_string())

统计案例分析

在模型微调之前,需要对数据集进行一些评估,而在分类数据集中各分类标签下数据均衡是首选需要确认的,一下是比较好的数据集:

label
9    5045
6    5040
0    5040
7    5017
3    5000
4    4994
8    4983
1    4981
5    4950
2    4950
Name: count, dtype: int64
label
9    10.090
6    10.080
0    10.080
7    10.034
3    10.000
4     9.988
8     9.966
1     9.962
5     9.900
2     9.900
Name: count, dtype: float64

在处理微博情感分析数据集时,我们发现标签7的数量明显多于其他标签,具体统计如下


label
7    33788
0     4893
1     3214
2     2828
3     2298
4     1782
5      892
6      305
Name: count, dtype: int64
label
7    67.576
0     9.786
1     6.428
2     5.656
3     4.596
4     3.564
5     1.784
6     0.610
Name: count, dtype: float64

从上述统计数据可以看出,标签7占据了大约67.576%的比例,而标签6仅占0.610%。这种严重的类别不平衡会导致以下问题:

  • 评估指标偏差:常用的评估指标如准确率(accuracy)可能会被多数类主导,导致模型看似表现良好但实际上对少数类的预测效果差。

  • 模型泛化能力:模型可能会过度拟合多数类,而对少数类的泛化能力较差。

可能的影响

  • 评估指标偏差:由于多数类样本数量远超少数类,模型可能在训练过程中更倾向于正确分类多数类样本,从而使得整体准确率较高,但对少数类的召回率较低。

  • 模型泛化能力:模型可能会过度拟合多数类,导致其在测试集上的表现不佳,尤其是在少数类上。

解决方案

重采样技术

过采样(Oversampling)

过采样是通过增加少数类样本的数量来平衡数据集的一种方法。常用的技术包括:

  • 随机过采样:简单地复制少数类样本,直到其数量与多数类相当。

  • SMOTE(Synthetic Minority Over-sampling Technique):通过在少数类样本之间生成新的合成样本,以增加少数类样本的数量。

欠采样(Undersampling)

欠采样是通过减少多数类样本的数量来平衡数据集的一种方法。常用的技术包括:

  • 随机欠采样:随机删除部分多数类样本,直到其数量与少数类相当。

  • Tomek Links 和 Condensed Nearest Neighbors (CNN):这些方法通过删除边界样本或冗余样本来减少多数类样本的数量。


import pandas as pd
from imblearn.under_sampling import RandomUnderSampler

# 定义文件路径和数据集分割
split = "train"
csv_file_path = f"/Volumes/Date/huggingface/dataset/Weibo/{split}.csv"
resampled_csv_file_path = f"/Volumes/Date/huggingface/dataset/Weibo/resample/{split}.csv"

# 读取CSV文件
df = pd.read_csv(csv_file_path)

# 将特征和标签分开
X = df[["text"]]
Y = df[["label"]]

# 初始化随机欠采样器
rus = RandomUnderSampler(sampling_strategy="auto", random_state=42)

# 应用随机欠采样
X_resampled, Y_resampled = rus.fit_resample(X, Y)

# 合并特征和标签,创建新的DataFrame
df_resampled = pd.concat([X_resampled, Y_resampled], axis=1)

# 保存均衡数据到新的CSV文件
df_resampled.to_csv(resampled_csv_file_path, index=False)

# 读取重采样后的数据集
df_resampled = pd.read_csv(resampled_csv_file_path)

调整损失函数

通过调整损失函数,可以为不同类别的样本赋予不同的权重,从而提高少数类的重要性。例如,在交叉熵损失中,可以为少数类赋予更大的权重:

import torch.nn as nn

class_weight = [1, 2, 2, 2, 2, 3, 3, 4]  # 根据类别不平衡程度设置权重
loss_func = nn.CrossEntropyLoss(weight=torch.tensor(class_weight).float().to(DEVICE))

集成学习

集成学习方法如Bagging、Boosting等可以通过组合多个弱分类器来提高对少数类的预测性能。常见的集成学习算法包括:

  • Random Forest:通过构建多个决策树并进行投票,以提高模型的鲁棒性和泛化能力。

  • AdaBoost:通过逐步增加错误分类样本的权重,使后续模型更加关注这些样本。

评估指标选择

除了准确率外,还应关注其他评估指标如精确率(Precision)、召回率(Recall)、F1分数(F1-Score)和AUC-ROC曲线等,这些指标更能反映模型对少数类的预测能力。

验证

在微博情感分析数据集上进行了实验。首先,使用imbalanced-learn库中的RandomUnderSampler对数据集进行了欠采样处理:

统计每个类别的数据量:
label
0    305
1    305
2    305
3    305
4    305
5    305
6    305
7    305

统计每个类别的比值 (%):
label
0    12.5
1    12.5
2    12.5
3    12.5
4    12.5
5    12.5
6    12.5
7    12.5

使用BERT模型对重采样后的数据集进行了训练,并设置了加权损失函数:

class_weight = [1, 2, 2, 2, 2, 3, 3, 4]
loss_func = nn.CrossEntropyLoss(weight=torch.tensor(class_weight).float().to(DEVICE))

for epoch in range(EPOCH):
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_loader):
        input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE), labels.to(DEVICE)
        out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        loss = loss_func(out, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

标签: #Transformer 7
相关文章
简述 Transformer 训练计算过程(刷新渲染)

简述 Transformer 训练计算过程(刷新渲染) 2025-01-11 23:54

Step1-定义数据集 用于创建 ChatGPT 的数据集为 570 GB。假设数据集为一下内容: 白日依山尽,黄河入海流。 马无夜草不肥,人无横财不富。 天行健,君子以自强不息,地势坤,君子以厚德载物。 Step2-计算词汇量

基于 internlm2 和 LangChain 搭建你的知识库

基于 internlm2 和 LangChain 搭建你的知识库 2025-02-27 14:25

环境配置 internlm2 模型部署 创建虚拟环境 conda create -n deepseek_rag python=3.10 -y conda activate deepseek_rag 并在环境中安装运行 demo 所需要的依赖 # 升级pip python -m pip install

Llama-Factory 微调全过程

Llama-Factory 微调全过程 2025-01-13 22:28

数据集 数据集下载:通过ModelScope获取原始数据集https://modelscope.cn/datasets/w10442005/ruozhiba_qa/summary git clone https://www.modelscope.cn/datasets/w10442005/ruozh

矩阵分解 2025-01-11 15:48

矩阵分解是一种通过将较大的矩阵分解为多个小矩阵已降低计算复杂度的技术,在模型训练微调上,通常用于简化模型、提高训练效率。矩阵分解有多种形式,一下是几种常见的模型微调权重分解方法: 奇异值分解 将矩阵分解为三个矩阵乘积的方法: W=U \Sigma V^{T} 其中: W是原始权重矩阵。 U和V是正交

LLM奥秘

LLM奥秘 2025-01-09 21:47

本文旨在通过最基础的数学内容,剔除机器学习中复杂的术语,从零描述LLM的工作原理。

GPT 模型微调

GPT 模型微调 2025-01-03 22:51

GPT-2 是一种基于 Transformer 的生成模型,专注于生成连贯的文本。在 Hugging Face 的Transformers 库中,GPT-2 已经被应用于多种中文文本生成任务,如古诗词、歌词和对联生成等。 GPT-2模型 from transformers import BertTo

目录
  • clearwind
  • 微信小程序

导航菜单

  • 首页
  • 分类
Copyright © 2024 your company All Rights Reserved. Powered by clearwind.
皖ICP备19023482号