def initialize_dynamics_model(self):
     obs_dim = self._obs[self.observation_key].shape[1]
     self.dynamics_model = Mlp(
         hidden_sizes=[128, 128],
         output_size=obs_dim,
         input_size=obs_dim + self._action_dim,
     )
     self.dynamics_model.to(ptu.device)
     self.dynamics_optimizer = Adam(self.dynamics_model.parameters())
     self.dynamics_loss = MSELoss()
Exemplo n.º 2
0
 def __init__(
         self,
         vae,
         hidden_sizes=list([64, 128, 64]),
         init_w=1e-3,
         hidden_init=ptu.fanin_init,
         output_activation=identity,
         layer_norm=False,
         **kwargs
 ):
     self.save_init_params(locals())
     super().__init__()
     self.vae = vae
     self.representation_size = self.vae.representation_size
     self.hidden_init = hidden_init
     self.output_activation = output_activation
     # self.dist_mu = np.zeros(self.representation_size)
     # self.dist_std = np.ones(self.representation_size)
     self.dist_mu = self.vae.dist_mu
     self.dist_std = self.vae.dist_std
     self.relu = nn.ReLU()
     self.init_w = init_w
     hidden_sizes = list(hidden_sizes)
     self.network=Mlp(hidden_sizes,
                      self.representation_size,
                      self.representation_size,
                      layer_norm=layer_norm,
                      hidden_init=hidden_init,
                      output_activation=output_activation,
                      init_w=init_w)
Exemplo n.º 3
0
 def __init__(
     self,
     representation_size,
     input_size,
     hidden_sizes,
     init_w=1e-3,
     hidden_init=ptu.fanin_init,
     output_activation=identity,
     output_scale=1,
     layer_norm=False,
 ):
     super().__init__()
     self.representation_size = representation_size
     self.hidden_init = hidden_init
     self.output_activation = output_activation
     self.dist_mu = np.zeros(self.representation_size)
     self.dist_std = np.ones(self.representation_size)
     self.relu = nn.ReLU()
     self.sigmoid = nn.Sigmoid()
     self.init_w = init_w
     hidden_sizes = list(hidden_sizes)
     self.encoder = TwoHeadMlp(hidden_sizes,
                               representation_size,
                               representation_size,
                               input_size,
                               layer_norm=layer_norm)
     hidden_sizes.reverse()
     self.decoder = Mlp(hidden_sizes,
                        input_size,
                        representation_size,
                        layer_norm=layer_norm,
                        output_activation=output_activation,
                        output_bias=None)
     self.output_scale = output_scale
Exemplo n.º 4
0
def experiment(variant):
    env = DiscreteSwimmerEnv(**variant['env_params'])

    qf = Mlp(input_size=int(np.prod(env.observation_space.shape)),
             output_size=env.action_space.n,
             **variant['qf_kwargs'])
    algorithm = DQN(env, qf=qf, **variant['algo_params'])
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 5
0
def experiment(variant):
    env = variant['env_class'](**variant['env_kwargs'])
    env = DiscretizeEnv(env, variant['num_bins'])
    # env = DiscreteReacherEnv(**variant['env_kwargs'])

    qf = Mlp(input_size=int(np.prod(env.observation_space.shape)),
             output_size=env.action_space.n,
             **variant['qf_kwargs'])
    algorithm = FiniteHorizonDQN(env, qf, **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 6
0
def experiment(variant):
    env = DiscreteReacherEnv(**variant['env_params'])
    qf = Mlp(
        hidden_sizes=[32, 32],
        input_size=int(np.prod(env.observation_space.shape)),
        output_size=env.action_space.n,
    )
    algorithm = DQN(env, qf=qf, **variant['algo_params'])
    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
Exemplo n.º 7
0
    def __init__(self,
                 representation_size,
                 input_size,
                 hidden_sizes=list([64, 128, 64]),
                 init_w=1e-3,
                 hidden_init=ptu.fanin_init,
                 output_activation=identity,
                 output_scale=1,
                 layer_norm=False,
                 normalize=True,
                 train_data_mean=None,
                 train_data_std=None,
                 **kwargs):
        self.save_init_params(locals())
        super().__init__()
        self.representation_size = representation_size
        self.hidden_init = hidden_init
        self.output_activation = output_activation
        self.dist_mu = np.zeros(self.representation_size)
        self.dist_std = np.ones(self.representation_size)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.init_w = init_w
        hidden_sizes = list(hidden_sizes)
        self.input_size = input_size
        self.encoder = TwoHeadMlp(hidden_sizes,
                                  representation_size,
                                  representation_size,
                                  input_size,
                                  layer_norm=layer_norm,
                                  hidden_init=hidden_init,
                                  output_activation=output_activation,
                                  init_w=init_w)
        hidden_sizes.reverse()
        self.decoder = Mlp(hidden_sizes,
                           input_size,
                           representation_size,
                           layer_norm=layer_norm,
                           hidden_init=hidden_init,
                           output_activation=output_activation,
                           init_w=init_w)
        self.output_scale = output_scale

        self.normalize = normalize
        if train_data_mean is None:
            self.train_data_mean = ptu.np_to_var(np.zeros(input_size))
        else:
            self.train_data_mean = train_data_mean
        if train_data_std is None:
            self.train_data_std = ptu.np_to_var(np.ones(input_size))
        else:
            self.train_data_std = train_data_std
Exemplo n.º 8
0
def experiment(variant):
    # register_grid_envs()
    # env = gym.make("GridMaze1-v0")
    env = gym.make("Pong-ram-v0")

    qf = Mlp(
        hidden_sizes=[100, 100],
        input_size=int(np.prod(env.observation_space.shape)),
        output_size=env.action_space.n,
    )
    algorithm = DQN(env, qf=qf, **variant['algo_params'])
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 9
0
 def __init__(
     self,
     obs_processor,
     obs_processor_output_dim,
     action_dim,
     hidden_sizes,
 ):
     super().__init__()
     self.obs_processor = obs_processor
     self.obs_processor_output_dim = obs_processor_output_dim
     self.mean_and_log_std_net = Mlp(
         hidden_sizes=hidden_sizes,
         output_size=action_dim * 2,
         input_size=obs_processor_output_dim,
     )
     self.action_dim = action_dim
Exemplo n.º 10
0
def experiment(variant):
    env = variant['env_class']()
    if variant['multitask']:
        env = MultitaskToFlatEnv(env)

    qf = Mlp(
        hidden_sizes=[32, 32],
        input_size=int(np.prod(env.observation_space.shape)),
        output_size=env.action_space.n,
    )
    qf_criterion = variant['qf_criterion_class']()
    algorithm = variant['algo_class'](env,
                                      qf=qf,
                                      qf_criterion=qf_criterion,
                                      **variant['algo_params'])
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 11
0
def experiment(variant):
    env = gym.make(variant['env_id'])
    training_env = gym.make(variant['env_id'])

    qf = Mlp(
        hidden_sizes=[32, 32],
        input_size=int(np.prod(env.observation_space.shape)),
        output_size=env.action_space.n,
    )
    qf_criterion = variant['qf_criterion_class']()
    algorithm = variant['algo_class'](env,
                                      training_env=training_env,
                                      qf=qf,
                                      qf_criterion=qf_criterion,
                                      **variant['algo_params'])
    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):
    # env = gym.make('CartPole-v0')
    # training_env = gym.make('CartPole-v0')
    # env = DiscreteReacherEnv(**variant['env_kwargs'])
    # env = DiscreteSwimmerEnv()
    env = variant['env_class'](**variant['env_kwargs'])
    env = DiscretizeEnv(env, variant['num_bins'])

    qf = Mlp(input_size=int(np.prod(env.observation_space.shape)),
             output_size=env.action_space.n,
             **variant['qf_kwargs'])
    qf_criterion = nn.MSELoss()
    # Use this to switch to DoubleDQN
    # algorithm = DoubleDQN(
    algorithm = variant['algo_class'](env,
                                      qf=qf,
                                      qf_criterion=qf_criterion,
                                      **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 13
0
    def __init__(
        self,
        representation_size,
        architecture,
        normalize=True,
        output_classes=100,
        encoder_class=CNN,
        decoder_class=DCNN,
        decoder_output_activation=identity,
        decoder_distribution='bernoulli',
        input_channels=1,
        imsize=224,
        init_w=1e-3,
        min_variance=1e-3,
        hidden_init=ptu.fanin_init,
        delta_features=False,
        pretrained_features=False,
    ):
        """

        :param representation_size:
        :param conv_args:
        must be a dictionary specifying the following:
            kernel_sizes
            n_channels
            strides
        :param conv_kwargs:
        a dictionary specifying the following:
            hidden_sizes
            batch_norm
        :param deconv_args:
        must be a dictionary specifying the following:
            hidden_sizes
            deconv_input_width
            deconv_input_height
            deconv_input_channels
            deconv_output_kernel_size
            deconv_output_strides
            deconv_output_channels
            kernel_sizes
            n_channels
            strides
        :param deconv_kwargs:
            batch_norm
        :param encoder_class:
        :param decoder_class:
        :param decoder_output_activation:
        :param decoder_distribution:
        :param input_channels:
        :param imsize:
        :param init_w:
        :param min_variance:
        :param hidden_init:
        """
        super().__init__()
        #         super().__init__(representation_size)
        if min_variance is None:
            self.log_min_variance = None
        else:
            self.log_min_variance = float(np.log(min_variance))
        self.input_channels = input_channels
        self.imsize = imsize
        self.imlength = self.imsize * self.imsize * self.input_channels
        self.representation_size = representation_size
        self.output_classes = output_classes

        self.normalize = normalize

        self.img_mean = torch.tensor([0.485, 0.456, 0.406])
        self.img_std = torch.tensor([0.229, 0.224, 0.225])
        self.img_mean = self.img_mean.repeat(epic.CROP_WIDTH, epic.CROP_HEIGHT,
                                             1).transpose(0, 2).to(ptu.device)
        self.img_std = self.img_std.repeat(epic.CROP_WIDTH, epic.CROP_HEIGHT,
                                           1).transpose(0, 2).to(ptu.device)
        # self.img_normalizer = torchvision.transforms.Normalize(self.img_mean, self.img_std)

        self.encoder = torchvision.models.resnet.ResNet(
            torchvision.models.resnet.BasicBlock,
            [2, 2, 2, 2],
            num_classes=representation_size,
        )
        self.encoder.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.encoder = nn.DataParallel(self.encoder)

        if pretrained_features:
            exclude_names = ["fc"]
            state_dict = load_state_dict_from_url(
                "https://download.pytorch.org/models/resnet18-5c106cde.pth",
                progress=True,
            )
            new_state_dict = state_dict.copy()
            for key in state_dict:
                for name in exclude_names:
                    if name in key:
                        del new_state_dict[key]
                        break
            self.encoder.load_state_dict(new_state_dict, strict=False)

        self.delta_features = delta_features
        input_size = representation_size * 2 if delta_features else representation_size * 3

        self.predictor = Mlp(output_size=output_classes,
                             input_size=input_size,
                             **architecture)
        self.predictor = self.predictor.to("cuda:0")

        self.epoch = 0
Exemplo n.º 14
0
class TimestepPredictionModel(torch.nn.Module):
    def __init__(
        self,
        representation_size,
        architecture,
        normalize=True,
        output_classes=100,
        encoder_class=CNN,
        decoder_class=DCNN,
        decoder_output_activation=identity,
        decoder_distribution='bernoulli',
        input_channels=1,
        imsize=224,
        init_w=1e-3,
        min_variance=1e-3,
        hidden_init=ptu.fanin_init,
        delta_features=False,
        pretrained_features=False,
    ):
        """

        :param representation_size:
        :param conv_args:
        must be a dictionary specifying the following:
            kernel_sizes
            n_channels
            strides
        :param conv_kwargs:
        a dictionary specifying the following:
            hidden_sizes
            batch_norm
        :param deconv_args:
        must be a dictionary specifying the following:
            hidden_sizes
            deconv_input_width
            deconv_input_height
            deconv_input_channels
            deconv_output_kernel_size
            deconv_output_strides
            deconv_output_channels
            kernel_sizes
            n_channels
            strides
        :param deconv_kwargs:
            batch_norm
        :param encoder_class:
        :param decoder_class:
        :param decoder_output_activation:
        :param decoder_distribution:
        :param input_channels:
        :param imsize:
        :param init_w:
        :param min_variance:
        :param hidden_init:
        """
        super().__init__()
        #         super().__init__(representation_size)
        if min_variance is None:
            self.log_min_variance = None
        else:
            self.log_min_variance = float(np.log(min_variance))
        self.input_channels = input_channels
        self.imsize = imsize
        self.imlength = self.imsize * self.imsize * self.input_channels
        self.representation_size = representation_size
        self.output_classes = output_classes

        self.normalize = normalize

        self.img_mean = torch.tensor([0.485, 0.456, 0.406])
        self.img_std = torch.tensor([0.229, 0.224, 0.225])
        self.img_mean = self.img_mean.repeat(epic.CROP_WIDTH, epic.CROP_HEIGHT,
                                             1).transpose(0, 2).to(ptu.device)
        self.img_std = self.img_std.repeat(epic.CROP_WIDTH, epic.CROP_HEIGHT,
                                           1).transpose(0, 2).to(ptu.device)
        # self.img_normalizer = torchvision.transforms.Normalize(self.img_mean, self.img_std)

        self.encoder = torchvision.models.resnet.ResNet(
            torchvision.models.resnet.BasicBlock,
            [2, 2, 2, 2],
            num_classes=representation_size,
        )
        self.encoder.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.encoder = nn.DataParallel(self.encoder)

        if pretrained_features:
            exclude_names = ["fc"]
            state_dict = load_state_dict_from_url(
                "https://download.pytorch.org/models/resnet18-5c106cde.pth",
                progress=True,
            )
            new_state_dict = state_dict.copy()
            for key in state_dict:
                for name in exclude_names:
                    if name in key:
                        del new_state_dict[key]
                        break
            self.encoder.load_state_dict(new_state_dict, strict=False)

        self.delta_features = delta_features
        input_size = representation_size * 2 if delta_features else representation_size * 3

        self.predictor = Mlp(output_size=output_classes,
                             input_size=input_size,
                             **architecture)
        self.predictor = self.predictor.to("cuda:0")

        self.epoch = 0

    def get_latents(self, x0, xt, xT):
        bz = x0.shape[0]

        x = torch.cat([x0, xt, xT], dim=0).view(
            -1,
            3,
            epic.CROP_HEIGHT,
            epic.CROP_WIDTH,
        )

        z = self.encode(x)
        # # import pdb; pdb.set_trace()
        # if self.normalize:
        #     x = x - self.img_mean
        #     x = x / self.img_std
        #     # x = self.img_normalizer(x)

        # zs = []
        # for i in range(0, 3 * bz, MAX_BATCH_SIZE):
        #     z = self.encoder(x[i:i+MAX_BATCH_SIZE, :, :, :])
        #     zs.append(z)

        # # z = self.encoder(x)
        # z = torch.cat(zs) # self.encoder(x) # .to("cuda:0")

        z0, zt, zT = z[:bz, :], z[bz:2 * bz, :], z[2 * bz:3 * bz, :]

        return z0, zt, zT

    def forward(self, x0, xt, xT):
        z0, zt, zT = self.get_latents(x0, xt, xT)

        # z0 = self.encoder(x0.view(-1, 3, 456, 256)).to("cuda:0") #.view((-1, 3, 240, 240))[:, :, :224, :224])
        # zt = self.encoder(xt.view(-1, 3, 456, 256)).to("cuda:0") # .view((-1, 3, 240, 240))[:, :, :224, :224])
        # zT = self.encoder(xT.view(-1, 3, 456, 256)).to("cuda:0") # .view((-1, 3, 240, 240))[:, :, :224, :224])

        if self.delta_features:
            dt = zt - z0
            dT = zT - z0
            z = torch.cat([dt, dT], dim=1)
        else:
            z = torch.cat([z0, zt, zT], dim=1)

        out = self.predictor(z)

        return out

    def encode(self, x):
        bz = x.shape[0]

        if self.normalize:
            x = x - self.img_mean
            x = x / self.img_std

        zs = []
        for i in range(0, bz, MAX_BATCH_SIZE):
            z = self.encoder(x[i:i + MAX_BATCH_SIZE, :, :, :])
            zs.append(z)

        z = torch.cat(zs)

        return z
        z = input
        for fc, gate in zip(self.fcs, self.gates):
            z = torch.sin(gate(z))
            h = fc(h) * z
        return self.last_fc(h)


def num_params(net):
    return nn.utils.parameters_to_vector(net.parameters()).shape[0]

mean_y = np.mean(test_y_np, axis=0)
print("Mean y", mean_y)
print("Constant error", np.mean((test_y_np - mean_y)**2))

plt.figure()
mlp = Mlp(hidden_sizes=[100, 100], output_size=out_dim, input_size=in_dim)

mlp_n_params = num_params(mlp)

h_size = 100
jac_net = JacobianNet(hidden_sizes=[h_size, h_size], output_size=out_dim,
                      input_size=in_dim)
# keep # paramers ~ same
while num_params(jac_net) > mlp_n_params:
    h_size -= 5
    jac_net = JacobianNet(hidden_sizes=[h_size, h_size], output_size=out_dim,
                          input_size=in_dim)
print("jac_net h_size:", h_size)

linear_net = Mlp(hidden_sizes=[], output_size=out_dim, input_size=in_dim)
class OnlineVaeRelabelingBuffer(ObsDictRelabelingBuffer):

    def __init__(
            self,
            vae,
            *args,
            decoded_obs_key='image_observation',
            decoded_achieved_goal_key='image_achieved_goal',
            decoded_desired_goal_key='image_desired_goal',
            exploration_rewards_type='None',
            exploration_rewards_scale=1.0,
            vae_priority_type='None',
            start_skew_epoch=0,
            power=1.0,
            internal_keys=None,
            exploration_schedule_kwargs=None,
            priority_function_kwargs=None,
            exploration_counter_kwargs=None,
            relabeling_goal_sampling_mode='vae_prior',
            decode_vae_goals=False,
            **kwargs
    ):
        if internal_keys is None:
            internal_keys = []

        for key in [
            decoded_obs_key,
            decoded_achieved_goal_key,
            decoded_desired_goal_key
        ]:
            if key not in internal_keys:
                internal_keys.append(key)
        super().__init__(internal_keys=internal_keys, *args, **kwargs)
        # assert isinstance(self.env, VAEWrappedEnv)
        self.vae = vae
        self.decoded_obs_key = decoded_obs_key
        self.decoded_desired_goal_key = decoded_desired_goal_key
        self.decoded_achieved_goal_key = decoded_achieved_goal_key
        self.exploration_rewards_type = exploration_rewards_type
        self.exploration_rewards_scale = exploration_rewards_scale
        self.start_skew_epoch = start_skew_epoch
        self.vae_priority_type = vae_priority_type
        self.power = power
        self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode
        self.decode_vae_goals = decode_vae_goals

        if exploration_schedule_kwargs is None:
            self.explr_reward_scale_schedule = \
                ConstantSchedule(self.exploration_rewards_scale)
        else:
            self.explr_reward_scale_schedule = \
                PiecewiseLinearSchedule(**exploration_schedule_kwargs)

        self._give_explr_reward_bonus = (
                exploration_rewards_type != 'None'
                and exploration_rewards_scale != 0.
        )
        self._exploration_rewards = np.zeros((self.max_size, 1), dtype=np.float32)
        self._prioritize_vae_samples = (
                vae_priority_type != 'None'
                and power != 0.
        )
        self._vae_sample_priorities = np.zeros((self.max_size, 1), dtype=np.float32)
        self._vae_sample_probs = None

        self.use_dynamics_model = (
                self.exploration_rewards_type == 'forward_model_error'
        )
        if self.use_dynamics_model:
            self.initialize_dynamics_model()

        type_to_function = {
            'reconstruction_error': self.reconstruction_mse,
            'bce': self.binary_cross_entropy,
            'latent_distance': self.latent_novelty,
            'latent_distance_true_prior': self.latent_novelty_true_prior,
            'forward_model_error': self.forward_model_error,
            'gaussian_inv_prob': self.gaussian_inv_prob,
            'bernoulli_inv_prob': self.bernoulli_inv_prob,
            'vae_prob': self.vae_prob,
            'hash_count': self.hash_count_reward,
            'None': self.no_reward,
        }

        self.exploration_reward_func = (
            type_to_function[self.exploration_rewards_type]
        )
        self.vae_prioritization_func = (
            type_to_function[self.vae_priority_type]
        )

        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.exploration_rewards_type == 'hash_count':
            if exploration_counter_kwargs is None:
                exploration_counter_kwargs = dict()
            self.exploration_counter = CountExploration(env=self.env, **exploration_counter_kwargs)
        self.epoch = 0

    def add_path(self, path):
        if self.decode_vae_goals:
            self.add_decoded_vae_goals_to_path(path)
        super().add_path(path)

    def add_decoded_vae_goals_to_path(self, path):
        # decoding the self-sampled vae images should be done in batch (here)
        # rather than in the env for efficiency
        desired_goals = flatten_dict(
            path['observations'],
            [self.desired_goal_key]
        )[self.desired_goal_key]
        desired_decoded_goals = self.env._decode(desired_goals)
        desired_decoded_goals = desired_decoded_goals.reshape(
            len(desired_decoded_goals),
            -1
        )
        for idx, next_obs in enumerate(path['observations']):
            path['observations'][idx][self.decoded_desired_goal_key] = \
                desired_decoded_goals[idx]
            path['next_observations'][idx][self.decoded_desired_goal_key] = \
                desired_decoded_goals[idx]

    def random_batch(self, batch_size):
        batch = super().random_batch(batch_size)
        exploration_rewards_scale = float(self.explr_reward_scale_schedule.get_value(self.epoch))
        if self._give_explr_reward_bonus:
            batch_idxs = batch['indices'].flatten()
            batch['exploration_rewards'] = self._exploration_rewards[batch_idxs]
            batch['rewards'] += exploration_rewards_scale * batch['exploration_rewards']
        return batch

    def get_diagnostics(self):
        if self._vae_sample_probs is None or self._vae_sample_priorities is None:
            stats = create_stats_ordered_dict(
                'VAE Sample Weights',
                np.zeros(self._size),
            )
            stats.update(create_stats_ordered_dict(
                'VAE Sample Probs',
                np.zeros(self._size),
            ))
        else:
            vae_sample_priorities = self._vae_sample_priorities[:self._size]
            vae_sample_probs = self._vae_sample_probs[:self._size]
            stats = create_stats_ordered_dict(
                'VAE Sample Weights',
                vae_sample_priorities,
            )
            stats.update(create_stats_ordered_dict(
                'VAE Sample Probs',
                vae_sample_probs,
            ))
        return stats

    def refresh_latents(self, epoch):
        self.epoch = epoch
        self.skew = (self.epoch > self.start_skew_epoch)
        batch_size = 512
        next_idx = min(batch_size, self._size)

        if self.exploration_rewards_type == 'hash_count':
            # you have to count everything then compute exploration rewards
            cur_idx = 0
            next_idx = min(batch_size, self._size)
            while cur_idx < self._size:
                idxs = np.arange(cur_idx, next_idx)
                normalized_imgs = self._next_obs[self.decoded_obs_key][idxs]
                self.update_hash_count(normalized_imgs)
                cur_idx = next_idx
                next_idx += batch_size
                next_idx = min(next_idx, self._size)

        cur_idx = 0
        obs_sum = np.zeros(self.vae.representation_size)
        obs_square_sum = np.zeros(self.vae.representation_size)
        while cur_idx < self._size:
            idxs = np.arange(cur_idx, next_idx)
            self._obs[self.observation_key][idxs] = \
                self.env._encode(self._obs[self.decoded_obs_key][idxs])
            self._next_obs[self.observation_key][idxs] = \
                self.env._encode(self._next_obs[self.decoded_obs_key][idxs])
            # WARNING: we only refresh the desired/achieved latents for
            # "next_obs". This means that obs[desired/achieve] will be invalid,
            # so make sure there's no code that references this.
            # TODO: enforce this with code and not a comment
            self._next_obs[self.desired_goal_key][idxs] = \
                self.env._encode(self._next_obs[self.decoded_desired_goal_key][idxs])
            self._next_obs[self.achieved_goal_key][idxs] = \
                self.env._encode(self._next_obs[self.decoded_achieved_goal_key][idxs])
            normalized_imgs = self._next_obs[self.decoded_obs_key][idxs]
            if self._give_explr_reward_bonus:
                rewards = self.exploration_reward_func(
                    normalized_imgs,
                    idxs,
                    **self.priority_function_kwargs
                )
                self._exploration_rewards[idxs] = rewards.reshape(-1, 1)
            if self._prioritize_vae_samples:
                if (
                        self.exploration_rewards_type == self.vae_priority_type
                        and self._give_explr_reward_bonus
                ):
                    self._vae_sample_priorities[idxs] = (
                        self._exploration_rewards[idxs]
                    )
                else:
                    self._vae_sample_priorities[idxs] = (
                        self.vae_prioritization_func(
                            normalized_imgs,
                            idxs,
                            **self.priority_function_kwargs
                        ).reshape(-1, 1)
                    )
            obs_sum+= self._obs[self.observation_key][idxs].sum(axis=0)
            obs_square_sum+= np.power(self._obs[self.observation_key][idxs], 2).sum(axis=0)

            cur_idx = next_idx
            next_idx += batch_size
            next_idx = min(next_idx, self._size)
        self.vae.dist_mu = obs_sum/self._size
        self.vae.dist_std = np.sqrt(obs_square_sum/self._size - np.power(self.vae.dist_mu, 2))

        if self._prioritize_vae_samples:
            """
            priority^power is calculated in the priority function
            for image_bernoulli_prob or image_gaussian_inv_prob and
            directly here if not.
            """
            if self.vae_priority_type == 'vae_prob':
                self._vae_sample_priorities[:self._size] = relative_probs_from_log_probs(
                    self._vae_sample_priorities[:self._size]
                )
                self._vae_sample_probs = self._vae_sample_priorities[:self._size]
            else:
                self._vae_sample_probs = self._vae_sample_priorities[:self._size] ** self.power
            p_sum = np.sum(self._vae_sample_probs)
            assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum)
            self._vae_sample_probs /= np.sum(self._vae_sample_probs)
            self._vae_sample_probs = self._vae_sample_probs.flatten()

    def sample_weighted_indices(self, batch_size):
        if (
            self._prioritize_vae_samples and
            self._vae_sample_probs is not None and
            self.skew
        ):
            indices = np.random.choice(
                len(self._vae_sample_probs),
                batch_size,
                p=self._vae_sample_probs,
            )
            assert (
                np.max(self._vae_sample_probs) <= 1 and
                np.min(self._vae_sample_probs) >= 0
            )
        else:
            indices = self._sample_indices(batch_size)
        return indices

    def _sample_goals_from_env(self, batch_size):
        self.env.goal_sampling_mode = self._relabeling_goal_sampling_mode
        return self.env.sample_goals(batch_size)

    def sample_buffer_goals(self, batch_size):
        """
        Samples goals from weighted replay buffer for relabeling or exploration.
        Returns None if replay buffer is empty.

        Example of what might be returned:
        dict(
            image_desired_goals: image_achieved_goals[weighted_indices],
            latent_desired_goals: latent_desired_goals[weighted_indices],
        )
        """
        if self._size == 0:
            return None
        weighted_idxs = self.sample_weighted_indices(
            batch_size,
        )
        next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs]
        next_latent_obs = self._next_obs[self.achieved_goal_key][weighted_idxs]
        return {
            self.decoded_desired_goal_key:  next_image_obs,
            self.desired_goal_key:          next_latent_obs
        }

    def random_vae_training_data(self, batch_size, epoch):
        # epoch no longer needed. Using self.skew in sample_weighted_indices
        # instead.
        weighted_idxs = self.sample_weighted_indices(
            batch_size,
        )

        next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs]
        observations = ptu.from_numpy(next_image_obs)
        return dict(
            observations=observations,
        )

    def reconstruction_mse(self, next_vae_obs, indices):
        torch_input = ptu.from_numpy(next_vae_obs)
        recon_next_vae_obs, _, _ = self.vae(torch_input)

        error = torch_input - recon_next_vae_obs
        mse = torch.sum(error ** 2, dim=1)
        return ptu.get_numpy(mse)

    def gaussian_inv_prob(self, next_vae_obs, indices):
        return np.exp(self.reconstruction_mse(next_vae_obs, indices))

    def binary_cross_entropy(self, next_vae_obs, indices):
        torch_input = ptu.from_numpy(next_vae_obs)
        recon_next_vae_obs, _, _ = self.vae(torch_input)

        error = - torch_input * torch.log(
            torch.clamp(
                recon_next_vae_obs,
                min=1e-30,  # corresponds to about -70
            )
        )
        bce = torch.sum(error, dim=1)
        return ptu.get_numpy(bce)

    def bernoulli_inv_prob(self, next_vae_obs, indices):
        torch_input = ptu.from_numpy(next_vae_obs)
        recon_next_vae_obs, _, _ = self.vae(torch_input)
        prob = (
                torch_input * recon_next_vae_obs
                + (1 - torch_input) * (1 - recon_next_vae_obs)
        ).prod(dim=1)
        return ptu.get_numpy(1 / prob)

    def vae_prob(self, next_vae_obs, indices, **kwargs):
        return compute_p_x_np_to_np(
            self.vae,
            next_vae_obs,
            power=self.power,
            **kwargs
        )

    def forward_model_error(self, next_vae_obs, indices):
        obs = self._obs[self.observation_key][indices]
        next_obs = self._next_obs[self.observation_key][indices]
        actions = self._actions[indices]

        state_action_pair = ptu.from_numpy(np.c_[obs, actions])
        prediction = self.dynamics_model(state_action_pair)
        mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs))
        return ptu.get_numpy(mse)

    def latent_novelty(self, next_vae_obs, indices):
        distances = ((self.env._encode(next_vae_obs) - self.vae.dist_mu) /
                     self.vae.dist_std) ** 2
        return distances.sum(axis=1)

    def latent_novelty_true_prior(self, next_vae_obs, indices):
        distances = self.env._encode(next_vae_obs) ** 2
        return distances.sum(axis=1)

    def _kl_np_to_np(self, next_vae_obs, indices):
        torch_input = ptu.from_numpy(next_vae_obs)
        mu, log_var = self.vae.encode(torch_input)
        return ptu.get_numpy(
            - torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
        )

    def update_hash_count(self, next_vae_obs):
        torch_input = ptu.from_numpy(next_vae_obs)
        mus, log_vars = self.vae.encode(torch_input)
        mus = ptu.get_numpy(mus)
        self.exploration_counter.increment_counts(mus)
        return None

    def hash_count_reward(self, next_vae_obs, indices):
        obs = self.env._encode(next_vae_obs)
        return self.exploration_counter.compute_count_based_reward(obs)

    def no_reward(self, next_vae_obs, indices):
        return np.zeros((len(next_vae_obs), 1))

    def initialize_dynamics_model(self):
        obs_dim = self._obs[self.observation_key].shape[1]
        self.dynamics_model = Mlp(
            hidden_sizes=[128, 128],
            output_size=obs_dim,
            input_size=obs_dim + self._action_dim,
        )
        self.dynamics_model.to(ptu.device)
        self.dynamics_optimizer = Adam(self.dynamics_model.parameters())
        self.dynamics_loss = MSELoss()

    def train_dynamics_model(self, batches=50, batch_size=100):
        if not self.use_dynamics_model:
            return
        for _ in range(batches):
            indices = self._sample_indices(batch_size)
            self.dynamics_optimizer.zero_grad()
            obs = self._obs[self.observation_key][indices]
            next_obs = self._next_obs[self.observation_key][indices]
            actions = self._actions[indices]
            if self.exploration_rewards_type == 'inverse_model_error':
                obs, next_obs = next_obs, obs

            state_action_pair = ptu.from_numpy(np.c_[obs, actions])
            prediction = self.dynamics_model(state_action_pair)
            mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs))

            mse.backward()
            self.dynamics_optimizer.step()

    def log_loss_under_uniform(self, model, data, batch_size, rl_logger, priority_function_kwargs):
        import torch.nn.functional as F
        log_probs_prior = []
        log_probs_biased = []
        log_probs_importance = []
        kles = []
        mses = []
        for i in range(0, data.shape[0], batch_size):
            img = data[i:min(data.shape[0], i + batch_size), :]
            torch_img = ptu.from_numpy(img)
            reconstructions, obs_distribution_params, latent_distribution_params = self.vae(torch_img)

            priority_function_kwargs['sampling_method'] = 'true_prior_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs)
            log_prob_prior = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'biased_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs)
            log_prob_biased = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'importance_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs)
            log_prob_importance = (log_p - log_q + log_d).mean()

            kle = model.kl_divergence(latent_distribution_params)
            mse = F.mse_loss(torch_img, reconstructions, reduction='elementwise_mean')
            mses.append(mse.item())
            kles.append(kle.item())
            log_probs_prior.append(log_prob_prior.item())
            log_probs_biased.append(log_prob_biased.item())
            log_probs_importance.append(log_prob_importance.item())

        rl_logger["Uniform Data Log Prob (Prior)"] = np.mean(log_probs_prior)
        rl_logger["Uniform Data Log Prob (Biased)"] = np.mean(log_probs_biased)
        rl_logger["Uniform Data Log Prob (Importance)"] = np.mean(log_probs_importance)
        rl_logger["Uniform Data KL"] = np.mean(kles)
        rl_logger["Uniform Data MSE"] = np.mean(mses)

    def _get_sorted_idx_and_train_weights(self):
        idx_and_weights = zip(range(len(self._vae_sample_probs)),
                              self._vae_sample_probs)
        return sorted(idx_and_weights, key=lambda x: x[1])