Ejemplo n.º 1
0
def vis(args):
    imgs = np.load(args.ds)
    vae = joblib.load(args.file)
    losses = []
    for i, image_obs in enumerate(imgs):
        img = normalize_image(image_obs)
        recon, *_ = eval_np(vae, img)
        error = ((recon - img)**2).sum()
        losses.append((i, error))

    losses.sort(key=lambda x: -x[1])

    for rank, (i, error) in enumerate(losses[:NUM_SHOWN]):
        image_obs = imgs[i]
        recon, *_ = eval_np(vae, normalize_image(image_obs))

        img = image_obs.reshape(3, 48, 48).transpose()
        rimg = recon.reshape(3, 48, 48).transpose()

        cv2.imshow("image, rank {}, loss {}".format(rank, error), img)
        cv2.imshow("recon, rank {}, loss {}".format(rank, error), rimg)
        print("rank {}\terror {}".format(rank, error))
    for j, (i, error) in enumerate(losses[-NUM_SHOWN:]):
        rank = len(losses) - j - 1
        image_obs = imgs[i]
        recon, *_ = eval_np(vae, normalize_image(image_obs))

        img = image_obs.reshape(3, 48, 48).transpose()
        rimg = recon.reshape(3, 48, 48).transpose()

        cv2.imshow("image, rank {}, loss {}".format(rank, error), img)
        cv2.imshow("recon, rank {}, loss {}".format(rank, error), rimg)
        print("rank {}\terror {}".format(rank, error))
    cv2.waitKey(0)
    cv2.destroyAllWindows()
Ejemplo n.º 2
0
    def __getitem__(self, idx):
        traj_i = idx // self.traj_length
        trans_i = idx % self.traj_length
        #cond_i = np.random.randint(0, self.traj_length)

        # x = self.data['observations'][traj_i, trans_i]
        # c = self.data['observations'][traj_i, cond_i]
        # # c = self.data['env'][traj_i]

        # x = Image.fromarray(x.reshape(self.root_len, self.root_len, 3), mode='RGB')
        # c = Image.fromarray(c.reshape(self.root_len, self.root_len, 3), mode='RGB')

        # x, c = self.resize(x), self.resize(c)

        # x_t = normalize_image(np.array(x).flatten()).squeeze()
        # env = normalize_image(np.array(c).flatten()).squeeze()


        env = normalize_image(self.data['env'][traj_i, :])
        x_t = normalize_image(self.data['observations'][traj_i, trans_i, :])

        data_dict = {
            'x_t': x_t,
            'env': env,
        }
        return data_dict
Ejemplo n.º 3
0
def add_heatmap_imgs_to_o_dict(env,
                               agent,
                               observation_key,
                               full_o,
                               v_function,
                               vectorized=False):
    o = full_o[observation_key]
    goal_grid = env.get_mesh_grid(observation_key)
    o_grid = np.c_[np.tile(o, (len(goal_grid), 1)), goal_grid]

    v_vals, indiv_v_vals = v_function(o_grid)
    v_vals = ptu.get_numpy(v_vals)
    indiv_v_vals = [ptu.get_numpy(indiv_v_val) for indiv_v_val in indiv_v_vals]

    vmin = np.array(indiv_v_vals).min()
    vmax = np.array(indiv_v_vals).max()
    # Assuming square observation space, how many points on x axis
    vary_len = int(len(goal_grid)**(1 / 2))
    if not vectorized:
        vmin = min(np.array(indiv_v_vals).min(), v_vals.min())
        vmax = max(np.array(indiv_v_vals).min(), v_vals.max())
        full_o['v_vals'] = normalize_image(
            env.get_image_plt(v_vals.reshape((vary_len, vary_len)),
                              imsize=env.imsize,
                              vmin=vmin,
                              vmax=vmax))

    for goal_dim in range(len(indiv_v_vals)):
        full_o['v_vals_dim_{}'.format(goal_dim)] = normalize_image(
            env.get_image_plt(indiv_v_vals[goal_dim].reshape(
                (vary_len, vary_len)),
                              imsize=env.imsize,
                              vmin=vmin,
                              vmax=vmax))
Ejemplo n.º 4
0
 def get_dataset_stats(self, data):
     torch_input = ptu.from_numpy(normalize_image(data))
     mus, log_vars = self.model.encode(torch_input)
     mus = ptu.get_numpy(mus)
     mean = np.mean(mus, axis=0)
     std = np.std(mus, axis=0)
     return mus, mean, std
Ejemplo n.º 5
0
    def get_batch(self, test_data=False, epoch=None):
        if self.use_parallel_dataloading:
            if test_data:
                dataloader = self.test_dataloader
            else:
                dataloader = self.train_dataloader
            samples = next(dataloader).to(ptu.device)
            return samples

        dataset = self.test_dataset if test_data else self.train_dataset
        skew = False
        if epoch is not None:
            skew = (self.start_skew_epoch < epoch)
        if not test_data and self.skew_dataset and skew:
            probs = self._train_weights / np.sum(self._train_weights)
            ind = np.random.choice(
                len(probs),
                self.batch_size,
                p=probs,
            )
        else:
            ind = np.random.randint(0, len(dataset), self.batch_size)
        samples = normalize_image(dataset[ind, :])
        if self.normalize:
            samples = ((samples - self.train_data_mean) + 1) / 2
        if self.background_subtract:
            samples = samples - self.train_data_mean
        return ptu.from_numpy(samples)
Ejemplo n.º 6
0
 def get_batch(self, train=True):
     dataset = self.train_dataset if train else self.test_dataset
     ind = np.random.randint(0, len(dataset), self.batch_size)
     samples = dataset[ind, :]
     samples = normalize_image(samples)
     if self.normalize:
         samples = ((samples - self.train_data_mean) + 1) / 2
     return ptu.np_to_var(samples)
Ejemplo n.º 7
0
 def random_batch(self, batch_size):
     i = np.random.choice(self.size, batch_size, replace=(self.size < batch_size))
     obs = self.data[i, :]
     if self.normalize:
         obs = normalize_image(obs)
     data_dict = {
         'observations': obs,
     }
     return np_to_pytorch_batch(data_dict)
Ejemplo n.º 8
0
 def get_batch_smooth(self, train=True):
     dataset = self.train_dataset if train else self.test_dataset
     ind = np.random.randint(0, len(dataset), self.batch_size)
     samples = dataset[ind, :]
     samples = normalize_image(samples)
     if self.normalize:
         samples = ((samples - self.train_data_mean) + 1) / 2
     x_next, x = samples[:, :self.x_next_index], samples[:,
                                                         self.x_next_index:]
     return ptu.np_to_var(x_next), ptu.np_to_var(x)
Ejemplo n.º 9
0
    def random_batch(self, batch_size):
        traj_i = np.random.choice(self.size, batch_size)
        trans_i = np.random.choice(self.traj_length, batch_size)
        # conditioning = np.random.choice(self.traj_length, batch_size)
        # env = normalize_image(self.data['observations'][traj_i, conditioning, :])
        try:
            env = normalize_image(self.data['env'][traj_i, :])
        except:
            env = normalize_image(self.data['observations'][traj_i, 0, :])
        x_t = normalize_image(self.data['observations'][traj_i, trans_i, :])

        episode_num = np.random.randint(0, self.size)
        episode_obs = normalize_image(self.data['observations'][episode_num, :8, :])


        data_dict = {
            'x_t': x_t,
            'env': env,
            'episode_obs': episode_obs,
        }
        return np_to_pytorch_batch(data_dict)
Ejemplo n.º 10
0
 def sample_goals(self, batch_size):
     if self.goal_buffer._size == 0:
         return None
     goal_idxs = self.goal_buffer._sample_indices(batch_size)
     goals = {
         'latent_desired_goal':
         self.goal_buffer._next_obs['latent_observation'][goal_idxs],
         'image_desired_goal':
         normalize_image(
             self.goal_buffer._next_obs['image_observation'][goal_idxs])
     }
     return goals
Ejemplo n.º 11
0
    def random_batch(self, batch_size):
        traj_i = np.random.choice(self.size, batch_size)
        trans_i = np.random.choice(self.traj_length - 1, batch_size)

        try:
            env = normalize_image(self.data['env'][traj_i, :])
        except:
            env = normalize_image(self.data['observations'][traj_i, 0, :])

        x_t = normalize_image(self.data['observations'][traj_i, trans_i, :])
        x_next = normalize_image(self.data['observations'][traj_i, trans_i + 1, :])

        episode_num = np.random.randint(0, self.size)
        episode_obs = normalize_image(self.data['observations'][episode_num, :8, :])

        data_dict = {
            'x_t': x_t,
            'x_next': x_next,
            'env': env,
            'actions': self.data['actions'][traj_i, trans_i, :],
            'episode_obs': episode_obs,
            'episode_acts': self.data['actions'][episode_num, :7, :],
        }
        return np_to_pytorch_batch(data_dict)
Ejemplo n.º 12
0
    def __getitem__(self, idx):
        traj_i = idx // self.traj_length
        trans_i = idx % self.traj_length

        x = Image.fromarray(self.data['observations'][traj_i, trans_i].reshape(48, 48, 3), mode='RGB')
        c = Image.fromarray(self.data['env'][traj_i].reshape(48, 48, 3), mode='RGB')

        # upsampling gives bad images so random resizing params set to 1 for now
        # crop = self.crop.get_params(c, (0.9, 0.9), (1, 1))
        #crop = self.crop.get_params(c, (1, 1), (1, 1))

        jitter = self.jitter.get_params((0.75,1.25), (0.9,1.1), (0.9,1.1), (-0.1,0.1))
        #jitter = self.jitter.get_params(0.5, 0.1, 0.1, 0.1)

        x, c = jitter(x), jitter(c)
        #c = jitter(F.resized_crop(c, crop[0], crop[1], crop[2], crop[3], (48, 48), Image.BICUBIC))
        x_t = normalize_image(np.array(x).flatten()).squeeze()
        env = normalize_image(np.array(c).flatten()).squeeze()

        data_dict = {
            'x_t': x_t,
            'env': env,
        }
        return data_dict
Ejemplo n.º 13
0
    def log_loss_under_uniform(self, model, data, priority_function_kwargs):
        import torch.nn.functional as F
        log_probs_prior = []
        log_probs_biased = []
        log_probs_importance = []
        kles = []
        mses = []
        for i in range(0, data.shape[0], self.batch_size):
            img = normalize_image(data[i:min(data.shape[0], i +
                                             self.batch_size), :])
            torch_img = ptu.from_numpy(img)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(
                torch_img)

            priority_function_kwargs['sampling_method'] = 'true_prior_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_prior = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'biased_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_biased = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'importance_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_importance = (log_p - log_q + log_d).mean()

            kle = model.kl_divergence(latent_distribution_params)
            mse = F.mse_loss(torch_img,
                             reconstructions,
                             reduction='elementwise_mean')
            mses.append(mse.item())
            kles.append(kle.item())
            log_probs_prior.append(log_prob_prior.item())
            log_probs_biased.append(log_prob_biased.item())
            log_probs_importance.append(log_prob_importance.item())

        logger.record_tabular("Uniform Data Log Prob (True Prior)",
                              np.mean(log_probs_prior))
        logger.record_tabular("Uniform Data Log Prob (Biased)",
                              np.mean(log_probs_biased))
        logger.record_tabular("Uniform Data Log Prob (Importance)",
                              np.mean(log_probs_importance))
        logger.record_tabular("Uniform Data KL", np.mean(kles))
        logger.record_tabular("Uniform Data MSE", np.mean(mses))
Ejemplo n.º 14
0
def add_heatmap_img_to_o_dict(env, agent, observation_key, full_o, v_function):
    o = full_o[observation_key]
    goal_grid = env.get_mesh_grid(observation_key)
    o_grid = np.c_[np.tile(o, (len(goal_grid), 1)), goal_grid]

    v_vals = v_function(o_grid)
    v_vals = ptu.get_numpy(v_vals)

    # Assuming square observation space, how many points on x axis
    vary_len = int(len(goal_grid)**(1 / 2))

    vmin = v_vals.min()
    vmax = v_vals.max()
    full_o['v_vals'] = normalize_image(
        env.get_image_plt(v_vals.reshape((vary_len, vary_len)),
                          imsize=env.imsize,
                          vmin=vmin,
                          vmax=vmax))
Ejemplo n.º 15
0
def viz_rewards(model, id, savefile=None):
    clip_data = load_clip(id)
    t_clip = transform_batch(clip_data)
    batch = ptu.from_numpy(normalize_image(t_clip))
    batch = batch.to("cuda")

    z = ptu.get_numpy(model.encoder(batch).cpu())

    z_goal = z[-1, :]
    distances = []
    for t in range(len(z)):
        d = np.linalg.norm(z[t, :] - z_goal)
        distances.append(d)

    plt.figure()
    plt.plot(distances)
    if savefile:
        plt.savefig(savefile)

    return np.array(distances)
Ejemplo n.º 16
0
    def _dump_imgs_and_reconstructions(self, idxs, filename):
        imgs = []
        recons = []
        for i in idxs:
            img_np = self.train_dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(self.input_channels, self.imsize,
                                 self.imsize).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.imsize,
                              self.imsize).transpose(1, 2)
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        save_file = osp.join(self.log_dir, filename)
        save_image(
            all_imgs.data,
            save_file,
            nrow=len(idxs),
        )
Ejemplo n.º 17
0
    def dump_uniform_imgs_and_reconstructions(self, dataset, epoch):
        idxs = np.random.choice(range(dataset.shape[0]), 4)
        filename = 'uniform{}.png'.format(epoch)
        imgs = []
        recons = []
        for i in idxs:
            img_np = dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(self.input_channels, self.imsize,
                                 self.imsize).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.imsize,
                              self.imsize).transpose(1, 2)
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        save_file = osp.join(self.log_dir, filename)
        save_image(
            all_imgs.data,
            save_file,
            nrow=4,
        )
Ejemplo n.º 18
0
 def __getitem__(self, idxs):
     samples = self.dataset[idxs, :]
     if self.should_normalize:
         samples = normalize_image(samples)
     return np.float32(samples)
Ejemplo n.º 19
0
    def __init__(
        self,
        model,
        batch_size=128,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        linearity_weight=0.0,
        distance_weight=0.0,
        loss_weights=None,
        use_linear_dynamics=False,
        use_parallel_dataloading=False,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
        key_to_reconstruct='observations',
        num_epochs=None,
    ):
        #TODO:steven fix pickling
        assert not use_parallel_dataloading, "Have to fix pickling the dataloaders first"

        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        self.num_epochs = num_epochs
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot
        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )

        self.key_to_reconstruct = key_to_reconstruct
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.linearity_weight = linearity_weight
        self.distance_weight = distance_weight
        self.loss_weights = loss_weights

        self.use_linear_dynamics = use_linear_dynamics
        self._extra_stats_to_log = None

        # stateful tracking variables, reset every epoch
        self.eval_statistics = collections.defaultdict(list)
        self.eval_data = collections.defaultdict(list)
        self.num_batches = 0
Ejemplo n.º 20
0
def get_clip_as_batch(id, max_frames=-1):
    clip_data = load_clip(id, max_frames)
    t_clip = transform_batch(clip_data)
    batch = ptu.from_numpy(normalize_image(t_clip))
    # batch = batch.to("cuda")
    return batch
Ejemplo n.º 21
0
 def _reconstruction_squared_error_np_to_np(self, np_imgs):
     torch_input = ptu.from_numpy(normalize_image(np_imgs))
     recons, *_ = self.model(torch_input)
     error = torch_input - recons
     return ptu.get_numpy((error**2).sum(dim=1))
Ejemplo n.º 22
0
 def _kl_np_to_np(self, np_imgs):
     torch_input = ptu.from_numpy(normalize_image(np_imgs))
     mu, log_var = self.model.encode(torch_input)
     return ptu.get_numpy(
         -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))