首页 最新 热门 推荐

  • 首页
  • 最新
  • 热门
  • 推荐

机器人模仿学习之动作分块ACT算法的代码剖析、部署训练

  • 25-03-02 12:02
  • 3906
  • 12632
blog.csdn.net

前言

本文最早是属于《斯坦福Mobile ALOHA背后的关键技术:动作分块ACT算法的原理解析》的第二、第三部分,涉及到动作分块ACT的代码剖析与部署训练

但因为想把ACT的代码逐行剖析的更细致些,加之为避免上一篇文章太过于长,故把动作分块ACT的代码剖析与部署实践这块独立出来成本文

第一部分 动作分块算法ACT的代码剖析

关于ACT的代码,我们可以重点研究下这个仓库:GitHub - tonyzhaozh/act,我司同事杜老师也于24年1.10日跑通了这份代码(如何跑通的教程见下文第二部分)

  • imitate_episodes.py,训练和评估 ACT
  • policy.py,An adaptor for ACT policy
  • detr,ACT 的模型定义 修改自 DETR
  • sim_env.py,具有 joint space control的 Mujoco + DM_Control 环境
  • ee_sim_env.py,具有EE space control的 Mujoco + DM_Control 环境
  • scripted_policy.py,模拟环境的脚本化策略
  • constants.py,跨文件共享的常量
  • utils.py,数据加载和辅助函数等实用程序
  • visualize_episodes.py,保存 .hdf5 数据集中的视频

1.1 ACT的训练与评估imitate_episodes.py

1.1.1 主程序

  1. 从命令行参数中获取模型训练和评估的相关配置
    1. def main(args):
    2. set_seed(1) # 设置随机种子以保证结果可重现
    3. # 解析命令行参数
    4. is_eval = args["eval"] # 是否为评估模式的布尔标志
    5. ckpt_dir = args["ckpt_dir"] # 保存/加载检查点的目录
    6. policy_class = args["policy_class"] # 使用的策略类
    7. onscreen_render = args["onscreen_render"] # 是否进行屏幕渲染的标志
    8. task_name = args["task_name"] # 任务名称
    9. batch_size_train = args["batch_size"] # 训练批大小
    10. batch_size_val = args["batch_size"] # 验证批大小
    11. num_epochs = args["num_epochs"] # 训练的总周期数
    12. use_waypoint = args["use_waypoint"] # 是否使用航点
    13. constant_waypoint = args["constant_waypoint"] # 持续航点的设置
    14. # 根据是否使用航点打印相应信息
    15. if use_waypoint:
    16. print("Using waypoint") # 使用航点
    17. if constant_waypoint is not None:
    18. print(f"Constant waypoint: {constant_waypoint}") # 持续航点
  2. 根据任务名称和配置获取任务参数,例如数据集目录、任务类型等
    1. # 获取任务参数
    2. is_sim = True # 硬编码为True以避免从aloha中查找常量
    3. # 如果是模拟任务,从constants导入SIM_TASK_CONFIGS
    4. if is_sim:
    5. from constants import SIM_TASK_CONFIGS
    6. task_config = SIM_TASK_CONFIGS[task_name]
    7. else:
    8. from aloha_scripts.constants import TASK_CONFIGS
    9. task_config = TASK_CONFIGS[task_name]
    10. # 从任务配置中获取相关参数
    11. dataset_dir = task_config["dataset_dir"]
    12. num_episodes = task_config["num_episodes"]
    13. episode_len = task_config["episode_len"]
    14. camera_names = task_config["camera_names"]
  3. 定义模型的架构和超参数,包括学习率、网络结构、层数等
    1. # 固定参数
    2. state_dim = 14 # 状态维度
    3. lr_backbone = 1e-5 # 主干网络的学习率
    4. backbone = "resnet18" # 使用的主干网络类型
  4. 根据策略类别设置策略配置
    1. # 根据策略类别设置策略配置
    2. if policy_class == "ACT":
    3. # ACT策略的特定参数
    4. enc_layers = 4
    5. dec_layers = 7
    6. nheads = 8
    7. policy_config = {
    8. "lr": args["lr"],
    9. "num_queries": args["chunk_size"],
    10. "kl_weight": args["kl_weight"],
    11. "hidden_dim": args["hidden_dim"],
    12. "dim_feedforward": args["dim_feedforward"],
    13. "lr_backbone": lr_backbone,
    14. "backbone": backbone,
    15. "enc_layers": enc_layers,
    16. "dec_layers": dec_layers,
    17. "nheads": nheads,
    18. "camera_names": camera_names,
    19. }
    20. elif policy_class == "CNNMLP":
    21. # CNNMLP策略的特定参数
    22. policy_config = {
    23. "lr": args["lr"],
    24. "lr_backbone": lr_backbone,
    25. "backbone": backbone,
    26. "num_queries": 1,
    27. "camera_names": camera_names,
    28. }
    29. else:
    30. raise NotImplementedError
  5. 配置训练参数
    1. # 配置训练参数
    2. config = {
    3. "num_epochs": num_epochs,
    4. "ckpt_dir": ckpt_dir,
    5. "episode_len": episode_len,
    6. "state_dim": state_dim,
    7. "lr": args["lr"],
    8. "policy_class": policy_class,
    9. "onscreen_render": onscreen_render,
    10. "policy_config": policy_config,
    11. "task_name": task_name,
    12. "seed": args["seed"],
    13. "temporal_agg": args["temporal_agg"],
    14. "camera_names": camera_names,
    15. "real_robot": not is_sim,
    16. }
  6. 如果设置为评估模式,加载保存的模型权重并在验证集上评估模型性能,计算成功率和平均回报
    1. # 如果为评估模式,执行评估流程
    2. if is_eval:
    3. ckpt_names = [f"policy_best.ckpt"]
    4. results = []
    5. for ckpt_name in ckpt_names:
    6. success_rate, avg_return = eval_bc(config, ckpt_name, save_episode=True)
    7. results.append([ckpt_name, success_rate, avg_return])
    8. for ckpt_name, success_rate, avg_return in results:
    9. print(f"{ckpt_name}: {success_rate=} {avg_return=}")
    10. print()
    11. exit()
    12. # 加载数据
    13. train_dataloader, val_dataloader, stats, _ = load_data(
    14. dataset_dir,
    15. num_episodes,
    16. camera_names,
    17. batch_size_train,
    18. batch_size_val,
    19. use_waypoint,
    20. constant_waypoint,
    21. )
    22. # 保存数据集统计信息
    23. if not os.path.isdir(ckpt_dir):
    24. os.makedirs(ckpt_dir)
    25. stats_path = os.path.join(ckpt_dir, f"dataset_stats.pkl")
    26. with open(stats_path, "wb") as f:
    27. pickle.dump(stats, f)
  7. 最后,将结果打印出来
    1. # 训练并获取最佳检查点信息
    2. best_ckpt_info = train_bc(train_dataloader, val_dataloader, config)
    3. best_epoch, min_val_loss, best_state_dict = best_ckpt_info
    4. # 保存最佳检查点
    5. ckpt_path = os.path.join(ckpt_dir, f"policy_best.ckpt")
    6. torch.save(best_state_dict, ckpt_path)
    7. print(f"Best ckpt, val loss {min_val_loss:.6f} @ epoch{best_epoch}")

1.1.2 make_policy、make_optimizer、get_image

根据指定的policy_class(策略类别,目前支持两种类型:"ACT"和"CNNMLP"),和policy_config(策略配置)创建一个策略模型对象

  1. def make_policy(policy_class, policy_config):
  2. if policy_class == 'ACT':
  3. policy = ACTPolicy(policy_config) # 如果策略类是 ACT,创建 ACTPolicy
  4. elif policy_class == 'CNNMLP':
  5. policy = CNNMLPPolicy(policy_config) # 如果策略类是 CNNMLP,创建 CNNMLPPolicy
  6. else:
  7. raise NotImplementedError # 如果不是以上两种类型,则抛出未实现错误
  8. return policy # 返回创建的策略对象

make_optimizer用于创建策略模型的优化器(optimizer),并返回创建的优化器对象。优化器的作用是根据策略模型的损失函数来更新模型的参数,以使损失函数尽量减小

  1. def make_optimizer(policy_class, policy):
  2. if policy_class == 'ACT':
  3. optimizer = policy.configure_optimizers() # 如果策略类是 ACT,配置优化器
  4. elif policy_class == 'CNNMLP':
  5. optimizer = policy.configure_optimizers() # 如果策略类是 CNNMLP,配置优化器
  6. else:
  7. raise NotImplementedError # 如果不是以上两种类型,则抛出未实现错误
  8. return optimizer # 返回配置的优化器

get_image的作用是获取一个时间步(ts)的图像数据。函数接受两个参数:ts和camera_names

  1. def get_image(ts, camera_names):
  2. curr_images = []
  3. for cam_name in camera_names:
  4. curr_image = rearrange(ts.observation['images'][cam_name], 'h w c -> c h w') # 重排图像数组
  5. curr_images.append(curr_image) # 将处理后的图像添加到列表中
  6. curr_image = np.stack(curr_images, axis=0) # 将图像列表堆叠成数组
  7. curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) # 将数组转换为 PyTorch 张量
  8. return curr_image # 返回处理后的图像张量
  1. ts是一个时间步的数据,包含了多个相机(摄像头)拍摄的图像
    ts.observation["images"]包含了各个相机拍摄的图像数据,而camera_names是一个列表,包含了要获取的相机的名称
  2. 函数通过循环遍历camera_names中的相机名称,从ts.observation["images"]中获取对应相机的图像数据
    这些图像数据首先通过rearrange函数重新排列维度,将"height-width-channels"的顺序变为"channels-height-width",以适应PyTorch的数据格式
  3. 获取的图像数据被放入curr_images列表中
  4. 接下来,函数将curr_images列表中的所有图像数据堆叠成一个张量(tensor),np.stack(curr_images, axis=0)这一行代码实现了这个操作
  5. 接着,图像数据被归一化到[0, 1]的范围,然后转换为PyTorch的float类型,并移到GPU上(如果可用)。最后,图像数据被增加了一个额外的维度(unsqueeze(0)),以适应模型的输入要求

最终,函数返回包含时间步图像数据的PyTorch张量。这个图像数据可以被用于输入到神经网络模型中进行处理

1.1.3 eval_bc:评估一个行为克隆(behavior cloning)模型

  1. 的
    1. def eval_bc(config, ckpt_name, save_episode=True):
    2. set_seed(1000) # 设置随机种子为 1000
    3. # 从配置中获取参数
    4. ckpt_dir = config['ckpt_dir']
    5. state_dim = config['state_dim']
    6. real_robot = config['real_robot']
    7. policy_class = config['policy_class']
    8. onscreen_render = config['onscreen_render']
    9. policy_config = config['policy_config']
    10. camera_names = config['camera_names']
    11. max_timesteps = config['episode_len']
    12. task_name = config['task_name']
    13. temporal_agg = config['temporal_agg']
    14. onscreen_cam = 'angle'
    15. # 加载策略和统计信息
    16. ckpt_path = os.path.join(ckpt_dir, ckpt_name)
    17. policy = make_policy(policy_class, policy_config)
    18. loading_status = policy.load_state_dict(torch.load(ckpt_path))
    19. print(loading_status)
    20. policy.cuda()
    21. policy.eval()
    22. print(f'Loaded: {ckpt_path}')
    23. stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
    24. with open(stats_path, 'rb') as f:
    25. stats = pickle.load(f)
    26. # 定义预处理和后处理函数
    27. pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
    28. post_process = lambda a: a * stats['action_std'] + stats['action_mean']
  2. 的
    1. # 加载环境
    2. if real_robot:
    3. from aloha_scripts.robot_utils import move_grippers # 从 aloha_scripts.robot_utils 导入 move_grippers
    4. from aloha_scripts.real_env import make_real_env # 从 aloha_scripts.real_env 导入 make_real_env
    5. env = make_real_env(init_node=True) # 创建真实机器人环境
    6. env_max_reward = 0
    7. else:
    8. from sim_env import make_sim_env # 从 sim_env 导入 make_sim_env
    9. env = make_sim_env(task_name) # 创建模拟环境
    10. env_max_reward = env.task.max_reward
    11. # 设置查询频率和时间聚合参数
    12. query_frequency = policy_config['num_queries']
    13. if temporal_agg:
    14. query_frequency = 1
    15. num_queries = policy_config['num_queries']
    16. # 设置最大时间步数
    17. max_timesteps = int(max_timesteps * 1) # 可以根据实际任务调整最大时间步数
  3. 设置评估的循环次数(num_rollouts),每次循环都会进行一次评估
    在每次循环中,初始化环境,执行模型生成的动作并观测环境的响应
    将每个时间步的观测数据(包括图像、关节位置等)存储在相应的列表中
    1. # 设置回放次数和初始化结果列表
    2. num_rollouts = 50
    3. episode_returns = []
    4. highest_rewards = []
    5. # 回放循环
    6. for rollout_id in range(num_rollouts):
    7. rollout_id += 0
    8. # 设置任务
    9. if 'sim_transfer_cube' in task_name:
    10. BOX_POSE[0] = sample_box_pose() # 在模拟重置中使用的 BOX_POSE
    11. elif 'sim_insertion' in task_name:
    12. BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # 在模拟重置中使用的 BOX_POSE
    13. ts = env.reset() # 重置环境
    14. # 处理屏幕渲染
    15. if onscreen_render:
    16. ax = plt.subplot()
    17. plt_img = ax.imshow(env._physics.render(height=480, width=640, camera_id=onscreen_cam))
    18. plt.ion()
    19. # 评估循环
    20. if temporal_agg:
    21. all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda()
    22. qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
    23. image_list = [] # 用于可视化的图像列表
    24. qpos_list = []
    25. target_qpos_list = []
    26. rewards = []
    27. # 在不计算梯度的模式下执行
    28. with torch.inference_mode():
    29. for t in range(max_timesteps):
    30. # 更新屏幕渲染和等待时间
    31. if onscreen_render:
    32. image = env._physics.render(height=480, width=640, camera_id=onscreen_cam)
    33. plt_img.set_data(image)
    34. plt.pause(DT)
    35. # 处理上一时间步的观测值以获取 qpos 和图像列表
    36. obs = ts.observation
    37. if 'images' in obs:
    38. image_list.append(obs['images'])
    39. else:
    40. image_list.append({'main': obs['image']})
    41. qpos_numpy = np.array(obs['qpos'])
    42. qpos = pre_process(qpos_numpy)
    43. qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
    44. qpos_history[:, t] = qpos
    45. curr_image = get_image(ts, camera_names)
    46. # 查询策略
    47. if config['policy_class'] == "ACT":
    48. if t % query_frequency == 0:
    49. all_actions = policy(qpos, curr_image)
    50. if temporal_agg:
    51. all_time_actions[[t], t:t+num_queries] = all_actions
    52. actions_for_curr_step = all_time_actions[:, t]
    53. actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
    54. actions_for_curr_step = actions_for_curr_step[actions_populated]
    55. k = 0.01
    56. exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
    57. exp_weights = exp_weights / exp_weights.sum()
    58. exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
    59. raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
    60. else:
    61. raw_action = all_actions[:, t % query_frequency]
    62. elif config['policy_class'] == "CNNMLP":
    63. raw_action = policy(qpos, curr_image)
    64. else:
    65. raise NotImplementedError
    66. # 后处理动作
    67. raw_action = raw_action.squeeze(0).cpu().numpy()
    68. action = post_process(raw_action)
    69. target_qpos = action
    70. # 步进环境
    71. ts = env.step(target_qpos)
    72. # 用于可视化的列表
    73. qpos_list.append(qpos_numpy)
    74. target_qpos_list.append(target_qpos)
    75. rewards.append(ts.reward)
    76. plt.close() # 关闭绘图窗口
    77. if real_robot:
    78. move_grippers([env.puppet_bot_left, env.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5) # 打开夹持器
    79. pass
    计算每次评估的总回报,以及每次评估的最高回报,并记录成功率
    1. # 计算回报和奖励
    2. rewards = np.array(rewards)
    3. episode_return = np.sum(rewards[rewards != None])
    4. episode_returns.append(episode_return)
    5. episode_highest_reward = np.max(rewards)
    6. highest_rewards.append(episode_highest_reward)
    7. print(f'Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {env_max_reward=}, Success: {episode_highest_reward == env_max_reward}')
    如果指定了保存评估过程中的图像数据,将每次评估的图像数据保存为视频
    1. # 保存视频
    2. if save_episode:
    3. save_videos(image_list, DT, video_path=os.path.join(ckpt_dir, f'video{rollout_id}.mp4'))
  4. 输出评估结果,包括成功率、平均回报以及回报分布
    将评估结果保存到文本文件中
    1. # 计算成功率和平均回报
    2. # 计算成功率,即最高奖励的次数与环境最大奖励相等的比率
    3. success_rate = np.mean(np.array(highest_rewards) == env_max_reward)
    4. # 计算平均回报
    5. avg_return = np.mean(episode_returns)
    6. # 创建一个包含成功率和平均回报的摘要字符串
    7. summary_str = f'\n成功率: {success_rate}\n平均回报: {avg_return}\n\n'
    8. # 遍历奖励范围,计算每个奖励范围内的成功率
    9. for r in range(env_max_reward + 1):
    10. # 统计最高奖励大于等于 r 的次数
    11. more_or_equal_r = (np.array(highest_rewards) >= r).sum()
    12. # 计算成功率
    13. more_or_equal_r_rate = more_or_equal_r / num_rollouts
    14. # 将结果添加到摘要字符串中
    15. summary_str += f'奖励 >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate*100}%\n'
    16. # 打印摘要字符串
    17. print(summary_str)
    18. # 将成功率保存到文本文件
    19. result_file_name = 'result_' + ckpt_name.split('.')[0] + '.txt'
    20. with open(os.path.join(ckpt_dir, result_file_name), 'w') as f:
    21. f.write(summary_str) # 写入摘要字符串
    22. f.write(repr(episode_returns)) # 写入回报数据
    23. f.write('\n\n')
    24. f.write(repr(highest_rewards)) # 写入最高奖励数据
    25. # 返回成功率和平均回报
    26. return success_rate, avg_return

最终,函数返回成功率和平均回报。这些结果可以用于评估模型的性能

1.1.4 forward_pass

  1. def forward_pass(data, policy):
  2. image_data, qpos_data, action_data, is_pad = data
  3. image_data, qpos_data, action_data, is_pad = (
  4. image_data.cuda(),
  5. qpos_data.cuda(),
  6. action_data.cuda(),
  7. is_pad.cuda(),
  8. )
  9. return policy(qpos_data, image_data, action_data, is_pad)

这个函数用于执行前向传播(forward pass)操作,以生成模型的输出。它接受以下参数:

  • data:包含输入数据的元组,其中包括图像数据、关节位置数据、动作数据以及填充标志
  • policy:行为克隆模型

函数的主要步骤如下:

  1. 将输入数据转移到GPU上,以便在GPU上进行计算。
  2. 调用行为克隆模型的前向传播方法(policy),将关节位置数据、图像数据、动作数据和填充标志传递给模型
  3. 返回模型的输出,这可能是模型对动作数据的预测结果

1.1.5 train_bc

这个函数用于训练行为克隆(Behavior Cloning)模型。它接受以下参数:

  1. train_dataloader:训练数据的数据加载器,用于从训练集中获取批次的数据。
  2. val_dataloader:验证数据的数据加载器,用于从验证集中获取批次的数据。
  3. config:包含训练配置信息的字典

函数的主要步骤如下

  1. 初始化训练过程所需的各种参数和配置
    1. def train_bc(train_dataloader, val_dataloader, config):
    2. num_epochs = config["num_epochs"]
    3. ckpt_dir = config["ckpt_dir"]
    4. seed = config["seed"]
    5. policy_class = config["policy_class"]
    6. policy_config = config["policy_config"]
    7. set_seed(seed)
  2. 创建行为克隆模型,并根据是否存在之前的训练检查点来加载模型权重
    1. policy = make_policy(policy_class, policy_config)
    2. # if ckpt_dir is not empty, prompt the user to load the checkpoint
    3. if os.path.isdir(ckpt_dir) and len(os.listdir(ckpt_dir)) > 1:
    4. print(f"Checkpoint directory {ckpt_dir} is not empty. Load checkpoint? (y/n)")
    5. load_ckpt = input()
    6. if load_ckpt == "y":
    7. # load the latest checkpoint
    8. latest_idx = max(
    9. [
    10. int(f.split("_")[2])
    11. for f in os.listdir(ckpt_dir)
    12. if f.startswith("policy_epoch_")
    13. ]
    14. )
    15. ckpt_path = os.path.join(
    16. ckpt_dir, f"policy_epoch_{latest_idx}_seed_{seed}.ckpt"
    17. )
    18. print(f"Loading checkpoint from {ckpt_path}")
    19. loading_status = policy.load_state_dict(torch.load(ckpt_path))
    20. print(loading_status)
    21. else:
    22. print("Not loading checkpoint")
    23. latest_idx = 0
    24. else:
    25. latest_idx = 0
  3. 定义优化器,用于更新模型的权重
    1. policy.cuda()
    2. optimizer = make_optimizer(policy_class, policy)
  4. 进行训练循环,每个循环迭代一个 epoch,包括以下步骤:
    验证:在验证集上计算模型的性能,并记录验证结果。如果当前模型的验证性能优于历史最佳模型,则保存当前模型的权重。
    训练:在训练集上进行模型的训练,计算损失并执行反向传播来更新模型的权重
    每隔一定周期,保存当前模型的权重和绘制训练曲线图
    1. train_history = []
    2. validation_history = []
    3. min_val_loss = np.inf
    4. best_ckpt_info = None
    5. for epoch in tqdm(range(latest_idx, num_epochs)):
    6. print(f"\nEpoch {epoch}")
    7. # validation
    8. with torch.inference_mode():
    9. policy.eval()
    10. epoch_dicts = []
    11. for batch_idx, data in enumerate(val_dataloader):
    12. forward_dict = forward_pass(data, policy)
    13. epoch_dicts.append(forward_dict)
    14. epoch_summary = compute_dict_mean(epoch_dicts)
    15. validation_history.append(epoch_summary)
    16. epoch_val_loss = epoch_summary["loss"]
    17. if epoch_val_loss < min_val_loss:
    18. min_val_loss = epoch_val_loss
    19. best_ckpt_info = (epoch, min_val_loss, deepcopy(policy.state_dict()))
    20. print(f"Val loss: {epoch_val_loss:.5f}")
    21. summary_string = ""
    22. for k, v in epoch_summary.items():
    23. summary_string += f"{k}: {v.item():.3f} "
    24. print(summary_string)
    25. # training
    26. policy.train()
    27. optimizer.zero_grad()
    28. for batch_idx, data in enumerate(train_dataloader):
    29. forward_dict = forward_pass(data, policy)
    30. # backward
    31. loss = forward_dict["loss"]
    32. loss.backward()
    33. optimizer.step()
    34. optimizer.zero_grad()
    35. train_history.append(detach_dict(forward_dict))
    36. e = epoch - latest_idx
    37. epoch_summary = compute_dict_mean(
    38. train_history[(batch_idx + 1) * e : (batch_idx + 1) * (epoch + 1)]
    39. )
    40. epoch_train_loss = epoch_summary["loss"]
    41. print(f"Train loss: {epoch_train_loss:.5f}")
    42. summary_string = ""
    43. for k, v in epoch_summary.items():
    44. summary_string += f"{k}: {v.item():.3f} "
    45. print(summary_string)
    46. if epoch % 100 == 0:
    47. ckpt_path = os.path.join(ckpt_dir, f"policy_epoch_{epoch}_seed_{seed}.ckpt")
    48. torch.save(policy.state_dict(), ckpt_path)
    49. plot_history(train_history, validation_history, epoch, ckpt_dir, seed)
    50. ckpt_path = os.path.join(ckpt_dir, f"policy_last.ckpt")
    51. torch.save(policy.state_dict(), ckpt_path)
  5. 训练完成后,保存最佳模型的权重和绘制训练曲线图
    1. best_epoch, min_val_loss, best_state_dict = best_ckpt_info
    2. ckpt_path = os.path.join(ckpt_dir, f"policy_epoch_{best_epoch}_seed_{seed}.ckpt")
    3. torch.save(best_state_dict, ckpt_path)
    4. print(
    5. f"Training finished:\nSeed {seed}, val loss {min_val_loss:.6f} at epoch {best_epoch}"
    6. )
    7. # save training curves
    8. plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed)
    9. return best_ckpt_info

1.1.6 plot_history

// 待更

第二部分 Mobile Aloha或Aloha软件层面代码的跑通与部署

// 待更

参考文献与推荐阅读

  1. Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware
  2. Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware(阅读笔记)
  3. Aloha 机械臂的学习记录2——AWE:AWE + ACT
  4. ..
文章知识点与官方知识档案匹配,可进一步学习相关知识
算法技能树首页概览57287 人正在系统学习中
注:本文转载自blog.csdn.net的v_JULY_v的文章"https://blog.csdn.net/v_JULY_v/article/details/135566948"。版权归原作者所有,此博客不拥有其著作权,亦不承担相应法律责任。如有侵权,请联系我们删除。
复制链接
复制链接
相关推荐
发表评论
登录后才能发表评论和回复 注册

/ 登录

评论记录:

未查询到任何数据!
回复评论:

分类栏目

后端 (14832) 前端 (14280) 移动开发 (3760) 编程语言 (3851) Java (3904) Python (3298) 人工智能 (10119) AIGC (2810) 大数据 (3499) 数据库 (3945) 数据结构与算法 (3757) 音视频 (2669) 云原生 (3145) 云平台 (2965) 前沿技术 (2993) 开源 (2160) 小程序 (2860) 运维 (2533) 服务器 (2698) 操作系统 (2325) 硬件开发 (2492) 嵌入式 (2955) 微软技术 (2769) 软件工程 (2056) 测试 (2865) 网络空间安全 (2948) 网络与通信 (2797) 用户体验设计 (2592) 学习和成长 (2593) 搜索 (2744) 开发工具 (7108) 游戏 (2829) HarmonyOS (2935) 区块链 (2782) 数学 (3112) 3C硬件 (2759) 资讯 (2909) Android (4709) iOS (1850) 代码人生 (3043) 阅读 (2841)

热门文章

101
推荐
关于我们 隐私政策 免责声明 联系我们
Copyright © 2020-2024 蚁人论坛 (iYenn.com) All Rights Reserved.
Scroll to Top