DQN 2015 Nature 论文复现:Atari Pong 游戏 84x84 像素输入实战(附 PyTorch 代码)
DQN 2015 Nature 论文复现Atari Pong 游戏 84x84 像素输入实战附 PyTorch 代码当DeepMind在2015年首次提出DQN算法并在Nature上发表时整个强化学习领域为之震动。这项研究首次证明一个单一的深度强化学习智能体能够在数十款Atari 2600游戏中达到人类水平的表现。本文将带您从零开始使用现代PyTorch框架完整复现这一里程碑式的工作特别聚焦于Atari Pong游戏的实现细节。1. 环境配置与预处理在开始构建DQN之前我们需要先搭建适合的训练环境。Atari游戏的原始输入为210×160像素的RGB图像这对计算资源提出了较高要求。遵循原论文的方法我们将进行以下预处理import gym import numpy as np from collections import deque import torch import torch.nn as nn import torch.optim as optim class AtariPreprocessor: def __init__(self, env_name, frame_skip4, history_length4): self.env gym.make(env_name) self.frame_skip frame_skip self.history deque(maxlenhistory_length) def reset(self): frame self.env.reset() processed self._process_frame(frame) for _ in range(self.history.maxlen): self.history.append(processed) return np.stack(self.history) def step(self, action): total_reward 0.0 for _ in range(self.frame_skip): frame, reward, done, info self.env.step(action) total_reward reward if done: break processed self._process_frame(frame) self.history.append(processed) return np.stack(self.history), total_reward, done, info def _process_frame(self, frame): # 转换为灰度图并调整大小 frame frame.mean(axis2) # RGB转灰度 frame frame[34:34160, :160] # 裁剪得分区域 frame frame[::2, ::2] # 下采样到80x80 return frame.astype(np.float32) / 255.0关键预处理步骤包括帧堆叠将连续的4帧堆叠作为网络输入提供时序信息帧跳过每4帧执行一次动作提高训练效率图像裁剪移除不相关的屏幕区域如得分显示灰度转换将RGB三通道简化为单通道归一化将像素值缩放到[0,1]范围注意原论文使用84×84分辨率但实际实现中80×80也是常见选择。确保测试时与训练分辨率一致。2. DQN网络架构设计DQN的核心是一个深度卷积神经网络其架构设计直接影响了特征提取能力。以下是PyTorch实现class DQN(nn.Module): def __init__(self, action_dim): super(DQN, self).__init__() self.conv1 nn.Conv2d(4, 32, kernel_size8, stride4) self.conv2 nn.Conv2d(32, 64, kernel_size4, stride2) self.conv3 nn.Conv2d(64, 64, kernel_size3, stride1) self.fc1 nn.Linear(7*7*64, 512) self.fc2 nn.Linear(512, action_dim) def forward(self, x): x x.float() / 255.0 # 确保输入归一化 x torch.relu(self.conv1(x)) x torch.relu(self.conv2(x)) x torch.relu(self.conv3(x)) x x.view(x.size(0), -1) # 展平 x torch.relu(self.fc1(x)) return self.fc2(x)网络结构参数对比如下层类型参数输出尺寸激活函数卷积层132个8×8滤波器步长420×20×32ReLU卷积层264个4×4滤波器步长29×9×64ReLU卷积层364个3×3滤波器步长17×7×64ReLU全连接层1512单元512ReLU输出层动作空间维度action_dim线性3. 经验回放与目标网络DQN的两个关键创新点需要特别实现class ReplayBuffer: def __init__(self, capacity): self.buffer deque(maxlencapacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): indices np.random.choice(len(self.buffer), batch_size, replaceFalse) states, actions, rewards, next_states, dones zip(*[self.buffer[idx] for idx in indices]) return ( torch.FloatTensor(np.array(states)), torch.LongTensor(np.array(actions)), torch.FloatTensor(np.array(rewards)), torch.FloatTensor(np.array(next_states)), torch.FloatTensor(np.array(dones)) ) def __len__(self): return len(self.buffer) class DQNAgent: def __init__(self, action_dim, lr1e-4, gamma0.99, tau1e-3): self.policy_net DQN(action_dim) self.target_net DQN(action_dim) self.target_net.load_state_dict(self.policy_net.state_dict()) self.optimizer optim.Adam(self.policy_net.parameters(), lrlr) self.gamma gamma self.tau tau def update_target(self): # 软更新目标网络 for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()): target_param.data.copy_( self.tau * policy_param.data (1.0 - self.tau) * target_param.data ) def get_action(self, state, epsilon): if np.random.random() epsilon: return np.random.randint(self.policy_net.fc2.out_features) with torch.no_grad(): q_values self.policy_net(state.unsqueeze(0)) return q_values.argmax().item()经验回放和目标网络的作用经验回放打破数据相关性提高样本效率目标网络稳定训练过程防止Q值过高估计软更新缓慢更新目标网络参数τ通常取0.0014. 完整训练流程将上述组件整合为完整的训练系统def train_dqn(env_namePongNoFrameskip-v4, batch_size32, buffer_size100000, total_steps1000000, learning_starts10000, target_update1000, gamma0.99, epsilon_start1.0, epsilon_end0.1, epsilon_decay100000): env AtariPreprocessor(env_name) agent DQNAgent(env.env.action_space.n) buffer ReplayBuffer(buffer_size) state env.reset() episode_reward 0 total_rewards [] epsilon epsilon_start for step in range(1, total_steps 1): # ε-贪心策略 epsilon epsilon_end (epsilon_start - epsilon_end) * \ np.exp(-1. * step / epsilon_decay) # 选择并执行动作 action agent.get_action(torch.FloatTensor(state), epsilon) next_state, reward, done, _ env.step(action) episode_reward reward # 存储转移样本 buffer.push(state, action, reward, next_state, done) state next_state # 训练阶段 if len(buffer) learning_starts and step % 4 0: batch buffer.sample(batch_size) states, actions, rewards, next_states, dones batch # 计算当前Q值 current_q agent.policy_net(states).gather(1, actions.unsqueeze(1)) # 计算目标Q值 with torch.no_grad(): next_q agent.target_net(next_states).max(1)[0] target_q rewards (1 - dones) * gamma * next_q # 计算损失并更新 loss nn.MSELoss()(current_q.squeeze(), target_q) agent.optimizer.zero_grad() loss.backward() agent.optimizer.step() # 更新目标网络 if step % target_update 0: agent.update_target() # 回合结束处理 if done: total_rewards.append(episode_reward) print(fStep: {step}, Reward: {episode_reward}, Epsilon: {epsilon:.2f}) state env.reset() episode_reward 0 return total_rewards训练过程中的关键参数设置参数推荐值作用batch_size32每次更新的样本数量buffer_size100,000经验回放缓存大小learning_starts10,000开始学习前的随机探索步数target_update1,000目标网络更新频率gamma0.99未来奖励折扣因子epsilon_start1.0初始探索率epsilon_end0.1最终探索率epsilon_decay100,000探索率衰减步数5. 训练技巧与性能优化在实际训练中以下几个技巧可以显著提升性能奖励裁剪将正奖励设为1负奖励设为-1有助于不同游戏间的泛化reward np.clip(reward, -1, 1)帧差分处理取连续帧的最大值消除Atari游戏的闪烁效果frame np.maximum(frame, last_frame)梯度裁剪防止梯度爆炸稳定训练过程for param in agent.policy_net.parameters(): param.grad.data.clamp_(-1, 1)学习率调度随着训练进展降低学习率scheduler optim.lr_scheduler.StepLR(agent.optimizer, step_size250000, gamma0.1)在Pong游戏中典型的训练曲线会经历以下阶段随机探索期0-10k步智能体随机移动胜率约50%初步学习期10k-100k步开始学习基本击球策略策略优化期100k-500k步发展出位置控制和反击策略稳定表现期500k步达到人类水平胜率超过90%6. 结果评估与可视化训练完成后我们需要评估智能体的实际表现def evaluate(agent, env, episodes10): total_rewards [] for _ in range(episodes): state env.reset() episode_reward 0 done False while not done: action agent.get_action(torch.FloatTensor(state), epsilon0.05) state, reward, done, _ env.step(action) episode_reward reward env.render() # 可视化游戏过程 total_rewards.append(episode_reward) return np.mean(total_rewards)对于Pong游戏成功的训练应能达到以下指标指标预期值说明平均奖励18每局21分制达到人类水平胜率90%对阵内置AI的获胜概率训练时间8-12小时使用现代GPU如RTX 30807. 进阶改进方向原始DQN虽然强大但仍有改进空间。以下是几个值得尝试的扩展Double DQN减少Q值高估问题next_actions agent.policy_net(next_states).max(1)[1] next_q agent.target_net(next_states).gather(1, next_actions.unsqueeze(1))优先经验回放更高效地利用重要样本td_error (current_q - target_q).abs() priority (td_error 1e-5).pow(alpha)Dueling架构分离状态价值和优势函数class DuelingDQN(nn.Module): def __init__(self, action_dim): super().__init__() # 共享特征提取层 self.feature nn.Sequential(...) # 价值流 self.value nn.Linear(512, 1) # 优势流 self.advantage nn.Linear(512, action_dim) def forward(self, x): features self.feature(x) value self.value(features) advantage self.advantage(features) return value advantage - advantage.mean()在实际项目中我发现使用Dueling架构能显著提升Pong游戏的训练速度通常在200k步左右就能达到不错的表现。而优先回放则在更复杂的游戏中效果更为明显。

相关新闻

无刷直流电机 PWM 控制实战:50kHz 频率下电流纹波降低 70% 的 3 个关键参数

无刷直流电机 PWM 控制实战:50kHz 频率下电流纹波降低 70% 的 3 个关键参数

无刷直流电机 PWM 控制实战:50kHz 频率下电流纹波降低 70% 的 3 个关键参数在医疗机器人、精密仪器等高精度应用场景中,无刷直流电机的电流纹波控制直接关系到系统寿命和运行稳定性。Portescap 实验室数据显示,当 PWM 频率从 20kHz 提升至 50…

2026/7/6 0:38:41阅读更多 →
TensorFlow Datasets 加载 Omniglot:3分钟完成数据预处理与 50 种字母表可视化

TensorFlow Datasets 加载 Omniglot:3分钟完成数据预处理与 50 种字母表可视化

TensorFlow Datasets 高效加载 Omniglot:从数据预处理到多语言字符可视化实战在深度学习项目中,数据准备环节往往消耗开发者大量时间。本文将展示如何利用TensorFlow Datasets(TFDS)这一官方工具,快速完成Omniglot数据…

2026/7/6 0:38:41阅读更多 →
PyTorch 2.0+ Dataset 实战:3种常见数据源(CSV/文件夹/内存)的加载与性能对比

PyTorch 2.0+ Dataset 实战:3种常见数据源(CSV/文件夹/内存)的加载与性能对比

PyTorch 2.0 多源数据加载实战:从CSV到内存Tensor的高效处理方案1. 为什么需要关注数据加载性能?在深度学习项目生命周期中,数据准备和处理通常占据70%以上的时间成本。PyTorch 2.0 虽然大幅提升了模型训练效率,但数据加载环节的瓶…

2026/7/6 0:38:41阅读更多 →
高并发秒杀三大核心技术实战

高并发秒杀三大核心技术实战

在构建高并发秒杀系统时,确保系统在高流量冲击下仍能保持高性能、高可用和数据一致性是核心目标。经过对业界主流方案的梳理,可以提炼出三大核心技术支柱:原子性库存扣减、分布式锁防超卖、以及异步消息队列解耦。下面将结合具体技术实现和实…

2026/7/6 1:48:45阅读更多 →
2026国内企业级智能体推荐:6款主流产品功能、适用场景全对比

2026国内企业级智能体推荐:6款主流产品功能、适用场景全对比

一、赛道速览 企业级智能体按能力分为两类: 对话知识型:问答、文档总结、信息检索(多数产品止步于此)业务执行型:能操作系统、填表单、跨系统搬数据,完成端到端流程 本文聚焦业务执行型。当前实现路径主要有…

2026/7/6 1:48:45阅读更多 →
关于Matlab今天我只说三点

关于Matlab今天我只说三点

matlab coder 、matlab compiler 和matlab compiler SDKMATLAB Coder 代码转换:将MATLAB代码转换为可读的、可移植的C/C代码。C/C源文件、静态库、动态库或MEX文件。无需MATLAB运行时,可在任何支持ANSI/ISO C/C的平台上编译运行。MATLAB Compiler 应用打…

2026/7/6 1:48:45阅读更多 →
RTX 3060 深度学习环境:CUDA 11.1 vs 11.8 版本选择与性能实测对比

RTX 3060 深度学习环境:CUDA 11.1 vs 11.8 版本选择与性能实测对比

RTX 3060 深度学习环境:CUDA 11.1 vs 11.8 版本选择与性能实测对比1. 硬件与软件基础环境搭建RTX 3060作为NVIDIA Ampere架构的中端显卡,拥有3584个CUDA核心和12GB GDDR6显存,是性价比极高的深度学习开发选择。但在实际使用中,CUD…

2026/7/6 1:48:45阅读更多 →
认真聊聊并发编程的10个坑

认真聊聊并发编程的10个坑

对于从事后端开发的同学来说,并发编程肯定再熟悉不过了。 说实话,在java中并发编程是一大难点,至少我是这么认为的。不光理解起来比较费劲,使用起来更容易踩坑。 不信,让继续往下面看。 今天重点跟大家一起聊聊并发…

2026/7/6 1:48:45阅读更多 →
PPG vs PPO:3 大核心差异解析与 2 阶段训练机制对样本效率的影响

PPG vs PPO:3 大核心差异解析与 2 阶段训练机制对样本效率的影响

PPG vs PPO:3 大核心差异解析与 2 阶段训练机制对样本效率的影响深度强化学习领域近年来涌现出多种改进算法,其中PPG(Phasic Policy Gradient)作为PPO(Proximal Policy Optimization)的进阶版本&#xff0c…

2026/7/6 1:43:45阅读更多 →
从GitHub安全案例解析常见漏洞与防护实践

从GitHub安全案例解析常见漏洞与防护实践

1. 项目概述:从GitHub Trending看安全实战 最近在GitHub Trending上看到一个项目,叫 skills4/skills ,它因为一些安全漏洞案例被大家讨论。这其实是一个挺典型的场景:一个旨在展示或教授某种技能的仓库,本身却成了安…

2026/7/5 0:01:08阅读更多 →
MLT 2026启示:因果推理与概率建模驱动下一代LLM应用

MLT 2026启示:因果推理与概率建模驱动下一代LLM应用

# MLT 2026启示:因果推理与概率建模驱动下一代LLM应用## 一、背景与挑战:从“黑箱预测”到“可信推理”2026年6月,第7届机器学习与趋势国际会议(MLT 2026)将在悉尼召开。会议议程中,“因果与可解释机器学习…

2026/7/5 0:01:08阅读更多 →
通达OA SQL注入漏洞深度剖析:从手工注入到自动化利用与防御

通达OA SQL注入漏洞深度剖析:从手工注入到自动化利用与防御

1. 项目概述与漏洞背景最近在梳理一些历史OA系统的安全风险时,通达OA v11.6版本中的一个老漏洞又进入了我的视线。这个漏洞位于/general/bi_design/appcenter/report_bi.func.php文件中,是一个典型的SQL注入点。虽然这个漏洞的利用方式看起来并不复杂&am…

2026/7/6 0:10:35阅读更多 →
Seraphine:基于LCU API的英雄联盟智能游戏助手技术解析与应用指南

Seraphine:基于LCU API的英雄联盟智能游戏助手技术解析与应用指南

Seraphine:基于LCU API的英雄联盟智能游戏助手技术解析与应用指南 【免费下载链接】Seraphine 英雄联盟战绩查询工具 项目地址: https://gitcode.com/gh_mirrors/se/Seraphine 技术架构先行:官方接口的合规应用 你是否曾在BP阶段手忙脚乱&#x…

2026/7/6 0:03:39阅读更多 →
多协议远程连接管理工具mRemoteNG:告别混乱,统一你的远程桌面管理

多协议远程连接管理工具mRemoteNG:告别混乱,统一你的远程桌面管理

多协议远程连接管理工具mRemoteNG:告别混乱,统一你的远程桌面管理 【免费下载链接】mRemoteNG mRemoteNG is the next generation of mRemote, open source, tabbed, multi-protocol, remote connections manager. 项目地址: https://gitcode.com/gh_m…

2026/7/6 0:03:39阅读更多 →
COUNT(DISTINCT) 与 GROUP BY 去重统计:5 亿数据量下的性能实测与选型指南

COUNT(DISTINCT) 与 GROUP BY 去重统计:5 亿数据量下的性能实测与选型指南

COUNT(DISTINCT) 与 GROUP BY 去重统计:5 亿数据量下的性能实测与选型指南在数据分析和处理领域,去重统计是最基础也是最频繁使用的操作之一。当数据量达到亿级规模时,不同的去重统计方法在性能上可能产生天壤之别。本文将基于 5 亿行数据的实…

2026/7/6 0:03:39阅读更多 →
YOLOv8推理性能优化:从1.2FPS到35FPS的全链路加速实践

YOLOv8推理性能优化:从1.2FPS到35FPS的全链路加速实践

如果你在部署 YOLOv8 时,发现推理速度只有可怜的 1-2 FPS,而别人的演示视频却能跑到 30 FPS 以上,那么问题很可能不在模型本身,而在于你的整个处理链路。很多开发者拿到一个训练好的 YOLOv8 模型后,会直接使用官方示例…

2026/7/5 1:30:27阅读更多 →
Coze与Dify对比指南:低代码AI应用开发从入门到实战

Coze与Dify对比指南:低代码AI应用开发从入门到实战

1. 从零到一:为什么你需要了解 Coze 和 Dify?如果你对 AI 应用开发感兴趣,但一看到“大模型”、“智能体”、“工作流”这些词就头疼,觉得门槛太高,那这篇文章就是为你准备的。很多开发者,包括我自己&#…

2026/7/5 3:48:10阅读更多 →
AI生图工具怎么选?2026年6月版实测对比

AI生图工具怎么选?2026年6月版实测对比

做自媒体的朋友应该都有体会:配图一直是个让人头疼的问题。2026年,AI生图工具已经非常成熟了,但工具太多反而不知道怎么选。以下是截至2026年6月我对主流AI生图工具的实测对比。Midjourney V8.1:速度之王2026年6月11日&#xff0c…

2026/7/5 3:48:09阅读更多 →