def test_step(self, batch, batch_idx): _, loss, losses = self(batch, test=True) output = {"test_loss": loss, "test_losses": losses} seq_len = self.hparams.Test["seq_len"] cond_data = { "p1_face": torch.zeros_like( batch["p1_face"] [:, :get_longest_history(self.hparams.Conditioning)]), "p2_face": batch.get("p2_face"), "p1_speech": batch.get("p1_speech"), "p2_speech": batch.get("p2_speech"), } predicted_seq = self.inference(seq_len, data=cond_data) output["predicted_prop_seq"] = predicted_seq.cpu().detach() gt_seq = batch["p1_face"][:, -predicted_seq.shape[1]:] output["gt_seq"] = gt_seq.cpu().detach() for modality in ["p2_face", "p2_speech", "p1_speech"]: if self.hparams.Conditioning[modality]["history"] > 0: deranged_batch = self.derange_batch(batch, [modality]) _, missaligned_nll, misaligned_losses = self(deranged_batch, test=True) output[f"nll_mismatched_{modality}"] = missaligned_nll.cpu( ).detach() output[f"losses_mismatched_{modality}"] = misaligned_losses cond_data = { "p1_face": torch.zeros_like( deranged_batch["p1_face"] [:, :get_longest_history(self.hparams.Conditioning)]), "p2_face": deranged_batch.get("p2_face"), "p1_speech": deranged_batch.get("p1_speech"), "p2_speech": deranged_batch.get("p2_speech"), } predicted_seq = self.inference(seq_len, data=cond_data) output[ f"predicted_mismatch_{modality}_seq"] = predicted_seq.cpu( ).detach() return output
def inference(self, seq_len, data=None): self.glow.init_rnn_hidden() output_shape = torch.zeros_like(data["p1_face"][:, 0, :]) frame_nb = None if self.hparams.Conditioning["use_frame_nb"]: frame_nb = torch.ones((data["p1_face"].shape[0], 1)).type_as( data["p1_face"] ) prev_p1_faces = data["p1_face"] start_ts = get_longest_history(self.hparams.Conditioning) for time_st in range(start_ts, seq_len): condition = self.create_conditioning(data, time_st, frame_nb, prev_p1_faces) output, _ = self.glow( condition=condition, eps_std=self.hparams.Infer["eps"], reverse=True, output_shape=output_shape, ) prev_p1_faces = torch.cat([prev_p1_faces, output.unsqueeze(1)], dim=1) if self.hparams.Conditioning["use_frame_nb"]: frame_nb += 2 return prev_p1_faces[:, start_ts:]
def forward(self, batch): self.glow.init_rnn_hidden() loss = 0 start_ts = get_longest_history(self.hparams.Conditioning) frame_nb = None if self.hparams.Conditioning["use_frame_nb"]: frame_nb = batch["frame_nb"].clone() + start_ts * 2 z_seq = [] losses = [] for time_st in range(start_ts, batch["p1_face"].shape[1]): curr_input = batch["p1_face"][:, time_st, :] condition = self.create_conditioning( batch, time_st, frame_nb, batch["p1_face"] ) z_enc, objective = self.glow(x=curr_input, condition=condition) tmp_loss = self.loss(objective, z_enc) losses.append(tmp_loss.cpu().detach()) loss += torch.mean(tmp_loss) if self.hparams.Conditioning["use_frame_nb"]: frame_nb += 2 z_seq.append(z_enc.detach()) return z_seq, (loss / len(z_seq)).unsqueeze(-1), losses
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): output = {} if batch_idx == 0: # and self.global_step > 0 new_batch = {x: y.type_as(outputs) for x, y in batch.items()} z_seq, loss, _ = pl_module.seq_glow(new_batch) output["jerk"] = {} idx = random.randint(0, batch["p1_face"].shape[0] - 1) if pl_module.hparams.Validation["inference"]: seq_len = pl_module.hparams.Validation["seq_len"] cond_data = { "p1_face": new_batch["p1_face"] [:, :get_longest_history(pl_module.hparams.Conditioning)], "p2_face": new_batch.get("p2_face"), "p1_speech": new_batch.get("p1_speech"), "p2_speech": new_batch.get("p2_speech"), } predicted_seq = pl_module.seq_glow.inference(seq_len, data=cond_data) gt_mean_jerk = calc_jerk( new_batch["p1_face"][:, -predicted_seq.shape[1]:]) generated_mean_jerk = calc_jerk(predicted_seq) pl_module.log("jerk/gt_mean", gt_mean_jerk) pl_module.log("jerk/generated_mean", generated_mean_jerk) pl_module.log("jerk/generated_mean_ratio", generated_mean_jerk / gt_mean_jerk) idx = random.randint(0, cond_data["p1_face"].shape[0] - 1) if pl_module.hparams.Validation["render"]: self.render_results(predicted_seq, new_batch, idx, pl_module) if pl_module.hparams.Validation["check_invertion"]: # Test if the Flow works correctly det_check = self.test_invertability(z_seq, loss, new_batch, pl_module) pl_module.log("reconstruction/error_percentage", det_check) if pl_module.hparams.Validation["scale_logging"]: self.log_scales(pl_module) # Test if the Flow is listening to other modalities if pl_module.hparams.Validation["wrong_context_test"]: mismatch = pl_module.hparams.Mismatch pl_module.log(f"mismatched_nll/actual_nll", loss) for key, modalities in mismatch["shuffle_batch"].items(): if all([ pl_module.hparams.Conditioning[x]["history"] > 0 for x in modalities ]): deranged_batch = derange_batch(new_batch, modalities) _, missaligned_nll, _ = pl_module.seq_glow( deranged_batch) pl_module.log(f"mismatched_nll/shuffle_batch_{key}", missaligned_nll) pl_module.log( f"mismatched_nll_ratios/shuffle_batch_{key}", loss - missaligned_nll, ) for key, modalities in mismatch["shuffle_time"].items(): if all([ pl_module.hparams.Conditioning[x]["history"] > 0 for x in modalities ]): deranged_batch = derange_batch(new_batch, modalities, shuffle_time=True) _, shuffled_nll, _ = pl_module.seq_glow(deranged_batch) pl_module.log(f"mismatched_nll/shuffle_time_{key}", shuffled_nll) pl_module.log( f"mismatched_nll_ratios/shuffle_time_{key}", loss - shuffled_nll, )
def validation_step(self, batch, batch_idx): z_seq, loss, _ = self(batch) if self.hparams.optuna and self.global_step > 20 and loss > 0: message = f"Trial was pruned since loss > 0" raise optuna.exceptions.TrialPruned(message) output = {"val_loss": loss} if batch_idx == 0: # and self.global_step > 0 output["jerk"] = {} idx = random.randint(0, batch["p1_face"].shape[0] - 1) if self.hparams.Validation["inference"]: seq_len = self.hparams.Validation["seq_len"] cond_data = { "p1_face": batch["p1_face"][ :, : get_longest_history(self.hparams.Conditioning) ], "p2_face": batch.get("p2_face"), "p1_speech": batch.get("p1_speech"), "p2_speech": batch.get("p2_speech"), } predicted_seq = self.inference(seq_len, data=cond_data) output["jerk"]["gt_mean"] = calc_jerk( batch["p1_face"][:, -predicted_seq.shape[1] :] ) output["jerk"]["generated_mean"] = calc_jerk(predicted_seq) idx = random.randint(0, cond_data["p1_face"].shape[0] - 1) if self.hparams.Validation["render"]: self.render_results(predicted_seq, batch, idx) if self.hparams.Validation["check_invertion"]: # Test if the Flow works correctly output["det_check"] = self.test_invertability(z_seq, loss, batch) if self.hparams.Validation["scale_logging"]: self.log_scales() # Test if the Flow is listening to other modalities if self.hparams.Validation["wrong_context_test"]: mismatch = self.hparams.Mismatch output["mismatched_nll"] = {"actual_nll": loss} for key, modalities in mismatch["shuffle_batch"].items(): if all( [ self.hparams.Conditioning[x]["history"] > 0 for x in modalities ] ): deranged_batch = self.derange_batch(batch, modalities) _, missaligned_nll, _ = self(deranged_batch) output["mismatched_nll"][ f"shuffle_batch_{key}" ] = missaligned_nll for key, modalities in mismatch["shuffle_time"].items(): if all( [ self.hparams.Conditioning[x]["history"] > 0 for x in modalities ] ): deranged_batch = self.derange_batch( batch, modalities, shuffle_time=True ) _, shuffled_nll, _ = self(deranged_batch) output["mismatched_nll"][f"shuffle_time_{key}"] = shuffled_nll return output