Beispiel #1
0
    def run_encoder(self, inputs, start_ind):
        if 'demo_seq' in inputs:
            if not 'enc_demo_seq' in inputs:
                inputs.enc_demo_seq, inputs.skips = batch_apply(inputs.demo_seq, self.encoder)
                if self._hp.use_convs and self._hp.use_skips:
                    inputs.skips = map_recursive(lambda s: s[:, 0], inputs.skips)  # only use start image activations

        if self._hp.separate_cnn_start_goal_encoder:
            enc_e_0, inputs.skips = self.start_goal_enc(inputs.I_0_image)
            inputs.enc_e_0 = remove_spatial(enc_e_0)
            inputs.enc_e_g = remove_spatial(self.start_goal_enc(inputs.I_g_image)[0])
        else:
            inputs.enc_e_0, inputs.skips = self.encoder(inputs.I_0)
            inputs.enc_e_g = self.encoder(inputs.I_g)[0]
        
        if 'demo_seq' in inputs:
            if self._hp.act_cond_inference:
                inputs.inf_enc_seq = self.inf_encoder(inputs.enc_demo_seq, inputs.actions)
            elif self._hp.states_inference:
                inputs.inf_enc_seq = batch_apply((inputs.enc_demo_seq, inputs.demo_seq_states[..., None, None]),
                                                 self.inf_encoder, separate_arguments=True)
            else:
                inputs.inf_enc_seq = self.inf_encoder(inputs.enc_demo_seq)
            inputs.inf_enc_key_seq = self.inf_key_encoder(inputs.enc_demo_seq)
            
        if self._hp.action_conditioned_pred:
            inputs.enc_action_seq = batch_apply(inputs.actions, self.action_encoder)
Beispiel #2
0
 def prune_sequence(self, inputs, outputs, key='images'):
     seq = getattr(outputs.tree.df, key)
     latent_seq = outputs.tree.df.e_g_prime
 
     existence = batch_apply(latent_seq, self.existence_predictor)[..., 0]
     outputs.existence_predictor = AttrDict(existence=existence)
 
     existing_frames = torch.sigmoid(existence) > 0.5
     pruned_seq = [seq[i][existing_frames[i]] for i in range(seq.shape[0])]
 
     return pruned_seq
Beispiel #3
0
 def forward(self, input, actions):
     net_outputs = self.net(input)
     padded_actions = torch.nn.functional.pad(
         actions, (0, 0, 0, net_outputs.shape[1] - actions.shape[1], 0, 0))
     # TODO quite sure the concatenation is automatic
     net_outputs = batch_apply(
         self.ac_net,
         torch.cat([net_outputs,
                    broadcast_final(padded_actions, input)],
                   dim=2))
     return net_outputs
Beispiel #4
0
def main():
    args = get_trainer_args()

    def load_videos(path):
        print("Loading trajectories from {}".format(path))
        if not path.endswith('.npy'):
            raise ValueError("Can only read in .npy files!")
        seqs = np.load(path)
        assert len(seqs.shape) == 5  # need [batch, T, C, H, W] input data
        assert seqs.shape[
            2] == 3  # assume 3-channeled seq with channel in last dim
        seqs = torch.Tensor(seqs)
        if args.use_gpu: seqs = seqs.cuda()
        return seqs  # range [-1, 1]

    gt_seqs = load_videos(args.gt)
    pred_seqs = load_videos(args.pred)
    print('shape: ', gt_seqs.shape)

    assert gt_seqs.shape == pred_seqs.shape

    n_seqs, time, c, h, w = gt_seqs.shape
    n_batches = int(np.floor(n_seqs / args.batch_size))

    # import pdb; pdb.set_trace()
    # get sequence mask (for sequences with variable length
    mask = 1 - torch.all(torch.all(torch.all(
        (gt_seqs + 1.0).abs() < 1e-6, dim=-1),
                                   dim=-1),
                         dim=-1)  # check for black images
    mask2 = 1 - torch.all(torch.all(torch.all((gt_seqs).abs() < 1e-6, dim=-1),
                                    dim=-1),
                          dim=-1)  # check for gray images
    mask = mask * mask2

    # Initializing the model
    model = models.PerceptualLoss(model='net-lin',
                                  net='alex',
                                  use_gpu=args.use_gpu)

    # run forward pass to compute LPIPS distances
    distances = []
    for b in range(n_batches):
        x, y = gt_seqs[b * args.batch_size:(b + 1) *
                       args.batch_size], pred_seqs[b *
                                                   args.batch_size:(b + 1) *
                                                   args.batch_size]
        lpips_dist = batch_apply((x, y), model, separate_arguments=True)
        distances.append(lpips_dist)
    distances = torch.cat(distances)
    mean_distance = distances[mask].mean()

    print("LPIPS distance: {}".format(mean_distance))
Beispiel #5
0
    def forward(self, *inp):
        inp = concat_inputs(*inp, dim=2)

        n = self._hp.n_lstm_layers
        if self.net.bidirectional:
            n = n * 2
        c0 = inp.new_zeros(n, inp.shape[0], self._hp.nz_mid_lstm)
        h0 = inp.new_zeros(n, inp.shape[0], self._hp.nz_mid_lstm)

        out = self.net(inp, (c0, h0))[0]
        projected = batch_apply(self.out_projection, out.contiguous())
        return projected
Beispiel #6
0
    def prune_sequence(self, inputs, outputs, key='images'):
        seq = getattr(outputs.tree.df, key)
        latent_seq = outputs.tree.df.e_g_prime
        
        distances = batch_apply(self.distance_predictor,
                                latent_seq[:, :-1].contiguous(), latent_seq[:, 1:].contiguous())[..., 0]
        outputs.distance_predictor = AttrDict(distances=distances)

        # distance_predictor outputs true if the two frames are too close
        close_frames = torch.sigmoid(distances) > self._hp.learned_pruning_threshold
        # Add a placeholder for the first frame
        close_frames = torch.cat([torch.zeros_like(close_frames[:, [0]]), close_frames], 1)
        
        pruned_seq = [seq[i][~close_frames[i]] for i in range(seq.shape[0])]
        
        return pruned_seq
Beispiel #7
0
    def decode_seq(self, inputs, encodings):
        """ Decodes a sequence of images given the encodings

        :param inputs:
        :param encodings:
        :param seq_len:
        :return:
        """

        # TODO skip from the goal as well
        extend_to_seq = lambda x: x[:, None][:, [0] * encodings.shape[1]
                                             ].contiguous()
        seq_skips = rmap(extend_to_seq, inputs.skips)
        pixel_source = rmap(extend_to_seq, [inputs.I_0, inputs.I_g])

        return batch_apply(self,
                           input=encodings,
                           skips=seq_skips,
                           pixel_source=pixel_source)
    def produce_tree(self, root, tree, tree_inputs, inputs, outputs):
        # Produce the tree to get the matching
        root.produce_tree_cont_time(*tree_inputs, self.one_step_planner,
                                    self._hp.hierarchy_levels)

        if not self.one_step_planner._sample_prior:
            tree.set_attr_bf(
                **self.decoder.decode_seq(inputs, tree.bf.e_g_prime))
            tree.bf.match_dist = outputs.gt_match_dists = self.one_step_planner.matcher.get_w(
                inputs.pad_mask, inputs, outputs)

            matched_index = tree.bf.match_dist.argmax(-1)
            tiled_enc_demo_seq = inputs.enc_demo_seq[:,
                                                     None].repeat_interleave(
                                                         matched_index.
                                                         shape[1], 1)
            matched_latents = batch_apply(
                [tiled_enc_demo_seq, matched_index],
                lambda pair: batchwise_index(pair[0], pair[1]))

            tree.bf.e_g_prime = matched_latents
Beispiel #9
0
    def forward(self, inputs, phase='train'):
        """
        forward pass at training time
        :param
            images shape = batch x time x height x width x channel
            pad mask shape = batch x time, 1 indicates actual image 0 is padded
        :return:
        """
        if self._hp.non_goal_conditioned:
            if 'demo_seq' in inputs:
                inputs.demo_seq[torch.arange(inputs.demo_seq.shape[0]), inputs.end_ind] = 0.0
                inputs.demo_seq_images[torch.arange(inputs.demo_seq.shape[0]), inputs.end_ind] = 0.0
            inputs.I_g = torch.zeros_like(inputs.I_g)
            if "I_g_image" in inputs:
                inputs.I_g_image = torch.zeros_like(inputs.I_g_image)
            if inputs.I_0.shape[-1] == 5: # special hack for maze
                inputs.I_0[..., -2:] = 0.0
                if "demo_seq" in inputs:
                    inputs.demo_seq[..., -2:] = 0.0

        # swap in actions if we want to train action sequence decoder
        if self._hp.train_on_action_seqs:
            inputs.demo_seq = torch.cat([inputs.actions, torch.zeros_like(inputs.actions[:, :1])], dim=1)

        model_output = AttrDict()
        inputs.reference_tensor = find_tensor(inputs)

        if 'start_ind' not in inputs:
            start_ind = torch.zeros(self._hp.batch_size, dtype=torch.long, device=inputs.reference_tensor.device)
        else:
            start_ind = inputs.start_ind
    
        self.run_encoder(inputs, start_ind)
    
        end_ind = inputs.end_ind if 'end_ind' in inputs else None
        if self._hp.regress_length:
            # predict total sequence length
            model_output.update(self.length_pred(inputs.enc_e_0, inputs.enc_e_g))
            if self._use_pred_length and (self._hp.length_pred_weight > 0 or end_ind is None):
                end_ind = torch.argmax(model_output.seq_len_pred.sample().long(), dim=1)
                if self._hp.action_conditioned_pred or self._hp.non_goal_conditioned:
                    # don't use predicted length when action conditioned
                    end_ind = torch.ones_like(end_ind) * (self._hp.max_seq_len - 1)
        # TODO clean this up. model_output.end_ind is not currently used anywhere
        model_output.end_ind = end_ind
    
        # Run the model to generate sequences
        model_output.update(self.predict_sequence(inputs, model_output, start_ind, end_ind, phase))
    
        if self.prune_sequences:
            if phase == 'train':
                inputs.model_enc_seq = self.get_matched_pruned_seqs(inputs, model_output)
            else:
                inputs.model_enc_seq = self.get_predicted_pruned_seqs(inputs, model_output)
            inputs.model_enc_seq = pad_sequence(inputs.model_enc_seq, batch_first=True)
            if len(inputs.model_enc_seq.shape) == 5:
                inputs.model_enc_seq = inputs.model_enc_seq[..., 0, 0]
                
            if self._hp.attach_inv_mdl and phase == 'train':
                model_output.update(self.inv_mdl(inputs, full_seq=self._inv_mdl_full_seq or self._hp.train_inv_mdl_full_seq))
            if self._hp.attach_state_regressor:
                regressor_inputs = inputs.model_enc_seq
                if not self._hp.supervised_decoder:
                    regressor_inputs = regressor_inputs.detach()
                model_output.regressed_state = batch_apply(regressor_inputs, self.state_regressor)
            if self._hp.attach_cost_mdl and self._hp.run_cost_mdl and phase == 'train':
                # There is an issue here since SVG doesn't output a latent for the first imagge
                # Beyong conceptual problems, this breaks if end_ind = 199
                model_output.update(self.cost_mdl(inputs))
    
        return model_output
Beispiel #10
0
 def forward(self, seq):
     return batch_apply(self.net, seq.contiguous())
Beispiel #11
0
 def forward(self, *args):
     return batch_apply(self.net, *args, self.net)