def train():
    startTime = time.time()
    args = parameters.parse_arguments()
    logging.basicConfig(filename=args.logfile, level=logging.INFO)
    logging.critical("\n\n" + args.log_header)
    logging.info(args)
    device = ("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"TIME: {time.time() - startTime}s Using device {device}")

    logging.info(f"TIME: {time.time()-startTime}s Loading dataset")
    try:
        with open(os.path.join(args.datadir, "data.pkl"), "rb") as f:
            data = pickle.load(f)
    except:
        data = DataLoader(args.datadir,
                          int(args.batchsize),
                          shuffle=int(args.shuffle))
        with open(os.path.join(args.datadir, "data.pkl"), "wb") as f:
            pickle.dump(data, f)
    data.batchSize = int(args.batchsize)
    logging.info(f"TIME: {time.time()-startTime}s Dataset Loaded")

    random.seed(args.seed)
    indices = list(range(len(data)))
    random.shuffle(
        indices
    )  # 0:floor((1-validationFrac)*len(data)) will be training data, rest will be validation data
    trainEndIndex = math.floor((1 - args.validation_frac) * (len(data)))

    model = UNet(in_channels=1,
                 num_classes=2,
                 start_filts=int(args.conv_filters),
                 up_mode=args.mode,
                 depth=int(args.depth),
                 batchnorm=args.batchnorm)
    model.reset_params()
    model = model.to(device)
    optimizer = None
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lrstart)
        logging.info(f"TIME: {time.time()-startTime}s Optimizer: adam")
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lrstart,
                              momentum=args.momentum)
        logging.info(f"TIME: {time.time()-startTime}s Optimizer: SGD")
    elif args.optimizer == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(), lr=args.lrstart)
        logging.info(f"TIME: {time.time()-startTime}s Optimizer: RMSProp")
    else:
        logging.error(
            f"TIME: {time.time()-startTime}s Incorrect optimizer given")

    scheduler = []
    if args.lrscheduler == "steplr":
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.decay)
        logging.info(f"TIME: {time.time()-startTime}s LRScheduler: StepLR")
    elif args.lrscheduler == "exponentiallr":
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                     gamma=args.decay)
        logging.info(
            f"TIME: {time.time()-startTime}s LRScheduler: exponentialLR")
    else:
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=int(args.epochs))
        logging.info(
            f"TIME: {time.time()-startTime}s LRScheduler: lr shouldn't change with epochs"
        )

    criteria = CombinedLoss(args.lambda_loss, args.loss_type)
    diceCoeff = DiceLoss()
    TL = []
    VL = []
    if not os.path.exists(os.path.join(os.getcwd(), "loss_files")):
        os.makedirs(os.path.join(os.getcwd(), "loss_files"))
    lossFile = open(os.path.join("loss_files", args.log_header + ".csv"), "w+")
    lossFile.write("Epoch,TrainLoss,ValidationLoss,Dice Coefficient\n")

    for epoch in tqdm(range(1, int(args.epochs) + 1), desc="Training model"):
        trainLoss = 0
        valLoss = 0
        trainingSample = 0
        testSample = 0
        netCoeff = 0
        for i in range(len(data)):
            images, masks = data[i]
            images = torch.tensor(images.astype(np.float32))
            masks = torch.tensor(masks.astype(np.float32))
            images = images.to(device)
            masks = masks.to(device)
            images = torch.transpose(images, 1, 3)
            masks = torch.transpose(masks, 1, 3)
            if i in indices[:trainEndIndex]:
                trainingSample += images.shape[0]
                networkPred = model(images)
                if args.regularization == 'l1':
                    reg = L1_regularization(model, args.reg_lamda1)
                    loss = criteria(masks, networkPred) + reg
                elif args.regularization == 'l1l2':
                    reg = L1L2_regularization(model, args.reg_lamda1,
                                              args.reg_lamda2)
                    loss = criteria(masks, networkPred) + reg
                else:
                    loss = criteria(masks, networkPred)
                loss.backward()
                trainLoss += loss.item()
                optimizer.step()
                model.zero_grad()
            else:
                with torch.no_grad():
                    testSample += images.shape[0]
                    prediction = model(images)
                    if (epoch % args.save_epochs
                            == 0) or (epoch == 1) or (epoch == args.epochs):
                        imgPath = os.path.join("validation_sample",
                                               args.log_header,
                                               f"epoch {epoch}")
                        if not os.path.exists(imgPath):
                            os.makedirs(imgPath)
                        hrt = images[0, 0, :, :].to("cpu")
                        plt.imshow(np.array(hrt), cmap='gray')
                        plt.title("Heart Image")
                        plt.savefig(os.path.join(imgPath, "heart.png"))
                        plt.clf()
                        # ax = figure.add_subplot(232, title="Mask 1 Predicted")
                        msk1 = prediction[0, 0, :, :].to("cpu")
                        plt.imshow(np.array(msk1), cmap='gray')
                        plt.title("Predicted Mask 1")
                        plt.savefig(os.path.join(imgPath, "pred-mask1.png"))
                        plt.clf()
                        # ax = figure.add_subplot(231, title="Mask 2 Predicted")
                        msk2 = prediction[0, 1, :, :].to("cpu")
                        plt.imshow(np.array(msk2), cmap='gray')
                        plt.title("Predicted Mask 2")
                        plt.savefig(os.path.join(imgPath, "pred-mask2.png"))
                        plt.clf()

                        msk = np.zeros((192, 192, 3))
                        msk[:, :, 0] = np.array(msk1)
                        msk[:, :, 1] = np.array(msk2)
                        plt.imshow(np.array(hrt), cmap='gray')
                        plt.imshow(msk, cmap='jet', alpha=0.4)
                        plt.title("predicted-RV")
                        plt.savefig(os.path.join(imgPath, "pred-RV.png"))
                        plt.clf()
                        # ax = figure.add_subplot(231, title="Mask 2 Real")
                        msk1 = masks[0, 0, :, :].to("cpu")
                        plt.imshow(np.array(msk1), cmap='gray')
                        plt.title("Actual Mask 1")
                        plt.savefig(os.path.join(imgPath, "actual-mask1.png"))
                        plt.clf()
                        # ax = figure.add_subplot(231, title="Mask 2 Real")
                        msk2 = masks[0, 1, :, :].to("cpu")
                        plt.imshow(np.array(msk2), cmap='gray')
                        plt.title("Actual Mask 2")
                        plt.savefig(os.path.join(imgPath, "actual-mask2.png"))
                        plt.clf()
                        # plt.savefig(os.path.join("validation_sample", f"{args.log_header}-epoch {epoch}.png"))
                        msk = np.zeros((192, 192, 3))
                        msk[:, :, 0] = np.array(msk1)
                        msk[:, :, 1] = np.array(msk2)
                        plt.imshow(np.array(hrt), cmap='gray')
                        plt.imshow(msk, cmap='jet', alpha=0.4)
                        plt.title("actual-RV")
                        plt.savefig(os.path.join(imgPath, "actual-RV.png"))
                        plt.clf()

                    if args.regularization == 'l1':
                        reg = L1_regularization(model, args.reg_lamda1)
                        loss = criteria(masks, prediction) + reg
                    elif args.regularization == 'l1l2':
                        reg = L1L2_regularization(model, args.reg_lamda1,
                                                  args.reg_lamda2)
                        loss = criteria(masks, prediction) + reg
                    else:
                        loss = criteria(masks, prediction)
                    valLoss += loss.item()
                    coeff = diceCoeff(masks, prediction)
                    netCoeff += torch.sum(1 - coeff).item()
        if (epoch % int(args.save_epochs) == 0) or (epoch == int(args.epochs)):
            if not os.path.exists(args.model_save_dir):
                os.makedirs(args.model_save_dir)
            # save model
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                os.path.join(args.model_save_dir,
                             f"model-epoch({epoch}).hdf5"))
            logging.info(
                f"TIME: {time.time()-startTime}s Model state saved for epoch: {epoch}"
            )
        logging.info(
            f"TIME: {time.time()-startTime}s TRAINING: Epoch: {epoch}, lr: {scheduler.get_last_lr()}, loss: {trainLoss/(2*trainingSample)}"
        )
        logging.info(
            f"TIME: {time.time()-startTime}s VALIDATION: Epoch: {epoch}, lr: {scheduler.get_last_lr()}, loss: {valLoss/(2*testSample)}"
        )
        TL.append(trainLoss / (2 * trainingSample))
        VL.append(valLoss / (2 * testSample))
        lossFile.write(
            f"{epoch},{trainLoss/(2*trainingSample)},{valLoss/(2*testSample)},{netCoeff/(2*testSample)}\n"
        )
        scheduler.step(
        )  # https://www.deeplearningwizard.com/deep_learning/boosting_models_pytorch/lr_scheduling/
    plt.plot(list(range(1, int(args.epochs) + 1)), TL, label="Training loss")
    plt.plot(list(range(1, int(args.epochs) + 1)), VL, label="Validation loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend(loc="best")
    if not os.path.exists(os.path.join(os.getcwd(), "plots")):
        os.makedirs(os.path.join(os.getcwd(), "plots"))
    plt.savefig(os.path.join("plots", args.log_header + ".png"))
Exemple #2
0
class Trainer:
    def __init__(self, seq_length, color_channels, unet_path="pretrained/unet.mdl",
                 discrim_path="pretrained/dicrim.mdl",
                 facenet_path="pretrained/facenet.mdl",
                 vgg_path="",
                 embedding_size=1000,
                 unet_depth=3,
                 unet_filts=32,
                 facenet_filts=32,
                 resnet=18):

        self.color_channels = color_channels
        self.margin = 0.5
        self.writer = SummaryWriter(log_dir="logs")

        self.unet_path = unet_path
        self.discrim_path = discrim_path
        self.facenet_path = facenet_path

        self.unet = UNet(in_channels=color_channels, out_channels=color_channels,
                         depth=unet_depth,
                         start_filts=unet_filts,
                         up_mode="upsample",
                         merge_mode='concat').to(device)

        self.discrim = FaceNetModel(embedding_size=embedding_size, start_filts=facenet_filts,
                                    in_channels=color_channels, resnet=resnet,
                                    pretrained=False).to(device)

        self.facenet = FaceNetModel(embedding_size=embedding_size, start_filts=facenet_filts,
                                    in_channels=color_channels, resnet=resnet,
                                    pretrained=False).to(device)

        if os.path.isfile(unet_path):
            self.unet.load_state_dict(torch.load(unet_path))
            print("unet loaded")

        if os.path.isfile(discrim_path):
            self.discrim.load_state_dict(torch.load(discrim_path))
            print("discrim loaded")

        if os.path.isfile(facenet_path):
            self.facenet.load_state_dict(torch.load(facenet_path))
            print("facenet loaded")
        if os.path.isfile(vgg_path):
            self.vgg_loss_network = LossNetwork(vgg_face_dag(vgg_path)).to(device)
            self.vgg_loss_network.eval()

            print("vgg loaded")

        self.mse_loss_function = nn.MSELoss().to(device)
        self.discrim_loss_function = nn.BCELoss().to(device)
        self.triplet_loss_function = TripletLoss(margin=self.margin)

        self.unet_optimizer = torch.optim.Adam(self.unet.parameters(), betas=(0.9, 0.999))
        self.discrim_optimizer = torch.optim.Adam(self.discrim.parameters(), betas=(0.9, 0.999))
        self.facenet_optimizer = torch.optim.Adam(self.facenet.parameters(), betas=(0.9, 0.999))

    def test(self, test_loader, epoch=0):
        X, y = next(iter(test_loader))

        B, D, C, W, H = X.shape
        # X = X.view(B, C * D, W, H)

        self.unet.eval()
        self.facenet.eval()
        self.discrim.eval()
        with torch.no_grad():
            y_ = self.unet(X.to(device))

            mse = self.mse_loss_function(y_, y.to(device))
            loss_G = self.loss_GAN_generator(btch_X=X.to(device))
            loss_D = self.loss_GAN_discrimator(btch_X=X.to(device), btch_y=y.to(device))

            loss_facenet, _, n_bad = self.loss_facenet(X.to(device), y.to(device))

        plt.title(f"epoch {epoch} mse={mse.item():.4} facenet={loss_facenet.item():.4} bad={n_bad / B ** 2}")
        i = np.random.randint(0, B)
        a = np.hstack((y[i].transpose(0, 1).transpose(1, 2), y_[i].transpose(0, 1).transpose(1, 2).to(cpu)))
        b = np.hstack((X[i][0].transpose(0, 1).transpose(1, 2),
                       X[i][-1].transpose(0, 1).transpose(1, 2)))
        plt.imshow(np.vstack((a, b)))
        plt.axis('off')
        plt.show()

        self.writer.add_scalar("test bad_percent", n_bad / B ** 2, global_step=epoch)
        self.writer.add_scalar("test loss", mse.item(), global_step=epoch)
        # self.writer.add_scalars("test GAN", {"discrim": loss_D.item(),
        #                                      "gen": loss_G.item()}, global_step=epoch)

        with torch.no_grad():
            n_for_show = 10
            y_show_ = y_.to(device)
            y_show = y.to(device)
            embeddings_anc, _ = self.facenet(y_show_)
            embeddings_pos, _ = self.facenet(y_show)

            embeds = torch.cat((embeddings_anc[:n_for_show], embeddings_pos[:n_for_show]))
            imgs = torch.cat((y_show_[:n_for_show], y_show[:n_for_show]))
            names = list(range(n_for_show)) * 2
            # print(embeds.shape, imgs.shape, len(names))
            # self.writer.add_embedding(mat=embeds, metadata=names, label_img=imgs, tag="embeddings", global_step=epoch)

        trshs, fprs, tprs = roc_curve(embeddings_anc.detach().to(cpu), embeddings_pos.detach().to(cpu))
        rnk1 = rank1(embeddings_anc.detach().to(cpu), embeddings_pos.detach().to(cpu))
        plt.step(fprs, tprs)
        # plt.xlim((1e-4, 1))
        plt.yticks(np.arange(0, 1, 0.05))
        plt.xticks(np.arange(min(fprs), max(fprs), 10))
        plt.xscale('log')
        plt.title(f"ROC auc={auc(fprs, tprs)} rnk1={rnk1}")
        self.writer.add_figure("ROC test", plt.gcf(), global_step=epoch)
        self.writer.add_scalar("auc", auc(fprs, tprs), global_step=epoch)
        self.writer.add_scalar("rank1", rnk1, global_step=epoch)
        print(f"\n###### {epoch} TEST mse={mse.item():.4} GAN(G/D)={loss_G.item():.4}/{loss_D.item():.4} "
              f"facenet={loss_facenet.item():.4} bad={n_bad / B ** 2:.4} auc={auc(fprs, tprs)} rank1={rnk1} #######")

    def test_test(self, test_loader):
        X, ys = next(iter(test_loader))
        true_idx = 0
        x = X[true_idx]

        D, C, W, H = x.shape
        # x = x.view(C * D, W, H)

        dists = list()
        with torch.no_grad():
            y_ = self.unet(x.to(device))

            embedding_anc, _ = self.facenet(y_)
            embeddings_pos, _ = self.facenet(ys)
            for emb_pos_item in embeddings_pos:
                dist = l2_dist.forward(embedding_anc, emb_pos_item)
                dists.append(dist)

        a_sorted = np.argsort(dists)

        a = np.hstack((ys[true_idx].transpose(0, 1).transpose(1, 2),
                       y_.transpose(0, 1).transpose(1, 2).to(cpu).numpy(),
                       ys[a_sorted[0]].transpose(0, 1).transpose(1, 2)))

        b = np.hstack((x[0:3].transpose(0, 1).transpose(1, 2),
                       x[D // 2 * C:D // 2 * C + 3].transpose(0, 1).transpose(1, 2),
                       x[-3:].transpose(0, 1).transpose(1, 2)))

        b_ = b - np.min(b)
        b_ = b_ / np.max(b)
        b_ = equalize_func([(b_ * 255).astype(np.uint8)], use_clahe=True)[0]
        b = b_.astype(np.float32) / 255

        plt.imshow(cv2.cvtColor(np.vstack((a, b)), cv2.COLOR_BGR2RGB))
        plt.axis('off')
        plt.show()

    def loss_facenet(self, X, y, is_detached=False):
        B, D, C, W, H = X.shape

        y_ = self.unet(X)

        embeddings_anc, D_fake = self.facenet(y_ if not is_detached else y_.detach())
        embeddings_pos, D_real = self.facenet(y)

        target_real = torch.full_like(D_fake, 1)
        loss_gen = self.discrim_loss_function(D_fake, target_real)

        pos_dist = l2_dist.forward(embeddings_anc, embeddings_pos)
        bad_triplets_loss = None

        n_bad = 0
        for shift in range(1, B):

            embeddings_neg = torch.roll(embeddings_pos, shift, 0)
            neg_dist = l2_dist.forward(embeddings_anc, embeddings_neg)

            bad_triplets_idxs = np.where((neg_dist - pos_dist < self.margin).cpu().numpy().flatten())[0]

            if shift == 1:
                bad_triplets_loss = self.triplet_loss_function.forward(embeddings_anc[bad_triplets_idxs],
                                                                       embeddings_pos[bad_triplets_idxs],
                                                                       embeddings_neg[bad_triplets_idxs]).to(
                    device)
            else:
                bad_triplets_loss += self.triplet_loss_function.forward(embeddings_anc[bad_triplets_idxs],
                                                                        embeddings_pos[bad_triplets_idxs],
                                                                        embeddings_neg[bad_triplets_idxs]).to(device)
            n_bad += len(bad_triplets_idxs)

        bad_triplets_loss /= B
        return bad_triplets_loss, torch.mean(loss_gen), n_bad

    # def loss_mse(self, btch_X, btch_y):
    #     btch_y_ = self.unet(btch_X)
    #     loss_unet = self.mse_loss_function(btch_y_, btch_y)
    #
    #     features_target = self.facenet.forward_mse(btch_y)
    #     features = self.facenet.forward_mse(btch_y_)
    #
    #     loss_first_layer = self.mse_loss_function(features, features_target)
    #     return loss_unet + loss_first_layer

    def loss_mse_vgg(self, btch_X, btch_y, k_mse, k_vgg):
        btch_y_ = self.unet(btch_X)
        # print(btch_y_.shape,btch_y.shape)
        perceptual_btch_y_ = self.vgg_loss_network(btch_y_)
        perceptual_btch_y = self.vgg_loss_network(btch_y)
        perceptual_loss = 0.0
        for a, b in zip(perceptual_btch_y_, perceptual_btch_y):
            perceptual_loss += self.mse_loss_function(a, b)
        return k_vgg * perceptual_loss + k_mse * self.mse_loss_function(btch_y_, btch_y)

    def loss_GAN_discrimator(self, btch_X, btch_y):
        btch_y_ = self.unet(btch_X)

        _, y_D_fake_ = self.discrim(btch_y_.detach())
        _, y_D_real_ = self.discrim(btch_y)

        target_fake = torch.full_like(y_D_fake_, 0)
        target_real = torch.full_like(y_D_real_, 1)

        loss_D_fake_ = self.discrim_loss_function(y_D_fake_, target_fake)
        loss_D_real_ = self.discrim_loss_function(y_D_real_, target_real)

        loss_discrim = (loss_D_real_ + loss_D_fake_)

        return loss_discrim

    def loss_GAN_generator(self, btch_X):
        btch_y_ = self.unet(btch_X)

        _, y_D_fake_ = self.discrim(btch_y_)

        target_real = torch.full_like(y_D_fake_, 1)

        loss_gen = self.discrim_loss_function(y_D_fake_, target_real)

        return loss_gen

    def relax_discriminator(self, btch_X, btch_y):
        self.discrim.zero_grad()

        # train with real
        y_discrim_real_ = self.discrim(btch_y)
        y_discrim_real_ = y_discrim_real_.mean()
        y_discrim_real_.backward(self.mone)

        # train with fake
        btch_y_ = self.unet(btch_X)
        y_discrim_fake_detached_ = self.discrim(btch_y_.detach())
        y_discrim_fake_detached_ = y_discrim_fake_detached_.mean()
        y_discrim_fake_detached_.backward(self.one)

        # gradient_penalty
        gradient_penalty = self.discrim_gradient_penalty(btch_y, btch_y_)
        gradient_penalty.backward()

        self.discrim_optimizer.step()

    def relax_generator(self, btch_X):
        self.unet.zero_grad()

        btch_y_ = self.unet(btch_X)

        y_discrim_fake_ = self.discrim(btch_y_)
        y_discrim_fake_ = y_discrim_fake_.mean()
        y_discrim_fake_.backward(self.mone)
        self.unet_optimizer.step()

    def discrim_gradient_penalty(self, real_y, fake_y):
        lambd = 10
        btch_size = real_y.shape[0]

        alpha = torch.rand(btch_size, 1, 1, 1).to(device)
        # print(alpha.shape, real_y.shape)
        alpha = alpha.expand_as(real_y)

        interpolates = alpha * real_y + (1 - alpha) * fake_y
        interpolates = interpolates.to(device)

        interpolates = autograd.Variable(interpolates, requires_grad=True)

        interpolates_out = self.discrim(interpolates)

        gradients = autograd.grad(outputs=interpolates_out, inputs=interpolates,
                                  grad_outputs=torch.ones(interpolates_out.size()).to(device),
                                  create_graph=True, retain_graph=True, only_inputs=True)[0]

        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambd
        return gradient_penalty

    def train(self, train_loader, test_loader, batch_size=2, epochs=30,
              k_gen=1, k_discrim=1, k_mse=1, k_facenet=1, k_facenet_back=1, k_vgg=1):
        """
        :param X: np.array shape=(n_videos, n_frames, h, w)
        :param y: np.array shape=(n_videos, h, w)
        :param epochs: int
        """
        print("\nSTART TRAINING\n")

        for epoch in range(epochs):
            self.test(test_loader, epoch)
            self.unet.train()
            self.facenet.train()
            self.discrim.train()
            # train by batches
            for idx, (btch_X, btch_y) in enumerate(train_loader):
                B, D, C, W, H = btch_X.shape
                # btch_X = btch_X.view(B, C * D, W, H)

                btch_X = btch_X.to(device)
                btch_y = btch_y.to(device)

                # Mse loss
                self.unet.zero_grad()

                mse = self.loss_mse_vgg(btch_X, btch_y, k_mse, k_vgg)

                mse.backward()
                self.unet_optimizer.step()

                # facenet_backup = deepcopy(self.facenet.state_dict())
                # for i in range(unrolled_iterations):
                self.discrim.zero_grad()
                loss_D = self.loss_GAN_discrimator(btch_X, btch_y)
                loss_D = k_discrim * loss_D
                loss_D.backward()
                self.discrim_optimizer.step()

                self.discrim.zero_grad()
                self.unet.zero_grad()
                loss_G = self.loss_GAN_generator(btch_X)
                loss_G = k_gen * loss_G
                loss_G.backward()
                self.unet_optimizer.step()

                # Facenet
                self.unet.zero_grad()
                self.facenet.zero_grad()
                facenet_loss, _, n_bad = self.loss_facenet(btch_X, btch_y)

                facenet_loss = k_facenet * facenet_loss
                facenet_loss.backward()
                self.facenet_optimizer.step()

                self.unet.zero_grad()
                self.facenet.zero_grad()
                facenet_back_loss, _, n_bad = self.loss_facenet(btch_X, btch_y)

                facenet_back_loss = k_facenet_back * facenet_back_loss
                facenet_back_loss.backward()
                self.unet_optimizer.step()

                print(f"btch {idx * batch_size} mse={mse.item():.4} GAN(G/D)={loss_G.item():.4}/{loss_D.item():.4} "
                      f"facenet={facenet_loss.item():.4} bad={n_bad / B ** 2:.4}")

                global_step = epoch * len(train_loader.dataset) // batch_size + idx
                self.writer.add_scalar("train bad_percent", n_bad / B ** 2, global_step=global_step)
                self.writer.add_scalar("train loss", mse.item(), global_step=global_step)
                # self.writer.add_scalars("train GAN", {"discrim": loss_D.item(),
                #                                       "gen": loss_G.item()}, global_step=global_step)

            torch.save(self.unet.state_dict(), self.unet_path)
            torch.save(self.discrim.state_dict(), self.discrim_path)
            torch.save(self.facenet.state_dict(), self.facenet_path)