Ejemplo n.º 1
0
def make_agent(obs_shape, action_shape, args, device):
    if args.agent == 'rad_sac':
        return RadSacAgent(
            obs_shape=obs_shape,
            action_shape=action_shape,
            device=device,
            hidden_dim=args.hidden_dim,
            discount=args.discount,
            init_temperature=args.init_temperature,
            alpha_lr=args.alpha_lr,
            alpha_beta=args.alpha_beta,
            actor_lr=args.actor_lr,
            actor_beta=args.actor_beta,
            actor_log_std_min=args.actor_log_std_min,
            actor_log_std_max=args.actor_log_std_max,
            actor_update_freq=args.actor_update_freq,
            critic_lr=args.critic_lr,
            critic_beta=args.critic_beta,
            critic_tau=args.critic_tau,
            critic_target_update_freq=args.critic_target_update_freq,
            encoder_type=args.encoder_type,
            encoder_feature_dim=args.encoder_feature_dim,
            encoder_lr=args.encoder_lr,
            encoder_tau=args.encoder_tau,
            num_layers=args.num_layers,
            num_filters=args.num_filters,
            log_interval=args.log_interval,
            detach_encoder=args.detach_encoder,
            latent_dim=args.latent_dim,
            data_augs=args.data_augs)
    else:
        assert 'agent is not supported: %s' % args.agent
Ejemplo n.º 2
0
def main():
    args = parse_args()

    if not (bool(args.viewer) ^ bool(args.save_path)):
        raise Exception("you need to provide --viewer xor --save-dir "
                        "arguments for this to do anything useful :)")

    if args.threads is not None:
        torch.set_num_threads(args.threads)

    # TODO: The next few calls are copy-pasted out of train.py. Consider
    # refactoring so that you don't have to copy-paste (otoh not very important
    # since this code only needs to be run once)
    if torch.cuda.is_available():
        dev = torch.device('cuda')
    else:
        dev = torch.device('cpu')
    pre_transform_image_size = args.pre_transform_image_size if 'crop' \
        in args.data_augs else args.image_size
    env = dmc2gym.make(
        domain_name=args.domain_name,
        task_name=args.task_name,
        seed=args.seed,
        visualize_reward=False,
        from_pixels=(args.encoder_type == 'pixel'),
        height=pre_transform_image_size,
        width=pre_transform_image_size,
        frame_skip=args.action_repeat)
    env.seed(args.seed)
    action_shape = env.action_space.shape
    obs_shape = (3 * args.frame_stack, args.image_size, args.image_size)
    agent = RadSacAgent(
        obs_shape=obs_shape,
        action_shape=action_shape,
        device=dev,
        hidden_dim=args.hidden_dim,
        encoder_type=args.encoder_type,
        encoder_feature_dim=args.encoder_feature_dim,
        num_layers=args.num_layers,
        num_filters=args.num_filters,
        latent_dim=args.latent_dim,
        data_augs=args.data_augs, )
    agent.load_ac(actor_path=args.actor_path)

    if args.viewer:
        dmc_env = unwrap(env)
        frames = collections.deque(maxlen=args.frame_stack or 1)

        def loaded_policy(time_step):
            # time_step just contains joint angles; we want image observation
            obs = env.env._get_obs(time_step)
            frames.append(obs)
            while len(frames) < frames.maxlen:
                # for init
                frames.append(obs)
            stacked_obs = np.concatenate(frames, axis=0) / 255.
            return agent.sample_action(stacked_obs)

        viewer.launch(dmc_env, policy=loaded_policy)
        return  # done

    # otherwise, we need to save a bunch of imitation.data.TrajectoryWithRew
    # instance to some directory somewhere…
    all_traj = []
    for t in range(args.ntraj):
        traj = sample_traj_stacked(env, agent,
                                   frame_stack=args.frame_stack or 1)
        all_traj.append(traj)
    # for now I'm just going to save all trajectories in one file
    print(f"Saving to '{args.save_path}'")
    save_compressed_pickle(all_traj, args.save_path)

    env.close()