コード例 #1
0
    def test_step(self, batch, batch_idx):
        _, loss, losses = self(batch, test=True)
        output = {"test_loss": loss, "test_losses": losses}

        seq_len = self.hparams.Test["seq_len"]
        cond_data = {
            "p1_face":
            torch.zeros_like(
                batch["p1_face"]
                [:, :get_longest_history(self.hparams.Conditioning)]),
            "p2_face":
            batch.get("p2_face"),
            "p1_speech":
            batch.get("p1_speech"),
            "p2_speech":
            batch.get("p2_speech"),
        }
        predicted_seq = self.inference(seq_len, data=cond_data)
        output["predicted_prop_seq"] = predicted_seq.cpu().detach()
        gt_seq = batch["p1_face"][:, -predicted_seq.shape[1]:]
        output["gt_seq"] = gt_seq.cpu().detach()
        for modality in ["p2_face", "p2_speech", "p1_speech"]:
            if self.hparams.Conditioning[modality]["history"] > 0:

                deranged_batch = self.derange_batch(batch, [modality])
                _, missaligned_nll, misaligned_losses = self(deranged_batch,
                                                             test=True)
                output[f"nll_mismatched_{modality}"] = missaligned_nll.cpu(
                ).detach()
                output[f"losses_mismatched_{modality}"] = misaligned_losses

                cond_data = {
                    "p1_face":
                    torch.zeros_like(
                        deranged_batch["p1_face"]
                        [:, :get_longest_history(self.hparams.Conditioning)]),
                    "p2_face":
                    deranged_batch.get("p2_face"),
                    "p1_speech":
                    deranged_batch.get("p1_speech"),
                    "p2_speech":
                    deranged_batch.get("p2_speech"),
                }

                predicted_seq = self.inference(seq_len, data=cond_data)
                output[
                    f"predicted_mismatch_{modality}_seq"] = predicted_seq.cpu(
                    ).detach()

        return output
コード例 #2
0
    def inference(self, seq_len, data=None):
        self.glow.init_rnn_hidden()

        output_shape = torch.zeros_like(data["p1_face"][:, 0, :])
        frame_nb = None
        if self.hparams.Conditioning["use_frame_nb"]:
            frame_nb = torch.ones((data["p1_face"].shape[0], 1)).type_as(
                data["p1_face"]
            )

        prev_p1_faces = data["p1_face"]

        start_ts = get_longest_history(self.hparams.Conditioning)

        for time_st in range(start_ts, seq_len):
            condition = self.create_conditioning(data, time_st, frame_nb, prev_p1_faces)

            output, _ = self.glow(
                condition=condition,
                eps_std=self.hparams.Infer["eps"],
                reverse=True,
                output_shape=output_shape,
            )

            prev_p1_faces = torch.cat([prev_p1_faces, output.unsqueeze(1)], dim=1)

            if self.hparams.Conditioning["use_frame_nb"]:
                frame_nb += 2

        return prev_p1_faces[:, start_ts:]
コード例 #3
0
    def forward(self, batch):
        self.glow.init_rnn_hidden()

        loss = 0
        start_ts = get_longest_history(self.hparams.Conditioning)

        frame_nb = None
        if self.hparams.Conditioning["use_frame_nb"]:
            frame_nb = batch["frame_nb"].clone() + start_ts * 2

        z_seq = []
        losses = []
        for time_st in range(start_ts, batch["p1_face"].shape[1]):
            curr_input = batch["p1_face"][:, time_st, :]
            condition = self.create_conditioning(
                batch, time_st, frame_nb, batch["p1_face"]
            )

            z_enc, objective = self.glow(x=curr_input, condition=condition)
            tmp_loss = self.loss(objective, z_enc)
            losses.append(tmp_loss.cpu().detach())
            loss += torch.mean(tmp_loss)

            if self.hparams.Conditioning["use_frame_nb"]:
                frame_nb += 2
            z_seq.append(z_enc.detach())

        return z_seq, (loss / len(z_seq)).unsqueeze(-1), losses
コード例 #4
0
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch,
                                batch_idx, dataloader_idx):
        output = {}
        if batch_idx == 0:  #  and self.global_step > 0
            new_batch = {x: y.type_as(outputs) for x, y in batch.items()}
            z_seq, loss, _ = pl_module.seq_glow(new_batch)
            output["jerk"] = {}
            idx = random.randint(0, batch["p1_face"].shape[0] - 1)
            if pl_module.hparams.Validation["inference"]:
                seq_len = pl_module.hparams.Validation["seq_len"]
                cond_data = {
                    "p1_face":
                    new_batch["p1_face"]
                    [:, :get_longest_history(pl_module.hparams.Conditioning)],
                    "p2_face":
                    new_batch.get("p2_face"),
                    "p1_speech":
                    new_batch.get("p1_speech"),
                    "p2_speech":
                    new_batch.get("p2_speech"),
                }
                predicted_seq = pl_module.seq_glow.inference(seq_len,
                                                             data=cond_data)

                gt_mean_jerk = calc_jerk(
                    new_batch["p1_face"][:, -predicted_seq.shape[1]:])
                generated_mean_jerk = calc_jerk(predicted_seq)

                pl_module.log("jerk/gt_mean", gt_mean_jerk)
                pl_module.log("jerk/generated_mean", generated_mean_jerk)
                pl_module.log("jerk/generated_mean_ratio",
                              generated_mean_jerk / gt_mean_jerk)

                idx = random.randint(0, cond_data["p1_face"].shape[0] - 1)
                if pl_module.hparams.Validation["render"]:
                    self.render_results(predicted_seq, new_batch, idx,
                                        pl_module)

            if pl_module.hparams.Validation["check_invertion"]:
                # Test if the Flow works correctly
                det_check = self.test_invertability(z_seq, loss, new_batch,
                                                    pl_module)
                pl_module.log("reconstruction/error_percentage", det_check)

            if pl_module.hparams.Validation["scale_logging"]:
                self.log_scales(pl_module)

            # Test if the Flow is listening to other modalities
            if pl_module.hparams.Validation["wrong_context_test"]:
                mismatch = pl_module.hparams.Mismatch
                pl_module.log(f"mismatched_nll/actual_nll", loss)

                for key, modalities in mismatch["shuffle_batch"].items():
                    if all([
                            pl_module.hparams.Conditioning[x]["history"] > 0
                            for x in modalities
                    ]):
                        deranged_batch = derange_batch(new_batch, modalities)
                        _, missaligned_nll, _ = pl_module.seq_glow(
                            deranged_batch)

                        pl_module.log(f"mismatched_nll/shuffle_batch_{key}",
                                      missaligned_nll)
                        pl_module.log(
                            f"mismatched_nll_ratios/shuffle_batch_{key}",
                            loss - missaligned_nll,
                        )

                for key, modalities in mismatch["shuffle_time"].items():
                    if all([
                            pl_module.hparams.Conditioning[x]["history"] > 0
                            for x in modalities
                    ]):
                        deranged_batch = derange_batch(new_batch,
                                                       modalities,
                                                       shuffle_time=True)
                        _, shuffled_nll, _ = pl_module.seq_glow(deranged_batch)
                        pl_module.log(f"mismatched_nll/shuffle_time_{key}",
                                      shuffled_nll)
                        pl_module.log(
                            f"mismatched_nll_ratios/shuffle_time_{key}",
                            loss - shuffled_nll,
                        )
コード例 #5
0
    def validation_step(self, batch, batch_idx):
        z_seq, loss, _ = self(batch)
        if self.hparams.optuna and self.global_step > 20 and loss > 0:
            message = f"Trial was pruned since loss > 0"
            raise optuna.exceptions.TrialPruned(message)
        output = {"val_loss": loss}

        if batch_idx == 0:  #  and self.global_step > 0
            output["jerk"] = {}
            idx = random.randint(0, batch["p1_face"].shape[0] - 1)
            if self.hparams.Validation["inference"]:
                seq_len = self.hparams.Validation["seq_len"]
                cond_data = {
                    "p1_face": batch["p1_face"][
                        :, : get_longest_history(self.hparams.Conditioning)
                    ],
                    "p2_face": batch.get("p2_face"),
                    "p1_speech": batch.get("p1_speech"),
                    "p2_speech": batch.get("p2_speech"),
                }
                predicted_seq = self.inference(seq_len, data=cond_data)

                output["jerk"]["gt_mean"] = calc_jerk(
                    batch["p1_face"][:, -predicted_seq.shape[1] :]
                )
                output["jerk"]["generated_mean"] = calc_jerk(predicted_seq)

                idx = random.randint(0, cond_data["p1_face"].shape[0] - 1)
                if self.hparams.Validation["render"]:
                    self.render_results(predicted_seq, batch, idx)

            if self.hparams.Validation["check_invertion"]:
                # Test if the Flow works correctly
                output["det_check"] = self.test_invertability(z_seq, loss, batch)

            if self.hparams.Validation["scale_logging"]:
                self.log_scales()

            # Test if the Flow is listening to other modalities
            if self.hparams.Validation["wrong_context_test"]:
                mismatch = self.hparams.Mismatch
                output["mismatched_nll"] = {"actual_nll": loss}

                for key, modalities in mismatch["shuffle_batch"].items():
                    if all(
                        [
                            self.hparams.Conditioning[x]["history"] > 0
                            for x in modalities
                        ]
                    ):
                        deranged_batch = self.derange_batch(batch, modalities)
                        _, missaligned_nll, _ = self(deranged_batch)

                        output["mismatched_nll"][
                            f"shuffle_batch_{key}"
                        ] = missaligned_nll

                for key, modalities in mismatch["shuffle_time"].items():
                    if all(
                        [
                            self.hparams.Conditioning[x]["history"] > 0
                            for x in modalities
                        ]
                    ):
                        deranged_batch = self.derange_batch(
                            batch, modalities, shuffle_time=True
                        )
                        _, shuffled_nll, _ = self(deranged_batch)

                        output["mismatched_nll"][f"shuffle_time_{key}"] = shuffled_nll
        return output