Exemple #1
0
    def _do_training(self):
        batch, weights, idxes = self.get_batch()
        self.t += 1
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = torch.autograd.Variable(ptu.from_numpy(
            weights.reshape(weights.shape + (1,))
            ), requires_grad=False)

        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        noise = torch.normal(
            torch.zeros_like(next_actions),
            self.target_policy_noise,
        )
        noise = torch.clamp(
            noise,
            -self.target_policy_noise_clip,
            self.target_policy_noise_clip
        )
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target) ** 2

        # IS 
        bellman_errors_1 = weights * bellman_errors_1

        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target) ** 2

        # IS
        bellman_errors_2 = weights * bellman_errors_2

        qf2_loss = bellman_errors_2.mean()

        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        # Update priorities
        EPS = 10 ** -6
        td_error_1 = np.abs(ptu.get_numpy(q1_pred) - q_target)
        td_error_2 = np.abs(ptu.get_numpy(q2_pred) - q_target)
        new_priorities = (td_error_1 + td_error_2) / 2 + EPS
        self.replay_buffer.update_priorities(idxes, new_priorities)

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)
            policy_loss = - q_output.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        """
        Save some statistics for eval using just one batch.
        """
        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = - q_output.mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Bellman Errors 1',
                ptu.get_numpy(bellman_errors_1),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Bellman Errors 2',
                ptu.get_numpy(bellman_errors_2),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy Action',
                ptu.get_numpy(policy_actions),
            ))
Exemple #2
0
def _elem_or_tuple_to_variable(elem_or_tuple):
    if isinstance(elem_or_tuple, tuple):
        return tuple(_elem_or_tuple_to_variable(e) for e in elem_or_tuple)
    return ptu.from_numpy(elem_or_tuple).float()
Exemple #3
0
 def encode_np(self, imgs, cond):
     return ptu.get_numpy(
         self.encode(ptu.from_numpy(imgs), ptu.from_numpy(cond)))
Exemple #4
0
def torch_ify(np_array_or_other):
    if isinstance(np_array_or_other, np.ndarray):
        return ptu.from_numpy(np_array_or_other)
    else:
        return np_array_or_other
Exemple #5
0
 def v_function(obs):
     action = policy.get_actions(obs)
     obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
     return qf1(obs, action)
Exemple #6
0
    def get_train_dict(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        batch_size = obs.size()[0]
        """
        Policy operations.
        """
        policy_actions = self.policy(obs)
        z_output = self.qf(obs, policy_actions)  # BATCH_SIZE x NUM_BINS
        q_output = (z_output * self.create_atom_values(batch_size)).sum(1)
        policy_loss = -q_output.mean()
        """
        Critic operations.
        """
        next_actions = self.target_policy(next_obs)
        target_qf_histogram = self.target_qf(
            next_obs,
            next_actions,
        )
        # z_target = ptu.Variable(torch.zeros(self.batch_size, self.num_bins))

        rewards_batch = rewards.repeat(1, self.num_bins)
        terminals_batch = terminals.repeat(1, self.num_bins)
        projected_returns = (self.reward_scale * rewards_batch +
                             (1. - terminals_batch) * self.discount *
                             self.create_atom_values(batch_size))
        projected_returns = torch.clamp(projected_returns, self.returns_min,
                                        self.returns_max)
        bin_values = (projected_returns - self.returns_min) / self.bin_width
        lower_bin_indices = torch.floor(bin_values)
        upper_bin_indices = torch.ceil(bin_values)
        lower_bin_deltas = target_qf_histogram * (upper_bin_indices -
                                                  bin_values)
        upper_bin_deltas = target_qf_histogram * (bin_values -
                                                  lower_bin_indices)

        z_target_np = np.zeros((batch_size, self.num_bins))
        lower_deltas_np = lower_bin_deltas.data.numpy()
        upper_deltas_np = upper_bin_deltas.data.numpy()
        lower_idxs_np = lower_bin_indices.data.numpy().astype(int)
        upper_idxs_np = upper_bin_indices.data.numpy().astype(int)
        for batch_i in range(self.batch_size):
            for bin_i in range(self.num_bins):
                z_target_np[batch_i,
                            bin_i] += (lower_deltas_np[batch_i,
                                                       lower_idxs_np[batch_i,
                                                                     bin_i]])
                z_target_np[batch_i,
                            bin_i] += (upper_deltas_np[batch_i,
                                                       upper_idxs_np[batch_i,
                                                                     bin_i]])
        z_target = ptu.Variable(ptu.from_numpy(z_target_np).float())

        # for j in range(self.num_bins):
        #     import ipdb; ipdb.set_trace()
        #     atom_value = self.atom_values_batch[:, j:j+1]
        #     projected_returns = self.reward_scale * rewards + (1. - terminals) * self.discount * (
        #         atom_value
        #     )
        #     bin_values = (projected_returns - self.returns_min) / self.bin_width
        #     lower_bin_indices = torch.floor(bin_values)
        #     upper_bin_indices = torch.ceil(bin_values)
        #     lower_bin_deltas = target_qf_histogram[:, j:j+1] * (
        #         upper_bin_indices - bin_values
        #     )
        #     upper_bin_deltas = target_qf_histogram[:, j:j+1] * (
        #         bin_values - lower_bin_indices
        #     )
        #     new_lower_bin_values = torch.gather(
        #         z_target, 1, lower_bin_indices.long().data
        #     ) + lower_bin_deltas
        #     new_upper_bin_values = torch.gather(
        #         z_target, 1, upper_bin_indices.long().data
        #     ) + upper_bin_deltas

        # noinspection PyUnresolvedReferences
        z_pred = self.qf(obs, actions)
        qf_loss = -(z_target * torch.log(z_pred)).sum(1).mean(0)

        return OrderedDict([
            ('Policy Actions', policy_actions),
            ('Policy Loss', policy_loss),
            ('QF Outputs', q_output),
            ('Z targets', z_target),
            ('Z predictions', z_pred),
            ('QF Loss', qf_loss),
        ])
Exemple #7
0
 def _kl_np_to_np(self, np_imgs):
     torch_input = ptu.from_numpy(normalize_image(np_imgs))
     mu, log_var = self.model.encode(torch_input)
     return ptu.get_numpy(
         -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))
Exemple #8
0
def visualize_rollout(
    env,
    world_model,
    logdir,
    max_path_length,
    low_level_primitives,
    policy=None,
    img_size=64,
    num_rollouts=4,
    use_raps_obs=False,
    use_true_actions=True,
):
    file_path = logdir + "/"
    os.makedirs(file_path, exist_ok=True)
    print(
        f"Generating Imagination Reconstructions No Intermediate Obs: {use_raps_obs} Actual Actions: {use_true_actions}"
    )
    file_suffix = f"imagination_reconstructions_raps_obs_{use_raps_obs}_actual_actions_{use_true_actions}.png"
    file_path += file_suffix
    img_shape = (img_size, img_size, 3)
    reconstructions = np.zeros(
        (num_rollouts, max_path_length + 1, *img_shape),
        dtype=np.uint8,
    )
    obs = np.zeros(
        (num_rollouts, max_path_length + 1, *img_shape),
        dtype=np.uint8,
    )
    for rollout in range(num_rollouts):
        for step in range(0, max_path_length + 1):
            if step == 0:
                observation = env.reset()
                new_img = ptu.from_numpy(observation)
                policy.reset(observation.reshape(1, -1))
                reward = 0
                if low_level_primitives:
                    policy_o = (None, observation.reshape(1, -1))
                else:
                    policy_o = observation.reshape(1, -1)
                # hack to pass typing checks
                vis = convert_img_to_save(
                    world_model.get_image_from_obs(
                        torch.from_numpy(observation.reshape(1, -1))).numpy())
                add_text(vis, "Ground Truth", (1, 60), 0.25, (0, 255, 0))
            else:
                high_level_action, state = policy.get_action(
                    policy_o, use_raps_obs, use_true_actions)
                state = state["state"]
                observation, reward, done, info = env.step(
                    high_level_action[0], )
                if low_level_primitives:
                    low_level_obs = np.expand_dims(
                        np.array(info["low_level_obs"]), 0)
                    low_level_action = np.expand_dims(
                        np.array(info["low_level_action"]), 0)
                    policy_o = (low_level_action, low_level_obs)
                else:
                    policy_o = observation.reshape(1, -1)
                (
                    primitive_name,
                    _,
                    _,
                ) = env.get_primitive_info_from_high_level_action(
                    high_level_action[0])
                # hack to pass typing checks
                vis = convert_img_to_save(
                    world_model.get_image_from_obs(
                        torch.from_numpy(observation.reshape(1, -1))).numpy())
                add_text(vis, primitive_name, (1, 60), 0.25, (0, 255, 0))
                add_text(vis, f"r: {reward}", (35, 7), 0.3, (0, 0, 0))

            obs[rollout, step] = vis
            if step != 0:
                new_img = reconstruct_from_state(state, world_model)
                if step == 1:
                    add_text(new_img, "Reconstruction", (1, 60), 0.25,
                             (0, 255, 0))
                reconstructions[rollout, step - 1] = new_img
                reward_pred = (world_model.reward(
                    world_model.get_features(
                        state)).detach().cpu().numpy().item())
                discount_pred = (world_model.get_dist(
                    world_model.pred_discount(world_model.get_features(state)),
                    std=None,
                    normal=False,
                ).mean.detach().cpu().numpy().item(), )[0]

                print(
                    f"Rollout {rollout} Step {step - 1} Predicted Reward {reward_pred}"
                )
                print(f"Rollout {rollout} Step {step - 1} Reward {prev_r}")
                print(
                    f"Rollout {rollout} Step {step - 1} Predicted Discount {discount_pred}"
                )
                print()
            prev_r = reward
        _, state = policy.get_action(policy_o)
        state = state["state"]
        new_img = reconstruct_from_state(state, world_model)
        reconstructions[rollout, max_path_length] = new_img
        reward_pred = (world_model.reward(
            world_model.get_features(state)).detach().cpu().numpy().item())
        discount_pred = (world_model.get_dist(
            world_model.pred_discount(world_model.get_features(state)),
            std=None,
            normal=False,
        ).mean.detach().cpu().numpy().item(), )[0]
        print(f"Rollout {rollout} Final Predicted Reward {reward_pred}")
        print(f"Rollout {rollout} Final Reward {reward}")
        print(f"Rollout {rollout} Final Predicted Discount {discount_pred}")
        print()

    im = np.zeros(
        (img_size * 2 * num_rollouts, (max_path_length + 1) * img_size, 3),
        dtype=np.uint8,
    )

    for rollout in range(num_rollouts):
        for step in range(max_path_length + 1):
            im[img_size * 2 * rollout:img_size * 2 * rollout + img_size,
               img_size * step:img_size * (step + 1), ] = obs[rollout, step]
            im[img_size * 2 * rollout + img_size:img_size * 2 * (rollout + 1),
               img_size * step:img_size *
               (step + 1), ] = reconstructions[rollout, step]
    cv2.imwrite(file_path, im)
    print(f"Saved Rollout Visualization to {file_path}")
    print()
Exemple #9
0
 def get_action(self, obs, deterministic=False):
     ''' sample action from the policy, conditioned on the task embedding '''
     z = self.z
     obs = ptu.from_numpy(obs[None])
     in_ = torch.cat([obs, z], dim=1)
     return self.policy.get_action(in_, deterministic=deterministic)
Exemple #10
0
 def get_action(self, obs):
     obs = np.expand_dims(obs, axis=0)
     obs = Variable(ptu.from_numpy(obs).float(), requires_grad=False)
     action, _, _ = self.__call__(obs, None)
     action = action.squeeze(0)
     return ptu.get_numpy(action), {}
Exemple #11
0
def post_epoch_video_func(
    algorithm,
    epoch,
    policy,
    img_size=256,
    mode="eval",
):
    if epoch == -1 or epoch % 100 == 0:
        print("Generating Video: ")
        env = algorithm.eval_env

        file_path = osp.join(logger.get_snapshot_dir(),
                             mode + "_" + str(epoch) + "_video.avi")

        img_array1 = []
        path_length = 0
        observation = env.reset()
        policy.reset(observation)
        obs = np.zeros(
            (4, algorithm.max_path_length, env.observation_space.shape[0]),
            dtype=np.uint8,
        )
        actions = np.zeros(
            (4, algorithm.max_path_length, env.action_space.shape[0]))
        while path_length < algorithm.max_path_length:
            action, agent_info = policy.get_action(observation, )
            observation, reward, done, info = env.step(
                action,
                render_every_step=True,
                render_mode="rgb_array",
                render_im_shape=(img_size, img_size),
            )
            img_array1.extend(env.envs[0].img_array)
            obs[0, path_length] = observation
            actions[0, path_length] = action
            path_length += 1

        img_array2 = []
        path_length = 0
        observation = env.reset()
        policy.reset(observation)
        while path_length < algorithm.max_path_length:
            action, agent_info = policy.get_action(observation, )
            observation, reward, done, info = env.step(
                action,
                render_every_step=True,
                render_mode="rgb_array",
                render_im_shape=(img_size, img_size),
            )
            img_array2.extend(env.envs[0].img_array)
            obs[1, path_length] = o
            actions[1, path_length] = action
            path_length += 1

        img_array3 = []
        path_length = 0
        observation = env.reset()
        policy.reset(observation)
        while path_length < algorithm.max_path_length:
            action, agent_info = policy.get_action(observationo, )
            observation, r, d, i = env.step(
                action,
                render_every_step=True,
                render_mode="rgb_array",
                render_im_shape=(img_size, img_size),
            )
            img_array3.extend(env.envs[0].img_array)
            obs[2, path_length] = observation
            actions[2, path_length] = action
            path_length += 1

        img_array4 = []
        path_length = 0
        observation = env.reset()
        policy.reset(observation)
        while path_length < algorithm.max_path_length:
            action, agent_info = policy.get_action(observation, )
            observation, reward, done, info = env.step(
                action,
                render_every_step=True,
                render_mode="rgb_array",
                render_im_shape=(img_size, img_size),
            )
            img_array4.extend(env.envs[0].img_array)
            obs[3, path_length] = observation
            actions[3, path_length] = action
            path_length += 1

        fourcc = cv2.VideoWriter_fourcc(*"DIVX")
        out = cv2.VideoWriter(file_path, fourcc, 100.0,
                              (img_size * 2, img_size * 2))
        max_len = max(len(img_array1), len(img_array2), len(img_array3),
                      len(img_array4))
        gif_clip = []
        for i in range(max_len):
            if i >= len(img_array1):
                im1 = img_array1[-1]
            else:
                im1 = img_array1[i]

            if i >= len(img_array2):
                im2 = img_array2[-1]
            else:
                im2 = img_array2[i]

            if i >= len(img_array3):
                im3 = img_array3[-1]
            else:
                im3 = img_array3[i]

            if i >= len(img_array4):
                im4 = img_array4[-1]
            else:
                im4 = img_array4[i]

            im12 = np.concatenate((im1, im2), 1)
            im34 = np.concatenate((im3, im4), 1)
            im = np.concatenate((im12, im34), 0)

            out.write(im)
            gif_clip.append(im)
        out.release()
        print("video saved to :", file_path)

        # gif_file_path = osp.join(
        #     logger.get_snapshot_dir(), mode + "_" + str(epoch) + ".gif"
        # )
        # clip = ImageSequenceClip(list(gif_clip), fps=20)
        # clip.write_gif(gif_file_path, fps=20)
        # takes way too much space
        obs, actions = ptu.from_numpy(obs), ptu.from_numpy(actions)
        (
            post,
            prior,
            post_dist,
            prior_dist,
            image_dist,
            reward_dist,
            pred_discount_dist,
            embed,
        ) = algorithm.trainer.world_model(obs.detach(), actions.detach())
        if isinstance(image_dist, tuple):
            image_dist, _ = image_dist
        image_dist_mean = image_dist.mean.detach()
        reconstructions = image_dist_mean[:, :3, :, :]
        reconstructions = (torch.clamp(
            reconstructions.permute(0, 2, 3, 1).reshape(
                4, algorithm.max_path_length, 64, 64, 3) + 0.5,
            0,
            1,
        ) * 255.0)
        reconstructions = ptu.get_numpy(reconstructions).astype(np.uint8)

        obs_np = ptu.get_numpy(obs[:, :, :64 * 64 * 3].reshape(
            4, algorithm.max_path_length, 3, 64,
            64).permute(0, 1, 3, 4, 2)).astype(np.uint8)
        file_path = osp.join(logger.get_snapshot_dir(),
                             mode + "_" + str(epoch) + "_reconstructions.png")
        im = np.zeros((128 * 4, algorithm.max_path_length * 64, 3),
                      dtype=np.uint8)
        for i in range(4):
            for j in range(algorithm.max_path_length):
                im[128 * i:128 * i + 64, 64 * j:64 * (j + 1)] = obs_np[i, j]
                im[128 * i + 64:128 * (i + 1),
                   64 * j:64 * (j + 1)] = reconstructions[i, j]
        cv2.imwrite(file_path, im)
        if image_dist_mean.shape[1] == 6:
            reconstructions = image_dist_mean[:, 3:6, :, :]
            reconstructions = (torch.clamp(
                reconstructions.permute(0, 2, 3, 1).reshape(
                    4, algorithm.max_path_length, 64, 64, 3) + 0.5,
                0,
                1,
            )) * 255.0
            reconstructions = ptu.get_numpy(reconstructions).astype(np.uint8)

            file_path = osp.join(
                logger.get_snapshot_dir(),
                mode + "_" + str(epoch) + "_reconstructions_wrist_cam.png",
            )
            obs_np = ptu.get_numpy(obs[:, :, 64 * 64 * 3:64 * 64 * 6].reshape(
                4, algorithm.max_path_length, 3, 64,
                64).permute(0, 1, 3, 4, 2)).astype(np.uint8)
            im = np.zeros((128 * 4, algorithm.max_path_length * 64, 3),
                          dtype=np.uint8)
            for i in range(4):
                for j in range(algorithm.max_path_length):
                    im[128 * i:128 * i + 64, 64 * j:64 * (j + 1)] = obs_np[i,
                                                                           j]
                    im[128 * i + 64:128 * (i + 1),
                       64 * j:64 * (j + 1)] = reconstructions[i, j]
            cv2.imwrite(file_path, im)
Exemple #12
0
    def forward(self, context=None, mask=None, r=None):
        if r is None:
            # hack for now to make things efficient
            min_context_len = min([
                d['observations'].shape[0] for task_trajs in context
                for d in task_trajs
            ])

            obs = np.array(
                [[d['observations'][:min_context_len] for d in task_trajs]
                 for task_trajs in context])
            next_obs = np.array([[
                d['next_observations'][:min_context_len] for d in task_trajs
            ] for task_trajs in context])
            if not self.state_only:
                acts = np.array(
                    [[d['actions'][:min_context_len] for d in task_trajs]
                     for task_trajs in context])
                all_timesteps = np.concatenate([obs, acts, next_obs], axis=-1)
            else:
                all_timesteps = np.concatenate([obs, next_obs], axis=-1)

            # FOR DEBUGGING THE ENCODER
            # all_timesteps = all_timesteps[:,:,-1:,:]

            # print(all_timesteps)
            # print(all_timesteps.shape)
            # print(all_timesteps.dtype)
            # print('----')
            # print(acts.shape)
            # print(obs.shape)
            # print(next_obs.shape)
            # if acts.shape[0] == 10:
            #     print(acts)
            #     print(obs)
            #     print(next_obs)
            # if acts.shape[0]

            all_timesteps = Variable(ptu.from_numpy(all_timesteps),
                                     requires_grad=False)

            # N_tasks x N_trajs x Len x Dim
            N_tasks, N_trajs, Len, Dim = all_timesteps.size(
                0), all_timesteps.size(1), all_timesteps.size(
                    2), all_timesteps.size(3)
            all_timesteps = all_timesteps.view(-1, Dim)
            embeddings = self.timestep_encoder(all_timesteps)
            embeddings = embeddings.view(N_tasks, N_trajs, Len, self.r_dim)

            if self.use_sum_for_traj_agg:
                traj_embeddings = torch.sum(embeddings, dim=2)
            else:
                traj_embeddings = torch.mean(embeddings, dim=2)

            # get r
            if mask is None:
                r = self.agg(traj_embeddings)
            else:
                r = self.agg_masked(traj_embeddings, mask)
        post_mean, post_log_sig_diag = self.r2z_map(r)
        return ReparamMultivariateNormalDiag(post_mean, post_log_sig_diag)
Exemple #13
0
    def fit(self, data, weights=None):
        if weights is None:
            weights = np.ones(len(data))
        sum_of_weights = weights.flatten().sum()
        weights = weights / sum_of_weights
        all_weights_pt = ptu.from_numpy(weights)

        indexed_train_data = IndexedData(data)
        if self.skew_sampling:
            base_sampler = WeightedRandomSampler(weights, len(weights))
        else:
            base_sampler = RandomSampler(indexed_train_data)

        train_dataloader = DataLoader(
            indexed_train_data,
            sampler=BatchSampler(
                base_sampler,
                batch_size=self.batch_size,
                drop_last=False,
            ),
        )
        if self.reset_vae_every_epoch:
            raise NotImplementedError()

        epoch_stats_list = defaultdict(list)
        for _ in range(self.num_inner_vae_epochs):
            for _, indexed_batch in enumerate(train_dataloader):
                idxs, batch = indexed_batch
                batch = batch[0].float().to(ptu.device)

                latents, means, log_vars, stds = (
                    self.encoder.get_encoding_and_suff_stats(batch))
                beta = 1
                kl = self.kl_to_prior(means, log_vars, stds)
                reconstruction_log_prob = self.compute_log_prob(
                    batch, self.decoder, latents)

                elbo = -kl * beta + reconstruction_log_prob
                if self.weight_loss:
                    idxs = torch.cat(idxs)
                    batch_weights = all_weights_pt[idxs].unsqueeze(1)
                    loss = -(batch_weights * elbo).sum()
                else:
                    loss = -elbo.mean()
                self.encoder_opt.zero_grad()
                self.decoder_opt.zero_grad()
                loss.backward()
                self.encoder_opt.step()
                self.decoder_opt.step()

                epoch_stats_list['losses'].append(ptu.get_numpy(loss))
                epoch_stats_list['kls'].append(ptu.get_numpy(kl.mean()))
                epoch_stats_list['log_probs'].append(
                    ptu.get_numpy(reconstruction_log_prob.mean()))
                epoch_stats_list['latent-mean'].append(
                    ptu.get_numpy(latents.mean()))
                epoch_stats_list['latent-std'].append(
                    ptu.get_numpy(latents.std()))
                for k, v in create_stats_ordered_dict(
                        'weights', ptu.get_numpy(all_weights_pt)).items():
                    epoch_stats_list[k].append(v)

        self._epoch_stats = {
            'unnormalized weight sum': sum_of_weights,
        }
        for k in epoch_stats_list:
            self._epoch_stats[k] = np.mean(epoch_stats_list[k])
Exemple #14
0
 def reconstruct(self, data):
     latents = self.encoder.encode(ptu.from_numpy(data))
     return ptu.get_numpy(self.decoder.reconstruct(latents))
 def _decode(self, latents):
     reconstructions, _ = self.vae.decode(ptu.from_numpy(latents))
     decoded = ptu.get_numpy(reconstructions)
     return decoded
    def __init__(
            self,
            env,
            policy,
            discriminator,
            policy_optimizer,
            expert_replay_buffer,
            disc_optim_batch_size=1024,
            policy_optim_batch_size=1024,
            num_update_loops_per_train_call=1000,
            num_disc_updates_per_loop_iter=1,
            num_policy_updates_per_loop_iter=1,

            # initial_only_disc_train_epochs=0,
            pretrain_disc=False,
            num_disc_pretrain_iters=1000,
            disc_lr=1e-3,
            disc_momentum=0.0,
            disc_optimizer_class=optim.Adam,
            use_grad_pen=True,
            grad_pen_weight=10,
            plotter=None,
            render_eval_paths=False,
            eval_deterministic=False,
            train_objective='airl',
            num_disc_input_dims=2,
            plot_reward_surface=True,
            use_survival_reward=False,
            use_ctrl_cost=False,
            ctrl_cost_weight=0.0,
            **kwargs):
        assert disc_lr != 1e-3, 'Just checking that this is being taken from the spec file'
        if eval_deterministic:
            eval_policy = MakeDeterministic(policy)
        else:
            eval_policy = policy

        super().__init__(env=env,
                         exploration_policy=policy,
                         eval_policy=eval_policy,
                         expert_replay_buffer=expert_replay_buffer,
                         **kwargs)

        self.policy_optimizer = policy_optimizer

        self.discriminator = discriminator
        self.rewardf_eval_statistics = None
        self.disc_optimizer = disc_optimizer_class(
            self.discriminator.parameters(),
            lr=disc_lr,
            betas=(disc_momentum, 0.999))
        print('\n\nDISC MOMENTUM: %f\n\n' % disc_momentum)

        self.disc_optim_batch_size = disc_optim_batch_size
        self.policy_optim_batch_size = policy_optim_batch_size

        assert train_objective in ['airl', 'fairl', 'gail', 'w1']
        self.train_objective = train_objective

        self.bce = nn.BCEWithLogitsLoss()
        target_batch_size = self.disc_optim_batch_size
        self.bce_targets = torch.cat([
            torch.ones(target_batch_size, 1),
            torch.zeros(target_batch_size, 1)
        ],
                                     dim=0)
        self.bce_targets = Variable(self.bce_targets)
        if ptu.gpu_enabled():
            self.bce.cuda()
            self.bce_targets = self.bce_targets.cuda()

        self.use_grad_pen = use_grad_pen
        self.grad_pen_weight = grad_pen_weight

        self.num_update_loops_per_train_call = num_update_loops_per_train_call
        self.num_disc_updates_per_loop_iter = num_disc_updates_per_loop_iter
        self.num_policy_updates_per_loop_iter = num_policy_updates_per_loop_iter

        # self.initial_only_disc_train_epochs = initial_only_disc_train_epochs
        self.pretrain_disc = pretrain_disc
        self.did_disc_pretraining = False
        self.num_disc_pretrain_iters = num_disc_pretrain_iters
        self.cur_epoch = -1

        self.plot_reward_surface = plot_reward_surface
        if plot_reward_surface:
            d = 6
            self._d = d
            self._d_len = np.arange(-d, d + 0.25, 0.25).shape[0]
            self.xy_var = []
            for i in np.arange(d, -d - 0.25, -0.25):
                for j in np.arange(-d, d + 0.25, 0.25):
                    self.xy_var.append([float(j), float(i)])
            self.xy_var = np.array(self.xy_var)
            self.xy_var = Variable(ptu.from_numpy(self.xy_var),
                                   requires_grad=False)

            # d = 20
            # self._d = d
            # # self._d_len = np.arange(0.98697072 - 0.3, 0.98697072 + 0.175, 0.02).shape[0]
            # self._d_len_rows = np.arange(0.74914774 - 0.35, 0.74914774 + 0.45, 0.01).shape[0]
            # self._d_len_cols = np.arange(0.98697072 - 0.3, 0.98697072 + 0.175, 0.01).shape[0]
            # self.xy_var = []
            # for i in np.arange(0.74914774 + 0.45, 0.74914774 - 0.35, -0.01):
            #     for j in np.arange(0.98697072 - 0.3, 0.98697072 + 0.175, 0.01):
            #         self.xy_var.append([float(j), float(i)])
            # self.xy_var = np.array(self.xy_var)
            # self.xy_var = Variable(ptu.from_numpy(self.xy_var), requires_grad=False)

        self.num_disc_input_dims = num_disc_input_dims

        self.use_survival_reward = use_survival_reward
        self.use_ctrl_cost = use_ctrl_cost
        self.ctrl_cost_weight = ctrl_cost_weight
 def _encode(self, imgs):
     latent_distribution_params = self.vae.encode(ptu.from_numpy(imgs))
     return ptu.get_numpy(latent_distribution_params[0])
    def _do_training(self):
        '''
        '''
        # train the discriminator (and the encoder)
        # print('$$$$$$$$$')
        # print(self.num_disc_updates_per_epoch)
        for i in range(self.num_disc_updates_per_epoch):
            self.encoder_optimizer.zero_grad()
            self.disc_optimizer.zero_grad()

            context_batch, context_pred_batch, test_pred_batch, policy_test_pred_batch, traj_len = self._get_disc_training_batch(
            )

            # convert it to a pytorch tensor
            # note that our objective says we should maximize likelihood of
            # BOTH the context_batch and the test_batch
            exp_obs_batch = np.concatenate((context_pred_batch['observations'],
                                            test_pred_batch['observations']),
                                           axis=0)
            exp_obs_batch = Variable(ptu.from_numpy(exp_obs_batch),
                                     requires_grad=False)
            exp_acts_batch = np.concatenate(
                (context_pred_batch['actions'], test_pred_batch['actions']),
                axis=0)
            exp_acts_batch = Variable(ptu.from_numpy(exp_acts_batch),
                                      requires_grad=False)

            policy_obs_batch = Variable(ptu.from_numpy(
                policy_test_pred_batch['observations']),
                                        requires_grad=False)
            policy_acts_batch = Variable(ptu.from_numpy(
                policy_test_pred_batch['actions']),
                                         requires_grad=False)

            post_dist = self.encoder(context_batch)
            # z = post_dist.sample() # N_tasks x Dim
            z = post_dist.mean

            # z_reg_loss = 0.0001 * z.norm(2, dim=1).mean()
            z_reg_loss = 0.0

            # make z's for expert samples
            context_pred_z = z.repeat(
                1, traj_len * self.num_context_trajs_for_training).view(
                    -1, z.size(1))
            test_pred_z = z.repeat(1, traj_len *
                                   self.num_test_trajs_for_training).view(
                                       -1, z.size(1))
            z_batch = torch.cat([context_pred_z, test_pred_z], dim=0)
            positive_obs_batch = torch.cat([exp_obs_batch, z_batch], dim=1)
            positive_acts_batch = exp_acts_batch

            # make z's for policy samples
            z_policy = z_batch
            negative_obs_batch = torch.cat([policy_obs_batch, z_policy], dim=1)
            negative_acts_batch = policy_acts_batch

            # compute the loss for the discriminator
            obs_batch = torch.cat([positive_obs_batch, negative_obs_batch],
                                  dim=0)
            acts_batch = torch.cat([positive_acts_batch, negative_acts_batch],
                                   dim=0)
            disc_logits = self.disc(obs_batch, acts_batch)
            disc_preds = (disc_logits > 0).type(torch.FloatTensor)
            # disc_percent_policy_preds_one = disc_preds[z.size(0):].mean()
            disc_loss = self.bce(disc_logits, self.bce_targets)
            accuracy = (disc_preds == self.bce_targets).type(
                torch.FloatTensor).mean()

            if self.use_grad_pen:
                eps = Variable(torch.rand(positive_obs_batch.size(0), 1),
                               requires_grad=True)
                if ptu.gpu_enabled(): eps = eps.cuda()

                # old and probably has a bad weird effect on the encoder
                # difference is that before I was also taking into account norm of grad of disc
                # wrt the z
                # interp_obs = eps*positive_obs_batch + (1-eps)*negative_obs_batch

                # permute the exp_obs_batch (not just within a single traj, but overall)
                # This is actually a really tricky question how to permute the batches
                # 1) permute within each of trajectories
                # z's will be matched, colors won't be matched anyways
                # 2) permute within trajectories corresponding to a single context set
                # z's will be matched, colors will be "more unmatched"
                # 3) just shuffle everything up
                # Also, the z's need to be handled appropriately

                interp_obs = eps * exp_obs_batch + (1 - eps) * policy_obs_batch
                # interp_z = z_batch.detach()
                # interp_obs = torch.cat([interp_obs, interp_z], dim=1)
                interp_obs.detach()
                # interp_obs.requires_grad = True
                interp_actions = eps * positive_acts_batch + (
                    1 - eps) * negative_acts_batch
                interp_actions.detach()
                # interp_actions.requires_grad = True
                gradients = autograd.grad(
                    outputs=self.disc(
                        torch.cat([interp_obs, z_batch.detach()], dim=1),
                        interp_actions).sum(),
                    inputs=[interp_obs, interp_actions],
                    # grad_outputs=torch.ones(exp_specs['batch_size'], 1).cuda(),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)
                # print(gradients[0].size())
                # z_norm = gradients[0][:,-50:].norm(2, dim=1)
                # print('Z grad norm: %.4f +/- %.4f' % (torch.mean(z_norm), torch.std(z_norm)))
                # print(gradients[0][:,-50:].size())
                # o_norm = gradients[0][:,:-50].norm(2, dim=1)

                # o_norm = gradients[0].norm(2, dim=1)
                # print('Obs grad norm: %.4f +/- %.4f' % (torch.mean(o_norm), torch.std(o_norm)))
                # print(gradients[0].size())

                # print(gradients[0][:,:50].norm(2, dim=1))

                total_grad = torch.cat([gradients[0], gradients[1]], dim=1)
                # print(total_grad.size())
                gradient_penalty = ((total_grad.norm(2, dim=1) - 1)**2).mean()

                # another form of grad pen
                # gradient_penalty = (total_grad.norm(2, dim=1) ** 2).mean()

                disc_loss = disc_loss + gradient_penalty * self.grad_pen_weight

            total_reward_loss = z_reg_loss + disc_loss

            total_reward_loss.backward()
            self.disc_optimizer.step()
            self.encoder_optimizer.step()

            # print(self.disc.fc0.bias[0])
            # print(self.encoder.traj_encoder.traj_enc_mlp.fc0.bias[0])

        # train the policy
        # print('--------')
        # print(self.num_policy_updates_per_epoch)
        for i in range(self.num_policy_updates_per_epoch):
            context_batch, policy_batch = self._get_policy_training_batch()
            policy_batch = np_to_pytorch_batch(policy_batch)

            post_dist = self.encoder(context_batch)
            # z = post_dist.sample() # N_tasks x Dim
            z = post_dist.mean
            z = z.detach()

            # repeat z to have the right size
            z = z.repeat(1, self.policy_batch_size_per_task).view(
                self.num_tasks_used_per_update *
                self.policy_batch_size_per_task, -1).detach()

            # now augment the obs with the latent sample z
            policy_batch['observations'] = torch.cat(
                [policy_batch['observations'], z], dim=1)
            policy_batch['next_observations'] = torch.cat(
                [policy_batch['next_observations'], z], dim=1)

            # compute the rewards
            # If you compute log(D) - log(1-D) then you just get the logits
            policy_rewards = self.disc(policy_batch['observations'],
                                       policy_batch['actions']).detach()
            policy_batch['rewards'] = policy_rewards
            # rew_more_than_zero = (rewards > 0).type(torch.FloatTensor).mean()
            # print(rew_more_than_zero.data[0])

            # do a policy update (the zeroing of grads etc. should be handled internally)
            # print(policy_rewards.size())
            self.policy_optimizer.train_step(policy_batch)
            # print(self.main_policy.fc0.bias[0])

        if self.eval_statistics is None:
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics = OrderedDict()
            self.eval_statistics['Disc Loss'] = np.mean(
                ptu.get_numpy(disc_loss))
            self.eval_statistics['Disc Acc'] = np.mean(ptu.get_numpy(accuracy))
            # self.eval_statistics['Disc Percent Policy Preds 1'] = np.mean(ptu.get_numpy(disc_percent_policy_preds_one))
            self.eval_statistics['Disc Rewards Mean'] = np.mean(
                ptu.get_numpy(policy_rewards))
            self.eval_statistics['Disc Rewards Std'] = np.std(
                ptu.get_numpy(policy_rewards))
            self.eval_statistics['Disc Rewards Max'] = np.max(
                ptu.get_numpy(policy_rewards))
            self.eval_statistics['Disc Rewards Min'] = np.min(
                ptu.get_numpy(policy_rewards))
            # self.eval_statistics['Disc Rewards GT Zero'] = np.mean(ptu.get_numpy(rew_more_than_zero))

            z_norm = z.norm(2, dim=1).mean()
            self.eval_statistics['Z Norm'] = np.mean(ptu.get_numpy(z_norm))

            if self.policy_optimizer.eval_statistics is not None:
                self.eval_statistics.update(
                    self.policy_optimizer.eval_statistics)
Exemple #19
0
    def __init__(
            self,
            input_width,
            input_height,
            input_channels,
            output_size,
            kernel_sizes,
            n_channels,
            strides,
            paddings,
            hidden_sizes=None,
            added_fc_input_size=0,
            batch_norm_conv=False,
            batch_norm_fc=False,
            init_w=1e-4,
            hidden_init=nn.init.xavier_uniform_,
            hidden_activation=nn.ReLU(),
            output_activation=identity,
    ):
        if hidden_sizes is None:
            hidden_sizes = []
        assert len(kernel_sizes) == \
               len(n_channels) == \
               len(strides) == \
               len(paddings)
        super().__init__()

        self.hidden_sizes = hidden_sizes
        self.input_width = input_width
        self.input_height = input_height
        self.input_channels = input_channels
        self.output_size = output_size
        self.output_activation = output_activation
        self.hidden_activation = hidden_activation
        self.batch_norm_conv = batch_norm_conv
        self.batch_norm_fc = batch_norm_fc
        self.added_fc_input_size = added_fc_input_size
        self.conv_input_length = self.input_width * self.input_height * self.input_channels

        self.conv_layers = nn.ModuleList()
        self.conv_norm_layers = nn.ModuleList()
        self.fc_layers = nn.ModuleList()
        self.fc_norm_layers = nn.ModuleList()

        for out_channels, kernel_size, stride, padding in \
                zip(n_channels, kernel_sizes, strides, paddings):
            conv = nn.Conv2d(input_channels,
                             out_channels,
                             kernel_size,
                             stride=stride,
                             padding=padding)
            hidden_init(conv.weight)
            conv.bias.data.fill_(0)

            conv_layer = conv
            self.conv_layers.append(conv_layer)
            input_channels = out_channels

        xcoords = np.expand_dims(np.linspace(-1, 1, self.input_width), 0).repeat(self.input_height, 0)
        ycoords = np.repeat(np.linspace(-1, 1, self.input_height), self.input_width).reshape((self.input_height, self.input_width))

        self.coords = from_numpy(np.expand_dims(np.stack([xcoords, ycoords], 0), 0))
Exemple #20
0
def test_vae_traj(vae, env_id, save_path=None, save_name=None):
    pjhome = os.environ['PJHOME']

    # optimal traj
    data_path = osp.join(pjhome,
                         'data/local/env/{}-optimal-traj.npy'.format(env_id))
    if not osp.exists(data_path):
        return
    imgs = np.load(data_path)
    traj_len, batch_size, imlen = imgs.shape
    imgs = imgs.reshape((-1, imlen))
    latents, _ = vae.encode(ptu.from_numpy(imgs))
    latents = ptu.get_numpy(latents)
    latent_distances = np.linalg.norm(latents - latents[-1], axis=1)
    puck_distances = np.load(
        osp.join(
            pjhome,
            'data/local/env/{}-optimal-traj-puck-distance.npy'.format(env_id)))

    fig, axs = plt.subplots(1, 2, figsize=(14, 5))
    axs = axs.reshape(-1)

    ax = axs[0]
    ax2 = ax.twinx()
    ax.plot(puck_distances, label='puck distance', color='r')
    ax2.plot(latent_distances, label='vae distance', color='b')
    ax.legend(loc='upper right')
    ax2.legend(loc='center right')

    # sub-optimal
    imgs = np.load(
        osp.join(pjhome,
                 'data/local/env/{}-local-optimal-traj.npy'.format(env_id)))
    goal_image = np.load(
        osp.join(
            pjhome,
            'data/local/env/{}-local-optimal-traj-goal.npy'.format(env_id)))
    traj_len, batch_size, imlen = imgs.shape
    puck_distances = np.load(
        osp.join(
            pjhome,
            'data/local/env/{}-local-optimal-traj-puck-distance.npy'.format(
                env_id)))

    imgs = imgs.reshape((-1, imlen))
    goal_image = goal_image.reshape((-1, imlen))
    latents = vae.encode(ptu.from_numpy(imgs))[0]
    latent_goal = vae.encode(ptu.from_numpy(goal_image))[0]

    latents = ptu.get_numpy(latents).reshape(traj_len, -1)
    latent_goal = ptu.get_numpy(latent_goal).flatten()
    latent_distances = np.linalg.norm(latents - latent_goal, axis=1)

    ax = axs[1]
    ax2 = ax.twinx()
    ax.plot(puck_distances, label='puck distance', color='r')
    ax2.plot(latent_distances, label='vae distance', color='b')
    ax.legend(loc='upper right')
    ax2.legend(loc='center right')

    plt.savefig(osp.join(save_path, save_name))
    plt.close('all')
Exemple #21
0
 def _reconstruction_squared_error_np_to_np(self, np_imgs):
     torch_input = ptu.from_numpy(normalize_image(np_imgs))
     recons, *_ = self.model(torch_input)
     error = torch_input - recons
     return ptu.get_numpy((error**2).sum(dim=1))
 def _decode(self, latents, vae=None):
     if vae is None:
         vae = self.vae_original
     reconstructions, _ = vae.decode(ptu.from_numpy(latents))
     decoded = ptu.get_numpy(reconstructions)
     return decoded
Exemple #23
0
 def set_param_values_np(self, param_values):
     torch_dict = OrderedDict()
     for key, tensor in param_values.items():
         torch_dict[key] = ptu.from_numpy(tensor)
     self.load_state_dict(torch_dict)
 def _encode(self, imgs, vae=None):
     if vae is None:
         vae = self.vae_original
     latent_distribution_params = vae.encode(ptu.from_numpy(imgs))
     return ptu.get_numpy(latent_distribution_params[0])
Exemple #25
0
 def v_function(obs):
     action = policy.get_actions(obs)
     obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
     return qf1(obs, action, return_individual_q_vals=True)
Exemple #26
0
    def __init__(
        self,
        env,
        policy,
        qf1,
        qf2,
        target_qf1,
        target_qf2,
        bonus_network,
        beta,
        use_bonus_critic,
        use_bonus_policy,
        use_log,
        bonus_norm_param,
        rewards_shift_param,
        device,
        discount=0.99,
        reward_scale=1.0,
        policy_lr=1e-3,
        qf_lr=1e-3,
        alpha_lr=3e-5,
        optimizer_class=optim.Adam,
        soft_target_tau=1e-2,
        target_update_period=1,
        plotter=None,
        render_eval_paths=False,
        use_automatic_entropy_tuning=True,
        target_entropy=None,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2

        self.device = device

        self.bonus_network = bonus_network
        self.beta = beta

        # type of adding bonus to critic or policy
        self.use_bonus_critic = use_bonus_critic
        self.use_bonus_policy = use_bonus_policy

        # use log in the bonus
        # if use_log : log(bonus)
        # else 1 - bonus
        self.use_log = use_log

        # normlization
        self.obs_mu, self.obs_std = bonus_norm_param
        self.normalize = self.obs_mu is not None

        if self.normalize:
            print('.......Using normailization in bonus........')
            self.obs_mu = ptu.from_numpy(self.obs_mu).to(device)
            self.obs_std = ptu.from_numpy(self.obs_std).to(device)
            # self.actions_mu = ptu.from_numpy(self.actions_mu).to(device)
            # self.actions_std = ptu.from_numpy(self.actions_std).to(device)

        self.rewards_shift_param = rewards_shift_param

        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(
                    self.env.action_space.shape).item(
                    )  # heuristic value from Tuomas
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=alpha_lr,
            )

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )

        self.discrete = False
        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self.clip_val = 1e3
Exemple #27
0
def _elem_or_tuple_to_variable(elem_or_tuple):
    if isinstance(elem_or_tuple, tuple):
        return tuple(_elem_or_tuple_to_variable(e) for e in elem_or_tuple)
    return Variable(ptu.from_numpy(elem_or_tuple).float(), requires_grad=False)
    def train(self):
        for t in range(self.num_updates):
            self.mus = Variable(ptu.from_numpy(self.np_mus), requires_grad=True)
            self.log_sigmas = Variable(ptu.from_numpy(self.np_log_sigmas), requires_grad=True)

            # generate rollouts for each task
            normal = ReparamMultivariateNormalDiag(self.mus, self.log_sigmas)
            all_samples = []
            all_log_probs = []
            all_rewards = []
            for sample_num in range(self.num_z_samples_per_update):
                sample = normal.sample()
                all_samples.append(sample)
                all_log_probs.append(normal.log_prob(sample))

                # evaluate each of the z's for each task
                rewards = []
                for i in range(len(self.obs_task_params)):
                    successes = self.gen_rollout(
                        self.obs_task_params[i],
                        self.task_params[i],
                        sample[i],
                        self.num_trajs_per_z_sample
                    )
                    rewards.append(np.mean(successes))
                all_rewards.append(rewards)
            
            all_log_probs = torch.cat(all_log_probs, dim=-1) # num_tasks x num_samples

            np_all_rewards = np.array(all_rewards).T # num_tasks x num_samples
            all_rewards = Variable(ptu.from_numpy(np_all_rewards))
            all_rewards = all_rewards - torch.mean(all_rewards, dim=-1, keepdim=True)
            all_rewards = self.reward_scale * all_rewards

            # compute gradients wrt mus and sigmas
            pg_loss = -1.0 * torch.sum(
                torch.mean(all_log_probs * all_rewards, dim=-1)
            )
            grads = autograd.grad(
                outputs=pg_loss,
                inputs=[self.mus, self.log_sigmas],
                only_inputs=True
            )

            # update the mus and sigmas
            if self.use_nat_grad:
                print('Nat Grad')
                mu_grad = grads[0] * (torch.exp(2 * self.log_sigmas)).detach()
                log_sig_grad = 0.5 * grads[1]
            else:
                print('Normal Grad')
                mu_grad = grads[0]
                log_sig_grad = grads[1]
            self.mus = self.mus - self.mu_lr * mu_grad
            self.log_sigmas = self.log_sigmas - self.log_sig_lr * log_sig_grad
            self.np_mus = ptu.get_numpy(self.mus)
            self.np_log_sigmas = ptu.get_numpy(self.log_sigmas)

            # logging
            np_all_rewards = np.mean(np_all_rewards, axis=-1)
            print('\n-----------------------------------------------')
            # print('Avg Reward: {}'.format(np.mean(np_all_rewards)))
            # print('Std Reward: {}'.format(np.std(np_all_rewards)))
            # print('Max Reward: {}'.format(np.max(np_all_rewards)))
            # print('Min Reward: {}'.format(np.min(np_all_rewards)))
            print(np_all_rewards)
            print(self.np_mus)
            print(np.exp(2*self.np_log_sigmas))
Exemple #29
0
 def decode_np(self, latents):
     reconstructions = self.decode(ptu.from_numpy(latents))
     decoded = ptu.get_numpy(reconstructions)
     return decoded
Exemple #30
0
    def __init__(
        self,
        discriminator,

        exp_data,
        pol_data,

        disc_optim_batch_size=1024,
        num_update_loops_per_train_call=1,
        num_disc_updates_per_loop_iter=1,

        disc_lr=1e-3,
        disc_momentum=0.0,
        disc_optimizer_class=optim.Adam,

        use_grad_pen=True,
        grad_pen_weight=10,

        train_objective='airl',
    ):
        assert disc_lr != 1e-3, 'Just checking that this is being taken from the spec file'
        
        self.exp_data, self.pol_data = exp_data, pol_data

        self.discriminator = discriminator
        self.rewardf_eval_statistics = None
        self.disc_optimizer = disc_optimizer_class(
            self.discriminator.parameters(),
            lr=disc_lr,
            betas=(disc_momentum, 0.999)
        )
        print('\n\nDISC MOMENTUM: %f\n\n' % disc_momentum)

        self.disc_optim_batch_size = disc_optim_batch_size

        assert train_objective in ['airl', 'fairl', 'gail', 'w1']
        self.train_objective = train_objective

        self.bce = nn.BCEWithLogitsLoss()
        target_batch_size = self.disc_optim_batch_size
        self.bce_targets = torch.cat(
            [
                torch.zeros(target_batch_size),
                torch.ones(target_batch_size),
                2*torch.ones(target_batch_size),
            ],
            dim=0
        ).type(torch.LongTensor)
        self.bce_targets = Variable(self.bce_targets)
        if ptu.gpu_enabled():
            self.bce.cuda()
            self.bce_targets = self.bce_targets.cuda()
        
        self.use_grad_pen = use_grad_pen
        self.grad_pen_weight = grad_pen_weight

        self.num_update_loops_per_train_call = num_update_loops_per_train_call
        self.num_disc_updates_per_loop_iter = num_disc_updates_per_loop_iter

        d = 5.0
        self._d = d
        self._d_len = np.arange(-d,d+0.25,0.25).shape[0]
        self.xy_var = []
        for i in np.arange(d,-d-0.25,-0.25):
            for j in np.arange(-d,d+0.25,0.25):
                self.xy_var.append([float(j),float(i)])
        self.xy_var = np.array(self.xy_var)
        self.xy_var = Variable(ptu.from_numpy(self.xy_var), requires_grad=False)