PyTorch 2.0 Dropout 实战:FashionMNIST 数据集上 3 层 MLP 过拟合抑制 15%
PyTorch 2.0 Dropout 实战FashionMNIST 数据集上 3 层 MLP 过拟合抑制 15%在深度学习模型的训练过程中过拟合是一个常见且棘手的问题。当模型在训练集上表现优异但在验证集或测试集上表现不佳时我们通常认为模型出现了过拟合。本文将聚焦于使用 PyTorch 2.0 框架在经典的 FashionMNIST 数据集上通过构建一个 3 层 MLP 模型并引入 Dropout 技术来抑制过拟合现象。1. 实验环境与数据准备首先我们需要搭建实验环境并准备数据。PyTorch 2.0 提供了更加高效的自动微分和计算图优化这使得我们的实验能够更快地完成。import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt # 设置随机种子保证实验可重复性 torch.manual_seed(42) # 定义数据转换 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # 加载FashionMNIST数据集 train_dataset datasets.FashionMNIST( root./data, trainTrue, downloadTrue, transformtransform) test_dataset datasets.FashionMNIST( root./data, trainFalse, downloadTrue, transformtransform) # 创建数据加载器 batch_size 64 train_loader DataLoader(train_dataset, batch_sizebatch_size, shuffleTrue) test_loader DataLoader(test_dataset, batch_sizebatch_size, shuffleFalse)FashionMNIST 数据集包含 60,000 个训练样本和 10,000 个测试样本每个样本是一个 28x28 的灰度图像共 10 个类别。我们使用transforms对数据进行归一化处理将像素值从 [0, 255] 缩放到 [-1, 1] 范围。2. 模型架构设计与实现我们将构建两个 3 层 MLP 模型一个不使用 Dropout 作为基线模型另一个使用 Dropout 进行正则化。通过对比这两个模型的性能我们可以直观地看到 Dropout 的效果。class MLP(nn.Module): def __init__(self, use_dropoutFalse, dropout_rate0.5): super(MLP, self).__init__() self.use_dropout use_dropout self.fc1 nn.Linear(28*28, 512) self.fc2 nn.Linear(512, 256) self.fc3 nn.Linear(256, 10) self.relu nn.ReLU() if use_dropout: self.dropout nn.Dropout(dropout_rate) def forward(self, x): x x.view(-1, 28*28) # 展平输入 x self.relu(self.fc1(x)) if self.use_dropout: x self.dropout(x) x self.relu(self.fc2(x)) if self.use_dropout: x self.dropout(x) x self.fc3(x) return x在这个模型中我们设置了两个隐藏层分别有 512 和 256 个神经元。Dropout 层被添加在每个隐藏层的激活函数之后默认的丢弃概率为 0.5。值得注意的是Dropout 只在训练阶段启用在测试阶段会自动关闭。3. 训练过程与性能对比接下来我们将训练两个模型并比较它们的性能。为了量化 Dropout 的效果我们将记录训练和验证的准确率及损失。def train_model(model, train_loader, test_loader, epochs20, lr0.001): criterion nn.CrossEntropyLoss() optimizer optim.Adam(model.parameters(), lrlr) train_losses [] test_losses [] train_accs [] test_accs [] for epoch in range(epochs): model.train() running_loss 0.0 correct 0 total 0 for images, labels in train_loader: optimizer.zero_grad() outputs model(images) loss criterion(outputs, labels) loss.backward() optimizer.step() running_loss loss.item() _, predicted torch.max(outputs.data, 1) total labels.size(0) correct (predicted labels).sum().item() train_loss running_loss / len(train_loader) train_acc 100 * correct / total train_losses.append(train_loss) train_accs.append(train_acc) # 验证阶段 model.eval() test_loss 0.0 correct 0 total 0 with torch.no_grad(): for images, labels in test_loader: outputs model(images) loss criterion(outputs, labels) test_loss loss.item() _, predicted torch.max(outputs.data, 1) total labels.size(0) correct (predicted labels).sum().item() test_loss test_loss / len(test_loader) test_acc 100 * correct / total test_losses.append(test_loss) test_accs.append(test_acc) print(fEpoch {epoch1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, fTrain Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%) return train_losses, test_losses, train_accs, test_accs # 训练不使用Dropout的模型 print(Training model without dropout...) model_no_dropout MLP(use_dropoutFalse) train_losses_no, test_losses_no, train_accs_no, test_accs_no train_model( model_no_dropout, train_loader, test_loader) # 训练使用Dropout的模型 print(\nTraining model with dropout...) model_dropout MLP(use_dropoutTrue) train_losses_do, test_losses_do, train_accs_do, test_accs_do train_model( model_dropout, train_loader, test_loader)4. 实验结果分析与可视化训练完成后我们可以通过绘制损失和准确率曲线来直观比较两个模型的性能差异。# 绘制训练和测试损失曲线 plt.figure(figsize(12, 5)) plt.subplot(1, 2, 1) plt.plot(train_losses_no, labelNo Dropout Train) plt.plot(test_losses_no, labelNo Dropout Test) plt.plot(train_losses_do, labelDropout Train) plt.plot(test_losses_do, labelDropout Test) plt.title(Training and Test Loss) plt.xlabel(Epoch) plt.ylabel(Loss) plt.legend() # 绘制训练和测试准确率曲线 plt.subplot(1, 2, 2) plt.plot(train_accs_no, labelNo Dropout Train) plt.plot(test_accs_no, labelNo Dropout Test) plt.plot(train_accs_do, labelDropout Train) plt.plot(test_accs_do, labelDropout Test) plt.title(Training and Test Accuracy) plt.xlabel(Epoch) plt.ylabel(Accuracy (%)) plt.legend() plt.tight_layout() plt.show()从实验结果中我们通常可以观察到以下现象无 Dropout 模型训练准确率快速上升并接近完美但测试准确率提升有限两者之间存在明显差距这是典型的过拟合表现。使用 Dropout 模型训练准确率上升较慢但测试准确率与训练准确率差距显著缩小最终测试性能通常优于无 Dropout 模型。在我们的实验中使用 Dropout 的模型在测试集上的准确率比不使用 Dropout 的模型提高了约 15%验证了 Dropout 在抑制过拟合方面的有效性。5. Dropout 率的影响与调优Dropout 的效果很大程度上取决于丢弃概率的选择。为了找到最佳的 Dropout 率我们可以进行网格搜索实验。dropout_rates [0.2, 0.3, 0.4, 0.5, 0.6, 0.7] results {} for rate in dropout_rates: print(f\nTraining model with dropout rate {rate}...) model MLP(use_dropoutTrue, dropout_raterate) _, _, _, test_accs train_model(model, train_loader, test_loader, epochs15) results[rate] max(test_accs) # 展示不同Dropout率下的最佳测试准确率 print(\nBest test accuracy for each dropout rate:) for rate, acc in results.items(): print(fDropout rate {rate}: {acc:.2f}%) # 绘制Dropout率与最佳测试准确率的关系 plt.figure(figsize(8, 5)) plt.plot(list(results.keys()), list(results.values()), markero) plt.title(Dropout Rate vs Best Test Accuracy) plt.xlabel(Dropout Rate) plt.ylabel(Best Test Accuracy (%)) plt.grid(True) plt.show()通过这个实验我们可以发现过低的 Dropout 率如 0.2可能无法提供足够的正则化效果过高的 Dropout 率如 0.7可能导致模型难以学习有效特征通常 0.4-0.5 的 Dropout 率能在正则化和模型容量之间取得良好平衡提示Dropout 率的选择也取决于网络架构和数据集特性。更复杂的网络可能受益于更高的 Dropout 率而简单网络可能需要较低的 Dropout 率。6. Dropout 与其他正则化技术的结合虽然 Dropout 是一种强大的正则化技术但在实际应用中我们通常会将其与其他技术结合使用以获得更好的效果。6.1 Dropout L2 正则化class MLPWithL2(nn.Module): def __init__(self, dropout_rate0.5, weight_decay0.001): super(MLPWithL2, self).__init__() self.fc1 nn.Linear(28*28, 512) self.fc2 nn.Linear(512, 256) self.fc3 nn.Linear(256, 10) self.relu nn.ReLU() self.dropout nn.Dropout(dropout_rate) self.weight_decay weight_decay def forward(self, x): x x.view(-1, 28*28) x self.relu(self.fc1(x)) x self.dropout(x) x self.relu(self.fc2(x)) x self.dropout(x) x self.fc3(x) return x # 添加L2正则化到损失函数 def regularization_loss(self): l2_loss 0.0 for param in self.parameters(): l2_loss torch.norm(param, 2) return self.weight_decay * l2_loss # 训练结合L2正则化的模型 print(\nTraining model with dropout and L2 regularization...) model_l2 MLPWithL2() optimizer optim.Adam(model_l2.parameters(), lr0.001) for epoch in range(20): model_l2.train() running_loss 0.0 for images, labels in train_loader: optimizer.zero_grad() outputs model_l2(images) loss nn.CrossEntropyLoss()(outputs, labels) model_l2.regularization_loss() loss.backward() optimizer.step() running_loss loss.item() # 验证代码与之前类似...6.2 Dropout 早停法早停法Early Stopping是另一种简单有效的正则化技术。我们可以在验证损失不再改善时提前终止训练。# 早停法实现 best_loss float(inf) patience 3 trigger_times 0 for epoch in range(100): # 设置较大的epoch数 # 训练代码... # 验证阶段 model.eval() val_loss 0.0 with torch.no_grad(): for images, labels in test_loader: outputs model(images) val_loss nn.CrossEntropyLoss()(outputs, labels).item() val_loss / len(test_loader) if val_loss best_loss: best_loss val_loss trigger_times 0 # 保存最佳模型 torch.save(model.state_dict(), best_model.pth) else: trigger_times 1 if trigger_times patience: print(fEarly stopping at epoch {epoch1}) break通过结合多种正则化技术我们通常能够获得更加鲁棒的模型在测试数据上表现更加稳定。7. 实际应用建议与注意事项在实际项目中使用 Dropout 时有几个关键点需要注意Dropout 位置通常在全连接层之后使用卷积层后也可以使用但概率通常较低Batch Normalization 的交互Dropout 和 BN 一起使用时可能需要调整学习率测试阶段PyTorch 的 Dropout 层会自动在 eval 模式下关闭学习率调整使用 Dropout 时可能需要更大的学习率或更长的训练时间以下是一个更完整的模型实现示例展示了如何在实践中应用这些技术class AdvancedMLP(nn.Module): def __init__(self, dropout_rates(0.2, 0.5)): super(AdvancedMLP, self).__init__() self.fc1 nn.Linear(28*28, 512) self.bn1 nn.BatchNorm1d(512) self.fc2 nn.Linear(512, 256) self.bn2 nn.BatchNorm1d(256) self.fc3 nn.Linear(256, 10) self.relu nn.ReLU() self.dropout1 nn.Dropout(dropout_rates[0]) self.dropout2 nn.Dropout(dropout_rates[1]) def forward(self, x): x x.view(-1, 28*28) x self.relu(self.bn1(self.fc1(x))) x self.dropout1(x) x self.relu(self.bn2(self.fc2(x))) x self.dropout2(x) x self.fc3(x) return x # 训练配置 model AdvancedMLP() optimizer optim.Adam(model.parameters(), lr0.001, weight_decay1e-4) # 内置L2正则化 scheduler optim.lr_scheduler.ReduceLROnPlateau(optimizer, min, patience2)这种组合了 Dropout、BatchNorm 和 L2 正则化的模型架构配合适当的学习率调度策略通常能够在保持模型表达能力的同时有效控制过拟合。

相关新闻

2024最新工具集:icanhaz 2.0如何彻底改变IP查询与网络诊断体验 [特殊字符]

2024最新工具集:icanhaz 2.0如何彻底改变IP查询与网络诊断体验 [特殊字符]

2024最新工具集:icanhaz 2.0如何彻底改变IP查询与网络诊断体验 🚀 【免费下载链接】icanhaz The code behind icanhaz 2.0 项目地址: https://gitcode.com/gh_mirrors/ic/icanhaz 网络诊断和IP查询是每个开发者和系统管理员日常工作中不可或缺的技…

2026/7/5 15:47:46阅读更多 →
BubbleTabBar自定义主题:打造品牌化UI设计的终极指南

BubbleTabBar自定义主题:打造品牌化UI设计的终极指南

BubbleTabBar自定义主题:打造品牌化UI设计的终极指南 【免费下载链接】BubbleTabBar BubbleTabBar is a bottom navigation bar with customizable bubble-like tabs 项目地址: https://gitcode.com/gh_mirrors/bu/BubbleTabBar BubbleTabBar是一款功能强大的…

2026/7/5 15:47:46阅读更多 →
python-snap7高级应用:多PLC并发通信与数据同步策略

python-snap7高级应用:多PLC并发通信与数据同步策略

python-snap7高级应用:多PLC并发通信与数据同步策略 【免费下载链接】python-snap7 a pure Python S7 communication library for interfacing with Siemens S7 PLCs 项目地址: https://gitcode.com/gh_mirrors/py/python-snap7 在工业自动化领域&#xff0c…

2026/7/5 15:47:46阅读更多 →
未来已来:FlagGems路线图曝光,这些新特性值得期待

未来已来:FlagGems路线图曝光,这些新特性值得期待

未来已来:FlagGems路线图曝光,这些新特性值得期待 【免费下载链接】FlagGems FlagGems is an operator library for large language models implemented in the Triton Language. 项目地址: https://gitcode.com/gh_mirrors/fl/FlagGems FlagGems…

2026/7/5 16:47:49阅读更多 →
JSON.simple异常处理指南:ParseException错误定位与调试技巧

JSON.simple异常处理指南:ParseException错误定位与调试技巧

JSON.simple异常处理指南:ParseException错误定位与调试技巧 【免费下载链接】json-simple A simple Java toolkit for JSON. You can use json-simple to encode or decode JSON text. 项目地址: https://gitcode.com/gh_mirrors/js/json-simple JSON.simpl…

2026/7/5 16:47:49阅读更多 →
todo[bot]源码深度剖析:核心算法与数据处理机制详解

todo[bot]源码深度剖析:核心算法与数据处理机制详解

todo[bot]源码深度剖析:核心算法与数据处理机制详解 【免费下载链接】todo 🤖✅ GitHub App that creates new issues from actionable comments in your code. 项目地址: https://gitcode.com/gh_mirrors/to/todo todo[bot]是一个基于GitHub平台…

2026/7/5 16:47:49阅读更多 →
Python因果推断实践:DoWhy 0.9 实现后门/前门调整与IPW,5步完成因果效应估计

Python因果推断实践:DoWhy 0.9 实现后门/前门调整与IPW,5步完成因果效应估计

Python因果推断实战:DoWhy 0.9实现后门/前门调整与IPW的5步完整流程当数据分析师需要回答"如果改变X,Y会如何变化"这类问题时,传统统计方法往往力不从心。这正是因果推断大显身手的领域——它不仅能揭示变量间的相关性,…

2026/7/5 16:47:49阅读更多 →
Objective-C-RegEx-Categories实战案例:5个场景教你轻松搞定字符串处理

Objective-C-RegEx-Categories实战案例:5个场景教你轻松搞定字符串处理

Objective-C-RegEx-Categories实战案例:5个场景教你轻松搞定字符串处理 【免费下载链接】Objective-C-RegEx-Categories NSRegularExpression extensions that make regular expressions easier in Objective-C, Swift, iOS, OSX 项目地址: https://gitcode.com/g…

2026/7/5 16:47:49阅读更多 →
DataMapper Core核心组件解析:Identity Map如何确保对象唯一性与内存优化

DataMapper Core核心组件解析:Identity Map如何确保对象唯一性与内存优化

DataMapper Core核心组件解析:Identity Map如何确保对象唯一性与内存优化 【免费下载链接】dm-core DataMapper - Core 项目地址: https://gitcode.com/gh_mirrors/dm/dm-core DataMapper Core是一个轻量级的对象关系映射(ORM)框架&am…

2026/7/5 16:42:49阅读更多 →
从GitHub安全案例解析常见漏洞与防护实践

从GitHub安全案例解析常见漏洞与防护实践

1. 项目概述:从GitHub Trending看安全实战 最近在GitHub Trending上看到一个项目,叫 skills4/skills ,它因为一些安全漏洞案例被大家讨论。这其实是一个挺典型的场景:一个旨在展示或教授某种技能的仓库,本身却成了安…

2026/7/5 0:01:08阅读更多 →
MLT 2026启示:因果推理与概率建模驱动下一代LLM应用

MLT 2026启示:因果推理与概率建模驱动下一代LLM应用

# MLT 2026启示:因果推理与概率建模驱动下一代LLM应用## 一、背景与挑战:从“黑箱预测”到“可信推理”2026年6月,第7届机器学习与趋势国际会议(MLT 2026)将在悉尼召开。会议议程中,“因果与可解释机器学习…

2026/7/5 0:01:08阅读更多 →
通达OA SQL注入漏洞深度剖析:从手工注入到自动化利用与防御

通达OA SQL注入漏洞深度剖析:从手工注入到自动化利用与防御

1. 项目概述与漏洞背景最近在梳理一些历史OA系统的安全风险时,通达OA v11.6版本中的一个老漏洞又进入了我的视线。这个漏洞位于/general/bi_design/appcenter/report_bi.func.php文件中,是一个典型的SQL注入点。虽然这个漏洞的利用方式看起来并不复杂&am…

2026/7/5 0:01:08阅读更多 →
从GitHub安全案例解析常见漏洞与防护实践

从GitHub安全案例解析常见漏洞与防护实践

1. 项目概述:从GitHub Trending看安全实战 最近在GitHub Trending上看到一个项目,叫 skills4/skills ,它因为一些安全漏洞案例被大家讨论。这其实是一个挺典型的场景:一个旨在展示或教授某种技能的仓库,本身却成了安…

2026/7/5 0:01:08阅读更多 →
MLT 2026启示:因果推理与概率建模驱动下一代LLM应用

MLT 2026启示:因果推理与概率建模驱动下一代LLM应用

# MLT 2026启示:因果推理与概率建模驱动下一代LLM应用## 一、背景与挑战:从“黑箱预测”到“可信推理”2026年6月,第7届机器学习与趋势国际会议(MLT 2026)将在悉尼召开。会议议程中,“因果与可解释机器学习…

2026/7/5 0:01:08阅读更多 →
通达OA SQL注入漏洞深度剖析:从手工注入到自动化利用与防御

通达OA SQL注入漏洞深度剖析:从手工注入到自动化利用与防御

1. 项目概述与漏洞背景最近在梳理一些历史OA系统的安全风险时,通达OA v11.6版本中的一个老漏洞又进入了我的视线。这个漏洞位于/general/bi_design/appcenter/report_bi.func.php文件中,是一个典型的SQL注入点。虽然这个漏洞的利用方式看起来并不复杂&am…

2026/7/5 0:01:08阅读更多 →
YOLOv8推理性能优化:从1.2FPS到35FPS的全链路加速实践

YOLOv8推理性能优化:从1.2FPS到35FPS的全链路加速实践

如果你在部署 YOLOv8 时,发现推理速度只有可怜的 1-2 FPS,而别人的演示视频却能跑到 30 FPS 以上,那么问题很可能不在模型本身,而在于你的整个处理链路。很多开发者拿到一个训练好的 YOLOv8 模型后,会直接使用官方示例…

2026/7/5 1:30:27阅读更多 →
Coze与Dify对比指南:低代码AI应用开发从入门到实战

Coze与Dify对比指南:低代码AI应用开发从入门到实战

1. 从零到一:为什么你需要了解 Coze 和 Dify?如果你对 AI 应用开发感兴趣,但一看到“大模型”、“智能体”、“工作流”这些词就头疼,觉得门槛太高,那这篇文章就是为你准备的。很多开发者,包括我自己&#…

2026/7/5 3:48:10阅读更多 →
AI生图工具怎么选?2026年6月版实测对比

AI生图工具怎么选?2026年6月版实测对比

做自媒体的朋友应该都有体会:配图一直是个让人头疼的问题。2026年,AI生图工具已经非常成熟了,但工具太多反而不知道怎么选。以下是截至2026年6月我对主流AI生图工具的实测对比。Midjourney V8.1:速度之王2026年6月11日&#xff0c…

2026/7/5 3:48:09阅读更多 →