def __init__(self, index, variant, candidate_size=10):
        ptu.set_gpu_mode(True)
        torch.set_num_threads(1)

        import sys
        sys.argv = ['']
        del sys

        env_max_action = variant['env_max_action']
        obs_dim = variant['obs_dim']
        action_dim = variant['action_dim']
        latent_dim = variant['latent_dim']
        vae_latent_dim = 2 * action_dim
        mlp_enconder_input_size = 2 * obs_dim + action_dim + 1 if variant[
            'use_next_obs_in_context'] else obs_dim + action_dim + 1

        mlp_enconder = MlpEncoder(hidden_sizes=[200, 200, 200],
                                  input_size=mlp_enconder_input_size,
                                  output_size=2 * variant['latent_dim'])
        self.context_encoder = ProbabilisticContextEncoder(
            mlp_enconder, variant['latent_dim'])
        self.Qs = FlattenMlp(
            hidden_sizes=variant['Qs_hidden_sizes'],
            input_size=obs_dim + action_dim + latent_dim,
            output_size=1,
        )
        self.vae_decoder = VaeDecoder(
            max_action=variant['env_max_action'],
            hidden_sizes=variant['vae_hidden_sizes'],
            input_size=obs_dim + vae_latent_dim + latent_dim,
            output_size=action_dim,
        )
        self.perturbation_generator = PerturbationGenerator(
            max_action=env_max_action,
            hidden_sizes=variant['perturbation_hidden_sizes'],
            input_size=obs_dim + action_dim + latent_dim,
            output_size=action_dim,
        )

        self.use_next_obs_in_context = variant['use_next_obs_in_context']

        self.env = env_producer(variant['domain'], variant['seed'])
        self.num_evals = variant['num_evals']
        self.max_path_length = variant['max_path_length']

        self.vae_latent_dim = vae_latent_dim
        self.candidate_size = variant['candidate_size']

        self.env.seed(10 * variant['seed'] + 1234 + index)
        set_seed(10 * variant['seed'] + 1234 + index)

        self.env.action_space.np_random.seed(123 + index)
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        vae_latent_dim_multiplicity,
        target_q_coef,
        actor_hid_sizes,
        critic_hid_sizes,
        vae_e_hid_sizes,
        vae_d_hid_sizes,
        encoder_latent_dim,
    ):

        vae_latent_dim = vae_latent_dim_multiplicity * action_dim
        self.actor = Actor(state_dim, action_dim, encoder_latent_dim,
                           actor_hid_sizes, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, encoder_latent_dim,
                                  actor_hid_sizes, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=3e-4)

        self.critic = Critic(state_dim, action_dim, encoder_latent_dim,
                             critic_hid_sizes).to(device)
        self.critic_target = Critic(state_dim, action_dim, encoder_latent_dim,
                                    critic_hid_sizes).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=3e-4)

        self.vae = VAE(state_dim, action_dim, encoder_latent_dim,
                       vae_latent_dim, vae_e_hid_sizes, vae_d_hid_sizes,
                       max_action).to(device)
        self.vae_optimizer = torch.optim.Adam(self.vae.parameters(), lr=3e-4)

        mlp_enconder_input_size = 2 * state_dim + action_dim + 1

        mlp_enconder = MlpEncoder(hidden_sizes=[200, 200, 200],
                                  input_size=mlp_enconder_input_size,
                                  output_size=2 * encoder_latent_dim)
        self.context_encoder = ProbabilisticContextEncoder(
            mlp_enconder, encoder_latent_dim)
        self.context_encoder_optimizer = torch.optim.Adam(
            self.context_encoder.parameters(), lr=3e-4)

        self.max_action = max_action
        self.action_dim = action_dim
        self.target_q_coef = target_q_coef

        self._need_to_update_eval_statistics = True
        self.eval_statistics = OrderedDict()
Exemplo n.º 3
0
def experiment(variant, prev_exp_state=None):

    domain = variant['domain']
    seed = variant['seed']
    goal = variant['goal']

    expl_env = env_producer(domain, seed, goal)

    env_max_action = float(expl_env.action_space.high[0])
    obs_dim = expl_env.observation_space.low.size
    action_dim = expl_env.action_space.low.size
    vae_latent_dim = 2 * action_dim
    mlp_enconder_input_size = 2 * obs_dim + action_dim + 1

    print('------------------------------------------------')
    print('obs_dim', obs_dim)
    print('action_dim', action_dim)
    print('------------------------------------------------')

    # Network module from tiMe

    mlp_enconder = MlpEncoder(hidden_sizes=[200, 200, 200],
                              input_size=mlp_enconder_input_size,
                              output_size=2 * variant['latent_dim'])

    context_encoder = ProbabilisticContextEncoder(mlp_enconder,
                                                  variant['latent_dim'])

    qf1 = FlattenMlp(
        hidden_sizes=variant['Qs_hidden_sizes'],
        input_size=obs_dim + action_dim + variant['latent_dim'],
        output_size=1,
    )
    target_qf1 = FlattenMlp(
        hidden_sizes=variant['Qs_hidden_sizes'],
        input_size=obs_dim + action_dim + variant['latent_dim'],
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=variant['Qs_hidden_sizes'],
        input_size=obs_dim + action_dim + variant['latent_dim'],
        output_size=1,
    )
    target_qf2 = FlattenMlp(
        hidden_sizes=variant['Qs_hidden_sizes'],
        input_size=obs_dim + action_dim + variant['latent_dim'],
        output_size=1,
    )
    vae_decoder = VaeDecoder(
        max_action=env_max_action,
        hidden_sizes=variant['vae_hidden_sizes'],
        input_size=obs_dim + vae_latent_dim + variant['latent_dim'],
        output_size=action_dim,
    )
    perturbation_generator = PerturbationGenerator(
        max_action=env_max_action,
        hidden_sizes=variant['perturbation_hidden_sizes'],
        input_size=obs_dim + action_dim + variant['latent_dim'],
        output_size=action_dim,
    )

    # Load the params obtained by tiMe
    ss = load_gzip_pickle(variant['path_to_snapshot'])
    ss = ss['trainer']

    encoder_state_dict = OrderedDict()
    for key, value in ss['context_encoder_state_dict'].items():
        if 'mlp_encoder' in key:
            encoder_state_dict[key.replace('mlp_encoder.', '')] = value

    mlp_enconder.load_state_dict(encoder_state_dict)

    qf1.load_state_dict(ss['Qs_state_dict'])

    target_qf1.load_state_dict(ss['Qs_state_dict'])

    qf2.load_state_dict(ss['Qs_state_dict'])

    target_qf2.load_state_dict(ss['Qs_state_dict'])

    vae_decoder.load_state_dict(ss['vae_decoder_state_dict'])

    perturbation_generator.load_state_dict(ss['perturbation_generator_dict'])

    tiMe_path_collector = tiMeSampler(
        expl_env,
        context_encoder,
        qf1,
        vae_decoder,
        perturbation_generator,
        vae_latent_dim=vae_latent_dim,
        candidate_size=variant['candidate_size'],
    )
    tiMe_path_collector.to(ptu.device)

    # Get producer function for policy
    policy_producer = get_policy_producer(
        obs_dim, action_dim, hidden_sizes=variant['policy_hidden_sizes'])
    # Finished getting producer

    remote_eval_path_collector = RemoteMdpPathCollector.remote(
        domain, seed * 10 + 1, goal, policy_producer)
    expl_path_collector = MdpPathCollector(expl_env, )
    replay_buffer = ReplayBuffer(variant['replay_buffer_size'],
                                 ob_space=expl_env.observation_space,
                                 action_space=expl_env.action_space)
    trainer = SACTrainer(policy_producer,
                         qf1=qf1,
                         target_qf1=target_qf1,
                         qf2=qf2,
                         target_qf2=target_qf2,
                         action_space=expl_env.action_space,
                         **variant['trainer_kwargs'])

    algorithm = BatchRLAlgorithm(
        trainer=trainer,
        exploration_data_collector=expl_path_collector,
        remote_eval_data_collector=remote_eval_path_collector,
        tiMe_data_collector=tiMe_path_collector,
        replay_buffer=replay_buffer,
        optimistic_exp_hp=variant['optimistic_exp'],
        **variant['algorithm_kwargs'])

    algorithm.to(ptu.device)

    start_epoch = prev_exp_state['epoch'] + \
        1 if prev_exp_state is not None else 0

    algorithm.train(start_epoch)
class RemotePathCollectorSingleMdp(object):
    def __init__(self, index, variant, candidate_size=10):
        ptu.set_gpu_mode(True)
        torch.set_num_threads(1)

        import sys
        sys.argv = ['']
        del sys

        env_max_action = variant['env_max_action']
        obs_dim = variant['obs_dim']
        action_dim = variant['action_dim']
        latent_dim = variant['latent_dim']
        vae_latent_dim = 2 * action_dim
        mlp_enconder_input_size = 2 * obs_dim + action_dim + 1 if variant[
            'use_next_obs_in_context'] else obs_dim + action_dim + 1

        mlp_enconder = MlpEncoder(hidden_sizes=[200, 200, 200],
                                  input_size=mlp_enconder_input_size,
                                  output_size=2 * variant['latent_dim'])
        self.context_encoder = ProbabilisticContextEncoder(
            mlp_enconder, variant['latent_dim'])
        self.Qs = FlattenMlp(
            hidden_sizes=variant['Qs_hidden_sizes'],
            input_size=obs_dim + action_dim + latent_dim,
            output_size=1,
        )
        self.vae_decoder = VaeDecoder(
            max_action=variant['env_max_action'],
            hidden_sizes=variant['vae_hidden_sizes'],
            input_size=obs_dim + vae_latent_dim + latent_dim,
            output_size=action_dim,
        )
        self.perturbation_generator = PerturbationGenerator(
            max_action=env_max_action,
            hidden_sizes=variant['perturbation_hidden_sizes'],
            input_size=obs_dim + action_dim + latent_dim,
            output_size=action_dim,
        )

        self.use_next_obs_in_context = variant['use_next_obs_in_context']

        self.env = env_producer(variant['domain'], variant['seed'])
        self.num_evals = variant['num_evals']
        self.max_path_length = variant['max_path_length']

        self.vae_latent_dim = vae_latent_dim
        self.candidate_size = variant['candidate_size']

        self.env.seed(10 * variant['seed'] + 1234 + index)
        set_seed(10 * variant['seed'] + 1234 + index)

        self.env.action_space.np_random.seed(123 + index)

    def async_evaluate(self, goal):
        self.env.set_goal(goal)

        self.context_encoder.clear_z()

        avg_reward = 0.
        avg_achieved = []
        final_achieved = []

        raw_context = deque()
        for i in range(self.num_evals):
            # Sample MDP indentity
            self.context_encoder.sample_z()
            inferred_mdp = self.context_encoder.z

            obs = self.env.reset()
            done = False
            path_length = 0

            while not done and path_length < self.max_path_length:
                action = self.select_actions(np.array(obs), inferred_mdp)
                next_obs, reward, done, env_info = self.env.step(action)
                avg_achieved.append(env_info['achieved'])
                if self.use_next_obs_in_context:
                    new_context = np.concatenate([
                        obs.reshape(1, -1),
                        action.reshape(1, -1),
                        next_obs.reshape(1, -1),
                        np.array(reward).reshape(1, -1)
                    ],
                                                 axis=1)
                else:
                    assert False
                    new_context = np.concatenate([
                        obs.reshape(1, -1),
                        action.reshape(1, -1),
                        np.array(reward).reshape(1, -1)
                    ],
                                                 axis=1)
                raw_context.append(new_context)
                obs = next_obs.copy()
                if i > 1:
                    avg_reward += reward
                path_length += 1

            context = from_numpy(np.concatenate(raw_context, axis=0))[None]
            self.context_encoder.infer_posterior(context)

            if i > 1:
                final_achieved.append(env_info['achieved'])

        avg_reward /= (self.num_evals - 2)
        if np.isscalar(env_info['achieved']):
            avg_achieved = np.mean(avg_achieved)
            final_achieved = np.mean(final_achieved)

        else:
            avg_achieved = np.stack(avg_achieved)
            avg_achieved = np.mean(avg_achieved, axis=0)

            final_achieved = np.stack(final_achieved)
            final_achieved = np.mean(final_achieved, axis=0)
        print(avg_reward)
        return avg_reward, (final_achieved.tolist(), self.env._goal.tolist())

    def async_evaluate_test(self, goal):
        self.env.set_goal(goal)
        self.context_encoder.clear_z()

        avg_reward_list = []
        online_achieved_list = []

        raw_context = deque()
        for _ in range(self.num_evals):
            # Sample MDP indentity
            self.context_encoder.sample_z()
            inferred_mdp = self.context_encoder.z

            obs = self.env.reset()
            done = False
            path_length = 0
            avg_reward = 0.
            online_achieved = []
            while not done and path_length < self.max_path_length:
                action = self.select_actions(np.array(obs), inferred_mdp)
                next_obs, reward, done, env_info = self.env.step(action)
                achieved = env_info['achieved']
                online_achieved.append(np.arctan(achieved[1] / achieved[0]))
                if self.use_next_obs_in_context:
                    new_context = np.concatenate([
                        obs.reshape(1, -1),
                        action.reshape(1, -1),
                        next_obs.reshape(1, -1),
                        np.array(reward).reshape(1, -1)
                    ],
                                                 axis=1)
                else:
                    new_context = np.concatenate([
                        obs.reshape(1, -1),
                        action.reshape(1, -1),
                        np.array(reward).reshape(1, -1)
                    ],
                                                 axis=1)
                raw_context.append(new_context)
                obs = next_obs.copy()
                avg_reward += reward
                path_length += 1

            avg_reward_list.append(avg_reward)
            online_achieved = np.array(online_achieved)
            online_achieved_list.append([
                online_achieved.mean(),
                online_achieved.std(), self.env._goal
            ])

            context = from_numpy(np.concatenate(raw_context, axis=0))[None]
            self.context_encoder.infer_posterior(context)

        return online_achieved_list

    def set_network_params(self, params_list):
        '''
        The shipped params are in cpu here. This function
        will set the params of the sampler's networks using
        the params in the params_list and ship them to gpu.
        '''
        context_encoder_params, Qs_params, vae_params, perturbation_params = params_list

        self.context_encoder.mlp_encoder.set_param_values(
            context_encoder_params)
        self.context_encoder.mlp_encoder.to(ptu.device)

        self.Qs.set_param_values(Qs_params)
        self.Qs.to(ptu.device)

        self.vae_decoder.set_param_values(vae_params)
        self.vae_decoder.to(ptu.device)

        self.perturbation_generator.set_param_values(perturbation_params)
        self.perturbation_generator.to(ptu.device)

    def select_actions(self, obs, inferred_mdp):

        # Repeat the obs as what BCQ has done,
        # candidate_size here indicates how many
        # candidate actions we need.
        obs = from_numpy(np.tile(obs.reshape(1, -1), (self.candidate_size, 1)))
        with torch.no_grad():
            inferred_mdp = inferred_mdp.repeat(self.candidate_size, 1)
            z = from_numpy(
                np.random.normal(0, 1, size=(obs.size(0),
                                             self.vae_latent_dim))).clamp(
                                                 -0.5, 0.5).to(ptu.device)
            candidate_actions = self.vae_decoder(obs, z, inferred_mdp)
            perturbed_actions = self.perturbation_generator.get_perturbed_actions(
                obs, candidate_actions, inferred_mdp)
            qv = self.Qs(obs, perturbed_actions, inferred_mdp)
            ind = qv.max(0)[1]
        return ptu.get_numpy(perturbed_actions[ind])
Exemplo n.º 5
0
def experiment(variant,
               bcq_policies,
               bcq_buffers,
               ensemble_params_list,
               prev_exp_state=None):
    # Create the multitask replay buffer based on the buffer list
    train_buffer = MultiTaskReplayBuffer(bcq_buffers_list=bcq_buffers, )
    # create multi-task environment and sample tasks
    env = env_producer(variant['domain'], variant['seed'])

    env_max_action = float(env.action_space.high[0])
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    vae_latent_dim = 2 * action_dim
    mlp_enconder_input_size = 2 * obs_dim + action_dim + 1 if variant[
        'use_next_obs_in_context'] else obs_dim + action_dim + 1

    variant['env_max_action'] = env_max_action
    variant['obs_dim'] = obs_dim
    variant['action_dim'] = action_dim

    variant['mlp_enconder_input_size'] = mlp_enconder_input_size

    # instantiate networks

    mlp_enconder = MlpEncoder(hidden_sizes=[200, 200, 200],
                              input_size=mlp_enconder_input_size,
                              output_size=2 * variant['latent_dim'])
    context_encoder = ProbabilisticContextEncoder(mlp_enconder,
                                                  variant['latent_dim'])

    ensemble_predictor = EnsemblePredictor(ensemble_params_list)

    Qs = FlattenMlp(
        hidden_sizes=variant['Qs_hidden_sizes'],
        input_size=obs_dim + action_dim + variant['latent_dim'],
        output_size=1,
    )
    vae_decoder = VaeDecoder(
        max_action=env_max_action,
        hidden_sizes=variant['vae_hidden_sizes'],
        input_size=obs_dim + vae_latent_dim + variant['latent_dim'],
        output_size=action_dim,
    )
    perturbation_generator = PerturbationGenerator(
        max_action=env_max_action,
        hidden_sizes=variant['perturbation_hidden_sizes'],
        input_size=obs_dim + action_dim + variant['latent_dim'],
        output_size=action_dim,
    )
    trainer = SuperQTrainer(
        ensemble_predictor=ensemble_predictor,
        num_network_ensemble=variant['num_network_ensemble'],
        bcq_policies=bcq_policies,
        std_threshold=variant['std_threshold'],
        is_combine=variant['is_combine'],
        nets=[context_encoder, Qs, vae_decoder, perturbation_generator])

    path_collector = RemotePathCollector(variant)

    algorithm = BatchMetaRLAlgorithm(
        trainer,
        path_collector,
        train_buffer,
        **variant['algo_params'],
    )

    algorithm.to(ptu.device)

    start_epoch = prev_exp_state['epoch'] + \
        1 if prev_exp_state is not None else 0

    # Log the variant
    logger.log("Variant:")
    logger.log(json.dumps(dict_to_safe_json(variant), indent=2))

    algorithm.train(start_epoch)
class BCQ(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        vae_latent_dim_multiplicity,
        target_q_coef,
        actor_hid_sizes,
        critic_hid_sizes,
        vae_e_hid_sizes,
        vae_d_hid_sizes,
        encoder_latent_dim,
    ):

        vae_latent_dim = vae_latent_dim_multiplicity * action_dim
        self.actor = Actor(state_dim, action_dim, encoder_latent_dim,
                           actor_hid_sizes, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, encoder_latent_dim,
                                  actor_hid_sizes, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=3e-4)

        self.critic = Critic(state_dim, action_dim, encoder_latent_dim,
                             critic_hid_sizes).to(device)
        self.critic_target = Critic(state_dim, action_dim, encoder_latent_dim,
                                    critic_hid_sizes).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=3e-4)

        self.vae = VAE(state_dim, action_dim, encoder_latent_dim,
                       vae_latent_dim, vae_e_hid_sizes, vae_d_hid_sizes,
                       max_action).to(device)
        self.vae_optimizer = torch.optim.Adam(self.vae.parameters(), lr=3e-4)

        mlp_enconder_input_size = 2 * state_dim + action_dim + 1

        mlp_enconder = MlpEncoder(hidden_sizes=[200, 200, 200],
                                  input_size=mlp_enconder_input_size,
                                  output_size=2 * encoder_latent_dim)
        self.context_encoder = ProbabilisticContextEncoder(
            mlp_enconder, encoder_latent_dim)
        self.context_encoder_optimizer = torch.optim.Adam(
            self.context_encoder.parameters(), lr=3e-4)

        self.max_action = max_action
        self.action_dim = action_dim
        self.target_q_coef = target_q_coef

        self._need_to_update_eval_statistics = True
        self.eval_statistics = OrderedDict()

    def get_perturbation(self, state, action, inferred_mdp):
        perturbation = self.actor.get_perturbation(state, action, inferred_mdp)
        return perturbation

    def select_action(self, state, inferred_mdp):
        with torch.no_grad():
            state = torch.FloatTensor(state.reshape(1,
                                                    -1)).repeat(10,
                                                                1).to(device)
            inferred_mdp = torch.FloatTensor(inferred_mdp.reshape(
                1, -1)).repeat(10, 1).to(device)
            action = self.actor(
                state, self.vae.decode(state, inferred_mdp=inferred_mdp),
                inferred_mdp)
            q1 = self.critic.q1(state, action, inferred_mdp)
            ind = q1.max(0)[1]
        return action[ind].cpu().data.numpy().flatten()

    def train(self, train_data, discount=0.99, tau=0.005):
        state_np, next_state_np, action, reward, done, context = train_data
        state = torch.FloatTensor(state_np).to(device)
        action = torch.FloatTensor(action).to(device)
        next_state = torch.FloatTensor(next_state_np).to(device)
        reward = torch.FloatTensor(reward).to(device)
        done = torch.FloatTensor(1 - done).to(device)
        context = torch.FloatTensor(context).to(device)

        gt.stamp('unpack_data', unique=False)

        # Infer mdep identity using context

        self.context_encoder_optimizer.zero_grad()

        inferred_mdp = self.context_encoder(context)
        in_mdp_batch_size = state.shape[0] // context.shape[0]
        inferred_mdp = torch.repeat_interleave(inferred_mdp,
                                               in_mdp_batch_size,
                                               dim=0)

        gt.stamp('infer_mdp_identity', unique=False)

        # Variational Auto-Encoder Training
        recon, mean, std = self.vae(state, action, inferred_mdp)
        recon_loss = F.mse_loss(recon, action)
        KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) -
                          std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * KL_loss

        gt.stamp('get_vae_loss', unique=False)

        self.vae_optimizer.zero_grad()
        vae_loss.backward(retain_graph=True)
        self.vae_optimizer.step()

        gt.stamp('update_vae', unique=False)

        # Critic Training
        self.critic_optimizer.zero_grad()

        with torch.no_grad():

            # Duplicate state 10 times
            state_rep = next_state.repeat_interleave(10, dim=0)
            inferred_mdp_rep = inferred_mdp.repeat_interleave(10, dim=0)

            target_Q1, target_Q2 = self.critic_target(
                state_rep,
                self.actor_target(
                    state_rep,
                    self.vae.decode(state_rep, inferred_mdp=inferred_mdp_rep),
                    inferred_mdp_rep), inferred_mdp_rep)

            # Soft Clipped Double Q-learning
            target_Q = self.target_q_coef * torch.min(target_Q1, target_Q2) + (
                1 - self.target_q_coef) * torch.max(target_Q1, target_Q2)
            target_Q = target_Q.view(state.shape[0], -1).max(1)[0].view(-1, 1)

            target_Q = reward + done * discount * target_Q

        current_Q1, current_Q2 = self.critic(state, action, inferred_mdp)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)

        gt.stamp('get_critic_loss', unique=False)

        self.critic_optimizer.zero_grad()
        critic_loss.backward(retain_graph=True)
        self.critic_optimizer.step()

        gt.stamp('update_critic', unique=False)

        self.context_encoder_optimizer.step()

        # Pertubation Model / Action Training
        sampled_actions = self.vae.decode(state,
                                          inferred_mdp=inferred_mdp.detach())
        perturbed_actions = self.actor(state, sampled_actions,
                                       inferred_mdp.detach())

        # Update through DPG
        actor_loss = -self.critic.q1(state, perturbed_actions,
                                     inferred_mdp.detach()).mean()

        gt.stamp('get_actor_loss', unique=False)

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        gt.stamp('update_actor', unique=False)

        # Update Target Networks
        for param, target_param in zip(self.critic.parameters(),
                                       self.critic_target.parameters()):
            target_param.data.copy_(tau * param.data +
                                    (1 - tau) * target_param.data)

        for param, target_param in zip(self.actor.parameters(),
                                       self.actor_target.parameters()):
            target_param.data.copy_(tau * param.data +
                                    (1 - tau) * target_param.data)
        """
        Save some statistics for eval   
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['actor_loss'] = np.mean(get_numpy(actor_loss))
            self.eval_statistics['critic_loss'] = np.mean(
                get_numpy(critic_loss))
            self.eval_statistics['vae_loss'] = np.mean(get_numpy(vae_loss))

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.actor, self.critic, self.vae, self.context_encoder.mlp_encoder
        ]

    def get_diagnostics(self):
        return self.eval_statistics

    def get_snapshot(self):
        return dict(actor_dict=self.actor.state_dict(),
                    critic_dict=self.critic.state_dict(),
                    vae_dict=self.vae.state_dict(),
                    context_encoder_dict=self.context_encoder.state_dict(),
                    eval_statistics=self.eval_statistics,
                    _need_to_update_eval_statistics=self.
                    _need_to_update_eval_statistics)

    def restore_from_snapshot(self, ss):
        ss = ss['trainer']

        self.actor.load_state_dict(ss['actor_dict'])
        self.actor.to(device)

        self.critic.load_state_dict(ss['critic_dict'])
        self.critic.to(device)

        self.vae.load_state_dict(ss['vae_dict'])
        self.vae.to(device)

        self.context_encoder.load_state_dict(ss['context_encoder_dict'])
        self.context_encoder.to(device)

        self.eval_statistics = ss['eval_statistics']
        self._need_to_update_eval_statistics = ss[
            '_need_to_update_eval_statistics']