Beispiel #1
0
    def train_epoch(self, epoch):
        self.model.train()
        losses = []
        per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1]))
        for batch in range(self.num_batches):
            inputs_np, labels_np = self.random_batch(
                self.X_train, self.y_train, batch_size=self.batch_size)
            inputs, labels = ptu.Variable(
                ptu.from_numpy(inputs_np)), ptu.Variable(
                    ptu.from_numpy(labels_np))
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            losses.append(loss.data[0])
            per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs - labels),
                                            2),
                                   axis=0)
            per_dim_losses[batch] = per_dim_loss

        logger.record_tabular("train/epoch", epoch)
        logger.record_tabular("train/loss", np.mean(np.array(losses)))
        for i in range(self.y_train.shape[1]):
            logger.record_tabular("train/dim " + str(i) + " loss",
                                  np.mean(per_dim_losses[:, i]))
Beispiel #2
0
    def test_values_correct_batch_full_matrix(self):
        """
        Check y = x^T diag(d) x
        """
        x = np.array([
            [2, 7],
            [2, 7],
        ])
        M = np.array([
            [
                [2, -1],
                [2, 1],
            ],
            [
                [0, -.1],
                [.2, .1],
            ]
        ])
        expected = np.array([
            [71],  # 2^2 * 2 + 7^2 * 1 + 2*7*(2-1) = 8 + 49 + 14 = 71
            [6.3],  # .2^2 * 0 + .7^2 * 1 + .2*.7*(2-1) = 0 + 4.9 + 1.4 = 6.3
        ])

        x_var = ptu.from_numpy(x).float()
        M_var = ptu.from_numpy(M).float()
        result_var = ptu.batch_square_vector(vector=x_var, M=M_var)
        result = ptu.get_numpy(result_var)

        self.assertNpAlmostEqual(expected, result)
Beispiel #3
0
    def test_values_correct_batches_diag(self):
        """
        Check y = x^T diag(d) x
        batch-wise
        """
        x = np.array([
            [1, 1],
            [2, 1],
        ])
        M = np.array([
            [
                [3, 0],
                [0, -1],
            ],
            [
                [1, 0],
                [0, 1],
            ]
        ])

        expected = np.array([
            [2],  # 1^2 * 3 + 1^1 * (-1) = 2
            [5],  # 2^2 * 1 + 1^1 * (1) = 5
        ])
        x_var = ptu.from_numpy(x).float()
        M_var = ptu.from_numpy(M).float()
        result_var = ptu.batch_square_vector(vector=x_var, M=M_var)
        result = ptu.get_numpy(result_var)

        self.assertNpAlmostEqual(expected, result)
Beispiel #4
0
 def get_debug_batch(self, train=True):
     dataset = self.train_dataset if train else self.test_dataset
     X, Y = dataset
     ind = np.random.randint(0, Y.shape[0], self.batch_size)
     X = X[ind, :]
     Y = Y[ind, :]
     return ptu.from_numpy(X), ptu.from_numpy(Y)
    def test_huber_loss_delta_3(self):
        criterion = modules.HuberLoss(3)

        x = np.array([
            [0],
        ])
        x_hat = np.array([
            [5],
        ])
        expected_loss = np.array([
            3 * (5 - 3 / 2),
        ])

        x_var = ptu.Variable(ptu.from_numpy(x).float())
        x_hat_var = ptu.Variable(ptu.from_numpy(x_hat).float())
        result_var = criterion(x_var, x_hat_var)
        result = ptu.get_numpy(result_var)
        self.assertNpAlmostEqual(expected_loss, result)

        x = np.array([
            [4],
        ])
        x_hat = np.array([
            [6],
        ])
        expected_loss = np.array([
            0.5 * 2 * 2,
        ])

        x_var = ptu.Variable(ptu.from_numpy(x).float())
        x_hat_var = ptu.Variable(ptu.from_numpy(x_hat).float())
        result_var = criterion(x_var, x_hat_var)
        result = ptu.get_numpy(result_var)
        self.assertNpAlmostEqual(expected_loss, result)
    def test_log_prob_gradient(self):
        """
        Same thing. Tanh term drops out since tanh has no params
        d/d mu log f_X(x) = - 2 (mu - x)
        d/d sigma log f_X(x) = 1/sigma^3 - 1/sigma
        :return:
        """
        mean_var = ptu.from_numpy(np.array([0]), requires_grad=True)
        std_var = ptu.from_numpy(np.array([0.25]), requires_grad=True)
        tanh_normal = TanhNormal(mean_var, std_var)
        z = ptu.from_numpy(np.array([1]))
        x = torch.tanh(z)
        log_prob = tanh_normal.log_prob(x, pre_tanh_value=z)

        gradient = ptu.from_numpy(np.array([1]))

        log_prob.backward(gradient)

        self.assertNpArraysEqual(
            ptu.get_numpy(mean_var.grad),
            np.array([16]),
        )
        self.assertNpArraysEqual(
            ptu.get_numpy(std_var.grad),
            np.array([4**3 - 4]),
        )
Beispiel #7
0
	def run_experiment(self):
		all_imgs = []
		policy = OUStrategy(env.action_space)
		for i in range(self.num_episodes):
			state = self.env.reset()
			img = ptu.from_numpy(state['image_observation']).view(1, 6912)
			latent_state = self.vae.encode(img)[0]

			true_curr = state['image_observation'] * 255.0
			all_imgs.append(ptu.from_numpy(true_curr).view(3, 48, 48))

			actions = []
			for j in range(self.episode_len):
				u = policy.get_action_from_raw_action(env.action_space.sample())
				actions.append(u)
				state = self.env.step(u)[0]
				true_curr = state['image_observation'] * 255.0
				all_imgs.append(ptu.from_numpy(true_curr).view(3, 48, 48))

			pred_curr = self.vae.decode(latent_state)[0] * 255.0
			all_imgs.append(pred_curr.view(3, 48, 48))

			for j in range(self.episode_len):
				u = ptu.from_numpy(actions[j]).view(1, 2)
				latent_state = self.vae.process_dynamics(latent_state, u)
				pred_curr = self.vae.decode(latent_state)[0] * 255.0
				all_imgs.append(pred_curr.view(3, 48, 48))

		all_imgs = torch.stack(all_imgs)
		save_image(
	        all_imgs.data,
	        "/home/khazatsky/rail/data/rail-khazatsky/sasha/dynamics_visualizer/dynamics.png",
	        nrow=self.episode_len + 1,
	    )
Beispiel #8
0
    def test_epoch(
        self,
        epoch,
    ):
        self.model.eval()
        val_losses = []
        per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1]))
        for batch in range(self.num_batches):
            inputs_np, labels_np = self.random_batch(
                self.X_test, self.y_test, batch_size=self.batch_size)
            inputs, labels = ptu.Variable(
                ptu.from_numpy(inputs_np)), ptu.Variable(
                    ptu.from_numpy(labels_np))
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            val_losses.append(loss.data[0])
            per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs - labels),
                                            2),
                                   axis=0)
            per_dim_losses[batch] = per_dim_loss

        logger.record_tabular("test/epoch", epoch)
        logger.record_tabular("test/loss", np.mean(np.array(val_losses)))
        for i in range(self.y_train.shape[1]):
            logger.record_tabular("test/dim " + str(i) + " loss",
                                  np.mean(per_dim_losses[:, i]))
        logger.dump_tabular()
Beispiel #9
0
    def load_dataset(self, dataset_path):
        dataset = load_local_or_remote_file(dataset_path)
        dataset = dataset.item()

        observations = dataset['observations']
        actions = dataset['actions']

        # dataset['observations'].shape # (2000, 50, 6912)
        # dataset['actions'].shape # (2000, 50, 2)
        # dataset['env'].shape # (2000, 6912)
        N, H, imlength = observations.shape

        self.vae.eval()
        for n in range(N):
            x0 = ptu.from_numpy(dataset['env'][n:n + 1, :] / 255.0)
            x = ptu.from_numpy(observations[n, :, :] / 255.0)
            latents = self.vae.encode(x, x0, distrib=False)

            r1, r2 = self.vae.latent_sizes
            conditioning = latents[0, r1:]
            goal = torch.cat(
                [ptu.randn(self.vae.latent_sizes[0]), conditioning])
            goal = ptu.get_numpy(goal)  # latents[-1, :]

            latents = ptu.get_numpy(latents)
            latent_delta = latents - goal
            distances = np.zeros((H - 1, 1))
            for i in range(H - 1):
                distances[i, 0] = np.linalg.norm(latent_delta[i + 1, :])

            terminals = np.zeros((H - 1, 1))
            # terminals[-1, 0] = 1
            path = dict(
                observations=[],
                actions=actions[n, :H - 1, :],
                next_observations=[],
                rewards=-distances,
                terminals=terminals,
            )

            for t in range(H - 1):
                # reward = -np.linalg.norm(latent_delta[i, :])

                obs = dict(
                    latent_observation=latents[t, :],
                    latent_achieved_goal=latents[t, :],
                    latent_desired_goal=goal,
                )
                next_obs = dict(
                    latent_observation=latents[t + 1, :],
                    latent_achieved_goal=latents[t + 1, :],
                    latent_desired_goal=goal,
                )

                path['observations'].append(obs)
                path['next_observations'].append(next_obs)

            # import ipdb; ipdb.set_trace()
            self.replay_buffer.add_path(path)
    def forward_model_error(self, next_vae_obs, indices):
        obs = self._obs[self.observation_key][indices]
        next_obs = self._next_obs[self.observation_key][indices]
        actions = self._actions[indices]

        state_action_pair = ptu.from_numpy(np.c_[obs, actions])
        prediction = self.dynamics_model(state_action_pair)
        mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs))
        return ptu.get_numpy(mse)
Beispiel #11
0
    def refresh_weights(self):
        if self.actual_weights is None or (self.counter % self.weight_update_period == 0 and self._use_weights):
            batch_size = 1024
            next_idx = min(batch_size, self._size)

            cur_idx = 0
            while cur_idx < self._size:
                idxs = np.arange(cur_idx, next_idx)
                obs = ptu.from_numpy(self._observations[idxs])
                actions = ptu.from_numpy(self._actions[idxs])

                q1_pred = self.qf1(obs, actions)
                q2_pred = self.qf2(obs, actions)

                new_obs_actions, policy_mean, policy_log_std, log_pi, entropy, policy_std, mean_action_log_prob, pretanh_value, dist = self.policy(
                    obs, reparameterize=True, return_log_prob=True,
                )

                qf1_new_actions = self.qf1(obs, new_obs_actions)
                qf2_new_actions = self.qf2(obs, new_obs_actions)
                q_new_actions = torch.min(
                    qf1_new_actions,
                    qf2_new_actions,
                )

                if self.awr_use_mle_for_vf:
                    v_pi = self.qf1(obs, policy_mean)
                else:
                    v_pi = self.qf1(obs, new_obs_actions)

                if self.awr_sample_actions:
                    u = new_obs_actions
                    if self.awr_min_q:
                        q_adv = q_new_actions
                    else:
                        q_adv = qf1_new_actions
                else:
                    u = actions
                    if self.awr_min_q:
                        q_adv = torch.min(q1_pred, q2_pred)
                    else:
                        q_adv = q1_pred

                advantage = q_adv - v_pi

                self.weights[idxs] = (advantage/self.beta).cpu().detach()

                cur_idx = next_idx
                next_idx += batch_size
                next_idx = min(next_idx, self._size)

            self.actual_weights = ptu.get_numpy(F.softmax(self.weights[:self._size], dim=0))
            p_sum = np.sum(self.actual_weights)
            assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum)
            self.actual_weights = (self.actual_weights/p_sum).flatten()
        self.counter += 1
def compute_log_p_log_q_log_d(model,
                              batch,
                              decoder_distribution='bernoulli',
                              num_latents_to_sample=1,
                              sampling_method='importance_sampling'):
    x_0 = ptu.from_numpy(batch["x_0"])
    data = batch["x_t"]
    imgs = ptu.from_numpy(data)
    latent_distribution_params = model.encode(imgs, x_0)
    r1 = model.latent_sizes[0]
    batch_size = data.shape[0]
    log_p, log_q, log_d = ptu.zeros(
        (batch_size, num_latents_to_sample)), ptu.zeros(
            (batch_size, num_latents_to_sample)), ptu.zeros(
                (batch_size, num_latents_to_sample))
    true_prior = Normal(ptu.zeros((batch_size, r1)), ptu.ones(
        (batch_size, r1)))
    mus, logvars = latent_distribution_params[:2]
    for i in range(num_latents_to_sample):
        if sampling_method == 'importance_sampling':
            latents = model.rsample(latent_distribution_params[:2])
        elif sampling_method == 'biased_sampling':
            latents = model.rsample(latent_distribution_params[:2])
        elif sampling_method == 'true_prior_sampling':
            latents = true_prior.rsample()
        else:
            raise EnvironmentError('Invalid Sampling Method Provided')

        stds = logvars.exp().pow(.5)
        vae_dist = Normal(mus, stds)
        log_p_z = true_prior.log_prob(latents).sum(dim=1)
        log_q_z_given_x = vae_dist.log_prob(latents).sum(dim=1)

        if len(latent_distribution_params) == 3:  # add conditioning for CVAEs
            latents = torch.cat((latents, latent_distribution_params[2]),
                                dim=1)

        if decoder_distribution == 'bernoulli':
            decoded = model.decode(latents)[0]
            log_d_x_given_z = torch.log(imgs * decoded + (1 - imgs) *
                                        (1 - decoded) + 1e-8).sum(dim=1)
        elif decoder_distribution == 'gaussian_identity_variance':
            _, obs_distribution_params = model.decode(latents)
            dec_mu, dec_logvar = obs_distribution_params
            dec_var = dec_logvar.exp()
            decoder_dist = Normal(dec_mu, dec_var.pow(.5))
            log_d_x_given_z = decoder_dist.log_prob(imgs).sum(dim=1)
        else:
            raise EnvironmentError('Invalid Decoder Distribution Provided')

        log_p[:, i] = log_p_z
        log_q[:, i] = log_q_z_given_x
        log_d[:, i] = log_d_x_given_z
    return log_p, log_q, log_d
    def _reconstruct_img(self, flat_img, x_0=None):
        self.vae.eval()

        if x_0 is None:
            x_0 = ptu.from_numpy(self._initial_obs["image_observation"][None])
        else:
            x_0 = ptu.from_numpy(x_0.reshape(1, -1))

        latent = self.vae.encode(ptu.from_numpy(flat_img.reshape(1,-1)), x_0, distrib=False)
        reconstructions, _ = self.vae.decode(latent)
        imgs = ptu.get_numpy(reconstructions)
        imgs = imgs.reshape(
            1, self.input_channels, self.imsize, self.imsize
        )
        return imgs[0]
    def reconstruction_mse(self, next_vae_obs, indices):
        torch_input = ptu.from_numpy(next_vae_obs)
        recon_next_vae_obs, _, _ = self.vae(torch_input)

        error = torch_input - recon_next_vae_obs
        mse = torch.sum(error ** 2, dim=1)
        return ptu.get_numpy(mse)
    def log_prob(self, value, pre_tanh_value=None):
        """
        Adapted from
        https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L73

        This formula is mathematically equivalent to log(1 - tanh(x)^2).

        Derivation:

        log(1 - tanh(x)^2)
         = log(sech(x)^2)
         = 2 * log(sech(x))
         = 2 * log(2e^-x / (e^-2x + 1))
         = 2 * (log(2) - x - log(e^-2x + 1))
         = 2 * (log(2) - x - softplus(-2x))

        :param value: some value, x
        :param pre_tanh_value: arctanh(x)
        :return:
        """
        if pre_tanh_value is None:
            value = torch.clamp(value, -0.999999, 0.999999)
            # pre_tanh_value = torch.log(
            # (1+value) / (1-value)
            # ) / 2
            pre_tanh_value = torch.log(1 + value) / 2 - torch.log(1 -
                                                                  value) / 2
            # ) / 2
        return self.normal.log_prob(pre_tanh_value) - 2. * (
            ptu.from_numpy(np.log([2.])) - pre_tanh_value -
            torch.nn.functional.softplus(-2. * pre_tanh_value))
 def get_dataset_stats(self, data):
     torch_input = ptu.from_numpy(normalize_image(data))
     mus, log_vars = self.model.encode(torch_input)
     mus = ptu.get_numpy(mus)
     mean = np.mean(mus, axis=0)
     std = np.std(mus, axis=0)
     return mus, mean, std
Beispiel #17
0
    def get_batch(self, test_data=False, epoch=None):
        if self.use_parallel_dataloading:
            if test_data:
                dataloader = self.test_dataloader
            else:
                dataloader = self.train_dataloader
            samples = next(dataloader).to(ptu.device)
            return samples

        dataset = self.test_dataset if test_data else self.train_dataset
        skew = False
        if epoch is not None:
            skew = (self.start_skew_epoch < epoch)
        if not test_data and self.skew_dataset and skew:
            probs = self._train_weights / np.sum(self._train_weights)
            ind = np.random.choice(
                len(probs),
                self.batch_size,
                p=probs,
            )
        else:
            ind = np.random.randint(0, len(dataset), self.batch_size)
        samples = normalize_image(dataset[ind, :])
        if self.normalize:
            samples = ((samples - self.train_data_mean) + 1) / 2
        if self.background_subtract:
            samples = samples - self.train_data_mean
        return ptu.from_numpy(samples)
 def compute_rewards(self, actions, obs):
     self.vae.eval()
     # TODO: implement log_prob/mdist
     if self.reward_type == 'latent_distance':
         achieved_goals = obs['latent_achieved_goal']
         desired_goals = obs['latent_desired_goal']
         dist = np.linalg.norm(desired_goals - achieved_goals, ord=self.norm_order, axis=1)
         return -dist
     elif self.reward_type == 'vectorized_latent_distance':
         achieved_goals = obs['latent_achieved_goal']
         desired_goals = obs['latent_desired_goal']
         return -np.abs(desired_goals - achieved_goals)
     elif self.reward_type == 'latent_sparse':
         achieved_goals = obs['latent_achieved_goal']
         desired_goals = obs['latent_desired_goal']
         dist = np.linalg.norm(desired_goals - achieved_goals, ord=self.norm_order, axis=1)
         reward = 0 if dist < self.epsilon else -1
         return reward
     elif self.reward_type == 'success_prob':
         desired_goals = self._decode(obs['latent_desired_goal'])
         achieved_goals = self.vae.decode(ptu.from_numpy(obs['latent_achieved_goal']))
         prob = self.vae.logprob(desired_goals, achieved_goals).exp()
         reward = prob
         return 1/0 #not sure about this anymore, number will be too low
     elif self.reward_type == 'state_distance':
         achieved_goals = obs['state_achieved_goal']
         desired_goals = obs['state_desired_goal']
         return - np.linalg.norm(desired_goals - achieved_goals, ord=self.norm_order, axis=1)
     elif self.reward_type == 'wrapped_env':
         return self.wrapped_env.compute_rewards(actions, obs)
     else:
         raise NotImplementedError
Beispiel #19
0
    def __init__(
            self,
            env,
            qf,
            replay_buffer,
            num_epochs=100,
            num_batches_per_epoch=100,
            qf_learning_rate=1e-3,
            batch_size=100,
            num_unique_batches=1000,
    ):
        self.qf = qf
        self.replay_buffer = replay_buffer
        self.env = env
        self.num_epochs = num_epochs
        self.num_batches_per_epoch = num_batches_per_epoch
        self.qf_learning_rate = qf_learning_rate
        self.batch_size = batch_size
        self.num_unique_batches = num_unique_batches

        self.qf_optimizer = optim.Adam(self.qf.parameters(),
                                       lr=self.qf_learning_rate)
        self.batch_iterator = None
        self.discount = ptu.Variable(
            ptu.from_numpy(np.zeros((batch_size, 1))).float()
        )
        self.mode_to_batch_iterator = {}
    def test_log_prob_value_give_pre_tanh_value(self):
        tanh_normal = TanhNormal(0, 1)
        z_np = np.array([1])
        x_np = np.tanh(z_np)
        z = ptu.from_numpy(z_np)
        x = ptu.from_numpy(x_np)
        log_prob = tanh_normal.log_prob(x, pre_tanh_value=z)

        log_prob_np = ptu.get_numpy(log_prob)
        log_prob_expected = (
            np.log(np.array([1 / np.sqrt(2 * np.pi)])) - 0.5  # from Normal
            - np.log(1 - x_np**2))
        self.assertNpArraysEqual(
            log_prob_expected,
            log_prob_np,
        )
Beispiel #21
0
 def get_action_and_P_matrix(self, obs):
     obs = np.expand_dims(obs, axis=0)
     obs = Variable(ptu.from_numpy(obs).float(), requires_grad=False)
     action, _, _, P = self.__call__(obs, None, return_P=True)
     action = action.squeeze(0)
     P = P.squeeze(0)
     return ptu.get_numpy(action), ptu.get_numpy(P)
 def get_batch(self, train=True):
     dataset = self.train_dataset if train else self.test_dataset
     ind = np.random.randint(0, len(dataset), self.batch_size)
     samples = dataset[ind, :]
     # if self.normalize:
     #     samples = ((samples - self.train_data_mean) + 1) / 2
     return ptu.from_numpy(samples)
    def encode(self, input):
        h = self.encoder(input)

        if self.use_softmax:
            h = torch.exp(h / self.temperature)
            # sum over x, then sum over y
            total = h.sum(2).sum(2).view(h.shape[0], h.shape[1], 1, 1)
            h = h / total

        maps_x = torch.sum(h, 2)
        maps_y = torch.sum(h, 3)
        weights = ptu.from_numpy(
            np.arange(maps_x.shape[-1]) / maps_x.shape[-1])
        fp_x = torch.sum(maps_x * weights, 2)
        fp_y = torch.sum(maps_y * weights, 2)

        h = torch.cat([fp_x, fp_y], 1)

        mu = self.fc1(h) if self.encode_feature_points else h

        if self.log_min_variance is None:
            logvar = self.fc2(h)
        else:
            logvar = self.log_min_variance + torch.abs(self.fc2(h))
        return (mu, logvar)
Beispiel #24
0
 def _decode(self, latents):
     #MAKE INTEGER
     self.vae.eval()
     latents = ptu.from_numpy(latents)
     reconstructions = self.vae.decode(latents, cont=True)
     decoded = ptu.get_numpy(reconstructions)
     decoded = np.clip(decoded, 0, 1)
     return decoded
    def random_vae_training_data(self, batch_size, epoch):
        # epoch no longer needed. Using self.skew in sample_weighted_indices
        # instead.
        weighted_idxs = self.sample_weighted_indices(batch_size, )

        next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs]
        observations = ptu.from_numpy(next_image_obs)

        x_0_indices = (weighted_idxs // 100) * 100
        x_0 = self._next_obs[self.decoded_obs_key][x_0_indices]
        x_0 = ptu.from_numpy(x_0)

        return dict(
            observations=observations,
            x_t=observations,
            x_0=x_0,
        )
 def bernoulli_inv_prob(self, next_vae_obs, indices):
     torch_input = ptu.from_numpy(next_vae_obs)
     recon_next_vae_obs, _, _ = self.vae(torch_input)
     prob = (
             torch_input * recon_next_vae_obs
             + (1 - torch_input) * (1 - recon_next_vae_obs)
     ).prod(dim=1)
     return ptu.get_numpy(1 / prob)
    def test_batch_square_diagonal_module(self):
        x = np.array([
            [2, 7],
        ])
        diag_vals = np.array([
            [2, 1],
        ])
        expected = np.array([[57]  # 2^2 * 2 + 7^2 * 1 = 8 + 49 = 57
                             ])

        x_var = ptu.Variable(ptu.from_numpy(x).float())
        diag_var = ptu.Variable(ptu.from_numpy(diag_vals).float())
        net = modules.BatchSquareDiagonal(2)
        result_var = net(vector=x_var, diag_values=diag_var)
        result = ptu.get_numpy(result_var)

        self.assertNpAlmostEqual(expected, result)
Beispiel #28
0
 def save_image_util(self, img, name):
     im = img.reshape(-1, 3, 500, 300).transpose([0, 1, 3, 2]) / 255.0
     im = im[:, :, 60:, 60:500]
     pt_img = ptu.from_numpy(im).view(-1, 3, epic.CROP_HEIGHT,
                                      epic.CROP_WIDTH)
     save_image(pt_img.data.cpu(),
                '%s_%d.png' % (name, self.episode_num),
                nrow=1)
 def _reconstruct_img(self, flat_img):
     self.vae.eval()
     latent_distribution_params = self.vae.encode(ptu.from_numpy(flat_img.reshape(1,-1)))
     reconstructions, _ = self.vae.decode(latent_distribution_params[0])
     imgs = ptu.get_numpy(reconstructions)
     imgs = imgs.reshape(
         1, self.input_channels, self.imsize, self.imsize
     )
     return imgs[0]
    def get_train_dict(self, subtraj_batch):
        subtraj_rewards = subtraj_batch['rewards']
        subtraj_rewards_np = ptu.get_numpy(subtraj_rewards).squeeze(2)
        returns = np_util.batch_discounted_cumsum(subtraj_rewards_np,
                                                  self.discount)
        returns = np.expand_dims(returns, 2)
        returns = np.ascontiguousarray(returns).astype(np.float32)
        returns = ptu.Variable(ptu.from_numpy(returns))
        subtraj_batch['returns'] = returns
        batch = flatten_subtraj_batch(subtraj_batch)
        # rewards = batch['rewards']
        returns = batch['returns']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy operations.
        """
        policy_actions = self.policy(obs)
        q = self.qf(obs, policy_actions)
        policy_loss = -q.mean()
        """
        Critic operations.
        """
        next_actions = self.policy(next_obs)
        # TODO: try to get this to work
        # next_actions = None
        q_target = self.target_qf(
            next_obs,
            next_actions,
        )
        # y_target = self.reward_scale * rewards + (1. - terminals) * self.discount * v_target
        batch_size = q_target.size()[0]
        discount_factors = self.discount_factors.repeat(
            batch_size // self.subtraj_length,
            1,
        )
        y_target = self.reward_scale * returns + (
            1. - terminals) * discount_factors * q_target
        # noinspection PyUnresolvedReferences
        y_target = y_target.detach()
        y_pred = self.qf(obs, actions)
        bellman_errors = (y_pred - y_target)**2
        qf_loss = self.qf_criterion(y_pred, y_target)

        return OrderedDict([
            ('Policy Actions', policy_actions),
            ('Policy Loss', policy_loss),
            ('Policy Q Values', q),
            ('Target Y', y_target),
            ('Predicted Y', y_pred),
            ('Bellman Errors', bellman_errors),
            ('Y targets', y_target),
            ('Y predictions', y_pred),
            ('QF Loss', qf_loss),
        ])