Ejemplo n.º 1
0
    def _fast_path_dist_cost(self, inputs):
        """Vectorized computation of path distance cost."""
        # sample start goal indices
        batch_size = inputs.end_ind.shape[0]
        start_idx = torch.rand(
            (batch_size, ),
            device=inputs.end_ind.device) * (inputs.end_ind.float() - 1)
        end_idx = torch.rand(
            (batch_size, ),
            device=inputs.end_ind.device) * (inputs.end_ind.float() -
                                             (start_idx + 1)) + (start_idx + 1)
        start_idx, end_idx = start_idx.long(), end_idx.long()

        # get start goal latents
        start = batchwise_index(inputs.model_enc_seq, start_idx).detach()
        end = batchwise_index(inputs.model_enc_seq, end_idx).detach()

        # compute path distance cost
        cum_diff = torch.cumsum(torch.norm(inputs.demo_seq[:, 1:] -
                                           inputs.demo_seq[:, :-1],
                                           dim=-1),
                                dim=1)
        cum_diff = torch.cat((torch.zeros(
            (batch_size, 1), dtype=cum_diff.dtype,
            device=cum_diff.device), cum_diff),
                             dim=1)  # prepend 0
        gt_cost = batchwise_index(cum_diff, end_idx) - batchwise_index(
            cum_diff, start_idx)

        return start, end, gt_cost[:, None].detach()
Ejemplo n.º 2
0
    def get_matched_sequence(self, tree, key):
        latents = tree.bf[key]
        indices = tree.bf.match_dist.argmax(1)
        # Two-dimensional indexing
        matched_sequence = rmap(lambda x: batchwise_index(x, indices), latents)

        return matched_sequence
Ejemplo n.º 3
0
 def get_n_top_bias_nodes(targets, weights):
     """ Return the probability that the downweighted nodes match the noisy frame"""
     inds = WeightsHacker.get_index(targets)
     
     noise_frames = batchwise_index(weights, inds, 2)
     n = noise_frames.mean(0)[:global_params.hp.n_top_bias_nodes].sum() / global_params.hp.top_bias
     
     return n
Ejemplo n.º 4
0
def plot_balanced_tree_with_actions(model_output,
                                    inputs,
                                    n_logged_samples,
                                    get_prob_fcn=None):
    tree = model_output.tree
    batch, channels, res, _ = tree.subgoal.images.shape
    _, action_dim = tree.subgoal.a_l.shape
    max_depth = tree.depth
    n_sg = (2**max_depth) - 1

    im_height = (
        max_depth * 2 +
        1) * res  # plot all actions (*2) and start/end frame again (+1)
    im_width = (n_sg + 2) * res
    im = np.asarray(0.7 * np.ones((n_logged_samples, im_height, im_width, 3)),
                    dtype=np.float32)

    # insert start and goal frame
    if inputs is not None:
        im[:, :res, :res] = imgtensor2np(inputs.traj_seq[:n_logged_samples, 0],
                                         n_logged_samples).transpose(
                                             0, 2, 3, 1)
        im[:, :res, -res:] = imgtensor2np(
            batchwise_index(inputs.traj_seq[:n_logged_samples],
                            model_output.end_ind[:n_logged_samples]),
            n_logged_samples).transpose(0, 2, 3, 1)

    if 'norm_gt_action_match_dists' in model_output:
        action_usage_prob = np.max(tensor2np(
            model_output.norm_gt_action_match_dists, n_logged_samples),
                                   axis=2)

    step = 1
    for i, segment in enumerate(tree):
        level = 2 * (max_depth - segment.depth + 1)
        dx = 2**(segment.depth - 2)
        im[:, level * res : (level + 1) * res, step * res: (step + 1) * res] = \
            imgtensor2np(segment.subgoal.images[:n_logged_samples], n_logged_samples).transpose(0, 2, 3, 1)
        a_l, a_r = tensor2np(segment.subgoal.a_l, n_logged_samples), tensor2np(
            segment.subgoal.a_r, n_logged_samples)
        if get_prob_fcn is not None:
            usage_prob_l, usage_prob_r = get_prob_fcn(segment)
        else:
            usage_prob_l, usage_prob_r = action_usage_prob[:, 2 *
                                                           i], action_usage_prob[:,
                                                                                 2
                                                                                 *
                                                                                 i
                                                                                 +
                                                                                 1]
        for b in range(n_logged_samples):
            im[b, (level-1) * res : level * res, int((step-dx) * res): int((step - dx + 1) * res)] = \
                framed_action2img(a_l[b], usage_prob_l[b], res, channels)
            im[b, (level - 1) * res: level * res, int((step + dx) * res): int((step + dx + 1) * res)] = \
                framed_action2img(a_r[b], usage_prob_r[b], res, channels)
        step += 1
    return im
Ejemplo n.º 5
0
 def index_input(self, input, t, aggregate=False, t1=None):
     if aggregate:
         assert t1 is not None  # need end time step for aggregation
         selected = torch.zeros_like(input[:, 0])
         for b in range(input.shape[0]):
             selected[b] = torch.sum(input[b, t[b]:t1[b]], dim=0)
     else:
         selected = batchwise_index(input, t)
     return selected
Ejemplo n.º 6
0
def plot_val_tree(model_output, inputs, n_logged_samples=3):
    tree = model_output.tree
    batch, _, channels, res, _ = tree.subgoals.images.shape
    max_depth = tree.depth
    n_sg = (2**max_depth) - 1

    dpi = 10
    fig_height, fig_width = 2 * res, n_sg * res

    im_height, im_width = max_depth * res + fig_height, 2 * res + fig_width
    im = np.asarray(0.7 * np.ones((n_logged_samples, im_height, im_width, 3)),
                    dtype=np.float32)

    # plot existence probabilities
    if 'p_n_hat' in tree.subgoals:
        p_n_hat = tensor2np(tree.df.p_n_hat, n_logged_samples)
        for i in range(n_logged_samples):
            im[i, :res, res:-res] = plot_dists([p_n_hat[i]], res, fig_width,
                                               dpi)
    if 'p_a_l_hat' in tree.subgoals:
        p_a_hat = tensor2np(
            sort_actions_depth_first(model_output, ['p_a_l_hat', 'p_a_r_hat'],
                                     n_logged_samples))
        for i in range(n_logged_samples):
            im[i, res:2 * res,
               int(3 * res / 4):int(-3 * res / 4)] = plot_dists(
                   [p_a_hat[i]], res, fig_width + int(res / 2), dpi)
        im = np.concatenate(
            (im[:, :fig_height],
             plot_balanced_tree_with_actions(
                 model_output,
                 inputs,
                 n_logged_samples,
                 get_prob_fcn=lambda s: (tensor2np(s.subgoal.p_a_l_hat),
                                         tensor2np(s.subgoal.p_a_r_hat)))),
            axis=1)
    else:
        with param(n_logged_samples=n_logged_samples):
            im[:, fig_height:,
               res:-res] = plot_balanced_tree(model_output).transpose(
                   (0, 2, 3, 1))

    # insert start and goal frame
    if inputs is not None:
        im[:, :res, :res] = imgtensor2np(inputs.traj_seq[:n_logged_samples, 0],
                                         n_logged_samples).transpose(
                                             0, 2, 3, 1)
        im[:, :res, -res:] = imgtensor2np(
            batchwise_index(inputs.traj_seq[:n_logged_samples],
                            model_output.end_ind[:n_logged_samples]),
            n_logged_samples).transpose(0, 2, 3, 1)

    return im
Ejemplo n.º 7
0
    def forward(self, inputs, e_l, e_r, start_ind, end_ind, timestep):
        assert timestep is not None
        output = AttrDict(gamma=None)

        if self.deterministic:
            output.q_z = self.q(e_l)
            return output

        values = inputs.inf_enc_seq
        keys = inputs.inf_enc_key_seq

        mult = int(timestep.shape[0] / keys.shape[0])
        if mult > 1:
            timestep = timestep.reshape(-1, mult)
            result = batchwise_index(values, timestep.long())
            e_tilde = result.reshape([-1] + list(result.shape[2:]))
        else:
            e_tilde = batchwise_index(values, timestep[:, 0].long())

        output.q_z = self.q(e_l, e_r, e_tilde)
        return output
Ejemplo n.º 8
0
    def get_timesteps(self, inputs):
        """
         # sample temporal distances between 1 and temp_dist, regress only first action
        :return:  None, call by reference
        """

        t0 = np.zeros(self._hp.batch_size)
        for b in range(self._hp.batch_size):
            t0[b] = np.random.randint(
                0, abs(inputs.end_ind[b].cpu().numpy() - self._hp.temp_dist),
                1)
        delta_t = np.random.randint(1, self._hp.temp_dist + 1,
                                    self._hp.batch_size)
        t1 = t0 + delta_t
        t0 = torch.tensor(t0,
                          device=inputs.traj_seq_states.device,
                          dtype=torch.long)
        t1 = torch.tensor(t1,
                          device=inputs.traj_seq_states.device,
                          dtype=torch.long)
        inputs.state_t0 = batchwise_index(inputs.traj_seq_states, t0)
        inputs.state_t1 = batchwise_index(inputs.traj_seq_states, t1)
        inputs.selected_action = batchwise_index(inputs.actions, t0)
Ejemplo n.º 9
0
 def sample_target(self, seq, end_inds, repeats):
     """
     
     :param seq:
     :param end_inds:
     :param repeats: how many times to sample from each sequence
     :return:
     """
     # Note: it would be easier to implement it using randomized length
     
     def get_random_ints(min_i, max_i):
         index = torch.rand(max_i.shape + (repeats,), device=max_i.device)
         index = (min_i + index * (max_i[:, None].float() - min_i)).floor()
         return index
     
     index = get_random_ints(1, end_inds + 1)
     target = batchwise_index(seq, index.long())
     
     return target
Ejemplo n.º 10
0
 def dump_metrics(self, it):
     with self._logger.log_to('results', it, 'metric'):
         best_idxs = self._get_best_idxs(
             self.full_evaluation[self._top_comp_metric])
         print_st = []
         for metric in sorted(self.full_evaluation):
             vals = self.full_evaluation[metric]
             if metric in ['psnr', 'ssim', 'mse']:
                 if metric not in self.evaluation_buffer: continue
                 best_vals = batchwise_index(vals, best_idxs)
                 print_st.extend([
                     best_vals.mean(),
                     best_vals.std(),
                     vals.std(axis=1).mean()
                 ])
                 self._logger.log(metric,
                                  vals if self._top_of_100 else None,
                                  best_vals)
         print(*print_st, sep=',')
Ejemplo n.º 11
0
    def produce_tree(self, root, tree, tree_inputs, inputs, outputs):
        # Produce the tree to get the matching
        root.produce_tree_cont_time(*tree_inputs, self.one_step_planner,
                                    self._hp.hierarchy_levels)

        if not self.one_step_planner._sample_prior:
            tree.set_attr_bf(
                **self.decoder.decode_seq(inputs, tree.bf.e_g_prime))
            tree.bf.match_dist = outputs.gt_match_dists = self.one_step_planner.matcher.get_w(
                inputs.pad_mask, inputs, outputs)

            matched_index = tree.bf.match_dist.argmax(-1)
            tiled_enc_demo_seq = inputs.enc_demo_seq[:,
                                                     None].repeat_interleave(
                                                         matched_index.
                                                         shape[1], 1)
            matched_latents = batch_apply(
                [tiled_enc_demo_seq, matched_index],
                lambda pair: batchwise_index(pair[0], pair[1]))

            tree.bf.e_g_prime = matched_latents
Ejemplo n.º 12
0
def select_e_0_e_g(seq, start_ind, end_ind):
    e_0 = batchwise_index(seq, start_ind)
    e_g = batchwise_index(seq, end_ind)
    return e_0, e_g