Example #1
0
    def _annotate_image(self, image, text, color=(0, 0, 255)):
        from multiworld.core.image_env import normalize_image
        from multiworld.core.image_env import unormalize_image

        if self.disable_annotated_images:
            return image

        img = unormalize_image(image).reshape(3, self.imsize,
                                              self.imsize).transpose((1, 2, 0))
        img = img[::, :, ::-1]
        img = img.copy()

        if self.imsize == 84:
            fontScale = 0.30
        elif self.imsize == 48:
            fontScale = 0.25
        else:
            fontScale = 0.50

        org = (0, self.imsize - 3)
        fontFace = 0

        cv2.putText(img=img,
                    text=text,
                    org=org,
                    fontFace=fontFace,
                    fontScale=fontScale,
                    color=color,
                    thickness=1)
        img = img[::, :, ::-1]
        return normalize_image(
            img.transpose((2, 0, 1)).reshape(self.imsize * self.imsize * 3))
Example #2
0
    def get_batch(self, train=True, epoch=None):
        if self.use_parallel_dataloading:
            if not train:
                dataloader = self.test_dataloader
            else:
                dataloader = self.train_dataloader
            samples = next(dataloader).to(ptu.device)
            return samples

        dataset = self.train_dataset if train else self.test_dataset
        skew = False
        if epoch is not None:
            skew = self.start_skew_epoch < epoch
        if train and self.skew_dataset and skew:
            probs = self._train_weights / np.sum(self._train_weights)
            ind = np.random.choice(
                len(probs),
                self.batch_size,
                p=probs,
            )
        else:
            ind = np.random.randint(0, len(dataset), self.batch_size)
        samples = normalize_image(dataset[ind, :])
        if self.normalize:
            samples = ((samples - self.train_data_mean) + 1) / 2
        if self.background_subtract:
            samples = samples - self.train_data_mean
        return ptu.from_numpy(samples)
Example #3
0
    def get_batch(self, train=True, epoch=None, sample_factors: bool = False):
        dataset = self.train_dataset if train else self.test_dataset
        factors = self.train_factors if train else self.test_factors
        skew = False
        if epoch is not None:
            skew = (self.start_skew_epoch < epoch)
        if train and self.skew_dataset and skew:
            probs = self._train_weights / np.sum(self._train_weights)
            ind = np.random.choice(
                len(probs),
                self.batch_size,
                p=probs,
            )
        else:
            ind = np.random.randint(0, len(dataset), self.batch_size)

        sample_data = normalize_image(dataset[ind, :])

        if self.normalize:
            sample_data = ((sample_data - self.train_data_mean) + 1) / 2
        if self.background_subtract:
            sample_data = sample_data - self.train_data_mean
        sample_data = ptu.from_numpy(sample_data)

        if sample_factors:
            return sample_data, ptu.from_numpy(factors[ind, :])
        else:
            return sample_data
    def sample_buffer_goals(self, batch_size, skew=True, key='image_observation_segmented'):
        """
        Samples goals from weighted replay buffer for relabeling or exploration.
        Returns None if replay buffer is empty.

        Example of what might be returned:
        dict(
            image_desired_goals: image_achieved_goals[weighted_indices],
            latent_desired_goals: latent_desired_goals[weighted_indices],
        )
        """

        if self._size == 0:
            return None
        weighted_idxs = self.sample_weighted_indices(
            batch_size, skew=skew
        )

        # NOTE yufei: this is the original RLkit code, I think it does not make sense in the segmentation case,
        # because self.decoded_obs_key is just 'image_observation', which can not serve as the 'image_desired_goal'
        # here. 
        # next_image_obs = normalize_image(
        #     self._next_obs[self.decoded_obs_key][weighted_idxs]
        # )

        next_latent_obs = self._next_obs[self.achieved_goal_key][weighted_idxs]
        next_img_obs = normalize_image(
            self._next_obs[key][weighted_idxs]
        ) # we should use the segmented images as the image_desired_goal
        # NOTE LSTM: if we ever want to change the key, remember to pass a key in!
        
        return {
            self.decoded_desired_goal_key:  next_img_obs,
            self.desired_goal_key:          next_latent_obs
        }
Example #5
0
 def get_dataset_stats(self, data):
     torch_input = ptu.from_numpy(normalize_image(data))
     mus, log_vars = self.model.encode(torch_input)
     mus = ptu.get_numpy(mus)
     mean = np.mean(mus, axis=0)
     std = np.std(mus, axis=0)
     return mus, mean, std
Example #6
0
    def _compute_train_weights(self):
        method = self.skew_config.get("method", "squared_error")
        power = self.skew_config.get("power", 1)
        batch_size = 512
        size = self.train_dataset.shape[0]
        next_idx = min(batch_size, size)
        cur_idx = 0
        weights = np.zeros(size)
        while cur_idx < self.train_dataset.shape[0]:
            idxs = np.arange(cur_idx, next_idx)
            data = self.train_dataset[idxs, :]
            if method == "vae_prob":
                data = normalize_image(data)
                weights[idxs] = compute_p_x_np_to_np(
                    self.model, data, power=power, **self.priority_function_kwargs
                )
            else:
                raise NotImplementedError("Method {} not supported".format(method))
            cur_idx = next_idx
            next_idx += batch_size
            next_idx = min(next_idx, size)

        if method == "vae_prob":
            weights = relative_probs_from_log_probs(weights)
        return weights
Example #7
0
    def __init__(
        self,
        train_dataset,
        test_dataset,
        model,
        train_seedsteps,
        test_seedsteps,
        train_actions=None,
        test_actions=None,
        batch_size=128,
        log_interval=0,
        gamma=0.5,
        lr=1e-3,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
    ):
        self.quick_init(locals())
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = model.beta
        self.gamma = gamma
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot

        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength
        self.train_seedsteps = train_seedsteps
        self.test_seedsteps = test_seedsteps

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(params, lr=self.lr)
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset.dtype == np.uint8
        assert self.test_dataset.dtype == np.uint8
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.train_actions = train_actions
        self.test_actions = test_actions

        self.batch_size = batch_size

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.vae_logger_stats_for_rl = {}
        self._extra_stats_to_log = None
    def random_vae_training_data(self, batch_size, epoch):
        # epoch no longer needed. Using self.skew in sample_weighted_indices
        # instead.
        weighted_idxs = self.sample_weighted_indices(batch_size, )

        next_image_obs = normalize_image(
            self._next_obs[self.decoded_obs_key][weighted_idxs])
        return dict(next_obs=ptu.from_numpy(next_image_obs))
Example #9
0
def postprocess_obs_dict(obs_dict):
    """
    Undo internal replay buffer representation changes: save images as bytes
    """
    for obs_key, obs in obs_dict.items():
        if 'image' in obs_key and obs is not None:
            obs_dict[obs_key] = normalize_image(obs)
    return obs_dict
Example #10
0
 def get_batch(self, train=True):
     dataset = self.train_dataset if train else self.test_dataset
     ind = np.random.randint(0, len(dataset), self.batch_size)
     samples = normalize_image(dataset[ind, :])
     if self.normalize:
         samples = ((samples - self.train_data_mean) + 1) / 2
     if self.background_subtract:
         samples = samples - self.train_data_mean
     return ptu.from_numpy(samples)
Example #11
0
def compute_sampled_latents(vae_env):
    vae_env.num_active_dims = 0
    for std in vae_env.vae.dist_std:
        if std > 0.15:
            vae_env.num_active_dims += 1

    vae_env.active_dims = vae_env.vae.dist_std.argsort()[-vae_env.num_active_dims:][::-1]
    vae_env.inactive_dims = vae_env.vae.dist_std.argsort()[:-vae_env.num_active_dims][::-1]

    if vae_env.use_vae_dataset and vae_env.vae_dataset_path is not None:
        from multiworld.core.image_env import normalize_image
        from railrl.misc.asset_loader import local_path_from_s3_or_local_path
        filename = local_path_from_s3_or_local_path(vae_env.vae_dataset_path)
        dataset = np.load(filename).item()
        vae_env.num_samples_for_latent_histogram = min(dataset['next_obs'].shape[0], vae_env.num_samples_for_latent_histogram)
        sampled_idx = np.random.choice(dataset['next_obs'].shape[0], vae_env.num_samples_for_latent_histogram)
        if vae_env.vae_input_key_prefix == 'state':
            vae_dataset_samples = dataset['next_obs'][sampled_idx]
        else:
            vae_dataset_samples = normalize_image(dataset['next_obs'][sampled_idx])
        del dataset
    else:
        vae_dataset_samples = None

    n = vae_env.num_samples_for_latent_histogram

    if vae_dataset_samples is not None:
        imgs = vae_dataset_samples
    else:
        if vae_env.vae_input_key_prefix == 'state':
            imgs = vae_env.wrapped_env.wrapped_env.sample_goals(n)['state_desired_goal']
        else:
            imgs = vae_env.wrapped_env.sample_goals(n)['image_desired_goal']

    batch_size = 2500
    latents, latents_noisy, latents_reproj = None, None, None
    for i in range(0, n, batch_size):
        batch_latents_mean, batch_latents_logvar = vae_env.encode_imgs(imgs[i:i + batch_size], clip_std=False)
        batch_latents_noisy = vae_env.reparameterize(batch_latents_mean, batch_latents_logvar, noisy=True)
        if vae_env.use_reprojection_network:
            batch_latents_reproj = ptu.get_numpy(vae_env.reproject_encoding(ptu.np_to_var(batch_latents_noisy)))
        if latents is None:
            latents = batch_latents_mean
            latents_noisy = batch_latents_noisy
            if vae_env.use_reprojection_network:
                latents_reproj = batch_latents_reproj
        else:
            latents = np.concatenate((latents, batch_latents_mean), axis=0)
            latents_noisy = np.concatenate((latents_noisy, batch_latents_noisy), axis=0)
            if vae_env.use_reprojection_network:
                latents_reproj = np.concatenate((latents_reproj, batch_latents_reproj), axis=0)

    vae_env.sampled_latents = latents
    vae_env.sampled_latents_noisy = latents_noisy
    vae_env.sampled_latents_reproj = latents_reproj
Example #12
0
    def log_loss_under_uniform(self, model, data, priority_function_kwargs):
        import torch.nn.functional as F

        log_probs_prior = []
        log_probs_biased = []
        log_probs_importance = []
        kles = []
        mses = []
        for i in range(0, data.shape[0], self.batch_size):
            img = normalize_image(data[i : min(data.shape[0], i + self.batch_size), :])
            torch_img = ptu.from_numpy(img)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(
                torch_img
            )

            priority_function_kwargs["sampling_method"] = "true_prior_sampling"
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs
            )
            log_prob_prior = log_d.mean()

            priority_function_kwargs["sampling_method"] = "biased_sampling"
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs
            )
            log_prob_biased = log_d.mean()

            priority_function_kwargs["sampling_method"] = "importance_sampling"
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs
            )
            log_prob_importance = (log_p - log_q + log_d).mean()

            kle = model.kl_divergence(latent_distribution_params)
            mse = F.mse_loss(torch_img, reconstructions, reduction="elementwise_mean")
            mses.append(mse.item())
            kles.append(kle.item())
            log_probs_prior.append(log_prob_prior.item())
            log_probs_biased.append(log_prob_biased.item())
            log_probs_importance.append(log_prob_importance.item())

        logger.record_tabular(
            "Uniform Data Log Prob (True Prior)", np.mean(log_probs_prior)
        )
        logger.record_tabular(
            "Uniform Data Log Prob (Biased)", np.mean(log_probs_biased)
        )
        logger.record_tabular(
            "Uniform Data Log Prob (Importance)", np.mean(log_probs_importance)
        )
        logger.record_tabular("Uniform Data KL", np.mean(kles))
        logger.record_tabular("Uniform Data MSE", np.mean(mses))
    def random_vae_training_data(
            self,
            batch_size,
            epoch,
            key=None):  # NOTE yufei: pass in a chosen key.
        # epoch no longer needed. Using self.skew in sample_weighted_indices
        # instead.
        # print("random_vae_training_data with key: ", key)
        weighted_idxs = self.sample_weighted_indices(batch_size, key)

        if key is None:
            key = self.decoded_obs_key
        next_image_obs = normalize_image(self._next_obs[key][weighted_idxs])
        return dict(next_obs=ptu.from_numpy(next_image_obs))
Example #14
0
def dump_reconstructions(vae_env, epoch, n_recon=16):
    from railrl.core import logger
    import os.path as osp
    from torchvision.utils import save_image

    if vae_env.use_vae_dataset and vae_env.vae_dataset_path is not None:
        from multiworld.core.image_env import normalize_image
        from railrl.misc.asset_loader import local_path_from_s3_or_local_path
        filename = local_path_from_s3_or_local_path(vae_env.vae_dataset_path)
        dataset = np.load(filename).item()
        sampled_idx = np.random.choice(dataset['next_obs'].shape[0], n_recon)
        if vae_env.vae_input_key_prefix == 'state':
            states = dataset['next_obs'][sampled_idx]
            imgs = ptu.np_to_var(
                vae_env.wrapped_env.states_to_images(states)
            )
            recon_samples, _, _ = vae_env.vae(ptu.np_to_var(states))
            recon_imgs = ptu.np_to_var(
                vae_env.wrapped_env.states_to_images(ptu.get_numpy(recon_samples))
            )
        else:
            imgs = ptu.np_to_var(
                normalize_image(dataset['next_obs'][sampled_idx])
            )
            recon_imgs, _, _, _ = vae_env.vae(imgs)
        del dataset
    else:
        return

    comparison = torch.cat([
        imgs.narrow(start=0, length=vae_env.wrapped_env.image_length, dimension=1).contiguous().view(
            -1,
            vae_env.wrapped_env.channels,
            vae_env.wrapped_env.imsize,
            vae_env.wrapped_env.imsize
        ),
        recon_imgs.contiguous().view(
            n_recon,
            vae_env.wrapped_env.channels,
            vae_env.wrapped_env.imsize,
            vae_env.wrapped_env.imsize
        )[:n_recon]
    ])

    if epoch is not None:
        save_dir = osp.join(logger.get_snapshot_dir(), 'r_%d.png' % epoch)
    else:
        save_dir = osp.join(logger.get_snapshot_dir(), 'r.png')
    save_image(comparison.data.cpu(), save_dir, nrow=n_recon)
Example #15
0
    def get_batch(self, train=True):
        if self.use_parallel_dataloading:
            if not train:
                dataloader = self.test_dataloader
            else:
                dataloader = self.train_dataloader
            samples = next(dataloader)
            return {
                'obs': ptu.Variable(samples[0][0]),
                'actions': ptu.Variable(samples[1][0]),
                'next_obs': ptu.Variable(samples[2][0]),
            }

        dataset = self.train_dataset if train else self.test_dataset
        ind = np.random.randint(0, len(dataset), self.batch_size)
        samples = normalize_image(dataset[ind, :])
        return ptu.np_to_var(samples)
Example #16
0
    def get_batch(self, train=True, epoch=None):
        if self.use_parallel_dataloading:
            if not train:
                dataloader = self.test_dataloader
            else:
                dataloader = self.train_dataloader
            samples = next(dataloader).to(ptu.device)
            return samples

        dataset = self.train_dataset if train else self.test_dataset

        ind = np.random.randint(0, len(dataset), self.batch_size)
        samples = normalize_image(dataset[ind, :])
        if self.normalize:
            samples = ((samples - self.train_data_mean) + 1) / 2
        if self.background_subtract:
            samples = samples - self.train_data_mean
        return ptu.from_numpy(samples)
    def random_lstm_training_data(self, batch_size, key=None):
        if key is None:
            key = self.decoded_obs_key

        traj_idxes = np.random.randint(0, self._traj_num, batch_size)
        imlen = self._next_obs[key].shape[-1]
        data = np.zeros((batch_size, self.max_path_length, imlen), dtype=np.uint8)
        for i in range(batch_size):
            data[i] = self._obs[key][traj_idxes[i] * self.max_path_length: 
                (traj_idxes[i] + 1) * self.max_path_length]        

        data = normalize_image(data)

        data = np.swapaxes(data, 0, 1) # traj_len x batch_size x imlen

        return dict(
            next_obs=ptu.from_numpy(data)
        )
Example #18
0
    def _dump_imgs_and_reconstructions(self, idxs, filename):
        imgs = []
        recons = []
        for i in idxs:
            img_np = self.train_dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(
                self.input_channels, self.imsize, self.imsize
            ).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.imsize, self.imsize).transpose(
                1, 2
            )
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        save_file = osp.join(logger.get_snapshot_dir(), filename)
        save_image(all_imgs.data, save_file, nrow=len(idxs))
Example #19
0
    def dump_uniform_imgs_and_reconstructions(self, dataset, epoch):
        idxs = np.random.choice(range(dataset.shape[0]), 4)
        filename = "uniform{}.png".format(epoch)
        imgs = []
        recons = []
        for i in idxs:
            img_np = dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(
                self.input_channels, self.imsize, self.imsize
            ).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.imsize, self.imsize).transpose(
                1, 2
            )
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        save_file = osp.join(logger.get_snapshot_dir(), filename)
        save_image(all_imgs.data, save_file, nrow=4)
    def sample_buffer_goals(self, batch_size):
        """
        Samples goals from weighted replay buffer for relabeling or exploration.
        Returns None if replay buffer is empty.

        Example of what might be returned:
        dict(
            image_desired_goals: image_achieved_goals[weighted_indices],
            latent_desired_goals: latent_desired_goals[weighted_indices],
        )
        """
        if self._size == 0:
            return None
        weighted_idxs = self.sample_weighted_indices(batch_size, )
        next_image_obs = normalize_image(
            self._next_obs[self.decoded_obs_key][weighted_idxs])
        next_latent_obs = self._next_obs[self.achieved_goal_key][weighted_idxs]
        return {
            self.decoded_desired_goal_key: next_image_obs,
            self.desired_goal_key: next_latent_obs,
        }
Example #21
0
def drop_puck(img, imsize=48):
    tempImg = copy.deepcopy(img)
    tempImg = tempImg.reshape(3, imsize, imsize).transpose()
    tempImg = tempImg * 255
    tempImg = tempImg.astype(np.uint8)

    tempImgCopy = copy.deepcopy(tempImg)

    # this segments out puck
    lowerColorBounds = (0, 60, 150) # RGB
    upperColorBounds = (10, 140, 255) # RGB
    mask_puck = cv2.inRange(tempImgCopy, lowerColorBounds, upperColorBounds)
    
    kernel = np.ones((2,2),np.uint8)
    mask_puck = cv2.dilate(mask_puck, kernel)

    mask_puck = mask_puck>200
    

    tmp = np.zeros_like(tempImgCopy[mask_puck])
    tmp[:, 0]= 120
    tmp[:, 1]= 120
    tmp[:, 2]= 100
    tempImgCopy[mask_puck] = tmp


    tempImgCopy = tempImgCopy.transpose()
    tempImgCopy = tempImgCopy.reshape([1,-1])
    tempImgCopy = normalize_image(tempImgCopy)

    # show_image = copy.deepcopy(tempImgCopy)
    # img = show_image.reshape(3, imsize, imsize).transpose()
    # img = img[::-1, :, ::-1]
    # cv2.imshow("segmented image", img)
    # cv2.waitKey()

    return tempImgCopy.flatten()
Example #22
0
    def _dump_imgs_and_reconstructions(self, idxs, filename):
        imgs = []
        recons = []
        for i in idxs:
            img_np = self.train_dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(self.input_channels, self.imsize,
                                 self.imsize).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.imsize,
                              self.imsize).transpose(1, 2)
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        # save_file = osp.join(logger.get_snapshot_dir(), filename)
        #save_file = osp.join('/mnt/manh/project/visual_RL_imaged_goal', filename)
        project_path = osp.abspath(os.curdir)
        save_dir = osp.join(project_path + str('/result_image/'), filename)
        save_image(
            all_imgs.data,
            save_file,
            nrow=len(idxs),
        )
Example #23
0
 def _reconstruction_squared_error_np_to_np(self, np_imgs):
     torch_input = ptu.from_numpy(normalize_image(np_imgs))
     recons, *_ = self.model(torch_input)
     error = torch_input - recons
     return ptu.get_numpy((error**2).sum(dim=1))
    def refresh_latents(self, epoch):
        self.epoch = epoch
        self.skew = self.epoch > self.start_skew_epoch
        batch_size = 512
        next_idx = min(batch_size, self._size)

        if self.exploration_rewards_type == "hash_count":
            # you have to count everything then compute exploration rewards
            cur_idx = 0
            next_idx = min(batch_size, self._size)
            while cur_idx < self._size:
                idxs = np.arange(cur_idx, next_idx)
                normalized_imgs = normalize_image(
                    self._next_obs[self.decoded_obs_key][idxs])
                cur_idx = next_idx
                next_idx += batch_size
                next_idx = min(next_idx, self._size)

        cur_idx = 0
        obs_sum = np.zeros(self.vae.representation_size)
        obs_square_sum = np.zeros(self.vae.representation_size)
        while cur_idx < self._size:
            idxs = np.arange(cur_idx, next_idx)
            self._obs[self.observation_key][idxs] = self.env._encode(
                normalize_image(self._obs[self.decoded_obs_key][idxs]))
            self._next_obs[self.observation_key][idxs] = self.env._encode(
                normalize_image(self._next_obs[self.decoded_obs_key][idxs]))
            # WARNING: we only refresh the desired/achieved latents for
            # "next_obs". This means that obs[desired/achieve] will be invalid,
            # so make sure there's no code that references this.
            # TODO: enforce this with code and not a comment
            self._next_obs[self.desired_goal_key][idxs] = self.env._encode(
                normalize_image(
                    self._next_obs[self.decoded_desired_goal_key][idxs]))
            self._next_obs[self.achieved_goal_key][idxs] = self.env._encode(
                normalize_image(
                    self._next_obs[self.decoded_achieved_goal_key][idxs]))
            normalized_imgs = normalize_image(
                self._next_obs[self.decoded_obs_key][idxs])
            if self._give_explr_reward_bonus:
                rewards = self.exploration_reward_func(
                    normalized_imgs, idxs, **self.priority_function_kwargs)
                self._exploration_rewards[idxs] = rewards.reshape(-1, 1)
            if self._prioritize_vae_samples:
                if (self.exploration_rewards_type == self.vae_priority_type
                        and self._give_explr_reward_bonus):
                    self._vae_sample_priorities[
                        idxs] = self._exploration_rewards[idxs]
                else:
                    self._vae_sample_priorities[
                        idxs] = self.vae_prioritization_func(
                            normalized_imgs, idxs,
                            **self.priority_function_kwargs).reshape(-1, 1)
            obs_sum += self._obs[self.observation_key][idxs].sum(axis=0)
            obs_square_sum += np.power(self._obs[self.observation_key][idxs],
                                       2).sum(axis=0)

            cur_idx = next_idx
            next_idx += batch_size
            next_idx = min(next_idx, self._size)
        self.vae.dist_mu = obs_sum / self._size
        self.vae.dist_std = np.sqrt(obs_square_sum / self._size -
                                    np.power(self.vae.dist_mu, 2))

        if self._prioritize_vae_samples:
            """
            priority^power is calculated in the priority function
            for image_bernoulli_prob or image_gaussian_inv_prob and
            directly here if not.
            """
            if self.vae_priority_type == "vae_prob":
                self._vae_sample_priorities[:self.
                                            _size] = relative_probs_from_log_probs(
                                                self.
                                                _vae_sample_priorities[:self.
                                                                       _size])
                self._vae_sample_probs = self._vae_sample_priorities[:self.
                                                                     _size]
            else:
                self._vae_sample_probs = (
                    self._vae_sample_priorities[:self._size]**self.power)
            p_sum = np.sum(self._vae_sample_probs)
            assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum)
            self._vae_sample_probs /= np.sum(self._vae_sample_probs)
            self._vae_sample_probs = self._vae_sample_probs.flatten()
Example #25
0
    def __init__(
        self,
        train_dataset,
        test_dataset,
        model,
        positive_range=2,
        negative_range=10,
        triplet_sample_num=8,
        triplet_loss_margin=0.5,
        batch_size=128,
        log_interval=0,
        recon_loss_coef=1,
        triplet_loss_coef=[],
        triplet_loss_type=[],
        ae_loss_coef=1,
        matching_loss_coef=1,
        vae_matching_loss_coef=1,
        matching_loss_one_side=False,
        contrastive_loss_coef=0,
        lstm_kl_loss_coef=0,
        adaptive_margin=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        use_parallel_dataloading=False,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
    ):

        print("In LSTM trainer, ae_loss_coef is: ", ae_loss_coef)
        print("In LSTM trainer, matching_loss_coef is: ", matching_loss_coef)
        print("In LSTM trainer, vae_matching_loss_coef is: ",
              vae_matching_loss_coef)

        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot

        self.recon_loss_coef = recon_loss_coef
        self.triplet_loss_coef = triplet_loss_coef
        self.ae_loss_coef = ae_loss_coef
        self.matching_loss_coef = matching_loss_coef
        self.vae_matching_loss_coef = vae_matching_loss_coef
        self.contrastive_loss_coef = contrastive_loss_coef
        self.lstm_kl_loss_coef = lstm_kl_loss_coef
        self.matching_loss_one_side = matching_loss_one_side

        # triplet loss range
        self.positve_range = positive_range
        self.negative_range = negative_range
        self.triplet_sample_num = triplet_sample_num
        self.triplet_loss_margin = triplet_loss_margin
        self.triplet_loss_type = triplet_loss_type
        self.adaptive_margin = adaptive_margin

        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset.dtype == np.uint8
        assert self.test_dataset.dtype == np.uint8

        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.skew_dataset:
            self._train_weights = self._compute_train_weights()
        else:
            self._train_weights = None

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.eval_statistics = OrderedDict()
        self._extra_stats_to_log = None
Example #26
0
    def __init__(
        self,
        train_dataset,
        test_dataset,
        model,
        batch_size=128,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        use_parallel_dataloading=True,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
    ):
        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot

        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset.dtype == np.uint8
        assert self.test_dataset.dtype == np.uint8
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.skew_dataset:
            self._train_weights = self._compute_train_weights()
        else:
            self._train_weights = None

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.eval_statistics = OrderedDict()
        self._extra_stats_to_log = None
Example #27
0
def _test_lstm(lstm_trainer,
               epoch,
               replay_buffer,
               env_id,
               lstm_save_period=1,
               uniform_dataset=None,
               save_prefix='r',
               lstm_segmentation_method='color',
               lstm_test_N=500,
               key='image_observation_segmented'):

    batch_sampler = replay_buffer.random_lstm_training_data

    save_imgs = epoch % lstm_save_period == 0
    log_fit_skew_stats = replay_buffer._prioritize_vae_samples and uniform_dataset is not None
    if uniform_dataset is not None:
        replay_buffer.log_loss_under_uniform(
            uniform_dataset,
            lstm_trainer.batch_size,
            rl_logger=lstm_trainer.vae_logger_stats_for_rl)
    lstm_trainer.test_epoch(epoch,
                            from_rl=True,
                            key=key,
                            sample_batch=batch_sampler,
                            save_reconstruction=save_imgs,
                            save_prefix=save_prefix)
    if save_imgs:
        sample_save_prefix = save_prefix.replace('r', 's')
        lstm_trainer.dump_samples(epoch, save_prefix=sample_save_prefix)
        if log_fit_skew_stats:
            replay_buffer.dump_best_reconstruction(epoch)
            replay_buffer.dump_worst_reconstruction(epoch)
            replay_buffer.dump_sampling_histogram(
                epoch, batch_size=lstm_trainer.batch_size)
        if uniform_dataset is not None:
            replay_buffer.dump_uniform_imgs_and_reconstructions(
                dataset=uniform_dataset, epoch=epoch)

        m = lstm_trainer.model
        pjhome = os.environ['PJHOME']
        seg_name = 'seg-' + 'color'

        if env_id in [
                'SawyerPushNIPSEasy-v0', 'SawyerPushHurdle-v0',
                'SawyerPushHurdleMiddle-v0'
        ]:
            N = 500
            data_file_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5.npy'.format(env_id, seg_name, N))
            puck_pos_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5-puck-pos.npy'.format(env_id, seg_name, N))
            if osp.exists(data_file_path):
                all_data = np.load(data_file_path)
                puck_pos = np.load(puck_pos_path)
                all_data = normalize_image(all_data.copy())
                compare_latent_distance(
                    m,
                    all_data,
                    puck_pos,
                    save_dir=logger.get_snapshot_dir(),
                    obj_name='puck',
                    save_name='online_lstm_latent_distance_{}.png'.format(
                        epoch))
        elif env_id == 'SawyerDoorHookResetFreeEnv-v1':
            N = 1000
            seg_name = 'seg-' + 'unet'
            data_file_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                'vae-only-{}-{}-{}-0-0.npy'.format(env_id, seg_name, N))
            door_angle_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                'vae-only-{}-{}-{}-0-0-door-angle.npy'.format(
                    env_id, seg_name, N))
            if osp.exists(data_file_path):
                all_data = np.load(data_file_path)
                door_angle = np.load(door_angle_path)
                all_data = normalize_image(all_data.copy())
                compare_latent_distance(
                    m,
                    all_data,
                    door_angle,
                    save_dir=logger.get_snapshot_dir(),
                    obj_name='door',
                    save_name='online_lstm_latent_distance_{}.png'.format(
                        epoch))
        elif env_id == 'SawyerPushHurdleResetFreeEnv-v0':
            N = 2000
            data_file_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                'vae-only-{}-{}-{}-0.3-0.5.npy'.format(env_id, seg_name, N))
            puck_pos_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                'vae-only-{}-{}-{}-0.3-0.5-puck-pos.npy'.format(
                    env_id, seg_name, N))
            if osp.exists(data_file_path):
                all_data = np.load(data_file_path)
                puck_pos = np.load(puck_pos_path)
                all_data = normalize_image(all_data.copy())
                compare_latent_distance(
                    m,
                    all_data,
                    puck_pos,
                    save_dir=logger.get_snapshot_dir(),
                    obj_name='puck',
                    save_name='online_lstm_latent_distance_{}.png'.format(
                        epoch))

        test_lstm_traj(env_id,
                       m,
                       save_path=logger.get_snapshot_dir(),
                       save_name='online_lstm_test_traj_{}.png'.format(epoch))
        test_masked_traj_lstm(
            env_id,
            m,
            save_dir=logger.get_snapshot_dir(),
            save_name='online_masked_test_{}.png'.format(epoch))
    def refresh_latents(self, epoch):
        self.epoch = epoch
        self.skew = (self.epoch > self.start_skew_epoch)
        batch_size = 512
        next_idx = min(batch_size, self._size)

        if self.exploration_rewards_type == 'hash_count':
            # you have to count everything then compute exploration rewards
            cur_idx = 0
            next_idx = min(batch_size, self._size)
            while cur_idx < self._size:
                idxs = np.arange(cur_idx, next_idx)
                normalized_imgs = (normalize_image(
                    self._next_obs[self.decoded_obs_key][idxs]))
                cur_idx = next_idx
                next_idx += batch_size
                next_idx = min(next_idx, self._size)

        cur_idx = 0
        obs_sum = np.zeros(self.vae.representation_size)
        obs_square_sum = np.zeros(self.vae.representation_size)
        while cur_idx < self._size:
            idxs = np.arange(cur_idx, next_idx)
            # NOTE yufei: observation should use env.vae_original (non-segmented images)
            self._obs[self.observation_key][idxs] = \
                self.env._encode(
                    normalize_image(self._obs[self.decoded_obs_key][idxs]), self.env.vae_original
                )
            self._next_obs[self.observation_key][idxs] = \
                self.env._encode(
                    normalize_image(self._next_obs[self.decoded_obs_key][idxs]), self.env.vae_original
                )
            # WARNING: we only refresh the desired/achieved latents for
            # "next_obs". This means that obs[desired/achieve] will be invalid,
            # so make sure there's no code that references this.
            # TODO: enforce this with code and not a comment
            # print("in online vae replay buffer, observation key is: ", self.observation_key)
            # print("in online vae replay buffer, decoded_obs key is: ", self.decoded_obs_key)
            # self.show_obs(normalize_image(self._next_obs[self.decoded_obs_key][0]), name='next obs')
            # self.show_obs(normalize_image(self._obs[self.decoded_obs_key][0]), name='obs')

            # NOTE yufei: for desired_goal_key we use env.vae_segmented
            self._next_obs[self.desired_goal_key][idxs] = \
                self.env._encode(
                    normalize_image(self._next_obs[self.decoded_desired_goal_key][idxs]), self.env.vae_segmented
                )
            self._next_obs[self.achieved_goal_key][idxs] = \
                self.env._encode(
                    normalize_image(self._next_obs[self.decoded_achieved_goal_key][idxs]), self.env.vae_segmented
                )
            # self._obs[self.desired_goal_key][idxs] = \
            #     self.env._encode(
            #         normalize_image(self._obs[self.decoded_desired_goal_key][idxs]), self.env.vae_segmented
            #     )
            # self._obs[self.achieved_goal_key][idxs] = \
            #     self.env._encode(
            #         normalize_image(self._obs[self.decoded_achieved_goal_key][idxs]), self.env.vae_segmented
            #     )
            # print("in online vae replay buffer, desired goal key is: ", self.desired_goal_key)
            # print("in online vae replay buffer, achieved goal key is: ", self.achieved_goal_key)
            # print("in online vae replay buffer, decoded_desired_goal_key is: ", self.decoded_desired_goal_key)
            # print("in online vae replay buffer, decoded_achieved_goal_key is: ", self.decoded_achieved_goal_key)
            # self.show_obs(normalize_image(self._next_obs[self.decoded_desired_goal_key][0]), name='decoded desired goal')
            # self.show_obs(normalize_image(self._next_obs[self.decoded_achieved_goal_key][0]), name='decoded achieved goal')

            normalized_imgs = (normalize_image(
                self._next_obs[self.decoded_obs_key][idxs]))
            normalized_imgs_seg = (normalize_image(
                self._next_obs[self.decoded_obs_key_seg][idxs]))

            if self._give_explr_reward_bonus:
                rewards = self.exploration_reward_func(
                    normalized_imgs, idxs, **self.priority_function_kwargs)
                self._exploration_rewards[idxs] = rewards.reshape(-1, 1)
            if self._prioritize_vae_samples:
                if (self.exploration_rewards_type == self.vae_priority_type
                        and self._give_explr_reward_bonus):
                    self._vae_sample_priorities[idxs] = (
                        self._exploration_rewards[idxs])
                else:  # NOTE yufei: this is what skewfit actually uses. So I only updated this.
                    self._vae_sample_priorities[idxs] = (
                        self.vae_prioritization_func(
                            self.vae, normalized_imgs, idxs,
                            **self.priority_function_kwargs).reshape(-1, 1))

                    self._vae_sample_priorities_seg[idxs] = (
                        self.vae_prioritization_func(
                            self.vae_seg, normalized_imgs_seg, idxs,
                            **self.priority_function_kwargs).reshape(-1, 1))

            obs_sum += self._obs[self.observation_key][idxs].sum(axis=0)
            obs_square_sum += np.power(self._obs[self.observation_key][idxs],
                                       2).sum(axis=0)

            cur_idx = next_idx
            next_idx += batch_size
            next_idx = min(next_idx, self._size)
        self.vae.dist_mu = obs_sum / self._size
        self.vae.dist_std = np.sqrt(obs_square_sum / self._size -
                                    np.power(self.vae.dist_mu, 2))

        if self._prioritize_vae_samples:
            """
            priority^power is calculated in the priority function
            for image_bernoulli_prob or image_gaussian_inv_prob and
            directly here if not.
            """
            if self.vae_priority_type == 'vae_prob':
                self._vae_sample_priorities[:self.
                                            _size] = relative_probs_from_log_probs(
                                                self.
                                                _vae_sample_priorities[:self.
                                                                       _size])
                self._vae_sample_probs = self._vae_sample_priorities[:self.
                                                                     _size]

                self._vae_sample_priorities_seg[:self._size] = relative_probs_from_log_probs(
                    self._vae_sample_priorities_seg[:self._size])
                self._vae_sample_probs_seg = self._vae_sample_priorities_seg[:
                                                                             self
                                                                             .
                                                                             _size]
            else:
                self._vae_sample_probs = self._vae_sample_priorities[:self.
                                                                     _size]**self.power
                self._vae_sample_probs_seg = self._vae_sample_priorities_seg[:
                                                                             self
                                                                             .
                                                                             _size]**self.power

            p_sum = np.sum(self._vae_sample_probs)
            assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum)
            self._vae_sample_probs /= np.sum(self._vae_sample_probs)
            self._vae_sample_probs = self._vae_sample_probs.flatten()

            p_sum = np.sum(self._vae_sample_probs_seg)
            assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum)
            self._vae_sample_probs_seg /= np.sum(self._vae_sample_probs_seg)
            self._vae_sample_probs_seg = self._vae_sample_probs_seg.flatten()
Example #29
0
    def __init__(
            self,
            train_dataset,
            test_dataset,
            model,
            batch_size=128,
            log_interval=0,
            # beta=0.5,
            # beta_schedule=None,
            lr=None,
            do_scatterplot=False,
            normalize=False,
            mse_weight=0.1,
            is_auto_encoder=False,
            background_subtract=False,
            use_parallel_dataloading=True,
            train_data_workers=2,
            # skew_dataset=False,
            # skew_config=None,
            priority_function_kwargs=None,
            start_skew_epoch=0,
            weight_decay=0,
            **kwargs):

        self.log_interval = log_interval
        self.batch_size = batch_size

        assert lr is not None

        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot
        self.representation_size = model.representation_size
        model.to(ptu.device)

        self.model = model
        self.pixelcnn = model.pixelcnn

        self.input_channels = model.input_channels
        self.imlength = model.imlength
        self.skew_dataset = False
        self._train_weights = None
        self.lr = lr
        # keep VQVAE params only
        params = []
        pixel_params = []
        for p in model.named_parameters():
            if 'pixelcnn' in p[0]:
                pixel_params.append(p[1])
            else:
                params.append(p[1])

        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )
        self.pixelcnn_opt = optim.Adam(pixel_params,
                                       lr=self.lr,
                                       weight_decay=weight_decay,
                                       amsgrad=True)
        self.pixelcnn_criterion = nn.CrossEntropyLoss().to(ptu.device)

        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset.dtype == np.uint8
        assert self.test_dataset.dtype == np.uint8
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        self._train_weights = None

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            #base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.eval_statistics = OrderedDict()
        self._extra_stats_to_log = None
Example #30
0
 def _kl_np_to_np(self, np_imgs):
     torch_input = ptu.from_numpy(normalize_image(np_imgs))
     mu, log_var = self.model.encode(torch_input)
     return ptu.get_numpy(
         -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))