(3)Gymnasium--CartPole的测试基于DQN

news/2023/11/28 0:03:58

 1、使用Pytorch基于DQN的实现

1.1 主要参考

(1)推荐pytorch官方的教程

Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 2.0.1+cu117 documentation

(2)

Pytorch 深度强化学习 – CartPole问题|极客笔记

2.2 pytorch官方的教程原理

待续,这两天时期多,过两天整理一下。

2.3代码实现

import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import countimport torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fenv = gym.make("CartPole-v1")# set up matplotlib
# is_ipython = 'inline' in matplotlib.get_backend()
# if is_ipython:
#     from IPython import displayplt.ion()# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))class ReplayMemory(object):def __init__(self, capacity):self.memory = deque([], maxlen=capacity)def push(self, *args):"""Save a transition"""self.memory.append(Transition(*args))def sample(self, batch_size):return random.sample(self.memory, batch_size)def __len__(self):return len(self.memory)class DQN(nn.Module):def __init__(self, n_observations, n_actions):super(DQN, self).__init__()self.layer1 = nn.Linear(n_observations, 128)self.layer2 = nn.Linear(128, 128)self.layer3 = nn.Linear(128, n_actions)# Called with either one element to determine next action, or a batch# during optimization. Returns tensor([[left0exp,right0exp]...]).def forward(self, x):x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))return self.layer3(x)# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)steps_done = 0def select_action(state):global steps_donesample = random.random()eps_threshold = EPS_END + (EPS_START - EPS_END) * \math.exp(-1. * steps_done / EPS_DECAY)steps_done += 1if sample > eps_threshold:with torch.no_grad():# t.max(1) will return the largest column value of each row.# second column on max result is index of where max element was# found, so we pick action with the larger expected reward.return policy_net(state).max(1)[1].view(1, 1)else:return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)episode_durations = []def plot_durations(show_result=False):plt.figure(1)durations_t = torch.tensor(episode_durations, dtype=torch.float)if show_result:plt.title('Result')else:plt.clf()plt.title('Training...')plt.xlabel('Episode')plt.ylabel('Duration')plt.plot(durations_t.numpy())# Take 100 episode averages and plot them tooif len(durations_t) >= 100:means = durations_t.unfold(0, 100, 1).mean(1).view(-1)means = torch.cat((torch.zeros(99), means))plt.plot(means.numpy())plt.pause(0.001)  # pause a bit so that plots are updated# if is_ipython:#     if not show_result:#         display.display(plt.gcf())#         display.clear_output(wait=True)#     else:#         display.display(plt.gcf())def optimize_model():if len(memory) < BATCH_SIZE:returntransitions = memory.sample(BATCH_SIZE)# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for# detailed explanation). This converts batch-array of Transitions# to Transition of batch-arrays.batch = Transition(*zip(*transitions))# Compute a mask of non-final states and concatenate the batch elements# (a final state would've been the one after which simulation ended)non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)), device=device, dtype=torch.bool)non_final_next_states = torch.cat([s for s in batch.next_stateif s is not None])state_batch = torch.cat(batch.state)action_batch = torch.cat(batch.action)reward_batch = torch.cat(batch.reward)# Compute Q(s_t, a) - the model computes Q(s_t), then we select the# columns of actions taken. These are the actions which would've been taken# for each batch state according to policy_netstate_action_values = policy_net(state_batch).gather(1, action_batch)# Compute V(s_{t+1}) for all next states.# Expected values of actions for non_final_next_states are computed based# on the "older" target_net; selecting their best reward with max(1)[0].# This is merged based on the mask, such that we'll have either the expected# state value or 0 in case the state was final.next_state_values = torch.zeros(BATCH_SIZE, device=device)with torch.no_grad():next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]# Compute the expected Q valuesexpected_state_action_values = (next_state_values * GAMMA) + reward_batch# Compute Huber losscriterion = nn.SmoothL1Loss()loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))# Optimize the modeloptimizer.zero_grad()loss.backward()# In-place gradient clippingtorch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)optimizer.step()if torch.cuda.is_available():num_episodes = 600
else:# num_episodes = 50num_episodes = 600for i_episode in range(num_episodes):# Initialize the environment and get it's statestate, info = env.reset()state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)for t in count():action = select_action(state)observation, reward, terminated, truncated, _ = env.step(action.item())reward = torch.tensor([reward], device=device)done = terminated or truncatedif terminated:next_state = Noneelse:next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)# Store the transition in memorymemory.push(state, action, next_state, reward)# Move to the next statestate = next_state# Perform one step of the optimization (on the policy network)optimize_model()# Soft update of the target network's weights# θ′ ← τ θ + (1 −τ )θ′target_net_state_dict = target_net.state_dict()policy_net_state_dict = policy_net.state_dict()for key in policy_net_state_dict:target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)target_net.load_state_dict(target_net_state_dict)if done:episode_durations.append(t + 1)plot_durations()breakprint('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()


http://www.ppmy.cn/news/994258.html

相关文章

【前端基础】通过 JSON.parse(JSON.stringify(obj)) 来实现深拷贝

最简便的实现方式是使用 JSON.parse(JSON.stringify(obj)) 来进行深拷贝。这种方法可以很方便地实现深拷贝&#xff0c;并且能够确保两个对象之间不会相互影响。 可以像下面这样使用 JSON.parse(JSON.stringify(obj)) 来实现深拷贝&#xff1a; // 使用 JSON.parse(JSON.stri…

【Spring】Spring之循环依赖底层源码解析

什么是循环依赖 A依赖了B&#xff0c;B依赖了A。 示例&#xff1a; // A依赖了B class A{public B b; }// B依赖了A class B{public A a; }其实&#xff0c;循环依赖并不是问题&#xff0c;因为对象之间相互依赖是很正常的事情。示例&#xff1a; A a new A(); B b new B…

重生之我要学C++第六天

这篇文章的主要内容是const以及权限问题、static关键字、友元函数和友元类&#xff0c;希望对大家有所帮助&#xff0c;点赞收藏评论支持一下吧&#xff01; 更多优质内容跳转&#xff1a; 专栏&#xff1a;重生之C启程(文章平均质量分93) 目录 const以及权限问题 1.const修饰…

外观模式-提供统一入口

在办理流量套餐的时候&#xff0c;我们可以去官网查查询套餐&#xff0c;找出符合我们需求的套餐&#xff0c;然后进行办理。官网是充斥着各种各样的套餐&#xff0c;如果我们一个个去查看这些套餐&#xff0c;耗费的时间很多。而且在办理套餐的时候&#xff0c;步骤也比较多。…

【Django学习】(十六)session_token认证过程与区别_响应定制

一、认识session与token 这里就直接引用别人的文章&#xff0c;不做过多说明 网络应用中session和token本质是一样的吗&#xff0c;有什么区别&#xff1f; - 知乎 二、token响应定制 在全局配置表中配置 DEFAULT_AUTHENTICATION_CLASSES: [# 指定jwt Token认证rest_framew…

Hutool工具类FileUtil----文件(夹)创建、删除、添加数据

1.文件(夹)创建 //创建文件&#xff0c;多级目录会循环创建出来String path "d:/hutool_test/hutool_test.txt";File touch FileUtil.touch("d:/hutool_test/hutool_test.txt");2.文件&#xff08;夹&#xff09;的校验 boolean isFile FileUtil.isFil…

pyspark 判断 Hive 表是否存在

Catalog.tableExists(tableName: str, dbName: Optional[str] None) → booltableName:表名 dbName&#xff1a;库名(可选) return&#xff1a;bool 值 spark SparkSession \.builder \.appName(tableExists) \.config(spark.num.executors, 6) \.config(spark.executor.memo…

RISC-V 指令集介绍

1. 背景介绍 指令集从本质上可以分为复杂指令集&#xff08;Complex Instruction Set Computer&#xff0c;CISC&#xff09;和精简指令集&#xff08;Reduced Instruction Set Computer&#xff0c;RISC&#xff09;两种。复杂指令集的特点是能够在一条指令内完成很多事情。 指…

Linux操作系统学习,Linux基础命令大全

目录 第一章、Linux简介和安装1.1&#xff09;Linux简介和分类1.2&#xff09;安装VMware虚拟机&#xff0c;在虚拟机中安装CentOS 7 第二章、虚拟机中Linux的IP地址配置详解2.1&#xff09;什么是IP地址&#xff0c;如何查看2.2&#xff09;虚拟机NAT模式中Linux的IP地址设置有…

单机版Antelope版本的OpenStack自动化安装

作者&#xff1a;吴业亮 博客&#xff1a;wuyeliang.blog.csdn.net 当前版本的支持安装单机版Antelope版本的OpenStack。部署脚本基于python3写的。操作系统基于Ubuntu 22.04.2 LTS。 一、基础配置。 1、安装操作系统&#xff0c;安装完成之后除了可上网&#xff0c;无需做任…

【多模态】23、RO-ViT | 基于 Transformer 的开发词汇目标检测(CVPR2023)

文章目录 一、背景二、方法2.1 基础内容2.2 Region-aware Image-text Pretraining2.3 Open-vocabulary Detector Finetuning 三、效果3.1 细节3.2 开放词汇目标检测效果3.3 Image-text retrieval3.4 Transfer object detection3.5 消融实验 论文&#xff1a;Region-Aware Pretr…

Vue 入门(一)

一、注意 Vue 不支持 IE8 及以下的版本&#xff0c;因为 Vue 使用了 IE8 无法模拟的 ECMAScript 5 特性&#xff0c;但它支持所有兼容 ECMAScript 5 的浏览器。 二、创建 Hello Vue Vue.js 的核心是实现了 MVVM 模式&#xff0c;它扮演的角色就是 ViewModel 层&#xff0c;那么…

面向对象中的多态性

一、权限修饰符 public, 缺省&#xff0c; protected&#xff0c;private 二、this和super关键字 this:表示当前对象 super:表示父类声明的成员 原则&#xff1a;遵循就近原则和追根溯源原则。 三、Object类 java.lang.Object类是所有java类的超类&#xff0c;即所有的J…

ZEPHYR 快速开发指南

简介 国内小伙伴在学习zephyr的时候&#xff0c;有以下几个痛点&#xff1a; 学习门槛过高github访问不畅&#xff0c;下载起来比较费劲。 这篇文章将我自己踩的坑介绍一下&#xff0c;顺便给大家优化一些地方&#xff0c;避免掉所有的坑。 首先用virtualbox 来安装一个ubu…

Linux | cramfs下载、安装

Linux | cramfs下载、安装 1.下载 cramfs-1.1.tar.gz 2.编译、安装 [fly752fac4b02e9 eeasy]$ tar zxf cramfs-1.1.tar.gz [fly752fac4b02e9 eeasy]$ cd cramfs-1.1/ [fly752fac4b02e9 cramfs-1.1]$ ls COPYING cramfsck.c GNUmakefile linux mkcramfs.c NOTES README …

Android 面试题 应用对内存是如何限制 八

&#x1f525; OutOfMemeryError的原因 &#x1f525; Android 针对每个应用有内存限制 , 当JVM因为没有足够的内存来为对象分配空间并且垃圾回收器也已经没有空间可回收时&#xff0c;就会抛出这个error&#xff08;注&#xff1a;非exception&#xff0c;因为这个问题已经严…

服务机器人有哪些品类

服务机器人是指具备自主运动、感知环境、实现人机交互等能力的机器人&#xff0c;它可以被应用于不同的场景&#xff0c;如餐饮、医疗、物流等行业。根据其功能和应用场景的不同&#xff0c;服务机器人可以分为以下几类&#xff1a;1. 餐饮服务机器人 随着社会发展和人们需…

JVM、Redis、反射

JVM JVM是Java virtual machine&#xff08;Java虚拟机&#xff09;的缩写&#xff0c;是一种用于计算机的规范&#xff0c;是通过在实际计算机上仿真模拟各种计算机功能来实现的。 主要组件构成&#xff1a; 1.类加载器 子系统负责从文件系统或者网络中加载Class文件&…

openGauss学习笔记-21 openGauss 简单数据管理-GROUP BY子句

文章目录 openGauss学习笔记-21 openGauss 简单数据管理-GROUP BY子句21.1 语法格式21.2 参数说明21.3 示例 openGauss学习笔记-21 openGauss 简单数据管理-GROUP BY子句 GROUP BY语句和SELECT语句一起使用&#xff0c;用来对相同的数据进行分组。您可以对一列或者多列进行分组…

对角线遍历——力扣498

文章目录 题目描述法一 直接模拟 题目描述 法一 直接模拟 class Solution { public:vector<int> findDiagonalOrder(vector<vector<int>>& mat){int mmat.size(), nmat[0].size();vector<int> res;for(int i0;i<mn-1;i){if(i%2){int x i<n …
最新文章