def forward(self, obs, action):
     """
     :param obs: torch Variable, [batch_size, sequence length, obs dim]
     :param action: torch Variable, [batch_size, sequence length, action dim]
     :return: torch Variable, [batch_size, sequence length, 1]
     """
     assert len(obs.size()) == 3
     inputs = torch.cat((obs, action), dim=2)
     batch_size, subsequence_length = obs.size()[:2]
     cx = Variable(
         ptu.FloatTensor(1, batch_size, self.hidden_size)
     )
     cx.data.fill_(0)
     hx = Variable(
         ptu.FloatTensor(1, batch_size, self.hidden_size)
     )
     hx.data.fill_(0)
     rnn_outputs, _ = self.lstm(inputs, (hx, cx))
     rnn_outputs.contiguous()
     rnn_outputs_flat = rnn_outputs.view(-1, self.hidden_size)
     obs_flat = obs.view(-1, self.obs_dim)
     action_flat = action.view(-1, self.action_dim)
     h = torch.cat((rnn_outputs_flat, obs_flat), dim=1)
     h = F.relu(self.fc1(h))
     h = torch.cat((h, action_flat), dim=1)
     h = F.relu(self.fc2(h))
     outputs_flat = self.last_fc(h)
     return outputs_flat.view(batch_size, subsequence_length, 1)
 def forward(self, obs, memory, action, write):
     """
     :param obs: torch Variable, [batch_size, sequence length, obs dim]
     :param memory: torch Variable, [batch_size, sequence length, memory dim]
     :param action: torch Variable, [batch_size, sequence length, action dim]
     :param write: torch Variable, [batch_size, sequence length, memory dim]
     :return: torch Variable, [batch_size, sequence length, 1]
     """
     rnn_inputs = torch.cat((obs, memory, action, write), dim=2)
     batch_size, subsequence_length, _ = obs.size()
     cx = Variable(
         ptu.FloatTensor(1, batch_size, self.hidden_size)
     )
     cx.data.fill_(0)
     hx = Variable(
         ptu.FloatTensor(1, batch_size, self.hidden_size)
     )
     hx.data.fill_(0)
     state = (hx, cx)
     rnn_outputs, _ = self.rnn(rnn_inputs, state)
     rnn_outputs.contiguous()
     rnn_outputs_flat = rnn_outputs.view(
         batch_size * subsequence_length,
         self.fc1.in_features,
     )
     outputs_flat = self.output_activation(self.last_fc(rnn_outputs_flat))
     return outputs_flat.view(batch_size, subsequence_length, 1)
Exemple #3
0
    def forward(
        self,
        obs,
        deterministic=False,
        return_log_prob=False,
        return_entropy=False,
        return_log_prob_of_mean=False,
    ):
        obs, taus = split_tau(obs)
        h = obs
        batch_size = h.size()[0]
        y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1)
        y_binary.zero_()
        t = taus.data.long()
        t = torch.clamp(t, min=0)
        y_binary.scatter_(1, t, 1)

        h = torch.cat((
            obs,
            ptu.Variable(y_binary),
        ), dim=1)

        return super().forward(
            obs=h,
            deterministic=deterministic,
            return_log_prob=return_log_prob,
            return_entropy=return_entropy,
            return_log_prob_of_mean=return_log_prob_of_mean,
        )
Exemple #4
0
    def forward(self, flat_obs, actions=None):
        obs, taus = split_tau(flat_obs)
        if actions is not None:
            h = torch.cat((obs, actions), dim=1)
        else:
            h = obs
        batch_size = h.size()[0]
        y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1)
        y_binary.zero_()
        t = taus.data.long()
        t = torch.clamp(t, min=0)
        y_binary.scatter_(1, t, 1)
        if actions is not None:
            h = torch.cat((
                obs,
                ptu.Variable(y_binary),
                actions
            ), dim=1)
        else:
            h = torch.cat((
                obs,
                ptu.Variable(y_binary),
            ), dim=1)

        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        return - torch.abs(self.last_fc(h))
Exemple #5
0
def dump_latent_plots(vae_env, epoch):
    from railrl.core import logger
    import os.path as osp
    from torchvision.utils import save_image

    if getattr(vae_env, "get_states_sweep", None) is None:
        return

    nx, ny = (vae_env.vis_granularity, vae_env.vis_granularity)
    states_sweep = vae_env.get_states_sweep(nx, ny)
    sweep_latents_mu, sweep_latents_logvar = vae_env.encode_states(states_sweep, clip_std=False)

    sweep_latents_std = np.exp(0.5*sweep_latents_logvar)
    sweep_latents_sample = vae_env.reparameterize(sweep_latents_mu, sweep_latents_logvar, noisy=True)
    images_mu_sc, images_std_sc, images_sample_sc = [], [], []
    imsize = 84
    for i in range(sweep_latents_mu.shape[1]):
        image_mu_sc = vae_env.transform_image(vae_env.get_image_plt(
            sweep_latents_mu[:,i].reshape((nx, ny)),
            vmin=-2.0, vmax=2.0,
            draw_state=False, imsize=imsize))
        images_mu_sc.append(image_mu_sc)

        image_std_sc = vae_env.transform_image(vae_env.get_image_plt(
            sweep_latents_std[:,i].reshape((nx, ny)),
            vmin=0.0, vmax=2.0,
            draw_state=False, imsize=imsize))
        images_std_sc.append(image_std_sc)

        image_sample_sc = vae_env.transform_image(vae_env.get_image_plt(
            sweep_latents_sample[:,i].reshape((nx, ny)),
            vmin=-3.0, vmax=3.0,
            draw_state=False, imsize=imsize))
        images_sample_sc.append(image_sample_sc)

    images = images_mu_sc + images_std_sc + images_sample_sc
    images = np.array(images)

    if vae_env.representation_size > 16:
        nrow = 16
    else:
        nrow = vae_env.representation_size

    if epoch is not None:
        save_dir = osp.join(logger.get_snapshot_dir(), 'z_%d.png' % epoch)
    else:
        save_dir = osp.join(logger.get_snapshot_dir(), 'z.png')
    save_image(
        ptu.FloatTensor(
            ptu.from_numpy(
                images.reshape(
                    (vae_env.representation_size*3, -1, imsize, imsize)
                ))),
        save_dir,
        nrow=nrow,
    )
Exemple #6
0
    def forward(
            self,
            flat_obs,
            return_preactivations=False
    ):
        obs, taus = split_tau(flat_obs)
        h = obs
        batch_size = h.size()[0]
        y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1)
        y_binary.zero_()
        t = taus.data.long()
        t = torch.clamp(t, min=0)
        y_binary.scatter_(1, t, 1)

        h = torch.cat((
            obs,
            ptu.Variable(y_binary),
        ), dim=1)

        return super().forward(
            h,
            return_preactivations=return_preactivations,
        )
Exemple #7
0
def dump_latent_histogram(vae_env, epoch, noisy=False, reproj=False, use_true_prior=None, draw_dots=False):
    from railrl.core import logger
    import os.path as osp
    from torchvision.utils import save_image

    images = vae_env.get_image_latent_histogram(
        noisy=noisy, reproj=reproj, draw_dots=draw_dots, use_true_prior=use_true_prior
    )
    if noisy:
        prefix = 'h'
    elif reproj:
        prefix = 'h_r'
    else:
        prefix = 'h_mu'

    if epoch is None:
        save_dir = osp.join(logger.get_snapshot_dir(), prefix + '.png')
    else:
        save_dir = osp.join(logger.get_snapshot_dir(), prefix + '_%d.png' % epoch)
    save_image(
        ptu.FloatTensor(ptu.from_numpy(images)),
        save_dir,
        nrow=int(np.sqrt(images.shape[0])),
    )
Exemple #8
0
 def eval_model_np(state, action):
     state = ptu.Variable(ptu.FloatTensor([[state]]), requires_grad=False)
     action = ptu.Variable(ptu.FloatTensor([[action]]), requires_grad=False)
     a, v = model(state, action)
     q = a + v
     return ptu.get_numpy(q)[0]
Exemple #9
0
def dump_video(
    env,
    policy,
    filename,
    rollout_function,
    qf=None,
    vf=None,
    rows=3,
    columns=6,
    pad_length=0,
    pad_color=255,
    do_timer=True,
    imsize=84,
    epoch=None,
    vis_list=None,
    vis_blacklist=None,
):
    if vis_list is None:
        vis_list = [
            'image_desired_goal',
            'image_observation',
            'reconstr_image_observation',
            'reconstr_image_reproj_observation',
            'image_desired_subgoal',
            'image_desired_subgoal_reproj',
            'image_plt',
            'image_latent_histogram_2d',
            'image_latent_histogram_mu_2d',
            'image_v_latent',
            'image_v',
            'image_v_noisy_state_and_goal',
            'image_v_noisy_state',
            'image_v_noisy_goal',
            'image_rew',
            'image_rew_euclidean',
            'image_rew_mahalanobis',
            'image_rew_logp',
            'image_rew_kl',
            'image_rew_kl_rev',
        ]

    if vis_blacklist is not None:
        vis_list = [x for x in vis_list if x not in vis_blacklist]

    num_channels = 1 if env.grayscale else 3
    frames = []
    N = rows * columns

    subgoal_images = []

    for i in range(N):
        start = time.time()
        path = rollout_function(
            env,
            policy,
            qf=qf,
            vf=vf,
            animated=False,
            epoch=epoch,
            rollout_num=i,
        )

        if 'image_desired_subgoals_reproj' in path['full_observations'][1]:
            image_ob = path['full_observations'][1][
                'image_observation'].reshape((-1, 3, imsize, imsize))
            if 'image_desired_goal_annotated' in path['full_observations'][1]:
                image_goal = path['full_observations'][1][
                    'image_desired_goal_annotated'].reshape(
                        (-1, 3, imsize, imsize))
            else:
                image_goal = path['full_observations'][1][
                    'image_desired_goal'].reshape((-1, 3, imsize, imsize))
            image_sg = path['full_observations'][1][
                'image_desired_subgoals_reproj']
            image_sg = image_sg.reshape((-1, 3, imsize, imsize))
            image = np.concatenate((image_ob, image_sg, image_goal))
            subgoal_images.append(image)

        mini_frames = []
        for d in path['full_observations'][1:]:
            get_image_kwargs = dict(
                pad_length=pad_length,
                pad_color=pad_color,
                imsize=imsize,
            )
            get_image_sweeps = [d.get(key, None) for key in vis_list]
            img = get_image(
                *get_image_sweeps,
                **get_image_kwargs,
            )

            mini_frames.append(img)
        horizon = len(mini_frames)
        frames += mini_frames
        if do_timer:
            print(i, time.time() - start)

    if len(subgoal_images) != 0:
        from railrl.core import logger
        import os.path as osp
        logdir = logger.get_snapshot_dir()
        filename_subgoals = osp.join(logdir,
                                     'sg_{epoch}.png'.format(epoch=epoch))
        nrow = subgoal_images[0].shape[0]
        subgoal_images = np.concatenate(subgoal_images)
        save_image(
            ptu.FloatTensor(ptu.from_numpy(subgoal_images)),
            filename_subgoals,
            nrow=nrow,
        )

    frames = np.array(frames, dtype=np.uint8).reshape(
        (N, horizon, -1, imsize + 2 * pad_length, num_channels))
    f1 = []
    for k1 in range(columns):
        f2 = []
        for k2 in range(rows):
            k = k1 * rows + k2
            f2.append(frames[k:k + 1, :, :, :, :].reshape(
                (horizon, -1, imsize + 2 * pad_length, num_channels)))
        f1.append(np.concatenate(f2, axis=1))
    outputdata = np.concatenate(f1, axis=2)
    skvideo.io.vwrite(filename, outputdata)
    print("Saved video to ", filename)
Exemple #10
0
    mu_stds = np.std(np.vstack(latent_mus), axis=0)
    plt.bar(np.arange(len(mu_stds)), mu_stds)
    plt.title("X-axis puck sweep")
    plt.xlabel("latent dim")
    plt.ylabel("Mean std")
    plt.show()

    sigma_stds = np.mean(np.vstack(latent_sigmas), axis=0)
    plt.bar(np.arange(len(sigma_stds)), sigma_stds)
    plt.title("X-axis puck sweep")
    plt.xlabel("latent dim")
    plt.ylabel("Sigma std")
    plt.show()

    imgs = np.array(imgs)
    imgs = ptu.FloatTensor(imgs)
    save_image(imgs, 'x-puck-sweep.png')

    #  ------------------ X axis - arm
    goals = []
    for x in np.arange(env.hand_low[0], env.hand_high[0], 0.01):
        new_hand_xyz = hand_xyz.copy()
        new_hand_xyz[0] = x
        goals.append(np.hstack((new_hand_xyz, obj_xyz)))
    imgs, latent_mus, latent_sigmas = get_info(goals)

    mu_stds = np.std(np.vstack(latent_mus), axis=0)
    plt.bar(np.arange(len(mu_stds)), mu_stds)
    plt.title("X-axis arm sweep")
    plt.xlabel("latent dim")
    plt.ylabel("Mean std")
Exemple #11
0
 def contains_points_pytorch(self, points):
     less_op = points < ptu.Variable(
         ptu.FloatTensor([self.max_x, self.max_y]))
     greater_op = points > ptu.Variable(
         ptu.FloatTensor([self.min_x, self.min_y]))
     return less_op * greater_op
    # return np.array(joint_angles), np.array(jacobians)
    return (
        np.hstack((np.array(states), np.array(actions))),
        np.array(next_states) - np.array(states)
    )


train_x_np, train_y_np = generate_data(N_PATHS, PATH_LENGTH)
test_x_np, test_y_np = generate_data(N_PATHS_TEST, PATH_LENGTH)
train_x = ptu.np_to_var(train_x_np)
train_y = ptu.np_to_var(train_y_np)
test_x = ptu.np_to_var(test_x_np)
test_y = ptu.np_to_var(test_y_np)

train_dataset = TensorDataset(
    ptu.FloatTensor(train_x_np),
    ptu.FloatTensor(train_y_np),
)
in_dim = train_x_np[0].size
out_dim = train_y_np[0].size
dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)


def train_network(net, title):
    train_losses = []
    test_losses = []
    times = []

    optimizer = Adam(net.parameters(), lr=1e-3)
    criterion = nn.MSELoss()
Exemple #13
0
 def get_new_state(self, batch_size):
     cx = Variable(ptu.FloatTensor(1, batch_size, self.hidden_size))
     cx.data.fill_(0)
     hx = Variable(ptu.FloatTensor(1, batch_size, self.hidden_size))
     hx.data.fill_(0)
     return hx, cx
Exemple #14
0
    def forward(self, obs, initial_memory):
        """
        :param obs: torch Variable, [batch_size, sequence length, obs dim]
        :param initial_memory: torch Variable, [batch_size, memory dim]
        :return: (actions, writes) tuple
            actions: [batch_size, sequence length, action dim]
            writes: [batch_size, sequence length, memory dim]
        """
        assert len(obs.size()) == 3
        assert len(initial_memory.size()) == 2
        batch_size, subsequence_length = obs.size()[:2]

        subtraj_writes = Variable(ptu.FloatTensor(batch_size,
                                                  subsequence_length,
                                                  self.memory_dim),
                                  requires_grad=False)
        subtraj_actions = Variable(ptu.FloatTensor(batch_size,
                                                   subsequence_length,
                                                   self.action_dim),
                                   requires_grad=False)
        if self.feed_action_to_memory:
            if self.num_splits_for_rnn_internally > 1:
                state = torch.split(
                    initial_memory,
                    self.memory_dim // self.num_splits_for_rnn_internally,
                    dim=1,
                )
                for i in range(subsequence_length):
                    current_obs = obs[:, i, :]
                    augmented_state = torch.cat((current_obs, ) + state, dim=1)
                    action = self.forward_action(augmented_state)
                    rnn_input = torch.cat([current_obs, action], dim=1)
                    state = self.rnn_cell(rnn_input, state)
                    subtraj_writes[:, i, :] = torch.cat(state, dim=1)
                    subtraj_actions[:, i, :] = action
            else:
                state = initial_memory
                for i in range(subsequence_length):
                    current_obs = obs[:, i, :]
                    augmented_state = torch.cat([current_obs, state], dim=1)
                    action = self.forward_action(augmented_state)
                    rnn_input = torch.cat([current_obs, action], dim=1)
                    state = self.rnn_cell(rnn_input, state)
                    subtraj_writes[:, i, :] = state
                    subtraj_actions[:, i, :] = action
            return subtraj_actions, subtraj_writes
        """
        Create the new writes.
        """
        if self.num_splits_for_rnn_internally > 1:
            state = torch.split(
                initial_memory,
                self.memory_dim // self.num_splits_for_rnn_internally,
                dim=1,
            )
            for i in range(subsequence_length):
                state = self.rnn_cell(obs[:, i, :], state)
                subtraj_writes[:, i, :] = torch.cat(state, dim=1)
        else:
            state = initial_memory
            for i in range(subsequence_length):
                state = self.rnn_cell(obs[:, i, :], state)
                subtraj_writes[:, i, :] = state

        # The reason that using a LSTM doesn't work is that this gives you only
        # the FINAL hx and cx, not all of them :(
        """
        Create the new subtrajectory memories with the initial memories and the
        new writes.
        """
        expanded_init_memory = initial_memory.unsqueeze(1)
        if subsequence_length > 1:
            memories = torch.cat(
                (
                    expanded_init_memory,
                    subtraj_writes[:, :-1, :],
                ),
                dim=1,
            )
        else:
            memories = expanded_init_memory
        """
        Use new memories to create env actions.
        """
        all_subtraj_inputs = torch.cat([obs, memories], dim=2)
        for i in range(subsequence_length):
            augmented_state = all_subtraj_inputs[:, i, :]
            action = self.forward_action(augmented_state)
            subtraj_actions[:, i, :] = action

        return subtraj_actions, subtraj_writes