def run_encoder(self, inputs, start_ind): if 'demo_seq' in inputs: if not 'enc_demo_seq' in inputs: inputs.enc_demo_seq, inputs.skips = batch_apply(inputs.demo_seq, self.encoder) if self._hp.use_convs and self._hp.use_skips: inputs.skips = map_recursive(lambda s: s[:, 0], inputs.skips) # only use start image activations if self._hp.separate_cnn_start_goal_encoder: enc_e_0, inputs.skips = self.start_goal_enc(inputs.I_0_image) inputs.enc_e_0 = remove_spatial(enc_e_0) inputs.enc_e_g = remove_spatial(self.start_goal_enc(inputs.I_g_image)[0]) else: inputs.enc_e_0, inputs.skips = self.encoder(inputs.I_0) inputs.enc_e_g = self.encoder(inputs.I_g)[0] if 'demo_seq' in inputs: if self._hp.act_cond_inference: inputs.inf_enc_seq = self.inf_encoder(inputs.enc_demo_seq, inputs.actions) elif self._hp.states_inference: inputs.inf_enc_seq = batch_apply((inputs.enc_demo_seq, inputs.demo_seq_states[..., None, None]), self.inf_encoder, separate_arguments=True) else: inputs.inf_enc_seq = self.inf_encoder(inputs.enc_demo_seq) inputs.inf_enc_key_seq = self.inf_key_encoder(inputs.enc_demo_seq) if self._hp.action_conditioned_pred: inputs.enc_action_seq = batch_apply(inputs.actions, self.action_encoder)
def prune_sequence(self, inputs, outputs, key='images'): seq = getattr(outputs.tree.df, key) latent_seq = outputs.tree.df.e_g_prime existence = batch_apply(latent_seq, self.existence_predictor)[..., 0] outputs.existence_predictor = AttrDict(existence=existence) existing_frames = torch.sigmoid(existence) > 0.5 pruned_seq = [seq[i][existing_frames[i]] for i in range(seq.shape[0])] return pruned_seq
def forward(self, input, actions): net_outputs = self.net(input) padded_actions = torch.nn.functional.pad( actions, (0, 0, 0, net_outputs.shape[1] - actions.shape[1], 0, 0)) # TODO quite sure the concatenation is automatic net_outputs = batch_apply( self.ac_net, torch.cat([net_outputs, broadcast_final(padded_actions, input)], dim=2)) return net_outputs
def main(): args = get_trainer_args() def load_videos(path): print("Loading trajectories from {}".format(path)) if not path.endswith('.npy'): raise ValueError("Can only read in .npy files!") seqs = np.load(path) assert len(seqs.shape) == 5 # need [batch, T, C, H, W] input data assert seqs.shape[ 2] == 3 # assume 3-channeled seq with channel in last dim seqs = torch.Tensor(seqs) if args.use_gpu: seqs = seqs.cuda() return seqs # range [-1, 1] gt_seqs = load_videos(args.gt) pred_seqs = load_videos(args.pred) print('shape: ', gt_seqs.shape) assert gt_seqs.shape == pred_seqs.shape n_seqs, time, c, h, w = gt_seqs.shape n_batches = int(np.floor(n_seqs / args.batch_size)) # import pdb; pdb.set_trace() # get sequence mask (for sequences with variable length mask = 1 - torch.all(torch.all(torch.all( (gt_seqs + 1.0).abs() < 1e-6, dim=-1), dim=-1), dim=-1) # check for black images mask2 = 1 - torch.all(torch.all(torch.all((gt_seqs).abs() < 1e-6, dim=-1), dim=-1), dim=-1) # check for gray images mask = mask * mask2 # Initializing the model model = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=args.use_gpu) # run forward pass to compute LPIPS distances distances = [] for b in range(n_batches): x, y = gt_seqs[b * args.batch_size:(b + 1) * args.batch_size], pred_seqs[b * args.batch_size:(b + 1) * args.batch_size] lpips_dist = batch_apply((x, y), model, separate_arguments=True) distances.append(lpips_dist) distances = torch.cat(distances) mean_distance = distances[mask].mean() print("LPIPS distance: {}".format(mean_distance))
def forward(self, *inp): inp = concat_inputs(*inp, dim=2) n = self._hp.n_lstm_layers if self.net.bidirectional: n = n * 2 c0 = inp.new_zeros(n, inp.shape[0], self._hp.nz_mid_lstm) h0 = inp.new_zeros(n, inp.shape[0], self._hp.nz_mid_lstm) out = self.net(inp, (c0, h0))[0] projected = batch_apply(self.out_projection, out.contiguous()) return projected
def prune_sequence(self, inputs, outputs, key='images'): seq = getattr(outputs.tree.df, key) latent_seq = outputs.tree.df.e_g_prime distances = batch_apply(self.distance_predictor, latent_seq[:, :-1].contiguous(), latent_seq[:, 1:].contiguous())[..., 0] outputs.distance_predictor = AttrDict(distances=distances) # distance_predictor outputs true if the two frames are too close close_frames = torch.sigmoid(distances) > self._hp.learned_pruning_threshold # Add a placeholder for the first frame close_frames = torch.cat([torch.zeros_like(close_frames[:, [0]]), close_frames], 1) pruned_seq = [seq[i][~close_frames[i]] for i in range(seq.shape[0])] return pruned_seq
def decode_seq(self, inputs, encodings): """ Decodes a sequence of images given the encodings :param inputs: :param encodings: :param seq_len: :return: """ # TODO skip from the goal as well extend_to_seq = lambda x: x[:, None][:, [0] * encodings.shape[1] ].contiguous() seq_skips = rmap(extend_to_seq, inputs.skips) pixel_source = rmap(extend_to_seq, [inputs.I_0, inputs.I_g]) return batch_apply(self, input=encodings, skips=seq_skips, pixel_source=pixel_source)
def produce_tree(self, root, tree, tree_inputs, inputs, outputs): # Produce the tree to get the matching root.produce_tree_cont_time(*tree_inputs, self.one_step_planner, self._hp.hierarchy_levels) if not self.one_step_planner._sample_prior: tree.set_attr_bf( **self.decoder.decode_seq(inputs, tree.bf.e_g_prime)) tree.bf.match_dist = outputs.gt_match_dists = self.one_step_planner.matcher.get_w( inputs.pad_mask, inputs, outputs) matched_index = tree.bf.match_dist.argmax(-1) tiled_enc_demo_seq = inputs.enc_demo_seq[:, None].repeat_interleave( matched_index. shape[1], 1) matched_latents = batch_apply( [tiled_enc_demo_seq, matched_index], lambda pair: batchwise_index(pair[0], pair[1])) tree.bf.e_g_prime = matched_latents
def forward(self, inputs, phase='train'): """ forward pass at training time :param images shape = batch x time x height x width x channel pad mask shape = batch x time, 1 indicates actual image 0 is padded :return: """ if self._hp.non_goal_conditioned: if 'demo_seq' in inputs: inputs.demo_seq[torch.arange(inputs.demo_seq.shape[0]), inputs.end_ind] = 0.0 inputs.demo_seq_images[torch.arange(inputs.demo_seq.shape[0]), inputs.end_ind] = 0.0 inputs.I_g = torch.zeros_like(inputs.I_g) if "I_g_image" in inputs: inputs.I_g_image = torch.zeros_like(inputs.I_g_image) if inputs.I_0.shape[-1] == 5: # special hack for maze inputs.I_0[..., -2:] = 0.0 if "demo_seq" in inputs: inputs.demo_seq[..., -2:] = 0.0 # swap in actions if we want to train action sequence decoder if self._hp.train_on_action_seqs: inputs.demo_seq = torch.cat([inputs.actions, torch.zeros_like(inputs.actions[:, :1])], dim=1) model_output = AttrDict() inputs.reference_tensor = find_tensor(inputs) if 'start_ind' not in inputs: start_ind = torch.zeros(self._hp.batch_size, dtype=torch.long, device=inputs.reference_tensor.device) else: start_ind = inputs.start_ind self.run_encoder(inputs, start_ind) end_ind = inputs.end_ind if 'end_ind' in inputs else None if self._hp.regress_length: # predict total sequence length model_output.update(self.length_pred(inputs.enc_e_0, inputs.enc_e_g)) if self._use_pred_length and (self._hp.length_pred_weight > 0 or end_ind is None): end_ind = torch.argmax(model_output.seq_len_pred.sample().long(), dim=1) if self._hp.action_conditioned_pred or self._hp.non_goal_conditioned: # don't use predicted length when action conditioned end_ind = torch.ones_like(end_ind) * (self._hp.max_seq_len - 1) # TODO clean this up. model_output.end_ind is not currently used anywhere model_output.end_ind = end_ind # Run the model to generate sequences model_output.update(self.predict_sequence(inputs, model_output, start_ind, end_ind, phase)) if self.prune_sequences: if phase == 'train': inputs.model_enc_seq = self.get_matched_pruned_seqs(inputs, model_output) else: inputs.model_enc_seq = self.get_predicted_pruned_seqs(inputs, model_output) inputs.model_enc_seq = pad_sequence(inputs.model_enc_seq, batch_first=True) if len(inputs.model_enc_seq.shape) == 5: inputs.model_enc_seq = inputs.model_enc_seq[..., 0, 0] if self._hp.attach_inv_mdl and phase == 'train': model_output.update(self.inv_mdl(inputs, full_seq=self._inv_mdl_full_seq or self._hp.train_inv_mdl_full_seq)) if self._hp.attach_state_regressor: regressor_inputs = inputs.model_enc_seq if not self._hp.supervised_decoder: regressor_inputs = regressor_inputs.detach() model_output.regressed_state = batch_apply(regressor_inputs, self.state_regressor) if self._hp.attach_cost_mdl and self._hp.run_cost_mdl and phase == 'train': # There is an issue here since SVG doesn't output a latent for the first imagge # Beyong conceptual problems, this breaks if end_ind = 199 model_output.update(self.cost_mdl(inputs)) return model_output
def forward(self, seq): return batch_apply(self.net, seq.contiguous())
def forward(self, *args): return batch_apply(self.net, *args, self.net)