Esempio n. 1
0
 def obtain_samples(self,
                    deterministic=False,
                    max_samples=np.inf,
                    max_trajs=np.inf,
                    accum_context=True,
                    resample=1):
     """
     Obtains samples in the environment until either we reach either max_samples transitions or
     num_traj trajectories.
     The resample argument specifies how often (in trajectories) the agent will resample it's context.
     """
     assert max_samples < np.inf or max_trajs < np.inf, "either max_samples or max_trajs must be finite"
     policy = MakeDeterministic(
         self.policy) if deterministic else self.policy
     paths = []
     n_steps_total = 0
     n_trajs = 0
     while n_steps_total < max_samples and n_trajs < max_trajs:
         path = rollout(self.env,
                        policy,
                        max_path_length=self.max_path_length,
                        accum_context=accum_context)
         # save the latent context that generated this trajectory
         path['context'] = policy.z.detach().cpu().numpy()
         paths.append(path)
         n_steps_total += len(path['observations'])
         n_trajs += 1
         # don't we also want the option to resample z ever transition?
         if n_trajs % resample == 0:
             policy.sample_z()
     return paths, n_steps_total
Esempio n. 2
0
def all_steps(policy,
              qf1,
              qf2,
              env,
              num_curriculum_eps=1,
              curr_k=10,
              curr_beta=0.9,
              curr_ep_length=300,
              bootstrap_value=True,
              use_cuda=True):
    if use_cuda:
        device = 'cuda'
    else:
        device = 'cpu'
    logger.log('\nStarting adaptative curriculum.\n')
    p = {}
    det_policy = MakeDeterministic(policy)
    for k, v in env.curr_grid.items():
        capabilities = []
        for i, init in enumerate(v):
            logger.log('Inicialização - Curriculum {}. Iter {} -> {}'.format(
                k, i, init))
            for e in range(num_curriculum_eps):
                accum_c = []
                o, d, ep_ret, ep_len = env.reset(curr_init=init,
                                                 init_strategy='adaptative',
                                                 curr_var=k), False, 0, 0
                c = 0
                while not (d or (ep_len == curr_ep_length)):
                    # Take deterministic actions at test time
                    a, _ = det_policy.get_action(o)
                    o, r, d, _ = env.step(a)
                    if bootstrap_value:
                        o = torch.Tensor(o).to(device)
                        dist = policy(o.view(1, -1))
                        new_obs_actions, _ = dist.rsample_and_logprob()
                        q_new_actions = torch.min(
                            qf1(o.view(1, -1), new_obs_actions),
                            qf2(o.view(1, -1), new_obs_actions),
                        )
                        # Estimates value
                        v = q_new_actions.mean()
                        if not use_cuda:
                            c += v.detach().numpy()
                        else:
                            c += v.detach().cpu().numpy()
                    else:
                        # Uses returns instead
                        ep_ret += r
                        c = ep_ret
                    ep_len += 1
                accum_c.append(c)
            capabilities.append(np.mean(accum_c))
        max_capability = np.max(capabilities)
        f = np.exp(-curr_k *
                   np.abs(np.array(capabilities) / max_capability - curr_beta))
        p[k] = f / f.sum()
    return p
Esempio n. 3
0
def simulate_policy(args):
    data = pickle.load(open(args.file, "rb"))
    policy_key = args.policy_type + '/policy'
    if policy_key in data:
        policy = data[policy_key]
    else:
        raise Exception("No policy found in loaded dict. Keys: {}".format(
            data.keys()))

    env_key = args.env_type + '/env'
    if env_key in data:
        env = data[env_key]
    else:
        raise Exception("No environment found in loaded dict. Keys: {}".format(
            data.keys()))

    if isinstance(env, RemoteRolloutEnv):
        env = env._wrapped_env
    print("Policy loaded")

    if args.enable_render:
        # some environments need to be reconfigured for visualization
        env.enable_render()
    if args.gpu:
        ptu.set_gpu_mode(True)
    if hasattr(policy, "to"):
        policy.to(ptu.device)
    if hasattr(env, "vae"):
        env.vae.to(ptu.device)

    if args.deterministic:
        policy = MakeDeterministic(policy)

    if args.pause:
        import ipdb
        ipdb.set_trace()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    paths = []
    while True:
        paths.append(
            rollout(
                env,
                policy,
                max_path_length=args.H,
                render=not args.hide,
            ))
        if args.log_diagnostics:
            if hasattr(env, "log_diagnostics"):
                env.log_diagnostics(paths, logger)
            for k, v in eval_util.get_generic_path_information(paths).items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
Esempio n. 4
0
def experiment(variant):
    with open('expert_demos_listing.yaml', 'r') as f:
        listings = yaml.load(f.read())
    demos_path = listings[variant['expert_name']]['file_paths'][
        variant['expert_idx']]
    print(demos_path)
    buffer_save_dict = joblib.load(demos_path)
    target_state_buffer = buffer_save_dict['data']
    # target_state_buffer /= variant['rescale']
    state_indices = torch.LongTensor(variant['state_indices'])

    env_specs = variant['env_specs']
    env = get_env(env_specs)
    env.seed(env_specs['eval_env_seed'])

    print('\n\nEnv: {}'.format(env_specs['env_name']))
    print('kwargs: {}'.format(env_specs['env_kwargs']))
    print('Obs Space: {}'.format(env.observation_space))
    print('Act Space: {}\n\n'.format(env.action_space))

    policy = joblib.load(variant['policy_checkpoint'])['exploration_policy']
    if variant['eval_deterministic']:
        policy = MakeDeterministic(policy)
    policy.to(ptu.device)

    eval_sampler = PathSampler(env,
                               policy,
                               variant['num_eval_steps'],
                               variant['max_path_length'],
                               no_terminal=variant['no_terminal'],
                               render=variant['render'],
                               render_kwargs=variant['render_kwargs'])
    test_paths = eval_sampler.obtain_samples()
    obs = []
    for path in test_paths:
        obs += path['observations']
    x = [o[0] for o in obs]
    y = [o[1] for o in obs]

    fig, ax = plt.subplots(figsize=(6, 6))
    plt.scatter(x, y)
    plt.xlim(-1.25, 20)
    plt.ylim(-1.25, 10)
    ax.set_yticks([0, 5, 10])
    ax.set_xticks([0, 5, 10, 15, 20])
    plt.savefig('./figs/' + variant['env_specs']['task_name'] + '.pdf',
                bbox_inches='tight')

    return 1
Esempio n. 5
0
    def obtain_samples(self,
                       deterministic=False,
                       num_samples=None,
                       num_rollouts=None,
                       is_online=False):
        policy = MakeDeterministic(
            self.policy) if deterministic else self.policy
        paths = []
        n_steps_total = 0
        max_samp = self.max_samples
        if num_samples is not None:
            max_samp = num_samples

        # import pdb; pdb.set_trace()
        while n_steps_total + self.max_path_length <= max_samp:
            if num_rollouts is not None and num_rollouts <= len(paths):
                break

            path = rollout(self.env,
                           policy,
                           max_path_length=self.max_path_length,
                           is_online=is_online)
            paths.append(path)
            n_steps_total += len(path['observations'])
        return paths
Esempio n. 6
0
def experiment(checkpoint, deterministic=False):
    d = joblib.load(checkpoint)
    print('Epoch = %d' % d['epoch'])

    print(d)

    algorithm = d['algorithm']
    algorithm.render = True
    print(algorithm.discriminator)

    if deterministic:
        algorithm.exploration_policy = MakeDeterministic(
            algorithm.exploration_policy)

    # print(algorithm.grad_pen_weight)

    # algorithm.do_not_train = True
    # algorithm.do_not_eval = True
    # for i in range(100):
    #     algorithm.generate_exploration_rollout()

    algorithm.num_steps_between_updates = 1000000
    algorithm.train_online()

    return 1
Esempio n. 7
0
    def __init__(self,
                 latent_dim,
                 context_encoder,
                 policy,
                 reward_predictor,
                 use_next_obs_in_context=False,
                 _debug_ignore_context=False,
                 _debug_do_not_sqrt=False,
                 _debug_use_ground_truth_context=False):
        super().__init__()
        self.latent_dim = latent_dim

        self.context_encoder = context_encoder
        self.policy = policy
        self.reward_predictor = reward_predictor
        self.deterministic_policy = MakeDeterministic(self.policy)
        self._debug_ignore_context = _debug_ignore_context
        self._debug_use_ground_truth_context = _debug_use_ground_truth_context

        # self.recurrent = kwargs['recurrent']
        # self.use_ib = kwargs['use_information_bottleneck']
        # self.sparse_rewards = kwargs['sparse_rewards']
        self.use_next_obs_in_context = use_next_obs_in_context

        # initialize buffers for z dist and z
        # use buffers so latent context can be saved along with model weights
        self.register_buffer('z', torch.zeros(1, latent_dim))
        self.register_buffer('z_means', torch.zeros(1, latent_dim))
        self.register_buffer('z_vars', torch.zeros(1, latent_dim))

        self.z_means = None
        self.z_vars = None
        self.context = None
        self.z = None

        # rp = reward predictor
        # TODO: add back in reward predictor code
        self.z_means_rp = None
        self.z_vars_rp = None
        self.z_rp = None
        self.context_encoder_rp = context_encoder
        self._use_context_encoder_snapshot_for_reward_pred = False

        self.latent_prior = torch.distributions.Normal(
            ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim))

        self._debug_do_not_sqrt = _debug_do_not_sqrt
Esempio n. 8
0
    def obtain_samples(self,
                       deterministic=False,
                       max_samples=np.inf,
                       max_trajs=np.inf,
                       accum_context_for_agent=False,
                       resample=1,
                       context_agent=None,
                       split=False):
        """
        Obtains samples in the environment until either we reach either max_samples transitions or
        num_traj trajectories.
        The resample argument specifies how often (in trajectories) the agent will resample it's context.
        """
        assert max_samples < np.inf or max_trajs < np.inf, "either max_samples or max_trajs must be finite"
        policy = MakeDeterministic(
            self.policy) if deterministic else self.policy
        paths = []
        n_steps_total = 0
        n_trajs = 0
        policy.reset_RNN()
        self.env.reset_mask()
        if not split:
            path = exprolloutsimple(
                self.env,
                policy,
                max_path_length=self.max_path_length,
                max_trajs=max_trajs,
                accum_context_for_agent=accum_context_for_agent,
                context_agent=context_agent)
            paths.append(path)
            n_steps_total += len(path['observations'])
            n_trajs += 1

            return paths, n_steps_total
        else:
            path = exprollout_splitsimple(
                self.env,
                policy,
                max_path_length=self.max_path_length,
                max_trajs=max_trajs,
                accum_context_for_agent=accum_context_for_agent,
                context_agent=context_agent)
            n_steps_total += self.max_path_length * max_trajs
            n_trajs += max_trajs

            return path, n_steps_total
    def obtain_samples(self,
                       deterministic=False,
                       max_samples=np.inf,
                       max_trajs=np.inf,
                       accum_context=True,
                       resample=1,
                       testing=False):
        assert max_samples < np.inf or max_trajs < np.inf, "either max_samples or max_trajs must be finite"
        policy = MakeDeterministic(
            self.policy) if deterministic else self.policy
        paths = []
        n_steps_total = 0
        n_trajs = 0

        if self.itr <= self.num_train_itr:
            if self.tandem_train:
                self._train(policy, accum_context)
                self.itr += 1
            else:
                for _ in range(self.num_train_itr):
                    self._train(policy, accum_context)
                    self.itr += 1

        while n_steps_total < max_samples and n_trajs < max_trajs:
            if testing:
                path = rollout(self.env,
                               policy,
                               max_path_length=self.max_path_length,
                               accum_context=accum_context)
            else:
                path = rollout(self.model,
                               policy,
                               max_path_length=self.max_path_length,
                               accum_context=accum_context)

            # save the latent context that generated this trajectory
            path['context'] = policy.z.detach().cpu().numpy()
            paths.append(path)
            n_steps_total += len(path['observations'])
            n_trajs += 1
            # don't we also want the option to resample z ever transition?
            if n_trajs % resample == 0:
                policy.sample_z()

        return paths, n_steps_total
Esempio n. 10
0
def experiment(variant):
    env = Point2DEnv(**variant['env_kwargs'])
    env = FlatGoalEnv(env)
    env = NormalizedBoxEnv(env)

    action_dim = int(np.prod(env.action_space.shape))
    obs_dim = int(np.prod(env.observation_space.shape))

    qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **variant['policy_kwargs'])
    eval_env = expl_env = env

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = TwinSACTrainer(env=eval_env,
                             policy=policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             **variant['trainer_kwargs'])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        data_buffer=replay_buffer,
        **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Esempio n. 11
0
def get_sac(evaluation_environment, parameters):
    """
    :param env - environment to get action shape
    :param parameters: dict with keys -
    hidden_sizes,
    sac_trainer_parameters
    :return: sac_policy, eval_policy, trainer
    """
    obs_dim = evaluation_environment.observation_space.low.size
    action_dim = evaluation_environment.action_space.low.size

    hidden_sizes_qf = parameters['hidden_sizes_qf']
    hidden_sizes_policy = parameters['hidden_sizes_policy']

    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes_qf,
    )

    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes_qf,
    )

    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes_qf,
    )

    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes_qf,
    )

    sac_policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=hidden_sizes_policy,
    )

    eval_policy = MakeDeterministic(sac_policy)

    trainer = SACTrainer(env=evaluation_environment,
                         policy=sac_policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **parameters['trainer_params'])

    return sac_policy, eval_policy, trainer
Esempio n. 12
0
def experiment(variant):
    num_agent = variant['num_agent']
    from cartpole import CartPoleEnv
    expl_env = CartPoleEnv(mode=3)
    eval_env = CartPoleEnv(mode=3)
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    policy_n, eval_policy_n, qf1_n, target_qf1_n, qf2_n, target_qf2_n = \
        [], [], [], [], [], []
    for i in range(num_agent):
        policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                    action_dim=action_dim,
                                    **variant['policy_kwargs'])
        eval_policy = MakeDeterministic(policy)
        qf1 = FlattenMlp(input_size=(obs_dim * num_agent +
                                     action_dim * num_agent),
                         output_size=1,
                         **variant['qf_kwargs'])
        target_qf1 = copy.deepcopy(qf1)
        qf2 = FlattenMlp(input_size=(obs_dim * num_agent +
                                     action_dim * num_agent),
                         output_size=1,
                         **variant['qf_kwargs'])
        target_qf2 = copy.deepcopy(qf1)
        policy_n.append(policy)
        eval_policy_n.append(eval_policy)
        qf1_n.append(qf1)
        target_qf1_n.append(target_qf1)
        qf2_n.append(qf2)
        target_qf2_n.append(target_qf2)

    eval_path_collector = MAMdpPathCollector(eval_env, eval_policy_n)
    expl_path_collector = MAMdpPathCollector(expl_env, policy_n)
    replay_buffer = MAEnvReplayBuffer(variant['replay_buffer_size'],
                                      expl_env,
                                      num_agent=num_agent)
    trainer = MASACTrainer(env=expl_env,
                           qf1_n=qf1_n,
                           target_qf1_n=target_qf1_n,
                           qf2_n=qf2_n,
                           target_qf2_n=target_qf2_n,
                           policy_n=policy_n,
                           **variant['trainer_kwargs'])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        log_path_function=get_generic_ma_path_information,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Esempio n. 13
0
    def __init__(self,
                 env,
                 policy,
                 qf,
                 vf,
                 policy_lr=1e-3,
                 qf_lr=1e-3,
                 vf_lr=1e-3,
                 policy_mean_reg_weight=1e-3,
                 policy_std_reg_weight=1e-3,
                 policy_pre_activation_weight=0.,
                 optimizer_class=optim.Adam,
                 train_policy_with_reparameterization=True,
                 soft_target_tau=1e-2,
                 plotter=None,
                 render_eval_paths=False,
                 eval_deterministic=True,
                 **kwargs):
        if eval_deterministic:
            eval_policy = MakeDeterministic(policy)
        else:
            eval_policy = policy
        super().__init__(env=env,
                         exploration_policy=policy,
                         eval_policy=eval_policy,
                         **kwargs)
        self.policy = policy
        self.qf = qf
        self.vf = vf
        self.train_policy_with_reparameterization = (
            train_policy_with_reparameterization)
        self.soft_target_tau = soft_target_tau
        self.policy_mean_reg_weight = policy_mean_reg_weight
        self.policy_std_reg_weight = policy_std_reg_weight
        self.policy_pre_activation_weight = policy_pre_activation_weight
        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.target_vf = vf.copy()
        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf_optimizer = optimizer_class(
            self.qf.parameters(),
            lr=qf_lr,
        )
        self.vf_optimizer = optimizer_class(
            self.vf.parameters(),
            lr=vf_lr,
        )
Esempio n. 14
0
def experiment(variant):
    env_specs = variant['env_specs']
    env = get_env(env_specs)
    env.seed(env_specs['eval_env_seed'])

    print('\n\nEnv: {}'.format(env_specs['env_name']))
    print('kwargs: {}'.format(env_specs['env_kwargs']))
    print('Obs Space: {}'.format(env.observation_space))
    print('Act Space: {}\n\n'.format(env.action_space))

    if variant['scale_env_with_demo_stats']:
        with open('expert_demos_listing.yaml', 'r') as f:
            listings = yaml.load(f.read())
        expert_demos_path = listings[variant['expert_name']]['file_paths'][
            variant['expert_idx']]
        buffer_save_dict = joblib.load(expert_demos_path)
        env = ScaledEnv(
            env,
            obs_mean=buffer_save_dict['obs_mean'],
            obs_std=buffer_save_dict['obs_std'],
            acts_mean=buffer_save_dict['acts_mean'],
            acts_std=buffer_save_dict['acts_std'],
        )

    policy = joblib.load(variant['policy_checkpoint'])['exploration_policy']
    if variant['eval_deterministic']:
        policy = MakeDeterministic(policy)
    policy.to(ptu.device)

    eval_sampler = PathSampler(env,
                               policy,
                               variant['num_eval_steps'],
                               variant['max_path_length'],
                               no_terminal=variant['no_terminal'],
                               render=variant['render'],
                               render_kwargs=variant['render_kwargs'])
    test_paths = eval_sampler.obtain_samples()
    average_returns = eval_util.get_average_returns(test_paths)
    print(average_returns)

    return 1
Esempio n. 15
0
def experiment(variant):

    expl_env = get_env()
    eval_env = get_env()

    post_epoch_funcs = []
    M = variant['layer_size']
    trainer = get_sac_model(env=eval_env, hidden_sizes=[M, M])
    policy = trainer.policy
    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'])
    columns = ['Epoch', 'mean', 'std']
    eval_result = pd.DataFrame(columns=columns)
    eval_output_csv = os.path.join(variant['log_dir'], 'eval_result.csv')

    def post_epoch_func(self, epoch):
        nonlocal eval_result
        nonlocal policy
        print(f'-------------post_epoch_func start-------------')
        eval_result = my_eval_policy(
            env=get_env(),
            algorithm=self,
            epoch=epoch,
            eval_result=eval_result,
            output_csv=eval_output_csv,
        )
        print(f'-------------post_epoch_func done-------------')

    algorithm.post_epoch_funcs = [
        post_epoch_func,
    ]
    algorithm.to(ptu.device)
    algorithm.train()
Esempio n. 16
0
def run_sac(base_expl_env, base_eval_env, variant):
    expl_env = FlatGoalEnv(base_expl_env, append_goal_to_obs=True)
    eval_env = FlatGoalEnv(base_eval_env, append_goal_to_obs=True)
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    M = variant["layer_size"]
    num_hidden = variant["num_hidden_layers"]
    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=[M] * num_hidden)
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=[M] * num_hidden)
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=[M] * num_hidden)
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=[M] * num_hidden)
    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                hidden_sizes=[M] * num_hidden)
    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant["replay_buffer_size"],
        expl_env,
    )
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant["trainer_kwargs"])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"])
    algorithm.train()
Esempio n. 17
0
 def obtain_samples(self, deterministic=False, num_samples=None, is_online=False):
     policy = MakeDeterministic(self.policy) if deterministic else self.policy
     paths = []
     n_steps_total = 0
     max_samp = self.max_samples
     if num_samples is not None:
         max_samp = num_samples
     while n_steps_total + self.max_path_length < max_samp:
         path = rollout(
             self.env, policy, max_path_length=self.max_path_length, is_online=is_online)
         paths.append(path)
         n_steps_total += len(path['observations'])
     return paths
Esempio n. 18
0
def experiment(specs):
    with open(path.join(specs['specific_exp_dir'], 'variant.json'), 'r') as f:
        variant = json.load(f)
    variant['algo_params']['do_not_train'] = True
    variant['seed'] = specs['seed']
    policy = joblib.load(path.join(specs['specific_exp_dir'],
                                   'params.pkl'))['exploration_policy']

    assert False, 'Do you really wanna make it deterministic?'
    policy = MakeDeterministic(policy)

    env_specs = variant['env_specs']
    env, _ = get_env(env_specs)
    training_env, _ = get_env(env_specs)

    variant['algo_params']['replay_buffer_size'] = int(
        np.floor(specs['num_episodes'] *
                 variant['algo_params']['max_path_length'] /
                 specs['subsampling']))
    # Hack until I figure out how things are gonna be in general then I'll clean it up
    if 'policy_uses_pixels' not in variant['algo_params']:
        variant['algo_params']['policy_uses_pixels'] = False
    if 'policy_uses_task_params' not in variant['algo_params']:
        variant['algo_params']['policy_uses_task_params'] = False
    if 'concat_task_params_to_policy_obs' not in variant['algo_params']:
        variant['algo_params']['concat_task_params_to_policy_obs'] = False
    replay_buffer = ExpertReplayBuffer(
        variant['algo_params']['replay_buffer_size'],
        env,
        subsampling=specs['subsampling'],
        policy_uses_pixels=variant['algo_params']['policy_uses_pixels'],
        policy_uses_task_params=variant['algo_params']
        ['policy_uses_task_params'],
        concat_task_params_to_policy_obs=variant['algo_params']
        ['concat_task_params_to_policy_obs'],
    )
    variant['algo_params']['freq_saving'] = 1

    algorithm = ExpertTrajGeneratorAlgorithm(
        env=env,
        training_env=training_env,
        exploration_policy=policy,
        replay_buffer=replay_buffer,
        max_num_episodes=specs['num_episodes'],
        **variant['algo_params'])

    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()

    return 1
Esempio n. 19
0
    def __init__(self,
                 env,
                 policy,
                 discriminator,
                 policy_optimizer,
                 expert_replay_buffer,
                 disc_optim_batch_size=32,
                 policy_optim_batch_size=1000,
                 disc_lr=1e-3,
                 disc_optimizer_class=optim.Adam,
                 use_grad_pen=True,
                 grad_pen_weight=10,
                 plotter=None,
                 render_eval_paths=False,
                 eval_deterministic=True,
                 **kwargs):
        assert disc_lr != 1e-3, 'Just checking that this is being taken from the spec file'
        if eval_deterministic:
            eval_policy = MakeDeterministic(policy)
        else:
            eval_policy = policy
        super().__init__(env=env,
                         exploration_policy=policy,
                         eval_policy=eval_policy,
                         expert_replay_buffer=expert_replay_buffer,
                         policy_optimizer=policy_optimizer,
                         **kwargs)

        self.discriminator = discriminator
        self.rewardf_eval_statistics = None
        self.disc_optimizer = disc_optimizer_class(
            self.discriminator.parameters(),
            lr=disc_lr,
        )

        self.disc_optim_batch_size = disc_optim_batch_size
        self.policy_optim_batch_size = policy_optim_batch_size

        self.bce = nn.BCEWithLogitsLoss()
        self.bce_targets = torch.cat([
            torch.ones(self.disc_optim_batch_size, 1),
            torch.zeros(self.disc_optim_batch_size, 1)
        ],
                                     dim=0)
        self.bce_targets = Variable(self.bce_targets)
        if ptu.gpu_enabled():
            self.bce.cuda()
            self.bce_targets = self.bce_targets.cuda()

        self.use_grad_pen = use_grad_pen
        self.grad_pen_weight = grad_pen_weight
Esempio n. 20
0
def sac(variant):
    expl_env = gym.make(variant["env_name"])
    eval_env = gym.make(variant["env_name"])
    expl_env.seed(variant["seed"])
    eval_env.set_eval()

    mode = variant["mode"]
    archi = variant["archi"]
    if mode == "her":
        variant["her"] = dict(
            observation_key="observation",
            desired_goal_key="desired_goal",
            achieved_goal_key="achieved_goal",
            representation_goal_key="representation_goal",
        )

    replay_buffer = get_replay_buffer(variant, expl_env)
    qf1, qf2, target_qf1, target_qf2, policy, shared_base = get_networks(
        variant, expl_env)
    expl_policy = policy
    eval_policy = MakeDeterministic(policy)

    expl_path_collector, eval_path_collector = get_path_collector(
        variant, expl_env, eval_env, expl_policy, eval_policy)

    mode = variant["mode"]
    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant["trainer_kwargs"],
    )
    if mode == "her":
        trainer = HERTrainer(trainer)
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"],
    )

    algorithm.to(ptu.device)
    algorithm.train()
Esempio n. 21
0
def main():
    arguments = sys.argv[1:]
    options = parse_arguments(arguments)
    config = load_config(options.config)
    checkpoint_directory = os.path.join(config.checkpoint_dir, options.name)
    checkpoint_iteration_directory = os.path.join(checkpoint_directory, f'iteration_{options.load}')
    checkpoint = torch.load(os.path.join(checkpoint_iteration_directory, 'model.pt'))
    environment = make_environment(config, options)
    policy = get_policy(config.policy_type, config.policy, environment)
    policy.load_state_dict(checkpoint['policy'])
    policy = MakeDeterministic(policy)
    result_directory = os.path.join('results', options.name)
    os.makedirs(result_directory, exist_ok=True)
    evaluate(environment, policy, result_directory  , options)
    save_options(options)
Esempio n. 22
0
    def __init__(
            self,
            env_sampler,
            policy,
            qf,
            discrete_policy=True,

            qf_lr=1e-3,
            optimizer_class=optim.Adam,

            sql_one_over_alpha=1.,

            soft_target_tau=1e-2,
            plotter=None,
            render_eval_paths=False,
            eval_deterministic=True,
            **kwargs
    ):
        if eval_deterministic:
            eval_policy = MakeDeterministic(policy)
        else:
            eval_policy = policy
        super().__init__(
            env_sampler=env_sampler,
            exploration_policy=policy,
            eval_policy=eval_policy,
            **kwargs
        )
        self.qf = qf
        self.policy = policy
        self.soft_target_tau = soft_target_tau
        # self.policy_mean_reg_weight = policy_mean_reg_weight
        # self.policy_std_reg_weight = policy_std_reg_weight
        # self.policy_pre_activation_weight = policy_pre_activation_weight
        self.plotter = plotter
        self.render_eval_paths = render_eval_paths
        # self.discrete_policy = discrete_policy

        self.sql_one_over_alpha = sql_one_over_alpha

        self.target_qf = qf.copy()
        self.eval_statistics = None

        self.qf_optimizer = optimizer_class(
            self.qf.parameters(),
            lr=qf_lr,
        )
Esempio n. 23
0
def eval_alg(policy, env, max_path_length, num_eval_rollouts, env_seed, eval_deterministic=False):
    if eval_deterministic:
        policy = MakeDeterministic(policy)
    
    env.seed(env_seed)

    eval_sampler = InPlacePathSampler(
        env=env,
        policy=policy,
        max_samples=max_path_length * (num_eval_rollouts + 1),
        max_path_length=max_path_length, policy_uses_pixels=False,
        policy_uses_task_params=False,
        concat_task_params_to_policy_obs=False
    )
    test_paths = eval_sampler.obtain_samples()
    path_trajs = [np.array([d['xy_pos'] for d in path["env_infos"]]) for path in test_paths]
    return {'path_trajs': path_trajs}
Esempio n. 24
0
def eval_alg(policy,
             env,
             num_eval_rollouts,
             eval_deterministic=False,
             max_path_length=1000):
    if eval_deterministic:
        policy = MakeDeterministic(policy)

    eval_sampler = InPlacePathSampler(env=env,
                                      policy=policy,
                                      max_samples=max_path_length *
                                      (num_eval_rollouts + 1),
                                      max_path_length=max_path_length,
                                      policy_uses_pixels=False,
                                      policy_uses_task_params=False,
                                      concat_task_params_to_policy_obs=False)
    test_paths = eval_sampler.obtain_samples()
    average_returns = get_average_returns(test_paths)
    return average_returns
Esempio n. 25
0
def experiment(variant):
    expl_env = NormalizedBoxEnv(HalfCheetahEnv())
    eval_env = NormalizedBoxEnv(HalfCheetahEnv())
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    M = variant["layer_size"]
    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=[M, M])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=[M, M])
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=[M, M])
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=[M, M])
    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                hidden_sizes=[M, M])
    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(eval_env, eval_policy)
    expl_path_collector = MdpPathCollector(expl_env, policy)
    replay_buffer = EnvReplayBuffer(variant["replay_buffer_size"], expl_env)
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant["trainer_kwargs"])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"])
    algorithm.to(ptu.device)
    algorithm.train()
Esempio n. 26
0
File: sac.py Progetto: xtma/dsac
def experiment(variant):
    dummy_env = make_env(variant['env'])
    obs_dim = dummy_env.observation_space.low.size
    action_dim = dummy_env.action_space.low.size
    expl_env = VectorEnv([
        lambda: make_env(variant['env'])
        for _ in range(variant['expl_env_num'])
    ])
    expl_env.seed(variant["seed"])
    expl_env.action_space.seed(variant["seed"])
    eval_env = SubprocVectorEnv([
        lambda: make_env(variant['env'])
        for _ in range(variant['eval_env_num'])
    ])
    eval_env.seed(variant["seed"])

    M = variant['layer_size']
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
    )
    eval_policy = MakeDeterministic(policy)
    eval_path_collector = VecMdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = VecMdpStepCollector(
        expl_env,
        policy,
    )
    replay_buffer = TorchReplayBuffer(
        variant['replay_buffer_size'],
        dummy_env,
    )
    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant['trainer_kwargs'],
    )
    algorithm = TorchVecOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'],
    )
    algorithm.to(ptu.device)
    algorithm.train()
Esempio n. 27
0
    def __init__(
            self,
            env,
            policy,
            discriminator,
            policy_optimizer,
            expert_replay_buffer,
            state_only=False,
            traj_based=False,
            disc_num_trajs_per_batch=128,
            disc_samples_per_traj=8,
            disc_optim_batch_size=1024,
            policy_optim_batch_size=1024,
            policy_optim_batch_size_from_expert=0,
            num_update_loops_per_train_call=1000,
            num_disc_updates_per_loop_iter=1,
            num_policy_updates_per_loop_iter=1,
            disc_lr=1e-3,
            disc_momentum=0.0,
            disc_optimizer_class=optim.Adam,
            use_grad_pen=True,
            grad_pen_weight=10,
            disc_ce_grad_clip=0.5,
            disc_gp_grad_clip=10.0,
            use_target_disc=False,
            target_disc=None,
            soft_target_disc_tau=0.005,
            rew_clip_min=None,
            rew_clip_max=None,
            plotter=None,
            render_eval_paths=False,
            eval_deterministic=True,
            use_disc_input_noise=False,
            disc_input_noise_scale_start=0.1,
            disc_input_noise_scale_end=0.0,
            epochs_till_end_scale=50.0,

            # both false is airl, if first one true fairl, else gail, both true is error
            use_exp_rewards=False,
            gail_mode=False,
            **kwargs):
        assert disc_lr != 1e-3, 'Just checking that this is being taken from the spec file'
        if eval_deterministic:
            eval_policy = MakeDeterministic(policy)
        else:
            eval_policy = policy
        assert state_only
        assert not traj_based

        assert not (use_exp_rewards and gail_mode), 'Only one or neither'

        super().__init__(env=env,
                         exploration_policy=policy,
                         eval_policy=eval_policy,
                         expert_replay_buffer=expert_replay_buffer,
                         **kwargs)

        self.state_only = state_only

        self.traj_based = traj_based
        self.disc_num_trajs_per_batch = disc_num_trajs_per_batch
        self.disc_samples_per_traj = disc_samples_per_traj

        self.policy_optimizer = policy_optimizer

        self.discriminator = discriminator
        self.rewardf_eval_statistics = None
        self.disc_optimizer = disc_optimizer_class(
            self.discriminator.parameters(),
            lr=disc_lr,
            betas=(disc_momentum, 0.999))
        print('\n\nDISC MOMENTUM: %f\n\n' % disc_momentum)

        self.disc_optim_batch_size = disc_optim_batch_size
        self.policy_optim_batch_size = policy_optim_batch_size
        self.policy_optim_batch_size_from_expert = policy_optim_batch_size_from_expert

        self.bce = nn.BCEWithLogitsLoss()
        if self.traj_based:
            target_batch_size = self.disc_num_trajs_per_batch * self.disc_samples_per_traj
        else:
            target_batch_size = self.disc_optim_batch_size
        self.bce_targets = torch.cat([
            torch.ones(target_batch_size, 1),
            torch.zeros(target_batch_size, 1)
        ],
                                     dim=0)
        self.bce_targets = Variable(self.bce_targets)
        if ptu.gpu_enabled():
            self.bce.cuda()
            self.bce_targets = self.bce_targets.cuda()

        self.use_grad_pen = use_grad_pen
        self.grad_pen_weight = grad_pen_weight

        self.disc_ce_grad_clip = disc_ce_grad_clip
        self.disc_gp_grad_clip = disc_gp_grad_clip
        self.disc_grad_buffer = {}
        self.disc_grad_buffer_is_empty = True

        self.use_target_disc = use_target_disc
        self.soft_target_disc_tau = soft_target_disc_tau

        if use_target_disc:
            if target_disc is None:
                print('\n\nMAKING TARGET DISC\n\n')
                self.target_disc = deepcopy(self.discriminator)
            else:
                print('\n\nUSING GIVEN TARGET DISC\n\n')
                self.target_disc = target_disc

        self.disc_ce_grad_norm = 0.0
        self.disc_ce_grad_norm_counter = 0.0
        self.max_disc_ce_grad = 0.0

        self.disc_gp_grad_norm = 0.0
        self.disc_gp_grad_norm_counter = 0.0
        self.max_disc_gp_grad = 0.0

        self.use_disc_input_noise = use_disc_input_noise
        self.disc_input_noise_scale_start = disc_input_noise_scale_start
        self.disc_input_noise_scale_end = disc_input_noise_scale_end
        self.epochs_till_end_scale = epochs_till_end_scale

        self.num_update_loops_per_train_call = num_update_loops_per_train_call
        self.num_disc_updates_per_loop_iter = num_disc_updates_per_loop_iter
        self.num_policy_updates_per_loop_iter = num_policy_updates_per_loop_iter

        self.use_exp_rewards = use_exp_rewards
        self.gail_mode = gail_mode
        self.rew_clip_min = rew_clip_min
        self.rew_clip_max = rew_clip_max
        self.clip_min_rews = rew_clip_min is not None
        self.clip_max_rews = rew_clip_max is not None

        d = 8.0
        self._d = d
        self._d_len = np.arange(-d, d + 0.25, 0.25).shape[0]
        self.xy_var = []
        for i in np.arange(-d, d + 0.25, 0.25):
            for j in np.arange(-d, d + 0.25, 0.25):
                self.xy_var.append([float(i), float(j)])
        self.xy_var = np.array(self.xy_var)
        self.xy_var = Variable(ptu.from_numpy(self.xy_var),
                               requires_grad=False)
Esempio n. 28
0
def experiment(variant):
    expl_env = make_env()
    eval_env = make_env()
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    M = variant['layer_size']
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
    )
    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['trainer_kwargs'])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):
    env = gym.make('RLkitGoalUR-v0')._start_ros_services()
    eval_env = gym.make('RLkitGoalUR-v0')
    expl_env = gym.make('RLkitGoalUR-v0')

    observation_key = 'observation'
    desired_goal_key = 'desired_goal'

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = eval_env.observation_space.spaces['observation'].low.size
    action_dim = eval_env.action_space.low.size
    goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size
    qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **variant['policy_kwargs'])
    eval_policy = MakeDeterministic(policy)
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Esempio n. 30
0
def active_representation_learning_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import ObsDictReplayBuffer
    from rlkit.torch.networks import ConcatMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.arl.active_representation_learning_algorithm import \
        ActiveRepresentationLearningAlgorithm
    from rlkit.torch.arl.representation_wrappers import RepresentationWrappedEnv
    from multiworld.core.image_env import ImageEnv
    from rlkit.samplers.data_collector import MdpPathCollector

    preprocess_rl_variant(variant)

    model_class = variant.get('model_class')
    model_kwargs = variant.get('model_kwargs')

    model = model_class(**model_kwargs)
    model.representation_size = 4
    model.imsize = 48
    variant["vae_path"] = model

    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    env = variant["env_class"](**variant['env_kwargs'])
    image_env = ImageEnv(
        env,
        variant.get('imsize'),
        init_camera=init_camera,
        transpose=True,
        normalize=True,
    )
    env = RepresentationWrappedEnv(
        image_env,
        model,
    )

    uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **variant['generate_uniform_dataset_kwargs'])
    else:
        uniform_dataset = None

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = env.observation_space.spaces[observation_key].low.size
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=hidden_sizes,
    )

    vae = env.vae

    replay_buffer = ObsDictReplayBuffer(env=env,
                                        **variant['replay_buffer_kwargs'])

    model_trainer_class = variant.get('model_trainer_class')
    model_trainer_kwargs = variant.get('model_trainer_kwargs')
    model_trainer = model_trainer_class(
        model,
        **model_trainer_kwargs,
    )
    # vae_trainer = ConvVAETrainer(
    #     env.vae,
    #     **variant['online_vae_trainer_kwargs']
    # )
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    max_path_length = variant['max_path_length']

    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['twin_sac_trainer_kwargs'])
    # trainer = HERTrainer(trainer)
    eval_path_collector = MdpPathCollector(
        env,
        MakeDeterministic(policy),
        # max_path_length,
        # observation_key=observation_key,
        # desired_goal_key=desired_goal_key,
    )
    expl_path_collector = MdpPathCollector(
        env,
        policy,
        # max_path_length,
        # observation_key=observation_key,
        # desired_goal_key=desired_goal_key,
    )

    algorithm = ActiveRepresentationLearningAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        model=model,
        model_trainer=model_trainer,
        uniform_dataset=uniform_dataset,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()