Beispiel #1
0
    def _test_g(self, G: ModelTrainer, real):
        gen_pred = torch.nn.functional.sigmoid(G.forward(real))

        G.compute_and_update_test_loss("MSELoss", gen_pred, real)

        metric = G.compute_metric("MeanSquaredError", gen_pred, real)
        G.update_test_metric("MeanSquaredError", metric / 32768)

        return gen_pred
Beispiel #2
0
    def _valid_s(self, S: ModelTrainer, inputs, target):
        target_ohe = to_onehot(torch.squeeze(target, dim=1).long(), num_classes=4)
        target = torch.squeeze(target, dim=1).long()

        seg_pred = torch.nn.functional.softmax(S.forward(inputs), dim=1)

        loss_S = S.compute_loss("DiceLoss", seg_pred, target_ohe)
        S.update_valid_loss("DiceLoss", loss_S.mean())

        metrics = S.compute_metrics(seg_pred, target)
        metrics["Dice"] = metrics["Dice"].mean()
        metrics["IoU"] = metrics["IoU"].mean()
        S.update_valid_metrics(metrics)

        return seg_pred, loss_S
Beispiel #3
0
    def _train_g(self, G: ModelTrainer, real, backward=True):
        G.zero_grad()

        gen_pred = torch.nn.functional.sigmoid(G.forward(real))

        loss_G = G.compute_and_update_train_loss("MSELoss", gen_pred, real)

        metric = G.compute_metric("MeanSquaredError", gen_pred, real)
        G.update_train_metric("MeanSquaredError", metric / 32768)

        if backward:
            loss_G.backward()
            G.step()

        return gen_pred
Beispiel #4
0
    def _train_s(self, S: ModelTrainer, inputs, target, backward=True):
        S.zero_grad()

        target_ohe = to_onehot(torch.squeeze(target, dim=1).long(), num_classes=4)
        target = torch.squeeze(target, dim=1).long()

        seg_pred = torch.nn.functional.softmax(S.forward(inputs), dim=1)

        loss_S = S.compute_loss("DiceLoss", seg_pred, target_ohe)
        S.update_train_loss("DiceLoss", loss_S.mean())

        metrics = S.compute_metrics(seg_pred, target)
        metrics["Dice"] = metrics["Dice"].mean()
        metrics["IoU"] = metrics["IoU"].mean()
        S.update_train_metrics(metrics)

        if backward:
            loss_S.mean().backward()
            S.step()

        return seg_pred, loss_S