def make_agent(obs_shape, action_shape, args, device): if args.agent == 'rad_sac': return RadSacAgent( obs_shape=obs_shape, action_shape=action_shape, device=device, hidden_dim=args.hidden_dim, discount=args.discount, init_temperature=args.init_temperature, alpha_lr=args.alpha_lr, alpha_beta=args.alpha_beta, actor_lr=args.actor_lr, actor_beta=args.actor_beta, actor_log_std_min=args.actor_log_std_min, actor_log_std_max=args.actor_log_std_max, actor_update_freq=args.actor_update_freq, critic_lr=args.critic_lr, critic_beta=args.critic_beta, critic_tau=args.critic_tau, critic_target_update_freq=args.critic_target_update_freq, encoder_type=args.encoder_type, encoder_feature_dim=args.encoder_feature_dim, encoder_lr=args.encoder_lr, encoder_tau=args.encoder_tau, num_layers=args.num_layers, num_filters=args.num_filters, log_interval=args.log_interval, detach_encoder=args.detach_encoder, latent_dim=args.latent_dim, data_augs=args.data_augs) else: assert 'agent is not supported: %s' % args.agent
def main(): args = parse_args() if not (bool(args.viewer) ^ bool(args.save_path)): raise Exception("you need to provide --viewer xor --save-dir " "arguments for this to do anything useful :)") if args.threads is not None: torch.set_num_threads(args.threads) # TODO: The next few calls are copy-pasted out of train.py. Consider # refactoring so that you don't have to copy-paste (otoh not very important # since this code only needs to be run once) if torch.cuda.is_available(): dev = torch.device('cuda') else: dev = torch.device('cpu') pre_transform_image_size = args.pre_transform_image_size if 'crop' \ in args.data_augs else args.image_size env = dmc2gym.make( domain_name=args.domain_name, task_name=args.task_name, seed=args.seed, visualize_reward=False, from_pixels=(args.encoder_type == 'pixel'), height=pre_transform_image_size, width=pre_transform_image_size, frame_skip=args.action_repeat) env.seed(args.seed) action_shape = env.action_space.shape obs_shape = (3 * args.frame_stack, args.image_size, args.image_size) agent = RadSacAgent( obs_shape=obs_shape, action_shape=action_shape, device=dev, hidden_dim=args.hidden_dim, encoder_type=args.encoder_type, encoder_feature_dim=args.encoder_feature_dim, num_layers=args.num_layers, num_filters=args.num_filters, latent_dim=args.latent_dim, data_augs=args.data_augs, ) agent.load_ac(actor_path=args.actor_path) if args.viewer: dmc_env = unwrap(env) frames = collections.deque(maxlen=args.frame_stack or 1) def loaded_policy(time_step): # time_step just contains joint angles; we want image observation obs = env.env._get_obs(time_step) frames.append(obs) while len(frames) < frames.maxlen: # for init frames.append(obs) stacked_obs = np.concatenate(frames, axis=0) / 255. return agent.sample_action(stacked_obs) viewer.launch(dmc_env, policy=loaded_policy) return # done # otherwise, we need to save a bunch of imitation.data.TrajectoryWithRew # instance to some directory somewhere… all_traj = [] for t in range(args.ntraj): traj = sample_traj_stacked(env, agent, frame_stack=args.frame_stack or 1) all_traj.append(traj) # for now I'm just going to save all trajectories in one file print(f"Saving to '{args.save_path}'") save_compressed_pickle(all_traj, args.save_path) env.close()