• 欢迎访问开心洋葱网站,在线教程,推荐使用最新版火狐浏览器和Chrome浏览器访问本网站,欢迎加入开心洋葱 QQ群
  • 为方便开心洋葱网用户,开心洋葱官网已经开启复制功能!
  • 欢迎访问开心洋葱网站,手机也能访问哦~欢迎加入开心洋葱多维思维学习平台 QQ群
  • 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏开心洋葱吧~~~~~~~~~~~~~!
  • 由于近期流量激增,小站的ECS没能经的起亲们的访问,本站依然没有盈利,如果各位看如果觉着文字不错,还请看官给小站打个赏~~~~~~~~~~~~~!

深度强化学习专栏 —— 3.实现一阶倒立摆

人工智能 bug404 1809次浏览 0个评论

深度强化学习专栏 —— 3.实现一阶倒立摆  


  • 这是今天我们要实现的目标。在上一篇文章深度强化学习专栏 —— 2.手撕DQN算法实现CartPole控制中,我们已经根据论文从头实现了一个DQN算法,准确的说是MlpDQN(另一种是CnnDQN),即多层感知DQN,因为在神经网络部分,我们使用的不是论文中描述的卷积网络,而是全连接的多层感知机。为什么使用多层感知机而不是卷积网络呢?很简单呢,我们是将CartPole的位置,速度,杆的角度,杆的角速度传入神经网络,神经网络输出动作的概率,这种向量输入,向量输出的模式,适合多层感知机而不是卷积网络,卷积网络一般是对图像信息进行操作。在后面,我们会直接获取cartpole的运行图像,经卷积网络输出动作的概率的尝试,敬请期待。
  • 我们继续第二篇中的内容。第二篇还留了两个问题: 五、不修改奖励函数,使用环境默认的奖励函数,如何达到较高的性能? 六、将编写成功的算法应用到倒立摆上。
  • 五、不修改奖励函数,使用环境默认的奖励函数,如何达到较高的性能? 经过尝试,我们发现我们手撕的算法并不适合在不修改奖励函数的情况下得到较高的reward,所以我们需要换个方法。在手撕算法的对立面,就是使用强化学习算法库(rl libraries),那么rl libraries、openai gym、environment之间是什么关系呢

下面这张图很好的总结了它们之间的关系。我们可用的强化学习算法库有很多,比如我喜欢用的stable baselines3(这是pytorch版本)、RLLIb(Ray)、baselines、spinningup等等,但是从易用性以及文档全面性来说,stable baselines都是不错的,虽然RLLib是与Ray深度绑定的,支持多机分布式大规模计算,但是对于单机计算来说,配置就显得很麻烦,而且文档对于我来说,显得较难理解,且Google上资料较少。再往后面就是OpenAI Gym,对于强化学习的训练,gym是一道跨不过去的坎,它提供了step()、reset()、render()等函数接口,强化学习算法直接调用这几个接口就可以完成训练过程,即不管对于什么样的环境、用的什么物理引擎,只要编写完成这几个接口,就可以实现算法应用到环境上进行训练,屏蔽了gym之后的不同。对于强化学习的训练来说,要么是游戏、要么是机器人,对于有物理接触的来说,通常都需要带有某个物理引擎的刚体动力学仿真软件,物理引擎负责计算刚体之间的运动学与动力学,虽然每个仿真软件的API都不同,即控制电机、获取图像等各个操作的函数都不同,但是有gym的存在,对于控制机器人的每个动作,都需要封装到gym的函数中,这样就屏蔽了不同仿真软件之间的差异。  
  说完了他们之间的关系,下面我们就是使用一个强化学习算法库尝试一下在不修改奖励函数的情况下,能不能达到很高的reward。其实在开始之前,我已经有信心会了,只是可能会经历非常多的训练过程。那么还是选我比较喜欢的stable baselines3吧。stable baselines3的文档在这里。 stable baselines3的安装非常简单,会自动同时安装pytorch>=1.4版本。(题外话:pip建议使用豆瓣源,清华源pip安装torch经常会报超时错误,而豆瓣源基本可以10M/s左右完成下载。换pip的源请点击这里,我都已经帮你们准备好了)  

pip install stable-baselines3

  安装其他组件,做数据处理这块,我喜欢用jupyter notebook、jupyter lab或者spyder,所以我选择安装了Spyder。  

pip install gym spyder spyder-terminal numpy

  从stable baselines3的DQN算法文档中,我们可以看到使用DQN训练一个智能体的代码是怎样的。我们将代码复制过来,稍作修改。  

import gym
import numpy as np

from stable_baselines3 import DQN
from stable_baselines3.dqn import MlpPolicy

env = gym.make('CartPole-v0')

model = DQN(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("dqn_pendulum")

obs = env.reset()
while True:
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

  但是设置total_timesteps=10000 对于训练来说是无济于事的,我们将其设置为30万,结果如下面动画,看起来效果还不是很好。  
深度强化学习专栏 —— 3.实现一阶倒立摆 我们将其增加到100万,看一下效果。
深度强化学习专栏 —— 3.实现一阶倒立摆 卧槽,怎么还是这么差???什么原因???有小伙伴能告知一下不?? 算了,用PPO算法试一下。  

import gym
from stable_baselines3 import PPO

env=gym.make('CartPole-v0')
model=PPO('MlpPolicy',env,verbose=1)
model.learn(total_timesteps=500000)

obs=env.reset()
for i in range(9000):
    action,_state=model.predict(obs,deterministic=True)
    obs,reward,done,info=env.step(action)
    env.render()
    if done:
        obs=env.reset()

深度强化学习专栏 —— 3.实现一阶倒立摆 卧槽,效果怎么这么好!!! PPO牛逼就对了!!!


  • 六、将编写成功的算法应用到倒立摆上 解决了CartPole的问题,接下来就是让倒立摆也运行起来了。不管是训练CartPole、超级玛丽、机器人,环境是必要的。一个环境主要包含了智能体的动作选择、动力学、运动学、奖励函数的设计、图像的渲染等。CartPole的环境封装在gym的源代码cartpole.py中,这个环境OpenAI已经帮我们编写好了,所以我们得以直接训练。而倒立摆的环境没有官方的封装,所以我们要不自己编写,要不使用别人已经实现的开源,我在网上找了一个相对较好的封装,其介绍请戳这里,里面介绍了奖励函数是怎样设计的以及他们做的一些算法工作,源代码请戳这里。现在有了环境,而且环境都符合gym规范,所以可以无缝的使用cartpole的代码,只把环境名称换一下即可。 首先将环境安装一下(见环境的github说明)

 

git clone https://github.com/jfpettit/cartpole-swingup-envs.git
pip install -e cartpole-swingup-envs

  安装完成之后,我们就可以像使用gym的内置环境一样来使用这个倒立摆的环境。下面代码是其github主页的example。  

import gym
import cartpole_swingup_envs
continuous_env = gym.make('CartPoleSwingUpContinuous-v0')  #离散动作的环境
discrete_env = gym.make('CartPoleSwingUpDiscrete-v0')  #连续动作的环境

  我们将自己手撕的DQN算法应用到其上面。  

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym
import matplotlib.pyplot as plt
import cartpole_swingup_envs  # 为什么要import这个呢?浏览一下上文提到的cartpole-swingup-envs,即倒立摆的github说明,即可看到怎样使用这个 环境,首先是就是需要导入这个包

BATCH_SIZE = 32            
LR = 0.01                  
EPSILON = 0.9               
GAMMA = 0.9                
TARGET_REPLACE_ITER = 100  
MEMORY_CAPACITY = 2000     
EPISODE=2000                

#env = gym.make('CartPoleSwingUpDiscrete-v0')
env = gym.make('CartPoleSwingUpDiscrete-v0')
#env = env.unwrapped
#env = gym.wrappers.Monitor(env, './video/',video_callable=lambda episode_id: True,force = True)

N_ACTIONS = env.action_space.n
N_STATES = env.observation_space.shape[0]

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device=torch.device("cpu")
torch.FloatTensor=torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor #如果有GPU和cuda,数据将转移到GPU执行
torch.LongTensor=torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor

class Net(nn.Module):
    def __init__(self, ):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(N_STATES, 50)
        self.fc1.weight.data.normal_(0, 0.1)   # initialization
        self.out = nn.Linear(50, N_ACTIONS)
        self.out.weight.data.normal_(0, 0.1)   # initialization

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        actions_value = self.out(x)
        return actions_value

class DQN:
    def __init__(self):
        self.net,self.target_net=Net().to(device),Net().to(device)

        self.learn_step_counter=0
        self.memory_counter=0
        self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))     # initialize memory
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=LR)
        self.loss_func = nn.MSELoss()


    def choose_action(self, x):
        x = torch.unsqueeze(torch.FloatTensor(x), 0)
        # input only one sample
        if np.random.uniform() < EPSILON:  
            actions_value = self.net.forward(x)
            action = torch.max(actions_value, 1)[1].data.cpu().numpy()
            action = action[0]
        else:   
            action = np.random.randint(0, N_ACTIONS)

        return action

    def store_transition(self, s, a, r, s_):
        transition = np.hstack((s, [a, r], s_))
        # replace the old memory with new memory
        index = self.memory_counter % MEMORY_CAPACITY
        self.memory[index, :] = transition
        self.memory_counter += 1

    def learn(self):
        # target parameter update
        if self.learn_step_counter % TARGET_REPLACE_ITER == 0:
            self.target_net.load_state_dict(self.net.state_dict())
        self.learn_step_counter += 1

        # sample batch transitions
        sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
        batch_memory = self.memory[sample_index, :]
        batch_s = torch.FloatTensor(batch_memory[:, :N_STATES])
        batch_a = torch.LongTensor(batch_memory[:, N_STATES:N_STATES+1].astype(int))
        batch_r = torch.FloatTensor(batch_memory[:, N_STATES+1:N_STATES+2])
        batch_s_ = torch.FloatTensor(batch_memory[:, -N_STATES:])


        q = self.net(batch_s).gather(1, batch_a)  # shape (batch, 1)
        q_target = self.target_net(batch_s_).detach()     # detach from graph, don't backpropagate
        y = batch_r + GAMMA * q_target.max(1)[0].view(BATCH_SIZE, 1)   # shape (batch, 1)
        loss = self.loss_func(q, y)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

dqn = DQN()

plot_x_data,plot_y_data=[],[]
for i_episode in range(10000):
    s = env.reset()
    episode_reward = 0
    while True:
        env.render()
        a = dqn.choose_action(s)

        # take action
        s_, r, done, info = env.step(a)

        x, x_dot, theta_cos, theta_sin, theta_dot=s_   #这个地方需要做些修改,和cartpole的不相同。修改的依据来自CartPoleSwingUpDiscrete-v0环境step()函数的return,可以看到这个环境,返回值的obs包含了5个元素,而cartpole是4个元素,阅读源码即可找到。
# =============================================================================
#         r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8
#         r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
#         r = r1 + r2
# =============================================================================

        dqn.store_transition(s, a, r, s_)

        episode_reward += r
        if dqn.memory_counter > MEMORY_CAPACITY:
            dqn.learn()
            if done:
                print('Episode: ', i_episode,
                      '| Episode_reward: ', round(episode_reward, 2))

        if done:
            break
        s = s_
    plot_x_data.append(i_episode)
    plot_y_data.append(episode_reward)
    plt.plot(plot_x_data,plot_y_data)

  经过1万次训练,我们可以看到结果
深度强化学习专栏 —— 3.实现一阶倒立摆   什么东西???怎么回事???? 那好吧,接受现实,用stable baselines的DQN试下,stable baselines总可以作为一个基准吧。  

import gym
import numpy as np
import cartpole_swingup_envs

from stable_baselines3 import DQN
from stable_baselines3.dqn import MlpPolicy

env = gym.make('CartPoleSwingUpDiscrete-v0')

model = DQN(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=1000000, log_interval=4)
model.save("dqn_pendulum")

obs = env.reset()
while True:
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

 
深度强化学习专栏 —— 3.实现一阶倒立摆 什么情况,DQN难道不适合解决这个问题??? 那PPO试一下好了。  

import gym
import numpy as np
import cartpole_swingup_envs

from stable_baselines3 import DQN,PPO
from stable_baselines3.dqn import MlpPolicy

env = gym.make('CartPoleSwingUpDiscrete-v0')

#model = DQN(MlpPolicy, env, verbose=1)
model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=1000000, log_interval=4)
model.save("dqn_pendulum")

obs = env.reset()
while True:
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

 
深度强化学习专栏 —— 3.实现一阶倒立摆 艾瑶瑶,PPO牛逼!!  


总结:  

  • 至此,我们经过重重险阻,终于让倒立摆立了起来,但是道路坎坷,自己手撕的DQN基本没用,stable baselines的DQN在解决cartpole问题还可以,但是倒立摆问题也扑街,可能与参数设置有关。不管是cartpole还是倒立摆,PPO都表现出了最好的性能,看起来OpenAI对PPO的情有独钟是有道理的,而且PPO可以改写为off-policy模式,提高了采样效率。
  • 虽然我们手撕的DQN在每个任务中的性能都是最弱的,但是并不代表我们手撕了就没用,它可以让我们加深对于算法的理解,真正体验从数学算法到写成程序的过程。

 


下一篇我们就要实现FlappyBird了。
深度强化学习专栏 —— 3.实现一阶倒立摆


注:第二张图片来自Part 1 – 自定义gym环境


开心洋葱 , 版权所有丨如未注明 , 均为原创丨未经授权请勿修改 , 转载请注明深度强化学习专栏 —— 3.实现一阶倒立摆
喜欢 (0)

您必须 登录 才能发表评论!

加载中……