admin管理员组

文章数量:1516870

这篇博客是 mujoco 官方教程文档中的第 5 篇 《The MJX tutorial provides usage examples of MuJoCo XLA, a branch of MuJoCo written in JAX》 ,此处我跳了一下最小二乘法那篇,因为那篇有些过于偏实际应用了,计划在下下一篇博客中介绍那篇文章。这篇博客刚好赶上该系列的番外一,在部署完成 JAX 之后趁热打铁介绍下使用 MJX 和纯 CPU 直接的差异。

这篇博客涉及到了一部分与强化学习有关的内容,如果实在读不懂可以先忽略细节部分,重点在于如何使用 MJX 这个库。我会在官方教程结束后开始写用 mujoco 进行强化学习的博客,届时会着重介绍 brax 、jax 这两个库以及 mujoco 提供的强化学习语法糖。

【Note】:在进入这一章节的学习前一定要确保你的环境中正确安装了 jax jaxlib 这两个库。在运行下面的命令没有弹出报错。

import jax
print(jax.devices())[cuda(id=0)]

如果执行除了上面输出的内容外还有类似下面的信息,则说明版本不匹配,jax 无法调用 GPU 资源,可以查看上面的番外博客重新安装对应的包:

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

这篇博客设计到的官方资源链接如下:

  • 官方 Github 仓库:
  • 官方 Colab 链接:

官方和我自己的博客代码放在下面的链接中,所有以 [offical] 开头的文件都是官方笔记,所有以 [note] 开头的文件都是和博客对应的笔记:

链接:  提取码: 83a4

1. 导入必要的包

import distutils.util
import os
import subprocess
import mujoco
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List
import mediapy as media
import matplotlib.pyplot as plt
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict
import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp
from mujoco import mjx
xla_flags = os.environ.get('XLA_FLAGS','')
xla_flags +=' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS']= xla_flags
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

2. 初次使用与验证

2.1 构造模型并准备CPU和GPU对象

构造一个简单的模型:

xml ="""
<mujoco>
  <worldbody>
    <light name="top" pos="0 0 1"/>
    <body name="box_and_sphere" euler="0 0 -30">
      <joint name="swing" type="hinge" axis="1 -1 0" pos="-.2 -.2 -.2"/>
      <geom name="red_box" type="box" size=".2 .2 .2" rgba="1 0 0 1"/>
      <geom name="green_sphere" pos=".2 .2 .2" size=".1" rgba="0 1 0 1"/>
    </body>
  </worldbody>
</mujoco>
"""

准备一个在 CPU 环境下的 model 和 data:

mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)

再准备一个 JAX 下的 model 和 data,其实就是将数据搬到 GPU 上:

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

打印两组 qpos mjData qpos 是一个驻留在 CPU 上的 numpy 数组,而 mjx.Data qpos 是一个驻留在 GPU 设备上的 JAX 数组。

print(mj_data.qpos,type(mj_data.qpos))print(mjx_data.qpos,type(mjx_data.qpos))

【Note】:第二行的输出一定要看到 jaxlib.xla_extension.ArrayImpl 才行。

2.2 运行 CPU 示例

scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT]=True
duration =3.8
framerate =60
frames =[]
mujoco.mj_resetData(mj_model, mj_data)
start_time = time.time()while mj_data.time < duration:
  mujoco.mj_step(mj_model, mj_data)iflen(frames)< mj_data.time * framerate:
    renderer.update_scene(mj_data, scene_option=scene_option)
    pixels = renderer.render()
    frames.append(pixels)print(f"Total cost {time.time()-start_time:.2f} seconds")
media.show_video(frames, fps=framerate)

2.3 运行 GPU 示例

现在使用 MJX 在 GPU 设备上运行完全相同的模拟,用 mjx.step 代替 mujoco.mj_step ,并对 mjx.step 进行了 jax.jit 处理以便它在 GPU 上高效运行。对于每一帧将 mjx.Data 转换回 mjData 以便使用 MuJoCo 渲染器。

jit_step = jax.jit(mjx.step)
frames =[]
mujoco.mj_resetData(mj_model, mj_data)
mjx_data = mjx.put_data(mj_model, mj_data)
start_time = time.time()while mjx_data.time < duration:
    mjx_data = jit_step(mjx_model, mjx_data)iflen(frames)< mjx_data.time * framerate:
        mj_data = mjx.get_data(mj_model, mjx_data)
        renderer.update_scene(mj_data, scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)print(f"Total cost {time.time()-start_time:.2f} seconds")
        
media.show_video(frames, fps=framerate)

此时你会发现在 GPU 上运行单线程物理模拟效率不高。MJX 的优势在于 可以在硬件加速设备上并行运行环境

在下面的示例中创建了 4096 个 mjx.Data 副本,并对批处理数据运行 mjx.step 。由于 MJX 是用 JAX 实现的,利用 jax.vmap 在所有 mjx.Data 上并行运行 mjx.step

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng,4096)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng,(1,))))(rng)
jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None,0)))
batch = jit_step(mjx_model, batch)print(batch.qpos)

batched_mj_data = mjx.get_data(mj_model, batch)print([d.qpos for d in batched_mj_data])


2. 使用 MJX 训练强化学习

运行大批量物理模拟对于训练强化学习策略非常有用。本文演示了如何使用 Brax 的强化学习库,通过 MJX 训练强化学习策略。

使用 MJX 和 Brax 实现了经典的 Humanoid 环境,继承 Brax 中的 MjxEnv 实现,以便在使用 Brax 强化学习实现进行训练的同时,使用 MJX 逐步实现物理模拟。

2.1 定义对象并引入资源

这里因为需要用到 mujoco 仓库中的资源,因此仍然需要在当前目录下从 Github 上拉取了 mujoco 官方仓库:

(mujoco) $ git clone git@github.com:google-deepmind/mujoco.git

然后定义 Humanoid 对象:

HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco'))/'mjx/test_data/humanoid'classHumanoid(PipelineEnv):def__init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(1.0,2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,**kwargs,):#
    mj_model = mujoco.MjModel.from_xml_path((HUMANOID_ROOT_PATH /'humanoid.xml').as_posix())
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations =6
    mj_model.opt.ls_iterations =6
    sys = mjcf.load_model(mj_model)
    physics_steps_per_control_step =5
    kwargs['n_frames']= kwargs.get('n_frames', physics_steps_per_control_step)
    kwargs['backend']='mjx'super().__init__(sys,**kwargs)
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation =(
        exclude_current_positions_from_observation
    )defreset(self, rng: jp.ndarray)-> State:"""Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng,3)
    low, hi =-self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1,(self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2,(self.sys.nv,), minval=low, maxval=hi
    )
    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics ={'forward_reward': zero,'reward_linvel': zero,'reward_quadctrl': zero,'reward_alive': zero,'x_position': zero,'y_position': zero,'distance_from_origin': zero,'x_velocity': zero,'y_velocity': zero,}return State(data, obs, reward, done, metrics)defstep(self, state: State, action: jp.ndarray)-> State:"""Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    velocity =(com_after - com_before)/ self.dt
    forward_reward = self._forward_reward_weight * velocity[0]
    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2]< min_z,0.0,1.0)
    is_healthy = jp.where(data.q[2]> max_z,0.0, is_healthy)if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy
    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
    obs = self._get_obs(data, action)
    reward = forward_reward + healthy_reward - ctrl_cost
    done =1.0- is_healthy if self._terminate_when_unhealthy else0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],)return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )def_get_obs(
      self, data: mjx.Data, action: jp.ndarray
  )-> jp.ndarray:"""Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,])
envs.register_environment('humanoid', Humanoid)

在上面的类成员函数 step 中定义了 is_healthy 标志位,如果在仿真过程中发现质心 z。轴高度 data.q[2] 没有在健康范围内,则会在外面触发结束仿真动作:

2.2 预览对象

env_name ='humanoid'
env = envs.get_environment(env_name)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state = jit_reset(jax.random.PRNGKey(0))
rollout =[state.pipeline_state]for i inrange(100): 
  ctrl =-0.1* jp.ones(env.sys.nu)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)
  
media.show_video(env.render(rollout, camera='side'), fps=1.0/ env.dt)

2.3 训练一个 PPO 策略

下面的代码官方在 A100 显卡上训练了 6 分钟,在我的单张 3060 Super 上训练了一下也是 6 分钟,你可以根据自己的算力情况对 train_fn 中的参数进行适当修改:

train_fn = functools.partial(
    ppo.train, 
    num_timesteps=20_000_000, 
    num_evals=5, reward_scaling=0.1,
    episode_length=1000, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
    batch_size=512, seed=0)
x_data =[]
y_data =[]
ydataerr =[]
times =[datetime.now()]
max_y, min_y =13000,0defprogress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])
  plt.xlim([0, train_fn.keywords['num_timesteps']*1.25])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'y={y_data[-1]:.3f}')
  plt.errorbar(
      x_data, y_data, yerr=ydataerr)
  plt.show()
  
start_time = time.time()
make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)print(f'Total cost time {time.time()-start_time:.2f} seconds')print(f'time to jit: {times[1]- times[0]}')print(f'time to train: {times[-1]- times[1]}')

使用 brax API 保存和加载训练好的策略:

train_fn = functools.partial(
    ppo.train, 
    num_timesteps=20_000_000, 
    num_evals=5, reward_scaling=0.1,
    episode_length=1000, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
    batch_size=512, seed=0)
x_data =[]
y_data =[]
ydataerr =[]
times =[datetime.now()]
max_y, min_y =13000,0defprogress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])
  plt.xlim([0, train_fn.keywords['num_timesteps']*1.25])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'y={y_data[-1]:.3f}')
  plt.errorbar(
      x_data, y_data, yerr=ydataerr)
  plt.show()
  
start_time = time.time()
make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)print(f'Total cost time {time.time()-start_time:.2f} seconds')print(f'time to jit: {times[1]- times[0]}')print(f'time to train: {times[-1]- times[1]}')

2.4 保存模型

model_path ='./mjx_brax_policy'
model.save_params(model_path, params)

2.5 加载模型并定义推理函数

params = model.load_params(model_path)
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

2.6 使用策略录制视频

可以很明显地发现使用 jax 做推理的流程与之前直接使用 mujoco 的推理是存在较大的差异。

rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
rollout =[state.pipeline_state]
n_step =500
render_every =2
start_time = time.time()for i inrange(n_step):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)if state.done:breakprint(f"Average evaluate step cost time: {(time.time()- start_time)/n_step:.5f}")
media.show_video(env.render(rollout[::render_every], camera='side'), fps=1.0/env.dt / render_every)

平均每步的推理耗时约 0.007s。


3. 在 mujoco 中使用 MJX 策略

还可以使用原生 mujoco 绑定执行物理步骤,以证明在 MJX 中训练的策略在 mujoco 中确实有效。

mj_model = eval_env.sys.mj_model
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)
ctrl = jp.zeros(mj_model.nu)
images =[]
start_time = time.time()for i inrange(n_step):
    act_rng, rng = jax.random.split(rng)
    
    obs = eval_env._get_obs(mjx.put_data(mj_model, mj_data), ctrl)
    ctrl, _ = jit_inference_fn(obs, act_rng)
    
    mj_data.ctrl = ctrl
    for _ inrange(eval_env._n_frames):
        mujoco.mj_step(mj_model, mj_data)if i % render_every ==0:
        renderer.update_scene(mj_data, camera='side')
        images.append(renderer.render())print(f"Average evaluate step cost time: {(time.time()- start_time)/n_step:.5f}")
media.show_video(images, fps=1.0/ eval_env.dt / render_every)


4. 使用域随机化训练策略

有时可能还希望在训练策略时对某些 mjModel 参数进行随机化。在 MJX 中可以轻松地创建一批环境,并在 mjx.Model 中填充随机值。下面的代码展示了一个随机化摩擦力和执行器增益/偏差的函数。

4.1 拉取模型资源仓库

这部分代码需要 mujoco_menagerie 中的一些资源,从github仓库中拉取并放在同级目录下:

(mujoco) $ git clone git@github.com:google-deepmind/mujoco_menagerie.git

4.2 定义域随机化函数

定义域随机化函数

defdomain_randomize(sys, rng):"""Randomizes the mjx.Model."""@jax.vmapdefrand(rng):
    _, key = jax.random.split(rng,2)# friction
    friction = jax.random.uniform(key,(1,), minval=0.6, maxval=1.4)
    friction = sys.geom_friction.at[:,0].set(friction)# actuator
    _, key = jax.random.split(key,2)
    gain_range =(-5,5)
    param = jax.random.uniform(
        key,(1,), minval=gain_range[0], maxval=gain_range[1])+ sys.actuator_gainprm[:,0]
    gain = sys.actuator_gainprm.at[:,0].set(param)
    bias = sys.actuator_biasprm.at[:,1].set(-param)return friction, gain, bias
  friction, gain, bias = rand(rng)
  in_axes = jax.tree_util.tree_map(lambda x:None, sys)
  in_axes = in_axes.tree_replace({'geom_friction':0,'actuator_gainprm':0,'actuator_biasprm':0,})
  sys = sys.tree_replace({'geom_friction': friction,'actuator_gainprm': gain,'actuator_biasprm': bias,})return sys, in_axes

如果想要 10 个具有随机摩擦和执行器参数的环境,可以调用 domain_randomize ,它会返回一个批处理的 mjx.Model 以及一个指定批处理的轴的字典。

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng,10)
batched_sys, _ = domain_randomize(env.sys, rng)print('Single env friction shape: ', env.sys.geom_friction.shape)print('Batched env friction shape: ', batched_sys.geom_friction.shape)print('Friction on geom 0: ', env.sys.geom_friction[0,0])print('Random frictions on geom 0: ', batched_sys.geom_friction[:,0,0])

4.3 定义辅助函数和类

这部分如果你看不懂在初学阶段不强求,主要是学习整个 mjx 的训练流程。

BARKOUR_ROOT_PATH = epath.Path('mujoco_menagerie/google_barkour_vb')defget_config():"""Returns reward config for barkour quadruped environment."""defget_default_rewards_config():
    default_config = config_dict.ConfigDict(dict(# The coefficients for all reward terms used for training. All# physical quantities are in SI units, if no otherwise specified,# i.e. joint positions are in rad, positions are measured in meters,# torques in Nm, and time in seconds, and forces in Newtons.
            scales=config_dict.ConfigDict(dict(# Tracking rewards are computed using exp(-delta^2/sigma)# sigma can be a hyperparameters to tune.# Track the base x-y velocity (no z-velocity tracking.)
                    tracking_lin_vel=1.5,# Track the angular velocity along z-axis, i.e. yaw rate.
                    tracking_ang_vel=0.8,# Below are regularization terms, we roughly divide the# terms to base state regularizations, joint# regularizations, and other behavior regularizations.# Penalize the base velocity in z direction, L2 penalty.
                    lin_vel_z=-2.0,# Penalize the base roll and pitch rate. L2 penalty.
                    ang_vel_xy=-0.05,# Penalize non-zero roll and pitch angles. L2 penalty.
                    orientation=-5.0,# L2 regularization of joint torques, |tau|^2.
                    torques=-0.0002,# Penalize the change in the action and encourage smooth# actions. L2 regularization |action - last_action|^2
                    action_rate=-0.01,# Encourage long swing steps.  However, it does not# encourage high clearances.
                    feet_air_time=0.2,# Encourage no motion at zero command, L2 regularization# |q - q_default|^2.
                    stand_still=-0.5,# Early termination penalty.
                    termination=-1.0,# Penalizing foot slipping on the ground.
                    foot_slip=-0.1,)),# Tracking reward = exp(-error^2/sigma).
            tracking_sigma=0.25,))return default_config
  default_config = config_dict.ConfigDict(dict(
          rewards=get_default_rewards_config(),))return default_config
classBarkourEnv(PipelineEnv):"""Environment for training the barkour quadruped joystick policy in MJX."""def__init__(
      self,
      obs_noise:float=0.05,
      action_scale:float=0.3,
      kick_vel:float=0.05,
      scene_file:str='scene_mjx.xml',**kwargs,):
    path = BARKOUR_ROOT_PATH / scene_file
    sys = mjcf.load(path.as_posix())
    self._dt =0.02# this environment is 50 fps
    sys = sys.tree_replace({'opt.timestep':0.004})# override menagerie params for smoother policy
    sys = sys.replace(
        dof_damping=sys.dof_damping.at[6:].set(0.5239),
        actuator_gainprm=sys.actuator_gainprm.at[:,0].set(35.0),
        actuator_biasprm=sys.actuator_biasprm.at[:,1].set(-35.0),)
    n_frames = kwargs.pop('n_frames',int(self._dt / sys.opt.timestep))super().__init__(sys, backend='mjx', n_frames=n_frames)
    self.reward_config = get_config()# set custom from kwargsfor k, v in kwargs.items():if k.endswith('_scale'):
        self.reward_config.rewards.scales[k[:-6]]= v
    self._torso_idx = mujoco.mj_name2id(
        sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value,'torso')
    self._action_scale = action_scale
    self._obs_noise = obs_noise
    self._kick_vel = kick_vel
    self._init_q = jp.array(sys.mj_model.keyframe('home').qpos)
    self._default_pose = sys.mj_model.keyframe('home').qpos[7:]
    self.lowers = jp.array([-0.7,-1.0,0.05]*4)
    self.uppers = jp.array([0.52,2.1,2.1]*4)
    feet_site =['foot_front_left','foot_hind_left','foot_front_right','foot_hind_right',]
    feet_site_id =[
        mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, f)for f in feet_site
    ]assertnotany(id_ ==-1for id_ in feet_site_id),'Site not found.'
    self._feet_site_id = np.array(feet_site_id)
    lower_leg_body =['lower_leg_front_left','lower_leg_hind_left','lower_leg_front_right','lower_leg_hind_right',]
    lower_leg_body_id =[
        mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, l)for l in lower_leg_body
    ]assertnotany(id_ ==-1for id_ in lower_leg_body_id),'Body not found.'
    self._lower_leg_body_id = np.array(lower_leg_body_id)
    self._foot_radius =0.0175
    self._nv = sys.nv
  defsample_command(self, rng: jax.Array)-> jax.Array:
    lin_vel_x =[-0.6,1.5]# min max [m/s]
    lin_vel_y =[-0.8,0.8]# min max [m/s]
    ang_vel_yaw =[-0.7,0.7]# min max [rad/s]
    _, key1, key2, key3 = jax.random.split(rng,4)
    lin_vel_x = jax.random.uniform(
        key1,(1,), minval=lin_vel_x[0], maxval=lin_vel_x[1])
    lin_vel_y = jax.random.uniform(
        key2,(1,), minval=lin_vel_y[0], maxval=lin_vel_y[1])
    ang_vel_yaw = jax.random.uniform(
        key3,(1,), minval=ang_vel_yaw[0], maxval=ang_vel_yaw[1])
    new_cmd = jp.array([lin_vel_x[0], lin_vel_y[0], ang_vel_yaw[0]])return new_cmd
  defreset(self, rng: jax.Array)-> State:# pytype: disable=signature-mismatch
    rng, key = jax.random.split(rng)
    pipeline_state = self.pipeline_init(self._init_q, jp.zeros(self._nv))
    state_info ={'rng': rng,'last_act': jp.zeros(12),'last_vel': jp.zeros(12),'command': self.sample_command(key),'last_contact': jp.zeros(4, dtype=bool),'feet_air_time': jp.zeros(4),'rewards':{k:0.0for k in self.reward_config.rewards.scales.keys()},'kick': jp.array([0.0,0.0]),'step':0,}
    obs_history = jp.zeros(15*31)# store 15 steps of history
    obs = self._get_obs(pipeline_state, state_info, obs_history)
    reward, done = jp.zeros(2)
    metrics ={'total_dist':0.0}for k in state_info['rewards']:
      metrics[k]= state_info['rewards'][k]
    state = State(pipeline_state, obs, reward, done, metrics, state_info)# pytype: disable=wrong-arg-typesreturn state
  defstep(self, state: State, action: jax.Array)-> State:# pytype: disable=signature-mismatch
    rng, cmd_rng, kick_noise_2 = jax.random.split(state.info['rng'],3)# kick
    push_interval =10
    kick_theta = jax.random.uniform(kick_noise_2, maxval=2* jp.pi)
    kick = jp.array([jp.cos(kick_theta), jp.sin(kick_theta)])
    kick *= jp.mod(state.info['step'], push_interval)==0
    qvel = state.pipeline_state.qvel  # pytype: disable=attribute-error
    qvel = qvel.at[:2].set(kick * self._kick_vel + qvel[:2])
    state = state.tree_replace({'pipeline_state.qvel': qvel})# physics step
    motor_targets = self._default_pose + action * self._action_scale
    motor_targets = jp.clip(motor_targets, self.lowers, self.uppers)
    pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
    x, xd = pipeline_state.x, pipeline_state.xd
    # observation data
    obs = self._get_obs(pipeline_state, state.info, state.obs)
    joint_angles = pipeline_state.q[7:]
    joint_vel = pipeline_state.qd[6:]# foot contact data based on z-position
    foot_pos = pipeline_state.site_xpos[self._feet_site_id]# pytype: disable=attribute-error
    foot_contact_z = foot_pos[:,2]- self._foot_radius
    contact = foot_contact_z <1e-3# a mm or less off the floor
    contact_filt_mm = contact | state.info['last_contact']
    contact_filt_cm =(foot_contact_z <3e-2)| state.info['last_contact']
    first_contact =(state.info['feet_air_time']>0)* contact_filt_mm
    state.info['feet_air_time']+= self.dt
    # done if joint limits are reached or robot is falling
    up = jp.array([0.0,0.0,1.0])
    done = jp.dot(math.rotate(up, x.rot[self._torso_idx -1]), up)<0
    done |= jp.any(joint_angles < self.lowers)
    done |= jp.any(joint_angles > self.uppers)
    done |= pipeline_state.x.pos[self._torso_idx -1,2]<0.18# reward
    rewards ={'tracking_lin_vel':(
            self._reward_tracking_lin_vel(state.info['command'], x, xd)),'tracking_ang_vel':(
            self._reward_tracking_ang_vel(state.info['command'], x, xd)),'lin_vel_z': self._reward_lin_vel_z(xd),'ang_vel_xy': self._reward_ang_vel_xy(xd),'orientation': self._reward_orientation(x),'torques': self._reward_torques(pipeline_state.qfrc_actuator),# pytype: disable=attribute-error'action_rate': self._reward_action_rate(action, state.info['last_act']),'stand_still': self._reward_stand_still(
            state.info['command'], joint_angles,),'feet_air_time': self._reward_feet_air_time(
            state.info['feet_air_time'],
            first_contact,
            state.info['command'],),'foot_slip': self._reward_foot_slip(pipeline_state, contact_filt_cm),'termination': self._reward_termination(done, state.info['step']),}
    rewards ={
        k: v * self.reward_config.rewards.scales[k]for k, v in rewards.items()}
    reward = jp.clip(sum(rewards.values())* self.dt,0.0,10000.0)# state management
    state.info['kick']= kick
    state.info['last_act']= action
    state.info['last_vel']= joint_vel
    state.info['feet_air_time']*=~contact_filt_mm
    state.info['last_contact']= contact
    state.info['rewards']= rewards
    state.info['step']+=1
    state.info['rng']= rng
    # sample new command if more than 500 timesteps achieved
    state.info['command']= jp.where(
        state.info['step']>500,
        self.sample_command(cmd_rng),
        state.info['command'],)# reset the step counter when done
    state.info['step']= jp.where(
        done |(state.info['step']>500),0, state.info['step'])# log total displacement as a proxy metric
    state.metrics['total_dist']= math.normalize(x.pos[self._torso_idx -1])[1]
    state.metrics.update(state.info['rewards'])
    done = jp.float32(done)
    state = state.replace(
        pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
    )return state
  def_get_obs(
      self,
      pipeline_state: base.State,
      state_info:dict[str, Any],
      obs_history: jax.Array,)-> jax.Array:
    inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0])
    local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot)
    obs = jp.concatenate([
        jp.array([local_rpyrate[2]])*0.25,# yaw rate
        math.rotate(jp.array([0,0,-1]), inv_torso_rot),# projected gravity
        state_info['command']* jp.array([2.0,2.0,0.25]),# command
        pipeline_state.q[7:]- self._default_pose,# motor angles
        state_info['last_act'],# last action])# clip, noise
    obs = jp.clip(obs,-100.0,100.0)+ self._obs_noise * jax.random.uniform(
        state_info['rng'], obs.shape, minval=-1, maxval=1)# stack observations through time
    obs = jp.roll(obs_history, obs.size).at[:obs.size].set(obs)return obs
  # ------------ reward functions----------------def_reward_lin_vel_z(self, xd: Motion)-> jax.Array:# Penalize z axis base linear velocityreturn jp.square(xd.vel[0,2])def_reward_ang_vel_xy(self, xd: Motion)-> jax.Array:# Penalize xy axes base angular velocityreturn jp.sum(jp.square(xd.ang[0,:2]))def_reward_orientation(self, x: Transform)-> jax.Array:# Penalize non flat base orientation
    up = jp.array([0.0,0.0,1.0])
    rot_up = math.rotate(up, x.rot[0])return jp.sum(jp.square(rot_up[:2]))def_reward_torques(self, torques: jax.Array)-> jax.Array:# Penalize torquesreturn jp.sqrt(jp.sum(jp.square(torques)))+ jp.sum(jp.abs(torques))def_reward_action_rate(
      self, act: jax.Array, last_act: jax.Array
  )-> jax.Array:# Penalize changes in actionsreturn jp.sum(jp.square(act - last_act))def_reward_tracking_lin_vel(
      self, commands: jax.Array, x: Transform, xd: Motion
  )-> jax.Array:# Tracking of linear velocity commands (xy axes)
    local_vel = math.rotate(xd.vel[0], math.quat_inv(x.rot[0]))
    lin_vel_error = jp.sum(jp.square(commands[:2]- local_vel[:2]))
    lin_vel_reward = jp.exp(-lin_vel_error / self.reward_config.rewards.tracking_sigma
    )return lin_vel_reward
  def_reward_tracking_ang_vel(
      self, commands: jax.Array, x: Transform, xd: Motion
  )-> jax.Array:# Tracking of angular velocity commands (yaw)
    base_ang_vel = math.rotate(xd.ang[0], math.quat_inv(x.rot[0]))
    ang_vel_error = jp.square(commands[2]- base_ang_vel[2])return jp.exp(-ang_vel_error / self.reward_config.rewards.tracking_sigma)def_reward_feet_air_time(
      self, air_time: jax.Array, first_contact: jax.Array, commands: jax.Array
  )-> jax.Array:# Reward air time.
    rew_air_time = jp.sum((air_time -0.1)* first_contact)
    rew_air_time *=(
        math.normalize(commands[:2])[1]>0.05)# no reward for zero commandreturn rew_air_time
  def_reward_stand_still(
      self,
      commands: jax.Array,
      joint_angles: jax.Array,)-> jax.Array:# Penalize motion at zero commandsreturn jp.sum(jp.abs(joint_angles - self._default_pose))*(
        math.normalize(commands[:2])[1]<0.1)def_reward_foot_slip(
      self, pipeline_state: base.State, contact_filt: jax.Array
  )-> jax.Array:# get velocities at feet which are offset from lower legs# pytype: disable=attribute-error
    pos = pipeline_state.site_xpos[self._feet_site_id]# feet position
    feet_offset = pos - pipeline_state.xpos[self._lower_leg_body_id]# pytype: enable=attribute-error
    offset = base.Transform.create(pos=feet_offset)
    foot_indices = self._lower_leg_body_id -1# we got rid of the world body
    foot_vel = offset.vmap().do(pipeline_state.xd.take(foot_indices)).vel
    # Penalize large feet velocity for feet that are in contact with the ground.return jp.sum(jp.square(foot_vel[:,:2])* contact_filt.reshape((-1,1)))def_reward_termination(self, done: jax.Array, step: jax.Array)-> jax.Array:return done &(step <500)defrender(
      self, trajectory: List[base.State], camera:str|None=None,
      width:int=240, height:int=320,)-> Sequence[np.ndarray]:
    camera = camera or'track'returnsuper().render(trajectory, camera=camera, width=width, height=height)
envs.register_environment('barkour', BarkourEnv)

初始化环境

env_name ='barkour'
env = envs.get_environment(env_name)

为了训练具有域随机化的策略,将域随机化函数传入 brax 训练函数; brax 会在推动仿真时调用域随机化函数。

ckpt_path = epath.Path('/tmp/quadrupred_joystick/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)defpolicy_params_fn(current_step, make_policy, params):
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path /f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)

定义训练函数

make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
        policy_hidden_layer_sizes=(128,128,128,128))
train_fn = functools.partial(
      ppo.train, num_timesteps=100_000_000, num_evals=10,
      reward_scaling=1, episode_length=1000, normalize_observations=True,
      action_repeat=1, unroll_length=20, num_minibatches=32,
      num_updates_per_batch=4, discounting=0.97, learning_rate=3.0e-4,
      entropy_cost=1e-2, num_envs=8192, batch_size=256,
      network_factory=make_networks_factory,
      randomization_fn=domain_randomize,
      policy_params_fn=policy_params_fn,
      seed=0)

4.4 执行训练

官方文档在 Tesla A100 GPU 上训练四足动物需要 6 分钟,我个人的 3060 Super 单卡上需要 14 分钟。

x_data =[]
y_data =[]
ydataerr =[]
times =[datetime.now()]
max_y, min_y =40,0
env = envs.get_environment(env_name)
eval_env = envs.get_environment(env_name)
start_time = time.time()
make_inference_fn, params, _= train_fn(environment=env,
                                       progress_fn=progress,
                                       eval_env=eval_env)print(f'Total cost time {time.time()- start_time:.2f}')print(f'time to jit: {times[1]- times[0]}')print(f'time to train: {times[-1]- times[1]}')

4.5 保存与加载模型

保存模型

model_path ='./mjx_brax_quadruped_policy'
model.save_params(model_path, params)
params = model.load_params(model_path)

加载模型

inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

4.6 可视化结果

对于 Barkour Quadruped 可以通过 x_vel y_vel ang_vel 设置操纵杆命令。 x_vel y_vel 定义相对于四足躯干的线性前向和侧向速度。 ang_vel 定义躯干在 z 方向的角速度。

eval_env = envs.get_environment(env_name)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

尝试修改下面这三个变量体验不同的推理结果

x_vel =1.0 
y_vel =0.0 
ang_vel =-0.5

根据预期输入执行推理

the_command = jp.array([x_vel, y_vel, ang_vel])# 初始化状态
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
state.info['command']= the_command
rollout =[state.pipeline_state]
n_steps =500
render_every =2
start_time = time.time()for i inrange(n_steps):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)print(f"Total cost {time.time()- start_time:.2f} seconds")
media.show_video(
    eval_env.render(rollout[::render_every], camera='track'),
    fps=1.0/ eval_env.dt / render_every)


5. 在高度场上微调模型

有时还想让四足动物学会在崎岖的地形上行走,从上面的操纵杆策略中取出最新的检查点,并在高度场地形上对其进行微调。

5.1 生成不规则高度场

scene_file ='scene_hfield_mjx.xml'
env = envs.get_environment(env_name, scene_file=scene_file)
jit_reset = jax.jit(env.reset)
state = jit_reset(jax.random.PRNGKey(0))
plt.imshow(env.render([state.pipeline_state], camera='track')[0])

5.2 定义训练函数

latest_ckpts =list(ckpt_path.glob('*'))
latest_ckpts.sort(key=lambda x:int(x.as_posix().split('/')[-1]))
latest_ckpt = latest_ckpts[-1]
train_fn = functools.partial(
      ppo.train, num_timesteps=40_000_000, num_evals=5,
      reward_scaling=1, episode_length=1000, normalize_observations=True,
      action_repeat=1, unroll_length=20, num_minibatches=32,
      num_updates_per_batch=4, discounting=0.97, learning_rate=3.0e-4,
      entropy_cost=1e-2, num_envs=8192, batch_size=256,
      network_factory=make_networks_factory,
      randomization_fn=domain_randomize, seed=0,
      restore_checkpoint_path=latest_ckpt)

5.3 对模型进行微调

在我的 3060 Super 单卡上花费了约 9 分钟。

x_data =[]
y_data =[]
ydataerr =[]
times =[datetime.now()]
max_y, min_y =40,0
env = envs.get_environment(env_name, scene_file=scene_file)
eval_env = envs.get_environment(env_name, scene_file=scene_file)
start_time = time.time()
make_inference_fn, params, _= train_fn(environment=env,
                                       progress_fn=progress,
                                       eval_env=eval_env)print(f'Total cost {time.time()- start_time:.2f} seconds')print(f'time to jit: {times[1]- times[0]}')print(f'time to train: {times[-1]- times[1]}')

5.4 可视化训练结果

eval_env = envs.get_environment(env_name, scene_file=scene_file)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

和上面的一样,可以修改下面这三个变量然后查看模型的结果

x_vel =1.0
y_vel =0.0
ang_vel =-0.5

执行推理:

the_command = jp.array([x_vel, y_vel, ang_vel])# 初始化状态
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
state.info['command']= the_command
rollout =[state.pipeline_state]
n_steps =500
render_every =2
start_time = time.time()for i inrange(n_steps):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)print(f'Total cost {time.time()-start_time:.2f} seconds')
media.show_video(
    eval_env.render(rollout[::render_every], camera='track'),
    fps=1.0/ eval_env.dt / render_every)

本文标签: 这篇博客系统编程