Ejemplo n.º 1
0
    def __init__(self,
                 memory_length,
                 input_dim=1,
                 output_dim=1,
                 hidden_sizes=[32],
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 hidden_activation=tf.keras.activations.relu,
                 output_activation=tf.keras.activations.linear,
                 logger_kwargs=None,
                 loger_file_name='learning_progress_log.txt'):
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.memory_length = memory_length
        self.memory_track_models = deque(maxlen=self.memory_length)
        self.memory_track_outputs = deque(maxlen=self.memory_length)
        # Define model holders
        self.input_ph = tf.placeholder(dtype=tf.float32,
                                       shape=(None, self.input_dim))
        for m_i in range(self.memory_length):
            self.memory_track_models.append(
                MLP(hidden_sizes + [output_dim],
                    hidden_activation=hidden_activation,
                    output_activation=output_activation))
            self.memory_track_outputs.append(self.memory_track_models[m_i](
                self.input_ph))
        # Define logger
        self.lp_logger = Logger(output_fname=loger_file_name, **logger_kwargs)
    def __init__(self, act_dim, obs_dim, n_post_action,
                 obs_set_size, track_obs_set_unc_frequency,
                 x_ph, a_ph, ac_kwargs, dropout_rate,
                 logger_kwargs,
                 tf_var_scope_main='main', tf_var_scope_target='target',
                 tf_var_scope_rnd='random_net_distill'):

        self.act_dim = act_dim
        self.obs_dim = obs_dim
        self.n_post_action = n_post_action

        self.obs_set_size = obs_set_size
        self.obs_set_is_empty = True
        self.track_obs_set_unc_frequency = track_obs_set_unc_frequency

        self.tf_var_scope_main = tf_var_scope_main
        self.tf_var_scope_target = tf_var_scope_target
        self.tf_var_scope_rnd = tf_var_scope_rnd
        self.tf_var_scope_main_unc = 'main_uncertainty'
        self.tf_var_scope_target_unc = 'target_uncertainty'
        self.tf_var_scope_rnd_unc = 'rnd_uncertainty'

        # Create Actor-critic and RND to load weights for post sampling
        with tf.variable_scope(self.tf_var_scope_main_unc):
            self.x_ph = x_ph
            self.a_ph = a_ph
            # Actor-critic
            self.pi, _, self.pi_dropout_mask_generator, self.pi_dropout_mask_phs, \
            self.q1, _, self.q1_dropout_mask_generator, self.q1_dropout_mask_phs, self.q1_pi, _, \
            self.q2, _, self.q2_dropout_mask_generator, self.q2_dropout_mask_phs = mlp_actor_critic(x_ph, a_ph, **ac_kwargs,
                                                                                                    dropout_rate=dropout_rate)
        with tf.variable_scope(self.tf_var_scope_rnd_unc):
            # import pdb; pdb.set_trace()
            # Random Network Distillation
            self.rnd_targ_act, \
            self.rnd_pred_act, _, \
            self.rnd_pred_act_dropout_mask_generator, self.rnd_pred_act_dropout_mask_phs, \
            self.rnd_targ_cri, \
            self.rnd_pred_cri, _, \
            self.rnd_pred_cri_dropout_mask_generator, self.rnd_pred_cri_dropout_mask_phs = random_net_distill(x_ph, a_ph,
                                                                                                              **ac_kwargs,
                                                                                                              dropout_rate=dropout_rate)
        self.dropout_masks_set_pi = self.pi_dropout_mask_generator.generate_dropout_mask(n_post_action)
        self.dropout_masks_set_q1 = self.q1_dropout_mask_generator.generate_dropout_mask(n_post_action)
        self.dropout_masks_set_q2 = self.q2_dropout_mask_generator.generate_dropout_mask(n_post_action)
        self.dropout_masks_set_rnd_act = self.rnd_pred_act_dropout_mask_generator.generate_dropout_mask(n_post_action)
        self.dropout_masks_set_rnd_cri = self.rnd_pred_cri_dropout_mask_generator.generate_dropout_mask(n_post_action)



        self.uncertainty_logger = Logger(output_fname='dropout_uncertainty.txt',
                                         **logger_kwargs)
        self.sample_logger = Logger(output_fname='dropout_sample_observation.txt',
                                    **logger_kwargs)

        self.delayed_dropout_masks_update = False
        self.delayed_dropout_masks_update_freq = 1000
Ejemplo n.º 3
0
 def __init__(self, obs_dim, act_dim, size,
              logger_fname='experiences_log.txt', **logger_kwargs):
     # ExperienceLogger: save experiences for supervised learning
     logger_kwargs['output_fname'] = logger_fname
     self.experience_logger = Logger(**logger_kwargs)
     self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32)
     self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32)
     self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)
     self.rews_buf = np.zeros(size, dtype=np.float32)
     self.done_buf = np.zeros(size, dtype=np.float32)
     self.ptr, self.size, self.max_size = 0, 0, size
def play_game(env,
              torch_load_kwargs={},
              actor_critic=CNNCritic,
              episodes=10,
              render=False,
              logger_kwargs={}):

    logger = Logger(**logger_kwargs)
    logger.save_config(locals())

    ac = actor_critic(env.observation_space, env.action_space)

    # model saved on GPU, load on CPU: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices
    ac_saved = torch.load(**torch_load_kwargs)
    ac_saved = ac_saved.to(device)
    ac.q.load_state_dict(ac_saved.q.module.state_dict())
    ac.q.to(device)

    avg_ret = 0
    avg_raw_ret = 0
    game = 0

    for ep in range(episodes):
        o, ep_ret, ep_len, d, raw_ret = env.reset(), 0, 0, False, 0
        while not d:
            if render:
                env.render()
            o = torch.as_tensor(o, dtype=torch.float32, device=device)
            o2, r, d, info = env.step(ac.act(o))
            ep_ret += r
            ep_len += 1
            o = o2

        print(f'Returns for episode {ep}: {ep_ret}')
        avg_ret += (1. / (ep + 1)) * (ep_ret - avg_ret)

        lives = info.get('ale.lives')
        if lives is not None and lives == 0:
            raw_rew = env.get_episode_rewards()[-1]
            raw_len = env.get_episode_lengths()[-1]
            logger.log_tabular('RawRet', raw_rew)
            logger.log_tabular('RawLen', raw_len)
            logger.log_tabular('GameId', game)
            wandb.log(logger.log_current_row)
            logger.dump_tabular()
            game += 1

    print('Average raw returns:', np.mean(env.get_episode_rewards()))
    print(f'Avg returns={avg_ret} over {episodes} episodes')
    env.close()
    def __init__(self,
                 act_dim,
                 obs_dim,
                 n_post_action,
                 obs_set_size,
                 track_obs_set_unc_frequency,
                 pi,
                 x_ph,
                 a_ph,
                 pi_dropout_mask_phs,
                 pi_dropout_mask_generator,
                 rnd_targ_act,
                 rnd_pred_act,
                 rnd_targ_cri,
                 rnd_pred_cri,
                 logger_kwargs,
                 tf_var_scope_main='main',
                 tf_var_scope_target='target',
                 tf_var_scope_unc='uncertainty',
                 uncertainty_type='dropout'):
        self.act_dim = act_dim
        self.obs_dim = obs_dim
        self.n_post_action = n_post_action
        # policy
        self.pi = pi
        self.x_ph = x_ph
        self.a_ph = a_ph
        # dropout
        self.pi_dropout_mask_phs = pi_dropout_mask_phs
        self.pi_dropout_mask_generator = pi_dropout_mask_generator
        # rnd
        self.rnd_targ_act = rnd_targ_act
        self.rnd_pred_act = rnd_pred_act
        self.rnd_targ_cri = rnd_targ_cri
        self.rnd_pred_cri = rnd_pred_cri

        self.obs_set_size = obs_set_size
        self.obs_set_is_empty = True
        self.track_obs_set_unc_frequency = track_obs_set_unc_frequency

        self.tf_var_scope_main = tf_var_scope_main
        self.tf_var_scope_target = tf_var_scope_target
        self.tf_var_scope_unc = tf_var_scope_unc

        self.uncertainty_logger = Logger(
            output_fname='{}_uncertainty.txt'.format(uncertainty_type),
            **logger_kwargs)
        self.sample_logger = Logger(
            output_fname='{}_sample_observation.txt'.format(uncertainty_type),
            **logger_kwargs)
Ejemplo n.º 6
0
class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for DDPG agents.
    """
    def __init__(self,
                 obs_dim,
                 act_dim,
                 size,
                 logger_fname='experiences_log.txt',
                 **logger_kwargs):
        # ExperienceLogger: save experiences for supervised learning
        logger_kwargs['output_fname'] = logger_fname
        self.experience_logger = Logger(**logger_kwargs)

        self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)
        self.acts_mu_buf = np.zeros([size, act_dim], dtype=np.float32)
        self.acts_alpha_buf = np.zeros([size, act_dim], dtype=np.float32)
        self.acts_beta_buf = np.zeros(
            [size, int(act_dim * (act_dim - 1) / 2)], dtype=np.float32)
        self.rews_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, act_mu, act_alpha, act_beta, rew, next_obs, done,
              step_index, steps_per_epoch, start_time):
        # Save experiences in disk
        self.log_experiences(obs, act, act_mu, act_alpha, act_beta, rew,
                             next_obs, done, step_index, steps_per_epoch,
                             start_time)
        # Save experiences in memory
        self.obs1_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.acts_mu_buf[self.ptr] = act_mu
        self.acts_alpha_buf[self.ptr] = act_alpha
        self.acts_beta_buf[self.ptr] = act_beta
        self.rews_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return dict(obs1=self.obs1_buf[idxs],
                    obs2=self.obs2_buf[idxs],
                    acts=self.acts_buf[idxs],
                    acts_mu=self.acts_mu_buf[idxs],
                    acts_alpha=self.acts_alpha_buf[idxs],
                    acts_beta=self.acts_beta_buf[idxs],
                    rews=self.rews_buf[idxs],
                    done=self.done_buf[idxs])

    def log_experiences(self, obs, act, act_mu, act_alpha, act_beta, rew,
                        next_obs, done, step_index, steps_per_epoch,
                        start_time):
        self.experience_logger.log_tabular('Epoch',
                                           step_index // steps_per_epoch)
        self.experience_logger.log_tabular('Step', step_index)
        # Log observation
        for i, o_i in enumerate(obs):
            self.experience_logger.log_tabular('o_{}'.format(i), o_i)
        # Log action
        for i, a_i in enumerate(act):
            self.experience_logger.log_tabular('a_{}'.format(i), a_i)
        for i, a_i in enumerate(act_mu):
            self.experience_logger.log_tabular('a_mu_{}'.format(i), a_i)
        for i, a_i in enumerate(act_alpha):
            self.experience_logger.log_tabular('a_alpha_{}'.format(i), a_i)
        for i, a_i in enumerate(act_beta):
            self.experience_logger.log_tabular('a_beta_{}'.format(i), a_i)

        # Log reward
        self.experience_logger.log_tabular('r', rew)
        # Log next observation
        for i, o2_i in enumerate(next_obs):
            self.experience_logger.log_tabular('o2_{}'.format(i), o2_i)
        # Log done
        self.experience_logger.log_tabular('d', done)
        self.experience_logger.log_tabular('Time', time.time() - start_time)
        self.experience_logger.dump_tabular(print_data=False)
Ejemplo n.º 7
0
class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for TD3 agents.
    """

    def __init__(self, obs_dim, act_dim, size,
                 logger_fname='experiences_log.txt', **logger_kwargs):
        # ExperienceLogger: save experiences for supervised learning
        logger_kwargs['output_fname'] = logger_fname
        self.experience_logger = Logger(**logger_kwargs)
        self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)
        self.rews_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done,
              uncertainty,
              q1_pred, q2_pred, q1_post, q2_post,
              rnd_e_act, rnd_e_cri,
              step_index, steps_per_epoch, start_time):
        # Save experiences in disk
        self.log_experiences(obs, act, rew, next_obs, done,
                             uncertainty,
                             q1_pred, q2_pred, q1_post, q2_post,
                             rnd_e_act, rnd_e_cri,
                             step_index, steps_per_epoch, start_time)
        # Save experiences in memory
        self.obs1_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.rews_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr+1) % self.max_size
        self.size = min(self.size+1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return dict(obs1=self.obs1_buf[idxs],
                    obs2=self.obs2_buf[idxs],
                    acts=self.acts_buf[idxs],
                    rews=self.rews_buf[idxs],
                    done=self.done_buf[idxs])

    def log_experiences(self, obs, act, rew, next_obs, done,
                        uncertainty,
                        q1_pred, q2_pred, q1_post, q2_post,
                        rnd_e_act, rnd_e_cri,
                        step_index, steps_per_epoch, start_time):
        self.experience_logger.log_tabular('Epoch', step_index // steps_per_epoch)
        self.experience_logger.log_tabular('Step', step_index)
        # Log observation
        for i, o_i in enumerate(obs):
            self.experience_logger.log_tabular('o_{}'.format(i), o_i)
        # Log action
        for i, a_i in enumerate(act):
            self.experience_logger.log_tabular('a_{}'.format(i), a_i)
        # Log reward
        self.experience_logger.log_tabular('r', rew)
        # Log next observation
        for i, o2_i in enumerate(next_obs):
            self.experience_logger.log_tabular('o2_{}'.format(i), o2_i)
        # Log uncertainty: flatten in row-major order
        for i, unc_i in enumerate(np.array(uncertainty).flatten(order='C')):
            self.experience_logger.log_tabular('unc_{}'.format(i), unc_i)
        # Log q1_post, q2_post
        self.experience_logger.log_tabular('q1_pred', q1_pred)
        self.experience_logger.log_tabular('q2_pred', q2_pred)
        # Log q1_post, q2_post
        for i in range(len(q1_post)):
            self.experience_logger.log_tabular('q1_post_{}'.format(i), q1_post[i])
            self.experience_logger.log_tabular('q2_post_{}'.format(i), q2_post[i])
        # Log RND actor prediction error
        self.experience_logger.log_tabular('rnd_e_act', rnd_e_act)
        # Log RND critic prediction error
        self.experience_logger.log_tabular('rnd_e_cri', rnd_e_cri)
        # Log done
        self.experience_logger.log_tabular('d', done)
        self.experience_logger.log_tabular('Time', time.time() - start_time)
        self.experience_logger.dump_tabular(print_data=False)
Ejemplo n.º 8
0
class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for TD3 agents.
    """
    def __init__(self,
                 obs_dim,
                 act_dim,
                 size,
                 logger_fname='experiences_log.txt',
                 **logger_kwargs):
        # ExperienceLogger: save experiences for supervised learning
        logger_kwargs['output_fname'] = logger_fname
        self.experience_logger = Logger(**logger_kwargs)
        self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)
        self.rews_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done, step_index, epoch_index,
              time, **kwargs):
        # Save experiences in disk
        self.log_experiences(obs, act, rew, next_obs, done, step_index,
                             epoch_index, time, **kwargs)
        self.obs1_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.rews_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return dict(obs1=self.obs1_buf[idxs],
                    obs2=self.obs2_buf[idxs],
                    acts=self.acts_buf[idxs],
                    rews=self.rews_buf[idxs],
                    done=self.done_buf[idxs])

    def log_experiences(self, obs, act, rew, next_obs, done, step_index,
                        epoch_index, time, **kwargs):
        self.experience_logger.log_tabular('Epoch', epoch_index)
        self.experience_logger.log_tabular('Step', step_index)
        # Log observation
        for i, o_i in enumerate(obs):
            self.experience_logger.log_tabular('o_{}'.format(i), o_i)
        # Log action
        for i, a_i in enumerate(act):
            self.experience_logger.log_tabular('a_{}'.format(i), a_i)
        # Log reward
        self.experience_logger.log_tabular('r', rew)
        # Log next observation
        for i, o2_i in enumerate(next_obs):
            self.experience_logger.log_tabular('o2_{}'.format(i), o2_i)
        # Log other data
        for key, value in kwargs.items():
            for i, v in enumerate(np.array(value).flatten(order='C')):
                self.experience_logger.log_tabular('{}_{}'.format(key, i), v)
        # Log done
        self.experience_logger.log_tabular('d', done)
        self.experience_logger.log_tabular('Time', time)
        self.experience_logger.dump_tabular(print_data=False)
Ejemplo n.º 9
0
class LearningProgress(object):
    """
    Learning Progress encapsulates multiple versions of learned policy or value function in a time sequence.
    These different versions are used to make predictions and calculate learning progress.
    """
    def __init__(self,
                 memory_length,
                 input_dim=1,
                 output_dim=1,
                 hidden_sizes=[32],
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 hidden_activation=tf.keras.activations.relu,
                 output_activation=tf.keras.activations.linear,
                 logger_kwargs=None,
                 loger_file_name='learning_progress_log.txt'):
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.memory_length = memory_length
        self.memory_track_models = deque(maxlen=self.memory_length)
        self.memory_track_outputs = deque(maxlen=self.memory_length)
        # Define model holders
        self.input_ph = tf.placeholder(dtype=tf.float32,
                                       shape=(None, self.input_dim))
        for m_i in range(self.memory_length):
            self.memory_track_models.append(
                MLP(hidden_sizes + [output_dim],
                    hidden_activation=hidden_activation,
                    output_activation=output_activation))
            self.memory_track_outputs.append(self.memory_track_models[m_i](
                self.input_ph))
        # Define logger
        self.lp_logger = Logger(output_fname=loger_file_name, **logger_kwargs)

    def compute_outputs(self, input, sess):
        outputs = np.zeros((self.memory_length, self.output_dim))
        for o_i in range(self.memory_length):
            outputs[o_i, :] = sess.run(
                self.memory_track_outputs[o_i],
                feed_dict={self.input_ph: input.reshape(1, -1)})
        return outputs

    def compute_learning_progress(self, input, sess, t=0, start_time=0):
        outputs = self.compute_outputs(input, sess)
        # First half part of memory track window
        first_half_outputs = outputs[0:int(np.floor(self.memory_length /
                                                    2)), :]
        # Second half part of memory track window
        second_half_outputs = outputs[int(np.floor(self.memory_length /
                                                   2)):, :]
        # L2 Norm i.e. Euclidean Distance
        lp_norm = np.linalg.norm(np.mean(second_half_outputs, axis=0) -
                                 np.mean(first_half_outputs, axis=0),
                                 ord=2)

        # Measure sum of variance
        var_outputs = np.sum(np.var(outputs, axis=0))
        var_first_half_outputs = np.sum(np.var(first_half_outputs, axis=0))
        var_second_half_outputs = np.sum(np.var(second_half_outputs, axis=0))
        self.lp_logger.log_tabular('VarAll', var_outputs)
        self.lp_logger.log_tabular('VarFirstHalf', var_first_half_outputs)
        self.lp_logger.log_tabular('VarSecondHalf', var_second_half_outputs)
        self.lp_logger.log_tabular(
            'VarChange', var_second_half_outputs - var_first_half_outputs)
        # Log
        self.lp_logger.log_tabular('Step', t)
        self.lp_logger.log_tabular('L2LP', lp_norm)
        self.lp_logger.log_tabular('Time', time.time() - start_time)
        self.lp_logger.dump_tabular(print_data=False)
        return lp_norm, var_outputs, var_first_half_outputs, var_second_half_outputs

    def update_latest_memory(self, weights):
        """Update oldest model to latest weights, then append the latest model and output_placeholder
        to the top of queue."""
        # Set oldest model to latest weights
        oldest_model = self.memory_track_models.popleft()
        oldest_model.set_weights(weights)
        self.memory_track_models.append(oldest_model)
        # Move the corresponding output_placeholder to the top of queue
        self.memory_track_outputs.append(self.memory_track_outputs.popleft())
class DropoutUncertaintyModule:
    """This class is to provide functions to investigate dropout-based uncertainty change trajectories."""
    def __init__(self, act_dim, obs_dim, n_post_action,
                 obs_set_size, track_obs_set_unc_frequency,
                 x_ph, a_ph, ac_kwargs, dropout_rate,
                 logger_kwargs,
                 tf_var_scope_main='main', tf_var_scope_target='target',
                 tf_var_scope_rnd='random_net_distill'):

        self.act_dim = act_dim
        self.obs_dim = obs_dim
        self.n_post_action = n_post_action

        self.obs_set_size = obs_set_size
        self.obs_set_is_empty = True
        self.track_obs_set_unc_frequency = track_obs_set_unc_frequency

        self.tf_var_scope_main = tf_var_scope_main
        self.tf_var_scope_target = tf_var_scope_target
        self.tf_var_scope_rnd = tf_var_scope_rnd
        self.tf_var_scope_main_unc = 'main_uncertainty'
        self.tf_var_scope_target_unc = 'target_uncertainty'
        self.tf_var_scope_rnd_unc = 'rnd_uncertainty'

        # Create Actor-critic and RND to load weights for post sampling
        with tf.variable_scope(self.tf_var_scope_main_unc):
            self.x_ph = x_ph
            self.a_ph = a_ph
            # Actor-critic
            self.pi, _, self.pi_dropout_mask_generator, self.pi_dropout_mask_phs, \
            self.q1, _, self.q1_dropout_mask_generator, self.q1_dropout_mask_phs, self.q1_pi, _, \
            self.q2, _, self.q2_dropout_mask_generator, self.q2_dropout_mask_phs = mlp_actor_critic(x_ph, a_ph, **ac_kwargs,
                                                                                                    dropout_rate=dropout_rate)
        with tf.variable_scope(self.tf_var_scope_rnd_unc):
            # import pdb; pdb.set_trace()
            # Random Network Distillation
            self.rnd_targ_act, \
            self.rnd_pred_act, _, \
            self.rnd_pred_act_dropout_mask_generator, self.rnd_pred_act_dropout_mask_phs, \
            self.rnd_targ_cri, \
            self.rnd_pred_cri, _, \
            self.rnd_pred_cri_dropout_mask_generator, self.rnd_pred_cri_dropout_mask_phs = random_net_distill(x_ph, a_ph,
                                                                                                              **ac_kwargs,
                                                                                                              dropout_rate=dropout_rate)
        self.dropout_masks_set_pi = self.pi_dropout_mask_generator.generate_dropout_mask(n_post_action)
        self.dropout_masks_set_q1 = self.q1_dropout_mask_generator.generate_dropout_mask(n_post_action)
        self.dropout_masks_set_q2 = self.q2_dropout_mask_generator.generate_dropout_mask(n_post_action)
        self.dropout_masks_set_rnd_act = self.rnd_pred_act_dropout_mask_generator.generate_dropout_mask(n_post_action)
        self.dropout_masks_set_rnd_cri = self.rnd_pred_cri_dropout_mask_generator.generate_dropout_mask(n_post_action)



        self.uncertainty_logger = Logger(output_fname='dropout_uncertainty.txt',
                                         **logger_kwargs)
        self.sample_logger = Logger(output_fname='dropout_sample_observation.txt',
                                    **logger_kwargs)

        self.delayed_dropout_masks_update = False
        self.delayed_dropout_masks_update_freq = 1000

    # Load weights
    def update_weights_of_target_unc(self, sess):
        """Update uncertainty policy to current policy"""
        sess.run(tf.group([tf.assign(v_unc, v_targ)
                           for v_targ, v_unc in
                           zip(get_vars(self.tf_var_scope_target), get_vars(self.tf_var_scope_target_unc))]))

    def update_weights_of_main_unc(self, sess):
        """Update uncertainty policy to current policy"""
        sess.run(tf.group([tf.assign(v_unc, v_main)
                           for v_main, v_unc in
                           zip(get_vars(self.tf_var_scope_main), get_vars(self.tf_var_scope_main_unc))]))

    def update_weights_of_rnd_unc(self, sess):
        """Update uncertainty policy to current policy"""
        sess.run(tf.group([tf.assign(v_unc, v_rnd)
                           for v_rnd, v_unc in
                           zip(get_vars(self.tf_var_scope_rnd), get_vars(self.tf_var_scope_rnd_unc))]))

    # Update dropout masks
    def uncertainty_pi_dropout_masks_update(self):
        """Update uncertainty dropout_masks."""
        self.dropout_masks_set_pi = self.pi_dropout_mask_generator.generate_dropout_mask(self.n_post_action)

    def uncertainty_q_dropout_masks_update(self):
        """Update uncertainty dropout_masks."""
        self.dropout_masks_set_q1 = self.q1_dropout_mask_generator.generate_dropout_mask(self.n_post_action)
        self.dropout_masks_set_q2 = self.q2_dropout_mask_generator.generate_dropout_mask(self.n_post_action)

    def uncertainty_rnd_dropout_masks_update(self):
        """Update uncertainty dropout_masks."""
        self.dropout_masks_set_rnd_act = self.rnd_pred_act_dropout_mask_generator.generate_dropout_mask(self.n_post_action)
        self.dropout_masks_set_rnd_cri = self.rnd_pred_cri_dropout_mask_generator.generate_dropout_mask(self.n_post_action)

    # Get post samples
    def get_post_samples_act(self, obs, sess, step_index):
        """Return a post sample of actions for an observation."""
        feed_dictionary = {self.x_ph: obs.reshape(1, -1)}
        a_post = np.zeros((self.n_post_action, self.act_dim))
        if not self.delayed_dropout_masks_update:
            self.uncertainty_pi_dropout_masks_update()
        elif step_index % self.delayed_dropout_masks_update_freq:
            self.uncertainty_pi_dropout_masks_update()
        else:
            pass

        for mask_i in range(len(self.pi_dropout_mask_phs)):
            feed_dictionary[self.pi_dropout_mask_phs[mask_i]] = self.dropout_masks_set_pi[mask_i]
        # import pdb; pdb.set_trace()
        a_post = sess.run(self.pi, feed_dict=feed_dictionary)[:,0,:]
        return a_post
    
    def get_post_samples_q(self, obs, act, sess, step_index):
        """Return a post sample for a (observation, action) pair."""
        feed_dictionary = {self.x_ph: obs.reshape(1, -1), self.a_ph: act.reshape(1, -1)}
        q1_post = np.zeros((self.n_post_action, ))
        q2_post = np.zeros((self.n_post_action, ))
        if not self.delayed_dropout_masks_update:
            self.uncertainty_q_dropout_masks_update()
        elif step_index % self.delayed_dropout_masks_update_freq:
            self.uncertainty_q_dropout_masks_update()
        else:
            pass

        for mask_i in range(len(self.q1_dropout_mask_phs)):
            feed_dictionary[self.q1_dropout_mask_phs[mask_i]] = self.dropout_masks_set_q1[mask_i]
            feed_dictionary[self.q2_dropout_mask_phs[mask_i]] = self.dropout_masks_set_q2[mask_i]
        # import pdb; pdb.set_trace()
        q1_post = sess.run(self.q1, feed_dict=feed_dictionary)[:,0]
        q2_post = sess.run(self.q2, feed_dict=feed_dictionary)[:,0]
        return q1_post, q2_post

    def get_post_samples_rnd_act(self, obs, sess, step_index):
        """Return a post sample of action predictions for an observation."""
        feed_dictionary = {self.x_ph: obs.reshape(1, -1)}
        rnd_a_post = np.zeros((self.n_post_action, self.act_dim))
        if not self.delayed_dropout_masks_update:
            self.uncertainty_pi_dropout_masks_update()
        elif step_index % self.delayed_dropout_masks_update_freq:
            self.uncertainty_rnd_dropout_masks_update()
        else:
            pass

        for mask_i in range(len(self.rnd_pred_act_dropout_mask_phs)):
            feed_dictionary[self.rnd_pred_act_dropout_mask_phs[mask_i]] = self.dropout_masks_set_rnd_act[mask_i]
        rnd_a_post = sess.run(self.rnd_pred_act, feed_dict=feed_dictionary)[:, 0, :]
        return rnd_a_post

    def get_post_samples_rnd_cri(self, obs, act, sess, step_index):
        """Return a post sample of q predictions for a (observation, action) pair."""
        feed_dictionary = {self.x_ph: obs.reshape(1, -1), self.a_ph: act.reshape(1, -1)}
        rnd_q_post = np.zeros((self.n_post_action,))
        if not self.delayed_dropout_masks_update:
            self.uncertainty_rnd_dropout_masks_update()
        elif step_index % self.delayed_dropout_masks_update_freq:
            self.uncertainty_rnd_dropout_masks_update()
        else:
            pass

        for mask_i in range(len(self.rnd_pred_cri_dropout_mask_phs)):
            feed_dictionary[self.rnd_pred_cri_dropout_mask_phs[mask_i]] = self.dropout_masks_set_rnd_cri[mask_i]
        # import pdb; pdb.set_trace()
        rnd_q_post = sess.run(self.rnd_pred_cri, feed_dict=feed_dictionary)[:, 0]
        return rnd_q_post

    # Sample observations to track
    def sample_obs_set_from_replay_buffer(self, replay_buffer):
        """Sample an obs set from replay buffer."""
        self.obs_set = replay_buffer.sample_batch(self.obs_set_size)['obs1']
        self.obs_set_is_empty = False
        # Save sampled observations
        for i, o in enumerate(self.obs_set):
            self.sample_logger.log_tabular('Observation', i)
            # import pdb; pdb.set_trace()
            for dim, o_i in enumerate(o):
                self.sample_logger.log_tabular('o_{}'.format(dim), o_i)
            self.sample_logger.dump_tabular(print_data=False)

    # Calculate uncertainty of tracked observations
    def calculate_obs_set_uncertainty(self, sess, epoch, step):
        self.uncertainty_logger.log_tabular('Epoch', epoch)
        self.uncertainty_logger.log_tabular('Step', step)
        for obs_i, obs in enumerate(self.obs_set):
            # Calculate uncertainty
            a_post = self.get_post_samples_act(obs, sess, step)
            a_cov = np.cov(a_post, rowvar=False)
            for unc_i, unc_v in enumerate(np.array(a_cov).flatten(order='C')):
                self.uncertainty_logger.log_tabular('Obs{}_unc_{}'.format(obs_i, unc_i), unc_v)
            # Calculate RND prediction error
            rnd_targ, rnd_pred, rnd_pred_error = self.calculate_actor_RND_pred_error(obs, sess)
            # import pdb; pdb.set_trace()
            for rnd_i in range(self.act_dim):
                self.uncertainty_logger.log_tabular('Obs{}_rnd_t_{}'.format(obs_i, rnd_i), rnd_targ[rnd_i])
                self.uncertainty_logger.log_tabular('Obs{}_rnd_p_{}'.format(obs_i, rnd_i), rnd_pred[rnd_i])
            self.uncertainty_logger.log_tabular('Obs{}_rnd_error'.format(obs_i), rnd_pred_error)
        self.uncertainty_logger.dump_tabular(print_data=False)

    # Calculate RND prediction errors
    def calculate_actor_RND_pred_error(self, obs, sess):
        """Calculate prediction error without dropout"""
        feed_dictionary = {self.x_ph: obs.reshape(1, -1)}
        for mask_i in range(len(self.rnd_pred_act_dropout_mask_phs)):
            feed_dictionary[self.rnd_pred_act_dropout_mask_phs[mask_i]] = np.ones([1, self.rnd_pred_act_dropout_mask_phs[mask_i].shape.as_list()[1]])
        targ, pred = sess.run([self.rnd_targ_act, self.rnd_pred_act], feed_dict=feed_dictionary)
        pred_error = np.sqrt(np.sum((pred[0]-targ)**2))
        return targ[0], pred[0][0], pred_error

    def calculate_critic_RND_pred_error(self, obs, act, sess):
        """Calculate prediction error without dropout"""
        feed_dictionary = {self.x_ph: obs.reshape(1, -1), self.a_ph: act.reshape(1, -1)}
        for mask_i in range(len(self.rnd_pred_cri_dropout_mask_phs)):
            feed_dictionary[self.rnd_pred_cri_dropout_mask_phs[mask_i]] = np.ones([1, self.rnd_pred_cri_dropout_mask_phs[mask_i].shape.as_list()[1]])
        targ, pred = sess.run([self.rnd_targ_cri, self.rnd_pred_cri], feed_dict=feed_dictionary)
        pred_error = np.sqrt(np.sum(pred[0]-targ)**2)
        return targ[0], pred[0], pred_error
class UncertaintyModule(object):
    """This class is to provide functions to investigate dropout-based uncertainty change trajectories."""
    def __init__(self,
                 act_dim,
                 obs_dim,
                 n_post_action,
                 obs_set_size,
                 track_obs_set_unc_frequency,
                 pi,
                 x_ph,
                 a_ph,
                 pi_dropout_mask_phs,
                 pi_dropout_mask_generator,
                 rnd_targ_act,
                 rnd_pred_act,
                 rnd_targ_cri,
                 rnd_pred_cri,
                 logger_kwargs,
                 tf_var_scope_main='main',
                 tf_var_scope_target='target',
                 tf_var_scope_unc='uncertainty',
                 uncertainty_type='dropout'):
        self.act_dim = act_dim
        self.obs_dim = obs_dim
        self.n_post_action = n_post_action
        # policy
        self.pi = pi
        self.x_ph = x_ph
        self.a_ph = a_ph
        # dropout
        self.pi_dropout_mask_phs = pi_dropout_mask_phs
        self.pi_dropout_mask_generator = pi_dropout_mask_generator
        # rnd
        self.rnd_targ_act = rnd_targ_act
        self.rnd_pred_act = rnd_pred_act
        self.rnd_targ_cri = rnd_targ_cri
        self.rnd_pred_cri = rnd_pred_cri

        self.obs_set_size = obs_set_size
        self.obs_set_is_empty = True
        self.track_obs_set_unc_frequency = track_obs_set_unc_frequency

        self.tf_var_scope_main = tf_var_scope_main
        self.tf_var_scope_target = tf_var_scope_target
        self.tf_var_scope_unc = tf_var_scope_unc

        self.uncertainty_logger = Logger(
            output_fname='{}_uncertainty.txt'.format(uncertainty_type),
            **logger_kwargs)
        self.sample_logger = Logger(
            output_fname='{}_sample_observation.txt'.format(uncertainty_type),
            **logger_kwargs)

    # TODO: target policy
    def uncertainty_policy_update_targ(self, sess):
        """Update uncertainty policy to current policy"""
        sess.run(
            tf.group([
                tf.assign(v_unc, v_main)
                for v_main, v_unc in zip(get_vars(self.tf_var_scope_target),
                                         get_vars(self.tf_var_scope_unc))
            ]))

    def uncertainty_policy_update(self, sess):
        """Update uncertainty policy to current policy"""
        sess.run(
            tf.group([
                tf.assign(v_unc, v_main)
                for v_main, v_unc in zip(get_vars(self.tf_var_scope_main),
                                         get_vars(self.tf_var_scope_unc))
            ]))

    def sample_obs_set_from_replay_buffer(self, replay_buffer):
        """Sample an obs set from replay buffer."""
        self.obs_set = replay_buffer.sample_batch(self.obs_set_size)['obs1']
        self.obs_set_is_empty = False
        # Save sampled observations
        for i, o in enumerate(self.obs_set):
            self.sample_logger.log_tabular('Observation', i)
            # import pdb; pdb.set_trace()
            for dim, o_i in enumerate(o):
                self.sample_logger.log_tabular('o_{}'.format(dim), o_i)
            self.sample_logger.dump_tabular(print_data=False)

    def calculate_obs_set_uncertainty(self, sess, epoch, step):
        self.uncertainty_logger.log_tabular('Epoch', epoch)
        self.uncertainty_logger.log_tabular('Step', step)
        for obs_i, obs in enumerate(self.obs_set):
            # Calculate uncertainty
            a_post = self.get_post_samples(obs, sess, step)
            a_cov = np.cov(a_post, rowvar=False)
            for unc_i, unc_v in enumerate(np.array(a_cov).flatten(order='C')):
                self.uncertainty_logger.log_tabular(
                    'Obs{}_unc_{}'.format(obs_i, unc_i), unc_v)
            # Calculate RND prediction error
            rnd_targ, rnd_pred, rnd_pred_error = self.calculate_actor_RND_pred_error(
                obs, sess)
            for rnd_i in range(self.act_dim):
                self.uncertainty_logger.log_tabular(
                    'Obs{}_rnd_t_{}'.format(obs_i, rnd_i), rnd_targ[rnd_i])
                self.uncertainty_logger.log_tabular(
                    'Obs{}_rnd_p_{}'.format(obs_i, rnd_i), rnd_pred[rnd_i])
            self.uncertainty_logger.log_tabular(
                'Obs{}_rnd_error'.format(obs_i), rnd_pred_error)
        self.uncertainty_logger.dump_tabular(print_data=False)

    def calculate_actor_RND_pred_error(self, obs, sess):
        feed_dictionary = {self.x_ph: obs.reshape(1, -1)}
        targ, pred = sess.run([self.rnd_targ_act, self.rnd_pred_act],
                              feed_dict=feed_dictionary)
        pred_error = np.sqrt(np.sum((pred - targ)**2))
        return targ[0], pred[0], pred_error

    def calculate_critic_RND_pred_error(self, obs, act, sess):
        feed_dictionary = {
            self.x_ph: obs.reshape(1, -1),
            self.a_ph: act.reshape(1, -1)
        }
        targ, pred = sess.run([self.rnd_targ_cri, self.rnd_pred_cri],
                              feed_dict=feed_dictionary)
        pred_error = np.sqrt(np.sum(pred - targ)**2)
        return targ[0], pred[0], pred_error

    def get_post_samples(self, obs, sess):
        """Return a post sample matrix for an observation."""
        return np.zeros(self.n_post_action, self.act_dim)