def generate_transitions(policy, env, num_timesteps_total,
                         max_steps_per_episode, save_path):
    buff = SimpleReplayBuffer(num_timesteps_total,
                              env.observation_space.shape,
                              gym_get_dim(env.action_space),
                              discrete_action_dim=True)

    cur_total = 0
    steps_left_in_episode = 0
    while cur_total != num_timesteps_total:
        if steps_left_in_episode == 0:
            steps_left_in_episode = max_steps_per_episode
            obs = env.reset()

        act = policy.get_action(obs)
        next_obs, rew, done, _ = env.step(act)
        buff.add_sample(obs, act, rew, done, next_obs)

        obs = next_obs
        cur_total += 1
        steps_left_in_episode -= 1

    save_dict = dict(
        observations=buff._observations,
        actions=buff._actions,
        rewards=buff._rewards,
        terminals=buff._terminals,
        next_observations=buff._next_obs,
    )
    joblib.dump(save_dict, save_path)

    # debug
    from scipy.misc import imsave
    actions = buff._actions
    observations = buff._observations
    for i in range(1000):
        a = actions[i]
        obs = observations[i]
        print(a)
        imsave('junk_vis/tiny/mem_grid_{}.png'.format(i),
               np.transpose(obs, (1, 2, 0)))
Example #2
0
 def __init__(
         self,
         max_replay_buffer_size,
         env,
         tasks,
 ):
     """
     :param max_replay_buffer_size:
     :param env:
     :param tasks: for multi-task setting
     """
     self.env = env
     self._ob_space = env.observation_space
     self._action_space = env.action_space
     self.task_buffers = dict([(idx, SimpleReplayBuffer(
         max_replay_buffer_size=max_replay_buffer_size,
         observation_dim=get_dim(self._ob_space),
         action_dim=get_dim(self._action_space),
     )) for idx in tasks])
    def __init__(
            self,
            max_replay_buffer_size,
            env,
            tasks,
    ):
        """
        :param max_replay_buffer_size:
        :param env:
        :param tasks: for multi-task setting
        """
        self.env = env
        self._ob_space = env.observation_space
        self._action_space = env.action_space

        self.d = [self.env.d0, self.env.d1]
        self.d_len = [0, 0]
        for idx in tasks:
            self.d_len[idx] = len(self.d[idx]['observations'])
        # d_len = [998999, 998999]

        self.task_buffers = dict([(idx, SimpleReplayBuffer(
            max_replay_buffer_size=self.d_len[idx],
            observation_dim=get_dim(self._ob_space),
            action_dim=get_dim(self._action_space),
        )) for idx in tasks])

        # Fill empty replay buffer for each task with offline data
        for idx in tasks:
            for i in range(self.d_len[idx]):
                d_task = self.d[idx]
                self.task_buffers[idx].add_sample(d_task['observations'][i],
                                                  d_task['actions'][i],
                                                  d_task['rewards'][i],
                                                  d_task['terminals'][i],
                                                  d_task['next_observations'][i])
Example #4
0
 def test_num_steps_can_sample(self):
     buffer = SimpleReplayBuffer(10000, 1, 1)
     buffer.add_sample(1, 1, 1, False, 1)
     buffer.add_sample(1, 1, 1, True, 1)
     buffer.terminate_episode()
     buffer.add_sample(1, 1, 1, False, 1)
     self.assertEqual(buffer.num_steps_can_sample(), 3)
def experiment(exp_specs):
    # Load the data -----------------------------------------------------------
    extra_data_path = exp_specs['extra_data_path']
    train_replay_buffer = joblib.load(extra_data_path)['replay_buffer']
    train_replay_buffer.change_max_size_to_cur_size()
    train_replay_buffer._next_obs = train_replay_buffer._next_obs[:, exp_specs[
        'extra_obs_dim']:]

    print('\nRewards: {} +/- {}'.format(np.mean(train_replay_buffer._rewards),
                                        np.std(train_replay_buffer._rewards)))

    next_obs_mean = np.mean(train_replay_buffer._next_obs, 0)
    next_obs_std = np.std(train_replay_buffer._next_obs, 0)
    print('\nNext Obs:\n{}\n+/-\n{}'.format(next_obs_mean, next_obs_std))

    print('\nAvg Next Obs Square Norm: {}'.format(
        np.mean(np.linalg.norm(train_replay_buffer._next_obs, axis=1)**2)))

    sample_batch = train_replay_buffer.random_batch(
        exp_specs['train_batch_size'])
    obs_dim = sample_batch['observations'].shape[-1]
    act_dim = sample_batch['actions'].shape[-1]

    val_replay_buffer = SimpleReplayBuffer(exp_specs['val_set_size'], obs_dim,
                                           act_dim)
    val_replay_buffer.set_buffer_from_dict(
        train_replay_buffer.sample_and_remove(exp_specs['val_set_size']))
    train_replay_buffer.set_buffer_from_dict(
        train_replay_buffer.sample_and_remove(exp_specs['train_set_size']))

    # Model Definitions -------------------------------------------------------
    model = GenericMap([obs_dim + act_dim],
                       [obs_dim - exp_specs['extra_obs_dim'] + 1],
                       siamese_input=False,
                       siamese_output=False,
                       num_hidden_layers=exp_specs['num_hidden_layers'],
                       hidden_dim=exp_specs['hidden_dim'],
                       act='relu',
                       use_bn=True,
                       deterministic=True)

    gap_model = GenericMap([obs_dim + act_dim], [
        obs_dim - exp_specs['extra_obs_dim'],
        obs_dim - exp_specs['extra_obs_dim']
    ],
                           siamese_input=False,
                           siamese_output=True,
                           num_hidden_layers=exp_specs['num_hidden_layers'],
                           hidden_dim=exp_specs['hidden_dim'],
                           act='relu',
                           use_bn=True,
                           deterministic=True)

    model_optim = Adam(model.parameters(), lr=float(exp_specs['lr']))
    gap_model_optim = Adam(gap_model.parameters(),
                           lr=float(exp_specs['gap_lr']))

    # Train -------------------------------------------------------------------
    model.train()
    for iter_num in range(exp_specs['max_iters']):
        model_optim.zero_grad()
        gap_model_optim.zero_grad()

        batch = train_replay_buffer.random_batch(exp_specs['train_batch_size'])
        batch = convert_numpy_dict_to_pytorch(batch)
        inputs = Variable(
            torch.cat([batch['observations'], batch['actions']], -1))
        outputs = Variable(
            torch.cat([batch['next_observations'], batch['rewards']], -1))
        true_next_obs = Variable(batch['next_observations'])

        preds = model([inputs])[0]
        gap_preds = gap_model([inputs])
        lower, upper = gap_preds[0], gap_preds[1]
        # residual for observations
        # preds = preds + Variable(torch.cat([batch['observations'], torch.zeros(exp_specs['train_batch_size'], 1)], 1))

        loss = torch.mean(torch.sum((outputs - preds)**2, -1))

        lower_loss = torch.mean(torch.sum(F.relu(lower - true_next_obs), -1))
        upper_loss = torch.mean(torch.sum(F.relu(true_next_obs - upper), -1))
        upper_lower_gap_loss = torch.mean(
            torch.sum(torch.abs(upper - lower), -1))

        total_loss = loss + upper_loss + lower_loss + float(
            exp_specs['upper_lower_gap_loss_weight']) * upper_lower_gap_loss

        total_loss.backward()
        model_optim.step()
        gap_model_optim.step()

        if iter_num % exp_specs['freq_val'] == 0:
            model.eval()

            val_batch = val_replay_buffer.random_batch(
                exp_specs['val_batch_size'])
            val_batch = convert_numpy_dict_to_pytorch(val_batch)
            inputs = Variable(
                torch.cat([val_batch['observations'], val_batch['actions']],
                          -1))
            outputs = Variable(
                torch.cat(
                    [val_batch['next_observations'], val_batch['rewards']],
                    -1))
            true_next_obs = Variable(val_batch['next_observations'])

            preds = model([inputs])[0]
            gap_preds = gap_model([inputs])
            lower, upper = gap_preds[0], gap_preds[1]
            # residual for observations
            # pred = preds + Variable(torch.cat([val_batch['observations'], torch.zeros(exp_specs['train_batch_size'], 1)], 1))

            loss = torch.mean(torch.sum((outputs - preds)**2, -1))
            next_obs_loss = torch.mean(
                torch.sum((outputs[:, :-1] - preds[:, :-1])**2, -1))
            rew_loss = torch.mean(
                torch.sum((outputs[:, -1:] - preds[:, -1:])**2, -1))

            lower_loss = torch.mean(
                torch.sum(F.relu(lower - true_next_obs), -1))
            upper_loss = torch.mean(
                torch.sum(F.relu(true_next_obs - upper), -1))
            upper_lower_gap_loss = torch.mean(
                torch.sum(torch.abs(upper - lower), -1))

            pred_over_upper = torch.mean(
                torch.sum(F.relu(preds[:, :-1] - upper), -1))
            pred_under_lower = torch.mean(
                torch.sum(F.relu(lower - preds[:, :-1]), -1))

            adj_next_obs_pred = torch.max(torch.min(preds[:, :-1], upper),
                                          lower)
            adj_next_obs_loss = torch.mean(
                torch.sum((outputs[:, :-1] - adj_next_obs_pred)**2, -1))

            ul_mean = (upper + lower) / 2.0
            ul_mean_as_obs_loss = torch.mean(
                torch.sum((outputs[:, :-1] - ul_mean)**2, -1))

            print('\n')
            print('-' * 20)
            print('Iter %d' % iter_num)
            print('Loss: %.4f' % loss)
            print('Obs Loss: %.4f' % next_obs_loss)
            print('Rew Loss: %.4f' % rew_loss)
            print('\nUpper Loss: %.4f' % upper_loss)
            print('Lower Loss: %.4f' % lower_loss)
            print('UL-Gap Loss: %.4f' % upper_lower_gap_loss)
            print('\nPred Over Upper: %.4f' % pred_over_upper)
            print('Pred Under Lower: %.4f' % pred_under_lower)
            print('\nAdj Obs Loss: %.4f' % adj_next_obs_loss)
            print('\nUL Mean as Obs Loss: %.4f' % ul_mean_as_obs_loss)

            model.train()
import os
from rlkit.data_management.simple_replay_buffer import SimpleReplayBuffer

# get the original
her_demos_path = '/scratch/gobi2/kamyar/oorl_rlkit/expert_demos/larger_object_range_fetch_pick_and_place/larger_object_range_easy_in_the_air_fetch_data_random_1000.npz'
rlkit_buffer_save_dir = '/scratch/gobi2/kamyar/oorl_rlkit/expert_demos/larger_object_range_fetch_pick_and_place'
d = np.load(her_demos_path)

# make the buffer
buffer_size = sum(len(path) for path in d['obs'])
obs_dim = {
    'obs': d['obs'][0][0]['observation'].shape[0],
    'obs_task_params': d['obs'][0][0]['desired_goal'].shape[0]
}
action_dim = len(d['acs'][0][0])
buffer = SimpleReplayBuffer(buffer_size, obs_dim, action_dim)

# fill the buffer
for path_num in range(len(d['obs'])):
    obs = d['obs'][path_num]
    acs = d['acs'][path_num]
    env_infos = d['info'][path_num]

    ep_len = len(obs)
    for j in range(ep_len - 1):
        o = {
            'obs': obs[j]['observation'],
            'obs_task_params': obs[j]['desired_goal']
        }
        a = acs[j]
        r = 0.  # the demons don't come with reward
Example #7
0
    def __init__(
            self,
            env_sampler,
            exploration_policy: ExplorationPolicy,
            neural_process,
            train_neural_process=False,
            latent_repr_mode='concat_params',  # OR concat_samples
            num_latent_samples=5,
            num_epochs=100,
            num_steps_per_epoch=10000,
            num_steps_per_eval=1000,
            num_updates_per_env_step=1,
            batch_size=1024,
            max_path_length=1000,
            discount=0.99,
            replay_buffer_size=1000000,
            reward_scale=1,
            render=False,
            save_replay_buffer=False,
            save_algorithm=False,
            save_environment=False,
            eval_sampler=None,
            eval_policy=None,
            replay_buffer=None,
            epoch_to_start_training=0):
        """
        Base class for RL Algorithms
        :param env: Environment used to evaluate.
        :param exploration_policy: Policy used to explore
        :param training_env: Environment used by the algorithm. By default, a
        copy of `env` will be made.
        :param num_epochs:
        :param num_steps_per_epoch:
        :param num_steps_per_eval:
        :param num_updates_per_env_step: Used by online training mode.
        :param num_updates_per_epoch: Used by batch training mode.
        :param batch_size:
        :param max_path_length:
        :param discount:
        :param replay_buffer_size:
        :param reward_scale:
        :param render:
        :param save_replay_buffer:
        :param save_algorithm:
        :param save_environment:
        :param eval_sampler:
        :param eval_policy: Policy to evaluate with.
        :param replay_buffer:
        """
        assert not train_neural_process, 'Have not implemented it yet! Remember to set it to train mode when training'
        self.neural_process = neural_process
        self.neural_process.set_mode('eval')
        self.latent_repr_mode = latent_repr_mode
        self.num_latent_samples = num_latent_samples
        self.env_sampler = env_sampler
        env, env_specs = env_sampler()
        self.training_env, _ = env_sampler(env_specs)
        # self.training_env = training_env or pickle.loads(pickle.dumps(env))
        # self.training_env = training_env or deepcopy(env)
        self.exploration_policy = exploration_policy
        self.num_epochs = num_epochs
        self.num_env_steps_per_epoch = num_steps_per_epoch
        self.num_steps_per_eval = num_steps_per_eval
        self.num_updates_per_train_call = num_updates_per_env_step
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = replay_buffer_size
        self.reward_scale = reward_scale
        self.render = render
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment
        self.epoch_to_start_training = epoch_to_start_training

        if self.latent_repr_mode == 'concat_params':

            def get_latent_repr(posterior_state):
                z_mean, z_cov = self.neural_process.get_posterior_params(
                    posterior_state)
                return np.concatenate([z_mean, z_cov])

            self.extra_obs_dim = 2 * self.neural_process.z_dim
        else:

            def get_latent_repr(posterior_state):
                z_mean, z_cov = self.neural_process.get_posterior_params(
                    posterior_state)
                samples = np.random.multivariate_normal(
                    z_mean, np.diag(z_cov), self.num_latent_samples)
                samples = samples.flatten()
                return samples

            self.extra_obs_dim = self.num_latent_samples * self.neural_process.z_dim
        self.get_latent_repr = get_latent_repr

        if eval_sampler is None:
            if eval_policy is None:
                eval_policy = exploration_policy
            eval_sampler = InPlacePathSampler(
                env=env,
                policy=eval_policy,
                max_samples=self.num_steps_per_eval + self.max_path_length,
                max_path_length=self.max_path_length,
                neural_process=neural_process,
                latent_repr_fn=get_latent_repr,
                reward_scale=reward_scale)
        self.eval_policy = eval_policy
        self.eval_sampler = eval_sampler

        self.action_space = env.action_space
        self.obs_space = env.observation_space

        self.env = env
        obs_space_dim = gym_get_dim(self.obs_space)
        act_space_dim = gym_get_dim(self.action_space)
        if replay_buffer is None:
            replay_buffer = SimpleReplayBuffer(
                self.replay_buffer_size,
                obs_space_dim + self.extra_obs_dim,
                act_space_dim,
                discrete_action_dim=isinstance(self.action_space, Discrete))
        self.replay_buffer = replay_buffer

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []
Example #8
0
    def __init__(
            self,
            env,
            agent,
            train_tasks,
            eval_tasks,
            meta_batch=64,
            num_iterations=100,
            num_train_steps_per_itr=1000,
            num_initial_steps=100,
            num_tasks_sample=100,
            num_steps_prior=100,
            num_steps_posterior=100,
            num_extra_rl_steps_posterior=100,
            num_evals=10,
            num_steps_per_eval=1000,
            batch_size=1024,
            low_batch_size=2048,  #TODO: Tune this batch size
            embedding_batch_size=1024,
            embedding_mini_batch_size=1024,
            max_path_length=1000,
            discount=0.99,
            replay_buffer_size=1000000,
            reward_scale=1,
            num_exp_traj_eval=1,
            update_post_train=1,
            eval_deterministic=True,
            render=False,
            save_replay_buffer=False,
            save_algorithm=False,
            save_environment=False,
            render_eval_paths=False,
            dump_eval_paths=False,
            plotter=None,
            use_goals=False):
        """
        :param env: training env
        :param agent: agent that is conditioned on a latent variable z that rl_algorithm is responsible for feeding in
        :param train_tasks: list of tasks used for training
        :param eval_tasks: list of tasks used for eval

        see default experiment config file for descriptions of the rest of the arguments
        """
        self.env = env
        self.agent = agent
        self.use_goals = use_goals
        assert (agent.use_goals == self.use_goals)
        self.exploration_agent = agent  # Can potentially use a different policy purely for exploration rather than also solving tasks, currently not being used
        self.train_tasks = train_tasks
        self.eval_tasks = eval_tasks
        self.meta_batch = meta_batch
        self.num_iterations = num_iterations
        self.num_train_steps_per_itr = num_train_steps_per_itr
        self.num_initial_steps = num_initial_steps
        self.num_tasks_sample = num_tasks_sample
        self.num_steps_prior = num_steps_prior
        self.num_steps_posterior = num_steps_posterior
        self.num_extra_rl_steps_posterior = num_extra_rl_steps_posterior
        self.num_evals = num_evals
        self.num_steps_per_eval = num_steps_per_eval
        self.batch_size = batch_size
        self.embedding_batch_size = embedding_batch_size
        self.embedding_mini_batch_size = embedding_mini_batch_size
        self.low_batch_size = low_batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = replay_buffer_size
        self.reward_scale = reward_scale
        self.update_post_train = update_post_train
        self.num_exp_traj_eval = num_exp_traj_eval
        self.eval_deterministic = eval_deterministic
        self.render = render
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment

        self.eval_statistics = None
        self.render_eval_paths = render_eval_paths
        self.dump_eval_paths = dump_eval_paths
        self.plotter = plotter

        obs_dim = int(np.prod(env.observation_space.shape))
        action_dim = int(np.prod(env.action_space.shape))

        self.sampler = InPlacePathSampler(
            env=env,
            policy=agent,
            max_path_length=self.max_path_length,
        )

        # separate replay buffers for
        # - training RL update
        # - training encoder update

        self.enc_replay_buffer = MultiTaskReplayBuffer(
            self.replay_buffer_size,
            env,
            self.train_tasks,
        )
        if self.use_goals:
            self.high_buffer = MultiTaskReplayBuffer(self.replay_buffer_size,
                                                     env, self.train_tasks)
            #Hacky method for changing the obs and action dimensions for the internal
            #buffers since they're not the same as the original environment
            internal_buffers = dict([
                (idx,
                 SimpleReplayBuffer(
                     max_replay_buffer_size=self.replay_buffer_size,
                     observation_dim=obs_dim,
                     action_dim=obs_dim,
                 )) for idx in self.train_tasks
            ])
            self.high_buffer.task_buffers = internal_buffers

            self.low_buffer = SimpleReplayBuffer(
                max_replay_buffer_size=replay_buffer_size,
                observation_dim=2 * obs_dim,
                action_dim=action_dim,
            )
        else:
            self.replay_buffer = MultiTaskReplayBuffer(
                self.replay_buffer_size,
                env,
                self.train_tasks,
            )

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []
Example #9
0
    def __init__(
            self,
            env,
            agent,
            qf1,
            qf2,
            target_qf1,
            target_qf2,
            total_steps=int(1e6),
            max_path_length=200,
            num_exp_traj_eval=1,
            start_fine_tuning=10,
            fine_tuning_steps=1,
            should_freeze_z=True,
            replay_buffer_size=int(1e6),
            batch_size=256,
            discount=0.99,
            policy_lr=1e-4,
            qf_lr=1e-4,
            temp_lr=1e-4,
            target_entropy=None,
            optimizer_class=torch.optim.Adam,
            soft_target_tau=1e-2
    ):
        self.env = env
        self.agent = agent

        # Ctitic networks
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.log_alpha = torch.zeros(1, requires_grad=True, device='cuda')
        self.log_alpha.to(device)
        self.target_entropy = target_entropy

        # Experimental setting
        self.total_steps = total_steps
        self.max_path_length = max_path_length
        self.num_exp_traj_eval = num_exp_traj_eval
        self.start_fine_tuning = start_fine_tuning
        self.fine_tuning_steps = fine_tuning_steps
        self.should_freeze_z = should_freeze_z

        # Hyperparams
        self.batch_size = batch_size
        self.discount = discount
        self.soft_target_tau = soft_target_tau

        self.replay_buffer = SimpleReplayBuffer(
            max_replay_buffer_size=replay_buffer_size,
            observation_dim=int(np.prod(env.observation_space.shape)),
            action_dim=int(np.prod(env.action_space.shape)),
        )

        self.q_losses = []
        self.temp_losses = []
        self.policy_losses = []
        self.temp_vals = []

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.agent.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )
        self.temp_optimizer = optimizer_class(
            [self.log_alpha],
            lr=temp_lr,
        )

        self.print_experiment_description()
Example #10
0
class PEARLFineTuningHelper:

    def __init__(
            self,
            env,
            agent,
            qf1,
            qf2,
            target_qf1,
            target_qf2,
            total_steps=int(1e6),
            max_path_length=200,
            num_exp_traj_eval=1,
            start_fine_tuning=10,
            fine_tuning_steps=1,
            should_freeze_z=True,
            replay_buffer_size=int(1e6),
            batch_size=256,
            discount=0.99,
            policy_lr=1e-4,
            qf_lr=1e-4,
            temp_lr=1e-4,
            target_entropy=None,
            optimizer_class=torch.optim.Adam,
            soft_target_tau=1e-2
    ):
        self.env = env
        self.agent = agent

        # Ctitic networks
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.log_alpha = torch.zeros(1, requires_grad=True, device='cuda')
        self.log_alpha.to(device)
        self.target_entropy = target_entropy

        # Experimental setting
        self.total_steps = total_steps
        self.max_path_length = max_path_length
        self.num_exp_traj_eval = num_exp_traj_eval
        self.start_fine_tuning = start_fine_tuning
        self.fine_tuning_steps = fine_tuning_steps
        self.should_freeze_z = should_freeze_z

        # Hyperparams
        self.batch_size = batch_size
        self.discount = discount
        self.soft_target_tau = soft_target_tau

        self.replay_buffer = SimpleReplayBuffer(
            max_replay_buffer_size=replay_buffer_size,
            observation_dim=int(np.prod(env.observation_space.shape)),
            action_dim=int(np.prod(env.action_space.shape)),
        )

        self.q_losses = []
        self.temp_losses = []
        self.policy_losses = []
        self.temp_vals = []

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.agent.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )
        self.temp_optimizer = optimizer_class(
            [self.log_alpha],
            lr=temp_lr,
        )

        self.print_experiment_description()

    def get_mean(self, losses):
        if not losses:
            return None
        tot = 0
        for tensor in losses:
            tot += np.mean(tensor.to('cpu').detach().numpy())
        return tot / len(losses)

    def collect_samples(self, should_accum_context):
        path = self.rollout(should_accum_context)
        self.replay_buffer.add_path(path)
        steps = path['rewards'].shape[0]
        ret = sum(path['rewards'])[0]
        return ret, steps

    def rollout(self, should_accum_context):
        should_fine_tune = not should_accum_context
        observations = []
        actions = []
        rewards = []
        terminals = []
        agent_infos = []
        env_infos = []
        o = self.env.reset()
        next_o = None
        path_length = 0
        done = False

        while (not done):
            a, agent_info = self.agent.get_action(o)
            next_o, r, d, env_info = self.env.step(a)
            real_done = False if path_length == self.max_path_length else d
            observations.append(o)
            rewards.append(r)
            terminals.append(real_done)
            actions.append(a)
            agent_infos.append(agent_info)
            path_length += 1
            o = next_o
            env_infos.append(env_info)
            if should_accum_context:
                self.agent.update_context([o, a, r, next_o, d, env_info])
            if should_fine_tune:
                for j in range(self.fine_tuning_steps):
                    self.fine_tuning_step()
            if d or path_length >= self.max_path_length:
                done = True

        actions = np.array(actions)
        if len(actions.shape) == 1:
            actions = np.expand_dims(actions, 1)
        observations = np.array(observations)
        if len(observations.shape) == 1:
            observations = np.expand_dims(observations, 1)
            next_o = np.array([next_o])
        next_observations = np.vstack(
            (
                observations[1:, :],
                np.expand_dims(next_o, 0)
            )
        )

        if should_accum_context:
            self.agent.sample_z()

        return dict(
            observations=observations,
            actions=actions,
            rewards=np.array(rewards).reshape(-1, 1),
            next_observations=next_observations,
            terminals=np.array(terminals).reshape(-1, 1),
            agent_infos=agent_infos,
            env_infos=env_infos,
        )

    def get_samples(self):
        batch = ptu.np_to_pytorch_batch(self.replay_buffer.random_batch(self.batch_size))
        o = batch['observations'][None, ...]
        a = batch['actions'][None, ...]
        r = batch['rewards'][None, ...]
        no = batch['next_observations'][None, ...]
        t = batch['terminals'][None, ...]
        return o, a, r, no, t

    def _min_q(self, obs, actions, task_z):
        q1 = self.qf1(obs, actions, task_z.detach())
        q2 = self.qf2(obs, actions, task_z.detach())
        min_q = torch.min(q1, q2)
        return min_q

    def _update_target_networks(self):
        ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau)
        ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau)

    def fine_tuning_step(self):
        obs, actions, rewards, next_obs, terms = self.get_samples()

        # flattens out the task dimension
        t, b, _ = obs.size()
        obs_flat = obs.view(t * b, -1)
        actions_flat = actions.view(t * b, -1)
        next_obs_flat = next_obs.view(t * b, -1)
        rewards_flat = rewards.view(self.batch_size, -1)
        terms_flat = terms.view(self.batch_size, -1)

        """
        QF Loss
        """
        with torch.no_grad():
            next_policy_outputs, task_z = self.agent(next_obs, self.agent.context)
            next_new_actions, _, _, next_log_prob = next_policy_outputs[:4]
            t_q1_pred = self.target_qf1(next_obs_flat, next_new_actions, task_z.detach())  # TODO: Remove .detach() if redundant
            t_q2_pred = self.target_qf2(next_obs_flat, next_new_actions, task_z.detach())
            t_q_min = torch.min(t_q1_pred, t_q2_pred)
            q_target = rewards_flat + (1. - terms_flat) * self.discount * (t_q_min - self.alpha * next_log_prob)
        q1_pred = self.qf1(obs_flat, actions_flat, task_z.detach())                    # TODO: Remove .detach() if redundant
        q2_pred = self.qf2(obs_flat, actions_flat, task_z.detach())
        qf_loss = torch.mean((q1_pred - q_target.detach()) ** 2) + torch.mean((q2_pred - q_target.detach()) ** 2)

        self.qf1_optimizer.zero_grad()
        self.qf2_optimizer.zero_grad()
        qf_loss.backward()
        self.qf1_optimizer.step()
        self.qf2_optimizer.step()

        """
        Policy and Temp Loss
        """
        for p in self.qf1.parameters():
            p.requires_grad = False
        for p in self.qf2.parameters():
            p.requires_grad = False

        policy_outputs, task_z = self.agent(obs, self.agent.context)
        new_actions, policy_mean, policy_log_std, log_prob = policy_outputs[:4]
        min_q_new_actions = self._min_q(obs_flat, new_actions, task_z)

        policy_loss = (self.alpha * log_prob - min_q_new_actions).mean()
        temp_loss = -self.alpha * (log_prob.detach() + self.target_entropy).mean()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        self.temp_optimizer.zero_grad()
        temp_loss.backward()
        self.temp_optimizer.step()

        for p in self.qf1.parameters():
            p.requires_grad = True
        for p in self.qf2.parameters():
            p.requires_grad = True

        """
        Update Target Networks
        """
        self._update_target_networks()

        self.q_losses.append(qf_loss.detach())
        self.temp_losses.append(temp_loss.detach())
        self.policy_losses.append(policy_loss.detach())
        self.temp_vals.append(self.alpha.detach())

    def evaluate_agent(self, n_starts=10):
        reward_sum = 0
        for _ in range(n_starts):
            path = rollout(self.env, self.agent, max_path_length=self.max_path_length, accum_context=False)
            reward_sum += sum(path['rewards'])[0]
        return reward_sum / n_starts

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def fine_tune(self, variant, seed):
        random.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)

        cumulative_timestep = 0
        i_episode = 0

        df = pd.DataFrame(
            columns=[
                'step',
                'real_step',
                'train_reward',
                'eval_reward',
                'loss/q-f1',
                'loss/alpha',
                'loss/policy',
                'val/alpha'
            ]
        )

        # For this experiment, we are evaluating in just one sampled task from the meta-test set
        tasks = self.env.get_all_task_idx()
        eval_tasks = list(tasks[-variant['n_eval_tasks']:])
        idx = eval_tasks[0]
        self.env.reset_task(idx)

        self.agent.clear_z()
        while cumulative_timestep < self.total_steps:
            i_episode += 1
            should_infer_posterior = self.num_exp_traj_eval <= i_episode < self.start_fine_tuning
            should_fine_tune = self.start_fine_tuning <= i_episode
            should_accum_context = not should_fine_tune
            if should_fine_tune and self.should_freeze_z and (not self.agent.freeze_z):
                self.agent.freeze_z = True
            train_reward, episode_steps = self.collect_samples(should_accum_context)
            cumulative_timestep += episode_steps
            if should_infer_posterior:
                self.agent.infer_posterior(self.agent.context)
            eval_reward = self.evaluate_agent()
            message = 'Episode {} \t\t Samples {} \t\t Real samples {} \t\t Train reward: {} \t\t Eval reward: {}'
            print(message.format(i_episode, i_episode * self.max_path_length, cumulative_timestep, train_reward,
                                 eval_reward))
            new_df_row = {
                'step': int(i_episode * self.max_path_length),
                'real_step': int(cumulative_timestep),
                'train_reward': train_reward,
                'eval_reward': eval_reward,
                'loss/q-f1': self.get_mean(self.q_losses),
                'loss/alpha': self.get_mean(self.temp_losses),
                'loss/policy': self.get_mean(self.policy_losses),
                'val/alpha': self.get_mean(self.temp_vals)
            }
            self.q_losses = []
            self.temp_losses = []
            self.policy_losses = []
            self.temp_vals = []
            df = df.append(new_df_row, ignore_index=True)
            results_path = "results_ft/{}/ft{}".format(variant['env_name'], "_{}".format(seed - 1))
            if not os.path.isdir(results_path):
                os.makedirs(results_path)
            df.to_csv("{}/progress.csv".format(results_path))

    def print_experiment_description(self):
        print("\n\n", " -" * 15, "\n")
        print("Total steps:  \t\t\t", self.total_steps)
        print("Max path length:  \t\t\t", self.max_path_length)
        print("Trajectory length with prior:  \t\t\t", self.num_exp_traj_eval)
        print("Start fine tuning after:  \t\t\t", self.start_fine_tuning)
        print("Number of fine-tuning steps:  \t\t\t", self.fine_tuning_steps)
        print("Should freeze Z during fine-tuning?  \t\t\t", self.should_freeze_z)
        print("Batch size:  \t\t\t", self.batch_size)
        print("Gamma:  \t\t\t", self.discount)
        print("Tau:  \t\t\t", self.soft_target_tau)
    'norm_halfcheetah_16_demos_20_subsampling',
    'norm_halfcheetah_4_demos_20_subsampling',
    'norm_ant_32_demos_20_subsampling', 'norm_ant_16_demos_20_subsampling',
    'norm_ant_4_demos_20_subsampling', 'norm_walker_32_demos_20_subsampling',
    'norm_walker_16_demos_20_subsampling',
    'norm_walker_4_demos_20_subsampling',
    'norm_hopper_32_demos_20_subsampling',
    'norm_hopper_16_demos_20_subsampling', 'norm_hopper_4_demos_20_subsampling'
]

for exp_name in demos_to_convert:
    d = joblib.load(
        osp.join(old_listings[exp_name]['exp_dir'],
                 old_listings[exp_name]['seed_runs'][0], 'extra_data.pkl'))
    size = d['train']._size
    obs_dim = d['train']._observation_dim
    act_dim = d['train']._action_dim

    train_rb = SimpleReplayBuffer(size + 1, obs_dim, act_dim)
    copy_over(d['train'], train_rb)
    d['train'] = train_rb

    test_rb = SimpleReplayBuffer(size + 1, obs_dim, act_dim)
    copy_over(d['test'], test_rb)
    d['test'] = test_rb

    joblib.dump(d,
                osp.join('/scratch/hdd001/home/kamyar/output/fmax_demos',
                         exp_name + '.pkl'),
                compress=3)
def experiment(exp_specs):
    # Set up logging ----------------------------------------------------------
    exp_id = exp_specs['exp_id']
    exp_prefix = exp_specs['exp_name']
    seed = exp_specs['seed']
    set_seed(seed)
    setup_logger(exp_prefix=exp_prefix, exp_id=exp_id, variant=exp_specs)

    # Load the data -----------------------------------------------------------
    extra_data_path = exp_specs['extra_data_path']
    train_replay_buffer = joblib.load(extra_data_path)['replay_buffer']
    train_replay_buffer.change_max_size_to_cur_size()
    train_replay_buffer._next_obs = train_replay_buffer._next_obs[:,exp_specs['extra_obs_dim']:]
    if exp_specs['remove_env_info']:
        train_replay_buffer._observations = train_replay_buffer._observations[:,exp_specs['extra_obs_dim']:]
    else:
        if exp_specs['normalize_env_info']:
            low, high = exp_specs['env_info_range'][0], exp_specs['env_info_range'][1]
            train_replay_buffer._observations[:,:exp_specs['extra_obs_dim']] -= (low + high)/2.0
            train_replay_buffer._observations[:,:exp_specs['extra_obs_dim']] /= (high - low)/2.0

    print('\nRewards: {} +/- {}'.format(
        np.mean(train_replay_buffer._rewards),
        np.std(train_replay_buffer._rewards)
    ))

    next_obs_mean = np.mean(train_replay_buffer._next_obs, 0)
    next_obs_std = np.std(train_replay_buffer._next_obs, 0)
    print('\nNext Obs:\n{}\n+/-\n{}'.format(
        next_obs_mean,
        next_obs_std
    ))

    print('\nAvg Next Obs Square Norm: {}'.format(
        np.mean(np.linalg.norm(train_replay_buffer._next_obs, axis=1)**2)
    ))

    sample_batch = train_replay_buffer.random_batch(exp_specs['train_batch_size'])
    obs_dim = sample_batch['observations'].shape[-1]
    act_dim = sample_batch['actions'].shape[-1]

    val_replay_buffer = SimpleReplayBuffer(exp_specs['val_set_size'], obs_dim, act_dim)
    val_replay_buffer.set_buffer_from_dict(
        train_replay_buffer.sample_and_remove(exp_specs['val_set_size'])
    )
    if exp_specs['train_from_beginning_transitions']:
        trans_dict = dict(
            observations=train_replay_buffer._observations[:exp_specs['train_set_size']],
            actions=train_replay_buffer._actions[:exp_specs['train_set_size']],
            rewards=train_replay_buffer._rewards[:exp_specs['train_set_size']],
            terminals=train_replay_buffer._terminals[:exp_specs['train_set_size']],
            next_observations=train_replay_buffer._next_obs[:exp_specs['train_set_size']],
        )
        train_replay_buffer.set_buffer_from_dict(trans_dict)
    else:
        train_replay_buffer.set_buffer_from_dict(
            train_replay_buffer.sample_and_remove(exp_specs['train_set_size'])
        )

    # Model Definitions -------------------------------------------------------
    if exp_specs['remove_env_info']:
        output_dim = [obs_dim + 1]
    else:
        output_dim = [obs_dim - exp_specs['extra_obs_dim'] + 1]
    model = GenericMap(
        [obs_dim + act_dim],
        output_dim,
        siamese_input=False,
        siamese_output=False,
        num_hidden_layers=exp_specs['num_hidden_layers'],
        hidden_dim=exp_specs['hidden_dim'],
        act='relu',
        use_bn=True,
        deterministic=True
    )

    model_optim = Adam(model.parameters(), lr=float(exp_specs['lr']))

    # Train -------------------------------------------------------------------
    model.train()
    for iter_num in range(exp_specs['max_iters']):
        model_optim.zero_grad()

        batch = train_replay_buffer.random_batch(exp_specs['train_batch_size'])
        batch = convert_numpy_dict_to_pytorch(batch)
        inputs = Variable(torch.cat([batch['observations'], batch['actions']], -1))
        outputs = Variable(torch.cat([batch['next_observations'], batch['rewards']], -1))

        preds = model([inputs])[0]
        if exp_specs['residual']:
            # residual for observations
            preds = preds + Variable(
                        torch.cat(
                            [
                                batch['observations'][:,exp_specs['extra_obs_dim']:],
                                torch.zeros(exp_specs['train_batch_size'], 1)
                            ],
                        1)
                    )
        
        loss = torch.mean(torch.sum((outputs - preds)**2, -1))

        loss.backward()
        model_optim.step()

        if iter_num % exp_specs['freq_val'] == 0:
            model.eval()

            val_batch = val_replay_buffer.random_batch(exp_specs['val_batch_size'])
            val_batch = convert_numpy_dict_to_pytorch(val_batch)
            inputs = Variable(torch.cat([val_batch['observations'], val_batch['actions']], -1))
            outputs = Variable(torch.cat([val_batch['next_observations'], val_batch['rewards']], -1))

            # print(exp_specs['remove_env_info'])
            # print(inputs)
            # print(outputs)
            # sleep(5)
            
            preds = model([inputs])[0]
            if exp_specs['residual']:
                # residual for observations
                preds = preds + Variable(
                            torch.cat(
                                [
                                    val_batch['observations'][:,exp_specs['extra_obs_dim']:],
                                    torch.zeros(exp_specs['train_batch_size'], 1)
                                ],
                            1)
                        )

            loss = torch.mean(torch.sum((outputs - preds)**2, -1))
            next_obs_loss = torch.mean(torch.sum((outputs[:,:-1] - preds[:,:-1])**2, -1))
            rew_loss = torch.mean(torch.sum((outputs[:,-1:] - preds[:,-1:])**2, -1))

            print('\n')
            print('-'*20)
            logger.record_tabular('Iter', iter_num)
            logger.record_tabular('Loss', loss.data[0])
            logger.record_tabular('Obs Loss', next_obs_loss.data[0])
            logger.record_tabular('Rew Loss', rew_loss.data[0])
            logger.dump_tabular(with_prefix=False, with_timestamp=False)

            model.train()