def _train_g(self, G: ModelTrainer, real, backward=True): G.zero_grad() gen_pred = torch.nn.functional.sigmoid(G.forward(real)) loss_G = G.compute_and_update_train_loss("MSELoss", gen_pred, real) metric = G.compute_metric("MeanSquaredError", gen_pred, real) G.update_train_metric("MeanSquaredError", metric / 32768) if backward: loss_G.backward() G.step() return gen_pred
def _train_s(self, S: ModelTrainer, inputs, target, backward=True): S.zero_grad() target_ohe = to_onehot(torch.squeeze(target, dim=1).long(), num_classes=4) target = torch.squeeze(target, dim=1).long() seg_pred = torch.nn.functional.softmax(S.forward(inputs), dim=1) loss_S = S.compute_loss("DiceLoss", seg_pred, target_ohe) S.update_train_loss("DiceLoss", loss_S.mean()) metrics = S.compute_metrics(seg_pred, target) metrics["Dice"] = metrics["Dice"].mean() metrics["IoU"] = metrics["IoU"].mean() S.update_train_metrics(metrics) if backward: loss_S.mean().backward() S.step() return seg_pred, loss_S