admin管理员组文章数量:1516870
这篇博客是 mujoco 官方教程文档中的第 5 篇 《The MJX tutorial provides usage examples of MuJoCo XLA, a branch of MuJoCo written in JAX》 ,此处我跳了一下最小二乘法那篇,因为那篇有些过于偏实际应用了,计划在下下一篇博客中介绍那篇文章。这篇博客刚好赶上该系列的番外一,在部署完成 JAX 之后趁热打铁介绍下使用 MJX 和纯 CPU 直接的差异。
这篇博客涉及到了一部分与强化学习有关的内容,如果实在读不懂可以先忽略细节部分,重点在于如何使用 MJX 这个库。我会在官方教程结束后开始写用 mujoco 进行强化学习的博客,届时会着重介绍 brax 、jax 这两个库以及 mujoco 提供的强化学习语法糖。
【Note】:在进入这一章节的学习前一定要确保你的环境中正确安装了
jax
和
jaxlib
这两个库。在运行下面的命令没有弹出报错。
import jax
print(jax.devices())[cuda(id=0)]如果执行除了上面输出的内容外还有类似下面的信息,则说明版本不匹配,jax 无法调用 GPU 资源,可以查看上面的番外博客重新安装对应的包:
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
这篇博客设计到的官方资源链接如下:
- 官方 Github 仓库:
- 官方 Colab 链接:
官方和我自己的博客代码放在下面的链接中,所有以
[offical]
开头的文件都是官方笔记,所有以
[note]
开头的文件都是和博客对应的笔记:
链接: 提取码: 83a4
1. 导入必要的包
import distutils.util
import os
import subprocess
import mujoco
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List
import mediapy as media
import matplotlib.pyplot as plt
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict
import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp
from mujoco import mjx
xla_flags = os.environ.get('XLA_FLAGS','')
xla_flags +=' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS']= xla_flags
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
2. 初次使用与验证
2.1 构造模型并准备CPU和GPU对象
构造一个简单的模型:
xml ="""
<mujoco>
<worldbody>
<light name="top" pos="0 0 1"/>
<body name="box_and_sphere" euler="0 0 -30">
<joint name="swing" type="hinge" axis="1 -1 0" pos="-.2 -.2 -.2"/>
<geom name="red_box" type="box" size=".2 .2 .2" rgba="1 0 0 1"/>
<geom name="green_sphere" pos=".2 .2 .2" size=".1" rgba="0 1 0 1"/>
</body>
</worldbody>
</mujoco>
"""准备一个在 CPU 环境下的 model 和 data:
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)再准备一个 JAX 下的 model 和 data,其实就是将数据搬到 GPU 上:
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
打印两组
qpos
。
mjData
的
qpos
是一个驻留在 CPU 上的 numpy 数组,而
mjx.Data
的
qpos
是一个驻留在 GPU 设备上的 JAX 数组。
print(mj_data.qpos,type(mj_data.qpos))print(mjx_data.qpos,type(mjx_data.qpos))
【Note】:第二行的输出一定要看到
jaxlib.xla_extension.ArrayImpl
才行。
2.2 运行 CPU 示例
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT]=True
duration =3.8
framerate =60
frames =[]
mujoco.mj_resetData(mj_model, mj_data)
start_time = time.time()while mj_data.time < duration:
mujoco.mj_step(mj_model, mj_data)iflen(frames)< mj_data.time * framerate:
renderer.update_scene(mj_data, scene_option=scene_option)
pixels = renderer.render()
frames.append(pixels)print(f"Total cost {time.time()-start_time:.2f} seconds")
media.show_video(frames, fps=framerate)2.3 运行 GPU 示例
现在使用 MJX 在 GPU 设备上运行完全相同的模拟,用
mjx.step
代替
mujoco.mj_step
,并对
mjx.step
进行了
jax.jit
处理以便它在 GPU 上高效运行。对于每一帧将
mjx.Data
转换回
mjData
以便使用 MuJoCo 渲染器。
jit_step = jax.jit(mjx.step)
frames =[]
mujoco.mj_resetData(mj_model, mj_data)
mjx_data = mjx.put_data(mj_model, mj_data)
start_time = time.time()while mjx_data.time < duration:
mjx_data = jit_step(mjx_model, mjx_data)iflen(frames)< mjx_data.time * framerate:
mj_data = mjx.get_data(mj_model, mjx_data)
renderer.update_scene(mj_data, scene_option=scene_option)
pixels = renderer.render()
frames.append(pixels)print(f"Total cost {time.time()-start_time:.2f} seconds")
media.show_video(frames, fps=framerate)此时你会发现在 GPU 上运行单线程物理模拟效率不高。MJX 的优势在于 可以在硬件加速设备上并行运行环境 。
在下面的示例中创建了 4096 个
mjx.Data
副本,并对批处理数据运行
mjx.step
。由于 MJX 是用 JAX 实现的,利用
jax.vmap
在所有
mjx.Data
上并行运行
mjx.step
。
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng,4096)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng,(1,))))(rng)
jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None,0)))
batch = jit_step(mjx_model, batch)print(batch.qpos)batched_mj_data = mjx.get_data(mj_model, batch)print([d.qpos for d in batched_mj_data])2. 使用 MJX 训练强化学习
运行大批量物理模拟对于训练强化学习策略非常有用。本文演示了如何使用
Brax
的强化学习库,通过 MJX 训练强化学习策略。
使用 MJX 和 Brax 实现了经典的
Humanoid
环境,继承
Brax
中的
MjxEnv
实现,以便在使用
Brax
强化学习实现进行训练的同时,使用 MJX 逐步实现物理模拟。
2.1 定义对象并引入资源
这里因为需要用到 mujoco 仓库中的资源,因此仍然需要在当前目录下从 Github 上拉取了 mujoco 官方仓库:
(mujoco) $ git clone git@github.com:google-deepmind/mujoco.git
然后定义
Humanoid
对象:
HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco'))/'mjx/test_data/humanoid'classHumanoid(PipelineEnv):def__init__(
self,
forward_reward_weight=1.25,
ctrl_cost_weight=0.1,
healthy_reward=5.0,
terminate_when_unhealthy=True,
healthy_z_range=(1.0,2.0),
reset_noise_scale=1e-2,
exclude_current_positions_from_observation=True,**kwargs,):#
mj_model = mujoco.MjModel.from_xml_path((HUMANOID_ROOT_PATH /'humanoid.xml').as_posix())
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations =6
mj_model.opt.ls_iterations =6
sys = mjcf.load_model(mj_model)
physics_steps_per_control_step =5
kwargs['n_frames']= kwargs.get('n_frames', physics_steps_per_control_step)
kwargs['backend']='mjx'super().__init__(sys,**kwargs)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight
self._healthy_reward = healthy_reward
self._terminate_when_unhealthy = terminate_when_unhealthy
self._healthy_z_range = healthy_z_range
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation =(
exclude_current_positions_from_observation
)defreset(self, rng: jp.ndarray)-> State:"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng,3)
low, hi =-self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.qpos0 + jax.random.uniform(
rng1,(self.sys.nq,), minval=low, maxval=hi
)
qvel = jax.random.uniform(
rng2,(self.sys.nv,), minval=low, maxval=hi
)
data = self.pipeline_init(qpos, qvel)
obs = self._get_obs(data, jp.zeros(self.sys.nu))
reward, done, zero = jp.zeros(3)
metrics ={'forward_reward': zero,'reward_linvel': zero,'reward_quadctrl': zero,'reward_alive': zero,'x_position': zero,'y_position': zero,'distance_from_origin': zero,'x_velocity': zero,'y_velocity': zero,}return State(data, obs, reward, done, metrics)defstep(self, state: State, action: jp.ndarray)-> State:"""Runs one timestep of the environment's dynamics."""
data0 = state.pipeline_state
data = self.pipeline_step(data0, action)
com_before = data0.subtree_com[1]
com_after = data.subtree_com[1]
velocity =(com_after - com_before)/ self.dt
forward_reward = self._forward_reward_weight * velocity[0]
min_z, max_z = self._healthy_z_range
is_healthy = jp.where(data.q[2]< min_z,0.0,1.0)
is_healthy = jp.where(data.q[2]> max_z,0.0, is_healthy)if self._terminate_when_unhealthy:
healthy_reward = self._healthy_reward
else:
healthy_reward = self._healthy_reward * is_healthy
ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
obs = self._get_obs(data, action)
reward = forward_reward + healthy_reward - ctrl_cost
done =1.0- is_healthy if self._terminate_when_unhealthy else0.0
state.metrics.update(
forward_reward=forward_reward,
reward_linvel=forward_reward,
reward_quadctrl=-ctrl_cost,
reward_alive=healthy_reward,
x_position=com_after[0],
y_position=com_after[1],
distance_from_origin=jp.linalg.norm(com_after),
x_velocity=velocity[0],
y_velocity=velocity[1],)return state.replace(
pipeline_state=data, obs=obs, reward=reward, done=done
)def_get_obs(
self, data: mjx.Data, action: jp.ndarray
)-> jp.ndarray:"""Observes humanoid body position, velocities, and angles."""
position = data.qpos
if self._exclude_current_positions_from_observation:
position = position[2:]return jp.concatenate([
position,
data.qvel,
data.cinert[1:].ravel(),
data.cvel[1:].ravel(),
data.qfrc_actuator,])
envs.register_environment('humanoid', Humanoid)
在上面的类成员函数
step
中定义了
is_healthy
标志位,如果在仿真过程中发现质心 z。轴高度
data.q[2]
没有在健康范围内,则会在外面触发结束仿真动作:
2.2 预览对象
env_name ='humanoid'
env = envs.get_environment(env_name)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state = jit_reset(jax.random.PRNGKey(0))
rollout =[state.pipeline_state]for i inrange(100):
ctrl =-0.1* jp.ones(env.sys.nu)
state = jit_step(state, ctrl)
rollout.append(state.pipeline_state)
media.show_video(env.render(rollout, camera='side'), fps=1.0/ env.dt)2.3 训练一个 PPO 策略
下面的代码官方在 A100 显卡上训练了 6 分钟,在我的单张 3060 Super 上训练了一下也是 6 分钟,你可以根据自己的算力情况对
train_fn
中的参数进行适当修改:
train_fn = functools.partial(
ppo.train,
num_timesteps=20_000_000,
num_evals=5, reward_scaling=0.1,
episode_length=1000, normalize_observations=True, action_repeat=1,
unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
batch_size=512, seed=0)
x_data =[]
y_data =[]
ydataerr =[]
times =[datetime.now()]
max_y, min_y =13000,0defprogress(num_steps, metrics):
times.append(datetime.now())
x_data.append(num_steps)
y_data.append(metrics['eval/episode_reward'])
ydataerr.append(metrics['eval/episode_reward_std'])
plt.xlim([0, train_fn.keywords['num_timesteps']*1.25])
plt.ylim([min_y, max_y])
plt.xlabel('# environment steps')
plt.ylabel('reward per episode')
plt.title(f'y={y_data[-1]:.3f}')
plt.errorbar(
x_data, y_data, yerr=ydataerr)
plt.show()
start_time = time.time()
make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)print(f'Total cost time {time.time()-start_time:.2f} seconds')print(f'time to jit: {times[1]- times[0]}')print(f'time to train: {times[-1]- times[1]}')使用 brax API 保存和加载训练好的策略:
train_fn = functools.partial(
ppo.train,
num_timesteps=20_000_000,
num_evals=5, reward_scaling=0.1,
episode_length=1000, normalize_observations=True, action_repeat=1,
unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
batch_size=512, seed=0)
x_data =[]
y_data =[]
ydataerr =[]
times =[datetime.now()]
max_y, min_y =13000,0defprogress(num_steps, metrics):
times.append(datetime.now())
x_data.append(num_steps)
y_data.append(metrics['eval/episode_reward'])
ydataerr.append(metrics['eval/episode_reward_std'])
plt.xlim([0, train_fn.keywords['num_timesteps']*1.25])
plt.ylim([min_y, max_y])
plt.xlabel('# environment steps')
plt.ylabel('reward per episode')
plt.title(f'y={y_data[-1]:.3f}')
plt.errorbar(
x_data, y_data, yerr=ydataerr)
plt.show()
start_time = time.time()
make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)print(f'Total cost time {time.time()-start_time:.2f} seconds')print(f'time to jit: {times[1]- times[0]}')print(f'time to train: {times[-1]- times[1]}')2.4 保存模型
model_path ='./mjx_brax_policy'
model.save_params(model_path, params)2.5 加载模型并定义推理函数
params = model.load_params(model_path)
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)2.6 使用策略录制视频
可以很明显地发现使用 jax 做推理的流程与之前直接使用 mujoco 的推理是存在较大的差异。
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
rollout =[state.pipeline_state]
n_step =500
render_every =2
start_time = time.time()for i inrange(n_step):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
rollout.append(state.pipeline_state)if state.done:breakprint(f"Average evaluate step cost time: {(time.time()- start_time)/n_step:.5f}")
media.show_video(env.render(rollout[::render_every], camera='side'), fps=1.0/env.dt / render_every)
平均每步的推理耗时约 0.007s。
3. 在 mujoco 中使用 MJX 策略
还可以使用原生 mujoco 绑定执行物理步骤,以证明在 MJX 中训练的策略在 mujoco 中确实有效。
mj_model = eval_env.sys.mj_model
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)
ctrl = jp.zeros(mj_model.nu)
images =[]
start_time = time.time()for i inrange(n_step):
act_rng, rng = jax.random.split(rng)
obs = eval_env._get_obs(mjx.put_data(mj_model, mj_data), ctrl)
ctrl, _ = jit_inference_fn(obs, act_rng)
mj_data.ctrl = ctrl
for _ inrange(eval_env._n_frames):
mujoco.mj_step(mj_model, mj_data)if i % render_every ==0:
renderer.update_scene(mj_data, camera='side')
images.append(renderer.render())print(f"Average evaluate step cost time: {(time.time()- start_time)/n_step:.5f}")
media.show_video(images, fps=1.0/ eval_env.dt / render_every)4. 使用域随机化训练策略
有时可能还希望在训练策略时对某些
mjModel
参数进行随机化。在 MJX 中可以轻松地创建一批环境,并在
mjx.Model
中填充随机值。下面的代码展示了一个随机化摩擦力和执行器增益/偏差的函数。
4.1 拉取模型资源仓库
这部分代码需要
mujoco_menagerie
中的一些资源,从github仓库中拉取并放在同级目录下:
(mujoco) $ git clone git@github.com:google-deepmind/mujoco_menagerie.git
4.2 定义域随机化函数
定义域随机化函数
defdomain_randomize(sys, rng):"""Randomizes the mjx.Model."""@jax.vmapdefrand(rng):
_, key = jax.random.split(rng,2)# friction
friction = jax.random.uniform(key,(1,), minval=0.6, maxval=1.4)
friction = sys.geom_friction.at[:,0].set(friction)# actuator
_, key = jax.random.split(key,2)
gain_range =(-5,5)
param = jax.random.uniform(
key,(1,), minval=gain_range[0], maxval=gain_range[1])+ sys.actuator_gainprm[:,0]
gain = sys.actuator_gainprm.at[:,0].set(param)
bias = sys.actuator_biasprm.at[:,1].set(-param)return friction, gain, bias
friction, gain, bias = rand(rng)
in_axes = jax.tree_util.tree_map(lambda x:None, sys)
in_axes = in_axes.tree_replace({'geom_friction':0,'actuator_gainprm':0,'actuator_biasprm':0,})
sys = sys.tree_replace({'geom_friction': friction,'actuator_gainprm': gain,'actuator_biasprm': bias,})return sys, in_axes
如果想要 10 个具有随机摩擦和执行器参数的环境,可以调用
domain_randomize
,它会返回一个批处理的
mjx.Model
以及一个指定批处理的轴的字典。
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng,10)
batched_sys, _ = domain_randomize(env.sys, rng)print('Single env friction shape: ', env.sys.geom_friction.shape)print('Batched env friction shape: ', batched_sys.geom_friction.shape)print('Friction on geom 0: ', env.sys.geom_friction[0,0])print('Random frictions on geom 0: ', batched_sys.geom_friction[:,0,0])4.3 定义辅助函数和类
这部分如果你看不懂在初学阶段不强求,主要是学习整个 mjx 的训练流程。
BARKOUR_ROOT_PATH = epath.Path('mujoco_menagerie/google_barkour_vb')defget_config():"""Returns reward config for barkour quadruped environment."""defget_default_rewards_config():
default_config = config_dict.ConfigDict(dict(# The coefficients for all reward terms used for training. All# physical quantities are in SI units, if no otherwise specified,# i.e. joint positions are in rad, positions are measured in meters,# torques in Nm, and time in seconds, and forces in Newtons.
scales=config_dict.ConfigDict(dict(# Tracking rewards are computed using exp(-delta^2/sigma)# sigma can be a hyperparameters to tune.# Track the base x-y velocity (no z-velocity tracking.)
tracking_lin_vel=1.5,# Track the angular velocity along z-axis, i.e. yaw rate.
tracking_ang_vel=0.8,# Below are regularization terms, we roughly divide the# terms to base state regularizations, joint# regularizations, and other behavior regularizations.# Penalize the base velocity in z direction, L2 penalty.
lin_vel_z=-2.0,# Penalize the base roll and pitch rate. L2 penalty.
ang_vel_xy=-0.05,# Penalize non-zero roll and pitch angles. L2 penalty.
orientation=-5.0,# L2 regularization of joint torques, |tau|^2.
torques=-0.0002,# Penalize the change in the action and encourage smooth# actions. L2 regularization |action - last_action|^2
action_rate=-0.01,# Encourage long swing steps. However, it does not# encourage high clearances.
feet_air_time=0.2,# Encourage no motion at zero command, L2 regularization# |q - q_default|^2.
stand_still=-0.5,# Early termination penalty.
termination=-1.0,# Penalizing foot slipping on the ground.
foot_slip=-0.1,)),# Tracking reward = exp(-error^2/sigma).
tracking_sigma=0.25,))return default_config
default_config = config_dict.ConfigDict(dict(
rewards=get_default_rewards_config(),))return default_config
classBarkourEnv(PipelineEnv):"""Environment for training the barkour quadruped joystick policy in MJX."""def__init__(
self,
obs_noise:float=0.05,
action_scale:float=0.3,
kick_vel:float=0.05,
scene_file:str='scene_mjx.xml',**kwargs,):
path = BARKOUR_ROOT_PATH / scene_file
sys = mjcf.load(path.as_posix())
self._dt =0.02# this environment is 50 fps
sys = sys.tree_replace({'opt.timestep':0.004})# override menagerie params for smoother policy
sys = sys.replace(
dof_damping=sys.dof_damping.at[6:].set(0.5239),
actuator_gainprm=sys.actuator_gainprm.at[:,0].set(35.0),
actuator_biasprm=sys.actuator_biasprm.at[:,1].set(-35.0),)
n_frames = kwargs.pop('n_frames',int(self._dt / sys.opt.timestep))super().__init__(sys, backend='mjx', n_frames=n_frames)
self.reward_config = get_config()# set custom from kwargsfor k, v in kwargs.items():if k.endswith('_scale'):
self.reward_config.rewards.scales[k[:-6]]= v
self._torso_idx = mujoco.mj_name2id(
sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value,'torso')
self._action_scale = action_scale
self._obs_noise = obs_noise
self._kick_vel = kick_vel
self._init_q = jp.array(sys.mj_model.keyframe('home').qpos)
self._default_pose = sys.mj_model.keyframe('home').qpos[7:]
self.lowers = jp.array([-0.7,-1.0,0.05]*4)
self.uppers = jp.array([0.52,2.1,2.1]*4)
feet_site =['foot_front_left','foot_hind_left','foot_front_right','foot_hind_right',]
feet_site_id =[
mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, f)for f in feet_site
]assertnotany(id_ ==-1for id_ in feet_site_id),'Site not found.'
self._feet_site_id = np.array(feet_site_id)
lower_leg_body =['lower_leg_front_left','lower_leg_hind_left','lower_leg_front_right','lower_leg_hind_right',]
lower_leg_body_id =[
mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, l)for l in lower_leg_body
]assertnotany(id_ ==-1for id_ in lower_leg_body_id),'Body not found.'
self._lower_leg_body_id = np.array(lower_leg_body_id)
self._foot_radius =0.0175
self._nv = sys.nv
defsample_command(self, rng: jax.Array)-> jax.Array:
lin_vel_x =[-0.6,1.5]# min max [m/s]
lin_vel_y =[-0.8,0.8]# min max [m/s]
ang_vel_yaw =[-0.7,0.7]# min max [rad/s]
_, key1, key2, key3 = jax.random.split(rng,4)
lin_vel_x = jax.random.uniform(
key1,(1,), minval=lin_vel_x[0], maxval=lin_vel_x[1])
lin_vel_y = jax.random.uniform(
key2,(1,), minval=lin_vel_y[0], maxval=lin_vel_y[1])
ang_vel_yaw = jax.random.uniform(
key3,(1,), minval=ang_vel_yaw[0], maxval=ang_vel_yaw[1])
new_cmd = jp.array([lin_vel_x[0], lin_vel_y[0], ang_vel_yaw[0]])return new_cmd
defreset(self, rng: jax.Array)-> State:# pytype: disable=signature-mismatch
rng, key = jax.random.split(rng)
pipeline_state = self.pipeline_init(self._init_q, jp.zeros(self._nv))
state_info ={'rng': rng,'last_act': jp.zeros(12),'last_vel': jp.zeros(12),'command': self.sample_command(key),'last_contact': jp.zeros(4, dtype=bool),'feet_air_time': jp.zeros(4),'rewards':{k:0.0for k in self.reward_config.rewards.scales.keys()},'kick': jp.array([0.0,0.0]),'step':0,}
obs_history = jp.zeros(15*31)# store 15 steps of history
obs = self._get_obs(pipeline_state, state_info, obs_history)
reward, done = jp.zeros(2)
metrics ={'total_dist':0.0}for k in state_info['rewards']:
metrics[k]= state_info['rewards'][k]
state = State(pipeline_state, obs, reward, done, metrics, state_info)# pytype: disable=wrong-arg-typesreturn state
defstep(self, state: State, action: jax.Array)-> State:# pytype: disable=signature-mismatch
rng, cmd_rng, kick_noise_2 = jax.random.split(state.info['rng'],3)# kick
push_interval =10
kick_theta = jax.random.uniform(kick_noise_2, maxval=2* jp.pi)
kick = jp.array([jp.cos(kick_theta), jp.sin(kick_theta)])
kick *= jp.mod(state.info['step'], push_interval)==0
qvel = state.pipeline_state.qvel # pytype: disable=attribute-error
qvel = qvel.at[:2].set(kick * self._kick_vel + qvel[:2])
state = state.tree_replace({'pipeline_state.qvel': qvel})# physics step
motor_targets = self._default_pose + action * self._action_scale
motor_targets = jp.clip(motor_targets, self.lowers, self.uppers)
pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
x, xd = pipeline_state.x, pipeline_state.xd
# observation data
obs = self._get_obs(pipeline_state, state.info, state.obs)
joint_angles = pipeline_state.q[7:]
joint_vel = pipeline_state.qd[6:]# foot contact data based on z-position
foot_pos = pipeline_state.site_xpos[self._feet_site_id]# pytype: disable=attribute-error
foot_contact_z = foot_pos[:,2]- self._foot_radius
contact = foot_contact_z <1e-3# a mm or less off the floor
contact_filt_mm = contact | state.info['last_contact']
contact_filt_cm =(foot_contact_z <3e-2)| state.info['last_contact']
first_contact =(state.info['feet_air_time']>0)* contact_filt_mm
state.info['feet_air_time']+= self.dt
# done if joint limits are reached or robot is falling
up = jp.array([0.0,0.0,1.0])
done = jp.dot(math.rotate(up, x.rot[self._torso_idx -1]), up)<0
done |= jp.any(joint_angles < self.lowers)
done |= jp.any(joint_angles > self.uppers)
done |= pipeline_state.x.pos[self._torso_idx -1,2]<0.18# reward
rewards ={'tracking_lin_vel':(
self._reward_tracking_lin_vel(state.info['command'], x, xd)),'tracking_ang_vel':(
self._reward_tracking_ang_vel(state.info['command'], x, xd)),'lin_vel_z': self._reward_lin_vel_z(xd),'ang_vel_xy': self._reward_ang_vel_xy(xd),'orientation': self._reward_orientation(x),'torques': self._reward_torques(pipeline_state.qfrc_actuator),# pytype: disable=attribute-error'action_rate': self._reward_action_rate(action, state.info['last_act']),'stand_still': self._reward_stand_still(
state.info['command'], joint_angles,),'feet_air_time': self._reward_feet_air_time(
state.info['feet_air_time'],
first_contact,
state.info['command'],),'foot_slip': self._reward_foot_slip(pipeline_state, contact_filt_cm),'termination': self._reward_termination(done, state.info['step']),}
rewards ={
k: v * self.reward_config.rewards.scales[k]for k, v in rewards.items()}
reward = jp.clip(sum(rewards.values())* self.dt,0.0,10000.0)# state management
state.info['kick']= kick
state.info['last_act']= action
state.info['last_vel']= joint_vel
state.info['feet_air_time']*=~contact_filt_mm
state.info['last_contact']= contact
state.info['rewards']= rewards
state.info['step']+=1
state.info['rng']= rng
# sample new command if more than 500 timesteps achieved
state.info['command']= jp.where(
state.info['step']>500,
self.sample_command(cmd_rng),
state.info['command'],)# reset the step counter when done
state.info['step']= jp.where(
done |(state.info['step']>500),0, state.info['step'])# log total displacement as a proxy metric
state.metrics['total_dist']= math.normalize(x.pos[self._torso_idx -1])[1]
state.metrics.update(state.info['rewards'])
done = jp.float32(done)
state = state.replace(
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
)return state
def_get_obs(
self,
pipeline_state: base.State,
state_info:dict[str, Any],
obs_history: jax.Array,)-> jax.Array:
inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0])
local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot)
obs = jp.concatenate([
jp.array([local_rpyrate[2]])*0.25,# yaw rate
math.rotate(jp.array([0,0,-1]), inv_torso_rot),# projected gravity
state_info['command']* jp.array([2.0,2.0,0.25]),# command
pipeline_state.q[7:]- self._default_pose,# motor angles
state_info['last_act'],# last action])# clip, noise
obs = jp.clip(obs,-100.0,100.0)+ self._obs_noise * jax.random.uniform(
state_info['rng'], obs.shape, minval=-1, maxval=1)# stack observations through time
obs = jp.roll(obs_history, obs.size).at[:obs.size].set(obs)return obs
# ------------ reward functions----------------def_reward_lin_vel_z(self, xd: Motion)-> jax.Array:# Penalize z axis base linear velocityreturn jp.square(xd.vel[0,2])def_reward_ang_vel_xy(self, xd: Motion)-> jax.Array:# Penalize xy axes base angular velocityreturn jp.sum(jp.square(xd.ang[0,:2]))def_reward_orientation(self, x: Transform)-> jax.Array:# Penalize non flat base orientation
up = jp.array([0.0,0.0,1.0])
rot_up = math.rotate(up, x.rot[0])return jp.sum(jp.square(rot_up[:2]))def_reward_torques(self, torques: jax.Array)-> jax.Array:# Penalize torquesreturn jp.sqrt(jp.sum(jp.square(torques)))+ jp.sum(jp.abs(torques))def_reward_action_rate(
self, act: jax.Array, last_act: jax.Array
)-> jax.Array:# Penalize changes in actionsreturn jp.sum(jp.square(act - last_act))def_reward_tracking_lin_vel(
self, commands: jax.Array, x: Transform, xd: Motion
)-> jax.Array:# Tracking of linear velocity commands (xy axes)
local_vel = math.rotate(xd.vel[0], math.quat_inv(x.rot[0]))
lin_vel_error = jp.sum(jp.square(commands[:2]- local_vel[:2]))
lin_vel_reward = jp.exp(-lin_vel_error / self.reward_config.rewards.tracking_sigma
)return lin_vel_reward
def_reward_tracking_ang_vel(
self, commands: jax.Array, x: Transform, xd: Motion
)-> jax.Array:# Tracking of angular velocity commands (yaw)
base_ang_vel = math.rotate(xd.ang[0], math.quat_inv(x.rot[0]))
ang_vel_error = jp.square(commands[2]- base_ang_vel[2])return jp.exp(-ang_vel_error / self.reward_config.rewards.tracking_sigma)def_reward_feet_air_time(
self, air_time: jax.Array, first_contact: jax.Array, commands: jax.Array
)-> jax.Array:# Reward air time.
rew_air_time = jp.sum((air_time -0.1)* first_contact)
rew_air_time *=(
math.normalize(commands[:2])[1]>0.05)# no reward for zero commandreturn rew_air_time
def_reward_stand_still(
self,
commands: jax.Array,
joint_angles: jax.Array,)-> jax.Array:# Penalize motion at zero commandsreturn jp.sum(jp.abs(joint_angles - self._default_pose))*(
math.normalize(commands[:2])[1]<0.1)def_reward_foot_slip(
self, pipeline_state: base.State, contact_filt: jax.Array
)-> jax.Array:# get velocities at feet which are offset from lower legs# pytype: disable=attribute-error
pos = pipeline_state.site_xpos[self._feet_site_id]# feet position
feet_offset = pos - pipeline_state.xpos[self._lower_leg_body_id]# pytype: enable=attribute-error
offset = base.Transform.create(pos=feet_offset)
foot_indices = self._lower_leg_body_id -1# we got rid of the world body
foot_vel = offset.vmap().do(pipeline_state.xd.take(foot_indices)).vel
# Penalize large feet velocity for feet that are in contact with the ground.return jp.sum(jp.square(foot_vel[:,:2])* contact_filt.reshape((-1,1)))def_reward_termination(self, done: jax.Array, step: jax.Array)-> jax.Array:return done &(step <500)defrender(
self, trajectory: List[base.State], camera:str|None=None,
width:int=240, height:int=320,)-> Sequence[np.ndarray]:
camera = camera or'track'returnsuper().render(trajectory, camera=camera, width=width, height=height)
envs.register_environment('barkour', BarkourEnv)初始化环境
env_name ='barkour'
env = envs.get_environment(env_name)
为了训练具有域随机化的策略,将域随机化函数传入
brax
训练函数;
brax
会在推动仿真时调用域随机化函数。
ckpt_path = epath.Path('/tmp/quadrupred_joystick/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)defpolicy_params_fn(current_step, make_policy, params):
orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(params)
path = ckpt_path /f'{current_step}'
orbax_checkpointer.save(path, params, force=True, save_args=save_args)定义训练函数
make_networks_factory = functools.partial(
ppo_networks.make_ppo_networks,
policy_hidden_layer_sizes=(128,128,128,128))
train_fn = functools.partial(
ppo.train, num_timesteps=100_000_000, num_evals=10,
reward_scaling=1, episode_length=1000, normalize_observations=True,
action_repeat=1, unroll_length=20, num_minibatches=32,
num_updates_per_batch=4, discounting=0.97, learning_rate=3.0e-4,
entropy_cost=1e-2, num_envs=8192, batch_size=256,
network_factory=make_networks_factory,
randomization_fn=domain_randomize,
policy_params_fn=policy_params_fn,
seed=0)4.4 执行训练
官方文档在 Tesla A100 GPU 上训练四足动物需要 6 分钟,我个人的 3060 Super 单卡上需要 14 分钟。
x_data =[]
y_data =[]
ydataerr =[]
times =[datetime.now()]
max_y, min_y =40,0
env = envs.get_environment(env_name)
eval_env = envs.get_environment(env_name)
start_time = time.time()
make_inference_fn, params, _= train_fn(environment=env,
progress_fn=progress,
eval_env=eval_env)print(f'Total cost time {time.time()- start_time:.2f}')print(f'time to jit: {times[1]- times[0]}')print(f'time to train: {times[-1]- times[1]}')4.5 保存与加载模型
保存模型
model_path ='./mjx_brax_quadruped_policy'
model.save_params(model_path, params)
params = model.load_params(model_path)加载模型
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)4.6 可视化结果
对于
Barkour Quadruped
可以通过
x_vel
、
y_vel
和
ang_vel
设置操纵杆命令。
x_vel
和
y_vel
定义相对于四足躯干的线性前向和侧向速度。
ang_vel
定义躯干在 z 方向的角速度。
eval_env = envs.get_environment(env_name)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)尝试修改下面这三个变量体验不同的推理结果
x_vel =1.0
y_vel =0.0
ang_vel =-0.5根据预期输入执行推理
the_command = jp.array([x_vel, y_vel, ang_vel])# 初始化状态
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
state.info['command']= the_command
rollout =[state.pipeline_state]
n_steps =500
render_every =2
start_time = time.time()for i inrange(n_steps):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
rollout.append(state.pipeline_state)print(f"Total cost {time.time()- start_time:.2f} seconds")
media.show_video(
eval_env.render(rollout[::render_every], camera='track'),
fps=1.0/ eval_env.dt / render_every)5. 在高度场上微调模型
有时还想让四足动物学会在崎岖的地形上行走,从上面的操纵杆策略中取出最新的检查点,并在高度场地形上对其进行微调。
5.1 生成不规则高度场
scene_file ='scene_hfield_mjx.xml'
env = envs.get_environment(env_name, scene_file=scene_file)
jit_reset = jax.jit(env.reset)
state = jit_reset(jax.random.PRNGKey(0))
plt.imshow(env.render([state.pipeline_state], camera='track')[0])5.2 定义训练函数
latest_ckpts =list(ckpt_path.glob('*'))
latest_ckpts.sort(key=lambda x:int(x.as_posix().split('/')[-1]))
latest_ckpt = latest_ckpts[-1]
train_fn = functools.partial(
ppo.train, num_timesteps=40_000_000, num_evals=5,
reward_scaling=1, episode_length=1000, normalize_observations=True,
action_repeat=1, unroll_length=20, num_minibatches=32,
num_updates_per_batch=4, discounting=0.97, learning_rate=3.0e-4,
entropy_cost=1e-2, num_envs=8192, batch_size=256,
network_factory=make_networks_factory,
randomization_fn=domain_randomize, seed=0,
restore_checkpoint_path=latest_ckpt)5.3 对模型进行微调
在我的 3060 Super 单卡上花费了约 9 分钟。
x_data =[]
y_data =[]
ydataerr =[]
times =[datetime.now()]
max_y, min_y =40,0
env = envs.get_environment(env_name, scene_file=scene_file)
eval_env = envs.get_environment(env_name, scene_file=scene_file)
start_time = time.time()
make_inference_fn, params, _= train_fn(environment=env,
progress_fn=progress,
eval_env=eval_env)print(f'Total cost {time.time()- start_time:.2f} seconds')print(f'time to jit: {times[1]- times[0]}')print(f'time to train: {times[-1]- times[1]}')5.4 可视化训练结果
eval_env = envs.get_environment(env_name, scene_file=scene_file)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)和上面的一样,可以修改下面这三个变量然后查看模型的结果
x_vel =1.0
y_vel =0.0
ang_vel =-0.5执行推理:
the_command = jp.array([x_vel, y_vel, ang_vel])# 初始化状态
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
state.info['command']= the_command
rollout =[state.pipeline_state]
n_steps =500
render_every =2
start_time = time.time()for i inrange(n_steps):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
rollout.append(state.pipeline_state)print(f'Total cost {time.time()-start_time:.2f} seconds')
media.show_video(
eval_env.render(rollout[::render_every], camera='track'),
fps=1.0/ eval_env.dt / render_every)版权声明:本文标题:_qpos在MuJoCo XLA中的秘籍:官方教程详解 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:https://www.betaflare.com/web/1773325304a3277890.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。


发表评论