Esempio n. 1
0
 def __init__(self,
              device,
              GAN,
              dataloader,
              model_dir='../models',
              num_epochs=1,
              criterion=nn.BCELoss(),
              lr=0.0002,
              beta1=0.5,
              nz=100,
              real_label=1.,
              fake_label=0.,
              plotting=False):
     self.device = device
     self.GAN = GAN
     self.dataloader = dataloader
     self.model_dir = model_dir
     self.num_epochs = num_epochs
     self.criterion = criterion
     self.nz = nz
     self.fixed_noise = torch.randn(64, self.nz, 1, 1, device=self.device)
     self.real_label = real_label
     self.fake_label = fake_label
     self.optimizerD = optim.Adam(self.GAN.D.parameters(),
                                  lr=lr,
                                  betas=(beta1, 0.999))
     self.optimizerG = optim.Adam(self.GAN.G.parameters(),
                                  lr=lr,
                                  betas=(beta1, 0.999))
     self.plotting = plotting
     if plotting:
         self.line_plotter = VisdomLinePlotter()
         self.image_plotter = VisdomImagePlotter()
def main(FLAGS):

    "train and validate the Unet model"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #data directory
    data_dir = FLAGS.dataset_dir
    #log_directory
    log_dir = FLAGS.log_dir
    # Hyper and other parameters
    train_batch_size = FLAGS.train_batch_size
    val_batch_size = FLAGS.val_batch_size
    aug_flag = FLAGS.aug
    num_epochs = FLAGS.epochs
    num_classes = 2
    # get the train and validation dataloaders
    dataloaders = get_dataloaders(data_dir, train_batch_size, val_batch_size,
                                  aug_flag)
    model = Unet(3, num_classes)

    # Uncomment to run traiing on Multiple GPUs
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model, device_ids=[0, 1])
    else:
        print("no multiple gpu found")
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=0.02,
                          momentum=0.9,
                          weight_decay=0.0005)
    #optimizer = optim.Adam(model.parameters(),lr = learning_rate)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    plotter = VisdomLinePlotter(env_name='Unet Train')
    # uncomment for leraning rate schgeduler..
    train_val(dataloaders, model, criterion, optimizer, num_epochs, log_dir,
              device)
Esempio n. 3
0
def train():
    if os.name == 'nt':
        data_dir = 'C:/Users/marky/Documents/Courses/saliency/datasets/DUTS/'
    else:
        data_dir = os.getenv(
            "HOME") + '/Documents/Courses/EE298-CV/finalproj/datasets/DUTS/'
    tra_image_dir = 'DUTS-TR/DUTS-TR-Image/'
    tra_label_dir = 'DUTS-TR/DUTS-TR-Mask/'
    test_image_dir = 'DUTS-TE/DUTS-TE-Image/'
    test_label_dir = 'DUTS-TE/DUTS-TE-Mask/'

    image_ext = '.jpg'
    label_ext = '.png'

    model_dir = "./saved_models/basnet_bsi_aug/"
    resume_train = False
    resume_model_path = model_dir + "basnet_bsi_epoch_81_itr_106839_train_1.511335_tar_0.098392.pth"
    last_epoch = 1
    epoch_num = 100000
    batch_size_train = 8
    batch_size_val = 1
    train_num = 0
    val_num = 0
    enableInpaintAug = False
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # ------- 5. training process --------
    print("---start training...")
    test_increments = 6250
    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 1
    next_test = ite_num + 0
    visdom_tab_title = "StructArchWithoutStructImgs(WithHFlip)"
    ############
    ############
    ############
    ############

    tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)
    print("data_dir + tra_image_dir + '*' + image_ext: ",
          data_dir + tra_image_dir + '*' + image_ext)
    test_img_name_list = glob.glob(data_dir + test_image_dir + '*' + image_ext)
    print("data_dir + test_image_dir + '*' + image_ext: ",
          data_dir + test_image_dir + '*' + image_ext)

    tra_lbl_name_list = []
    for img_path in tra_img_name_list:
        img_name = img_path.split("/")[-1]
        aaa = img_name.split(".")
        bbb = aaa[0:-1]
        imidx = bbb[0]
        for i in range(1, len(bbb)):
            imidx = imidx + "." + bbb[i]
        tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
    test_lbl_name_list = []
    for img_path in test_img_name_list:
        img_name = img_path.split("/")[-1]
        aaa = img_name.split(".")
        bbb = aaa[0:-1]
        imidx = bbb[0]
        for i in range(1, len(bbb)):
            imidx = imidx + "." + bbb[i]
        test_lbl_name_list.append(data_dir + test_label_dir + imidx +
                                  label_ext)

    print("---")
    print("train images: ", len(tra_img_name_list))
    print("train labels: ", len(tra_lbl_name_list))
    print("---")

    print("---")
    print("test images: ", len(test_img_name_list))
    print("test labels: ", len(test_lbl_name_list))
    print("---")

    train_num = len(tra_img_name_list)
    test_num = len(test_img_name_list)
    salobj_dataset = SalObjDataset(img_name_list=tra_img_name_list,
                                   lbl_name_list=tra_lbl_name_list,
                                   transform=transforms.Compose([
                                       RescaleT(256),
                                       RandomCrop(224),
                                       ToTensorLab(flag=0)
                                   ]),
                                   category="train",
                                   enableInpaintAug=enableInpaintAug)
    salobj_dataset_test = SalObjDataset(img_name_list=test_img_name_list,
                                        lbl_name_list=test_lbl_name_list,
                                        transform=transforms.Compose([
                                            RescaleT(256),
                                            RandomCrop(224),
                                            ToTensorLab(flag=0)
                                        ]),
                                        category="test",
                                        enableInpaintAug=enableInpaintAug)
    salobj_dataloader = DataLoader(salobj_dataset,
                                   batch_size=batch_size_train,
                                   shuffle=True,
                                   num_workers=1)
    salobj_dataloader_test = DataLoader(salobj_dataset_test,
                                        batch_size=batch_size_val,
                                        shuffle=True,
                                        num_workers=1)

    # ------- 3. define model --------
    # define the net
    net = BASNet(3, 1)
    if resume_train:
        # print("resume_model_path:", resume_model_path)
        checkpoint = torch.load(resume_model_path)
        net.load_state_dict(checkpoint)
    if torch.cuda.is_available():
        net.to(device)

    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(),
                           lr=0.001,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)

    plotter = VisdomLinePlotter(env_name=visdom_tab_title)

    best_ave_mae = 100000
    best_max_fmeasure = 0
    best_relaxed_fmeasure = 0
    best_ave_maxf = 0
    best_own_RelaxedFmeasure = 0
    for epoch in range(last_epoch - 1, epoch_num):
        ### Train network
        train_loss0 = AverageMeter()
        train_loss1 = AverageMeter()
        train_loss2 = AverageMeter()
        train_loss3 = AverageMeter()
        train_loss4 = AverageMeter()
        train_loss5 = AverageMeter()
        train_loss6 = AverageMeter()
        train_loss7 = AverageMeter()
        train_struct_loss1 = AverageMeter()
        train_struct_loss2 = AverageMeter()
        train_struct_loss3 = AverageMeter()
        train_struct_loss4 = AverageMeter()
        train_struct_loss5 = AverageMeter()
        train_struct_loss6 = AverageMeter()
        train_struct_loss7 = AverageMeter()

        test_loss0 = AverageMeter()
        test_loss1 = AverageMeter()
        test_loss2 = AverageMeter()
        test_loss3 = AverageMeter()
        test_loss4 = AverageMeter()
        test_loss5 = AverageMeter()
        test_loss6 = AverageMeter()
        test_loss7 = AverageMeter()
        test_struct_loss1 = AverageMeter()
        test_struct_loss2 = AverageMeter()
        test_struct_loss3 = AverageMeter()
        test_struct_loss4 = AverageMeter()
        test_struct_loss5 = AverageMeter()
        test_struct_loss6 = AverageMeter()
        test_struct_loss7 = AverageMeter()

        average_mae = AverageMeter()
        average_maxf = AverageMeter()
        average_relaxedf = AverageMeter()
        average_own_RelaxedFMeasure = AverageMeter()
        net.train()
        for i, data in enumerate(salobj_dataloader):
            ite_num = ite_num + 1
            ite_num4val = ite_num4val + 1
            inputs, labels, labels_struct = data['image'], data['label'], data[
                'label2']

            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)
            labels_struct = labels_struct.type(torch.FloatTensor)

            # wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v, labels_struct_v = Variable(
                    inputs.to(device), requires_grad=False), Variable(
                        labels.to(device), requires_grad=False), Variable(
                            labels_struct.to(device), requires_grad=False)
            else:
                inputs_v, labels_v, labels_struct_v = Variable(
                    inputs, requires_grad=False), Variable(
                        labels,
                        requires_grad=False), Variable(labels_struct,
                                                       requires_grad=False)

            # y zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct = net(
                inputs_v)
            loss2, loss = muti_bce_loss_fusion(
                d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct,
                d3_struct, d4_struct, d5_struct, d6_struct, d7_struct,
                labels_v, train_loss0, train_loss1, train_loss2, train_loss3,
                train_loss4, train_loss5, train_loss6, train_loss7,
                train_struct_loss1, train_struct_loss2, train_struct_loss3,
                train_struct_loss4, train_struct_loss5, train_struct_loss6,
                train_struct_loss7)
            loss.backward()
            optimizer.step()

            # # print statistics
            running_loss += loss.data
            running_tar_loss += loss2.data

            # del temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct, loss2, loss

            print(
                "[train epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f "
                % (epoch + 1, epoch_num,
                   (i + 1) * batch_size_train, train_num, ite_num,
                   running_loss / ite_num4val, running_tar_loss / ite_num4val))
        plotter.plot('loss0', 'train', 'Main Loss 0', epoch + 1,
                     float(train_loss0.avg))
        plotter.plot('loss1', 'train', 'Main Loss 1', epoch + 1,
                     float(train_loss1.avg))
        plotter.plot('loss2', 'train', 'Main Loss 2', epoch + 1,
                     float(train_loss2.avg))
        plotter.plot('loss3', 'train', 'Main Loss 3', epoch + 1,
                     float(train_loss3.avg))
        plotter.plot('loss4', 'train', 'Main Loss 4', epoch + 1,
                     float(train_loss4.avg))
        plotter.plot('loss5', 'train', 'Main Loss 5', epoch + 1,
                     float(train_loss5.avg))
        plotter.plot('loss6', 'train', 'Main Loss 6', epoch + 1,
                     float(train_loss6.avg))
        plotter.plot('loss7', 'train', 'Main Loss 7', epoch + 1,
                     float(train_loss7.avg))
        plotter.plot('structloss1', 'train', 'Struct Loss 1', epoch + 1,
                     float(train_struct_loss1.avg))
        plotter.plot('structloss2', 'train', 'Struct Loss 2', epoch + 1,
                     float(train_struct_loss2.avg))
        plotter.plot('structloss3', 'train', 'Struct Loss 3', epoch + 1,
                     float(train_struct_loss3.avg))
        plotter.plot('structloss4', 'train', 'Struct Loss 4', epoch + 1,
                     float(train_struct_loss4.avg))
        plotter.plot('structloss5', 'train', 'Struct Loss 5', epoch + 1,
                     float(train_struct_loss5.avg))
        plotter.plot('structloss6', 'train', 'Struct Loss 6', epoch + 1,
                     float(train_struct_loss6.avg))
        plotter.plot('structloss7', 'train', 'Struct Loss 7', epoch + 1,
                     float(train_struct_loss7.avg))

        ### Validate model
        print("---Evaluate model---")
        if ite_num >= next_test:  # test and save model 10000 iterations, due to very large DUTS-TE dataset
            next_test = ite_num + test_increments
            net.eval()
            max_epoch_fmeasure = 0
            for i, data in enumerate(salobj_dataloader_test):
                inputs, labels = data['image'], data['label']
                inputs = inputs.type(torch.FloatTensor)
                labels = labels.type(torch.FloatTensor)
                if torch.cuda.is_available():
                    inputs_v, labels_v = Variable(
                        inputs.to(device),
                        requires_grad=False), Variable(labels.to(device),
                                                       requires_grad=False)
                else:
                    inputs_v, labels_v = Variable(
                        inputs,
                        requires_grad=False), Variable(labels,
                                                       requires_grad=False)
                d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct = net(
                    inputs_v)

                pred = d0[:, 0, :, :]
                pred = normPRED(pred)
                pred = pred.squeeze()
                predict_np = pred.cpu().data.numpy()
                im = Image.fromarray(predict_np * 255).convert('RGB')
                img_name = test_img_name_list[i]
                image = cv2.imread(img_name)
                imo = im.resize((image.shape[1], image.shape[0]),
                                resample=Image.BILINEAR)
                imo = imo.convert("L")  ###  Convert to grayscale 1-channel
                resizedImg_np = np.array(
                    imo)  ### Result is 2D numpy array predicted salient map
                img__lbl_name = test_lbl_name_list[i]
                gt_img = np.array(Image.open(img__lbl_name).convert(
                    "L"))  ### Ground truth salient map

                ### Compute metrics
                result_mae = getMAE(gt_img, resizedImg_np)
                average_mae.update(result_mae, 1)
                precision, recall = getPRCurve(gt_img, resizedImg_np)
                result_maxfmeasure = getMaxFMeasure(precision, recall)
                result_maxfmeasure = result_maxfmeasure.mean()
                average_maxf.update(result_maxfmeasure, 1)
                if (result_maxfmeasure > max_epoch_fmeasure):
                    max_epoch_fmeasure = result_maxfmeasure
                result_relaxedfmeasure = getRelaxedFMeasure(
                    gt_img, resizedImg_np)
                result_ownrelaxedfmeasure = own_RelaxedFMeasure(
                    gt_img, resizedImg_np)
                average_relaxedf.update(result_relaxedfmeasure, 1)
                average_own_RelaxedFMeasure.update(result_ownrelaxedfmeasure,
                                                   1)
                loss2, loss = muti_bce_loss_fusion(
                    d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct,
                    d3_struct, d4_struct, d5_struct, d6_struct, d7_struct,
                    labels_v, test_loss0, test_loss1, test_loss2, test_loss3,
                    test_loss4, test_loss5, test_loss6, test_loss7,
                    test_struct_loss1, test_struct_loss2, test_struct_loss3,
                    test_struct_loss4, test_struct_loss5, test_struct_loss6,
                    test_struct_loss7)
                del d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct, loss2, loss
                print(
                    "[test epoch: %3d/%3d, batch: %5d/%5d, ite: %d] test loss: %3f, tar: %3f "
                    % (epoch + 1, epoch_num, (i + 1) * batch_size_val,
                       test_num, ite_num, running_loss / ite_num4val,
                       running_tar_loss / ite_num4val))
            model_name = model_dir + "basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f.pth" % (
                epoch + 1, ite_num, running_loss / ite_num4val,
                running_tar_loss / ite_num4val)
            torch.save(net.state_dict(), model_name)
            running_loss = 0.0
            running_tar_loss = 0.0
            net.train()  # resume train
            ite_num4val = 1
            if (average_mae.avg < best_ave_mae):
                best_ave_mae = average_mae.avg
                newname = model_dir + "bestMAE/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_mae_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_ave_mae)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (max_epoch_fmeasure > best_max_fmeasure):
                best_max_fmeasure = max_epoch_fmeasure
                newname = model_dir + "bestEpochMaxF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_maxfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_max_fmeasure)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (average_maxf.avg > best_ave_maxf):
                best_ave_maxf = average_maxf.avg
                newname = model_dir + "bestAveMaxF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_avemfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_ave_maxf)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (average_relaxedf.avg > best_relaxed_fmeasure):
                best_relaxed_fmeasure = average_relaxedf.avg
                newname = model_dir + "bestAveRelaxF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_averelaxfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_relaxed_fmeasure)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (average_own_RelaxedFMeasure.avg > best_own_RelaxedFmeasure):
                best_own_RelaxedFmeasure = average_own_RelaxedFMeasure.avg
                newname = model_dir + "bestOwnRelaxedF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_averelaxfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_own_RelaxedFmeasure)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            plotter.plot('loss0', 'test', 'Main Loss 0', epoch + 1,
                         float(test_loss0.avg))
            plotter.plot('loss1', 'test', 'Main Loss 1', epoch + 1,
                         float(test_loss1.avg))
            plotter.plot('loss2', 'test', 'Main Loss 2', epoch + 1,
                         float(test_loss2.avg))
            plotter.plot('loss3', 'test', 'Main Loss 3', epoch + 1,
                         float(test_loss3.avg))
            plotter.plot('loss4', 'test', 'Main Loss 4', epoch + 1,
                         float(test_loss4.avg))
            plotter.plot('loss5', 'test', 'Main Loss 5', epoch + 1,
                         float(test_loss5.avg))
            plotter.plot('loss6', 'test', 'Main Loss 6', epoch + 1,
                         float(test_loss6.avg))
            plotter.plot('loss7', 'test', 'Main Loss 7', epoch + 1,
                         float(test_loss7.avg))
            plotter.plot('structloss1', 'test', 'Struct Loss 1', epoch + 1,
                         float(test_struct_loss1.avg))
            plotter.plot('structloss2', 'test', 'Struct Loss 2', epoch + 1,
                         float(test_struct_loss2.avg))
            plotter.plot('structloss3', 'test', 'Struct Loss 3', epoch + 1,
                         float(test_struct_loss3.avg))
            plotter.plot('structloss4', 'test', 'Struct Loss 4', epoch + 1,
                         float(test_struct_loss4.avg))
            plotter.plot('structloss5', 'test', 'Struct Loss 5', epoch + 1,
                         float(test_struct_loss5.avg))
            plotter.plot('structloss6', 'test', 'Struct Loss 6', epoch + 1,
                         float(test_struct_loss6.avg))
            plotter.plot('structloss7', 'test', 'Struct Loss 7', epoch + 1,
                         float(test_struct_loss7.avg))
            plotter.plot('mae', 'test', 'Average Epoch MAE', epoch + 1,
                         float(average_mae.avg))
            plotter.plot('max_maxf', 'test', 'Max Max Epoch F-Measure',
                         epoch + 1, float(max_epoch_fmeasure))
            plotter.plot('ave_maxf', 'test', 'Average Max F-Measure',
                         epoch + 1, float(average_maxf.avg))
            plotter.plot('ave_relaxedf', 'test', 'Average Relaxed F-Measure',
                         epoch + 1, float(average_relaxedf.avg))
            plotter.plot('own_RelaxedFMeasure', 'test',
                         'Own Average Relaxed F-Measure', epoch + 1,
                         float(average_own_RelaxedFMeasure.avg))
    print('-------------Congratulations! Training Done!!!-------------')
#dataset = ShapDatasetTop(args.normal_path, args.adversarial_path, normal_only=True, normalise=True)
#dataset_all = ShapDatasetTop(args.normal_path, args.adversarial_path, normal_only=False, normalise=True)
dataset = ShapDatasetFly(args.normal_path,
                         args.adversarial_path,
                         args.collate_path,
                         args.model,
                         normal_only=True,
                         large_normal=True)
dataset_all = ShapDatasetFly(args.normal_path,
                             args.adversarial_path,
                             args.collate_path,
                             args.model,
                             large_normal=True)

global plotter
plotter = VisdomLinePlotter(args.plot)

# https://docs.microsoft.com/en-us/archive/msdn-magazine/2019/april/test-run-neural-anomaly-detection-using-pytorch


class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.l1 = nn.Linear(100, 50)
        self.l2 = nn.Linear(50, 25)
        self.l3 = nn.Linear(25, 50)
        self.l4 = nn.Linear(50, 100)

    def forward(self, x):
        z = torch.tanh(self.l1(x))
def train():
    data_dir = '/media/markytools/New Volume/Courses/EE298CompVis/finalproject/datasets/'
    test_image_dir = 'DUTS/DUTS-TE/DUTS-TE-Image/'
    test_label_dir = 'DUTS/DUTS-TE/DUTS-TE-Mask/'
    image_ext = '.jpg'
    label_ext = '.png'

    model_dir = "../saved_models/"
    resume_train = True
    resume_model_path = model_dir + "basnet-original.pth"
    last_epoch = 1
    epoch_num = 100000
    batch_size_train = 8
    batch_size_val = 1
    train_num = 0
    val_num = 0
    enableInpaintAug = False
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") #set CPU to 0
    # ------- 5. training process --------
    print("---start training...")
    test_increments = 15000
    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 1
    next_test = ite_num + 0
    ############
    ############
    ############
    ############
    test_img_name_list = glob.glob(data_dir + test_image_dir + '*' + image_ext)
    print("data_dir + test_image_dir + '*' + image_ext: ", data_dir + test_image_dir + '*' + image_ext)

    test_lbl_name_list = []
    for img_path in test_img_name_list:
    	img_name = img_path.split("/")[-1]
    	aaa = img_name.split(".")
    	bbb = aaa[0:-1]
    	imidx = bbb[0]
    	for i in range(1,len(bbb)):
    		imidx = imidx + "." + bbb[i]
    	test_lbl_name_list.append(data_dir + test_label_dir + imidx + label_ext)

    print("---")
    print("test images: ", len(test_img_name_list))
    print("test labels: ", len(test_lbl_name_list))
    print("---")

    test_num = len(test_img_name_list)
    for test_lbl in test_lbl_name_list:
        test_jpg = test_lbl.replace("png", "jpg")
        test_jpg = test_jpg.replace("Mask", "Image")
        if test_jpg not in test_img_name_list: print("test_lbl not in label: ", test_lbl)

    salobj_dataset_test = SalObjDataset(
        img_name_list=test_img_name_list,
        lbl_name_list=test_lbl_name_list,
        transform=transforms.Compose([
            RescaleT(256),
            RandomCrop(224),
            ToTensorLab(flag=0)]),
    		category="test",
    		enableInpaintAug=enableInpaintAug)
    salobj_dataloader_test = DataLoader(salobj_dataset_test, batch_size=batch_size_val, shuffle=True, num_workers=1)

    # ------- 3. define model --------
    # define the net
    net = BASNet(3, 1)
    if resume_train:
    	# print("resume_model_path:", resume_model_path)
    	checkpoint = torch.load(resume_model_path)
    	net.load_state_dict(checkpoint)
    if torch.cuda.is_available():
        net.to(device)

    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

    plotter = VisdomLinePlotter(env_name='NewlyAddedRelaxedMeasureEnv1')

    best_ave_mae = 100000
    best_max_fmeasure = 0
    best_relaxed_fmeasure = 0
    best_ave_maxf = 0

    ### Train network
    train_loss0 = AverageMeter()
    train_loss1 = AverageMeter()
    train_loss2 = AverageMeter()
    train_loss3 = AverageMeter()
    train_loss4 = AverageMeter()
    train_loss5 = AverageMeter()
    train_loss6 = AverageMeter()
    train_loss7 = AverageMeter()


    test_loss0 = AverageMeter()
    test_loss1 = AverageMeter()
    test_loss2 = AverageMeter()
    test_loss3 = AverageMeter()
    test_loss4 = AverageMeter()
    test_loss5 = AverageMeter()
    test_loss6 = AverageMeter()
    test_loss7 = AverageMeter()

    average_mae = AverageMeter()
    average_maxf = AverageMeter()
    average_relaxedf = AverageMeter()
    ### Validate model
    print("---Evaluate model---")
    next_test = ite_num + test_increments
    net.eval()
    max_epoch_fmeasure = 0
    for i, data in enumerate(salobj_dataloader_test):
        inputs, labels = data['image'], data['label']
        inputs = inputs.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)
        if torch.cuda.is_available():
            inputs_v, labels_v = Variable(inputs.to(device), requires_grad=False), Variable(labels.to(device), requires_grad=False)
        else:
            inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
        d0, d1, d2, d3, d4, d5, d6, d7 = net(inputs_v)
        pred = d0[:,0,:,:]
        pred = normPRED(pred)
        pred = pred.squeeze()
        predict_np = pred.cpu().data.numpy()
        im = Image.fromarray(predict_np*255).convert('RGB')
        img_name = test_img_name_list[i]
        image = cv2.imread(img_name)
        imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
        imo = imo.convert("L") ###  Convert to grayscale 1-channel
        resizedImg_np = np.array(imo) ### Result is 2D numpy array predicted salient map
        img__lbl_name = test_lbl_name_list[i]
        gt_img = np.array(Image.open(img__lbl_name).convert("L")) ### Ground truth salient map

        ### Compute metrics
        img_name_png =
        result_mae = getMAE(gt_img, resizedImg_np)
        average_mae.update(result_mae, 1)
        precision, recall = getPRCurve(gt_img, resizedImg_np)
        result_maxfmeasure = getMaxFMeasure(precision, recall)
        result_maxfmeasure = result_maxfmeasure.mean()
        average_maxf.update(result_maxfmeasure, 1)
        if (result_maxfmeasure > max_epoch_fmeasure):
        	max_epoch_fmeasure = result_maxfmeasure
        result_relaxedfmeasure = getRelaxedFMeasure(gt_img, resizedImg_np)
        average_relaxedf.update(result_relaxedfmeasure, 1)
        loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7,labels_v,
        						test_loss0,
        						test_loss1,
        						test_loss2,
        						test_loss3,
        						test_loss4,
        						test_loss5,
        						test_loss6,
        						test_loss7)
        del d0, d1, d2, d3, d4, d5, d6, d7,loss2, loss
    print("Average Epoch MAE: ", average_mae.avg)
    print("Max Max Epoch F-Measure: ", average_maxf.avg)
    print("Average Max F-Measure: ", max_epoch_fmeasure)
    print("Average Relaxed F-Measure: ", average_relaxedf.avg)

    print('-------------Congratulations! Training Done!!!-------------')
Esempio n. 6
0
    def train(self):
        plotter = VisdomLinePlotter(env_name=visdom_tab_title)
        iter_num = 30000  # each batch only train 30000 iters.(This number is just a random choice...)
        aveGrad = 0
        for epoch in range(self.config.epoch):
            r_edge_loss, r_sal_loss, r_sum_loss = 0, 0, 0
            self.net.zero_grad()
            for i, data_batch in enumerate(self.train_loader):
                if (i + 1) == iter_num: break
                edge_image, edge_label, sal_image, sal_label = data_batch[
                    'edge_image'], data_batch['edge_label'], data_batch[
                        'sal_image'], data_batch['sal_label']
                if (sal_image.size(2) != sal_label.size(2)) or (
                        sal_image.size(3) != sal_label.size(3)):
                    print('IMAGE ERROR, PASSING```')
                    continue
                edge_image, edge_label, sal_image, sal_label = Variable(
                    edge_image), Variable(edge_label), Variable(
                        sal_image), Variable(sal_label)
                if self.config.cuda:
                    edge_image, edge_label, sal_image, sal_label = edge_image.cuda(
                    ), edge_label.cuda(), sal_image.cuda(), sal_label.cuda()

                # edge part
                edge_pred = self.net(edge_image, mode=0)
                #edge_loss_fuse = bce2d(edge_pred[0], edge_label)
                edge_loss_fuse = bce2d(edge_pred[0], edge_label)
                print(edge_loss_fuse)
                edge_loss_part = []
                for ix in edge_pred[1]:
                    edge_loss_part.append(
                        bce2d(ix, edge_label, reduction='sum'))
                #edge_loss = (edge_loss_fuse + sum(edge_loss_part)) / (self.iter_size * self.config.batch_size)
                #edge_loss = (sum(edge_loss_part)) / (self.iter_size * self.config.batch_size)
                edge_loss = sum(edge_loss_part) / len(
                    edge_loss_part
                )  #(sum(edge_loss_part)) / (self.iter_size * self.config.batch_size)
                r_edge_loss = edge_loss
                # sal part
                sal_pred = self.net(sal_image, mode=1)
                sal_loss_fuse = F.binary_cross_entropy_with_logits(
                    sal_pred, sal_label)
                sal_loss = sal_loss_fuse  #/ (self.iter_size * self.config.batch_size)
                r_sal_loss = sal_loss

                loss = sal_loss + edge_loss
                r_sum_loss = loss
                loss.backward()

                aveGrad += 1

                # accumulate gradients as done in DSS
                if aveGrad % self.iter_size == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    aveGrad = 0

                if i % (self.show_every // self.config.batch_size) == 0:
                    if i == 0:
                        x_showEvery = 1
                    print(
                        'epoch: [%2d/%2d], iter: [%5d/%5d]  ||  Edge : %10.4f  ||  Sal : %10.4f  ||  Sum : %10.4f'
                        % (epoch, self.config.epoch, i, iter_num,
                           r_edge_loss / x_showEvery, r_sal_loss / x_showEvery,
                           r_sum_loss / x_showEvery))
                    print('Learning rate: ' + str(self.lr))
                    r_edge_loss, r_sal_loss, r_sum_loss = 0, 0, 0

            plotter.plot('edge_loss', 'train',
                         'Balanced Binary Cross Entropy Loss', epoch + 1,
                         float(edge_loss))
            plotter.plot('sal_loss', 'train', 'Binary Cross Entropy Loss',
                         epoch + 1, float(sal_loss))
            plotter.plot('loss', 'train', '', epoch + 1, float(loss))
            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(
                    self.net.state_dict(), '%s/models/epoch_%d.pth' %
                    (self.config.save_folder, epoch + 1))

            if epoch in self.lr_decay_epoch:
                self.lr = self.lr * 0.1
                self.optimizer = Adam(filter(lambda p: p.requires_grad,
                                             self.net.parameters()),
                                      lr=self.lr,
                                      weight_decay=self.wd)

        torch.save(self.net.state_dict(),
                   '%s/models/final.pth' % self.config.save_folder)
Esempio n. 7
0
    def __init__(
        self,
        name = 'default',
        results_dir = 'results',
        models_dir = 'models',
        transfer_from_checkpoint = None,
        plotting = False,
        base_dir = './',
        image_size = 128,
        network_capacity = 16,
        fmap_max = 512,
        transparent = False,
        batch_size = 16,
        mixed_prob = 0.9,
        gradient_accumulate_every=1,
        lr = 2e-4,
        lr_mlp = 0.1,
        ttur_mult = 2,
        rel_disc_loss = False,
        num_workers = None,
        save_every = 1000,
        evaluate_every = 1000,
        num_image_tiles = 8,
        trunc_psi = 0.6,
        fp16 = False,
        cl_reg = False,
        no_pl_reg = False,
        fq_layers = [],
        fq_dict_size = 256,
        attn_layers = [],
        no_const = False,
        aug_prob = 0.,
        aug_types = ['translation', 'cutout'],
        top_k_training = False,
        generator_top_k_gamma = 0.99,
        generator_top_k_frac = 0.5,
        dataset_aug_prob = 0.,
        calculate_fid_every = None,
        calculate_fid_num_images = 12800,
        clear_fid_cache = False,
        is_ddp = False,
        rank = 0,
        world_size = 1,
        log = False,
        *args,
        **kwargs
    ):
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir
        self.transfer_from_checkpoint = transfer_from_checkpoint
        self.fid_dir = base_dir / 'fid' / name
        self.config_path = self.models_dir / name / '.config.json'

        assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
        self.image_size = image_size
        self.network_capacity = network_capacity
        self.fmap_max = fmap_max
        self.transparent = transparent

        self.fq_layers = cast_list(fq_layers)
        self.fq_dict_size = fq_dict_size
        self.has_fq = len(self.fq_layers) > 0

        self.attn_layers = cast_list(attn_layers)
        self.no_const = no_const

        self.aug_prob = aug_prob
        self.aug_types = aug_types

        self.lr = lr
        self.lr_mlp = lr_mlp
        self.ttur_mult = ttur_mult
        self.rel_disc_loss = rel_disc_loss
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.mixed_prob = mixed_prob

        self.num_image_tiles = num_image_tiles
        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        self.av = None
        self.trunc_psi = trunc_psi

        self.no_pl_reg = no_pl_reg
        self.pl_mean = None

        self.gradient_accumulate_every = gradient_accumulate_every

        assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex is not available for you to use mixed precision training'
        self.fp16 = fp16

        self.cl_reg = cl_reg

        self.d_loss = 0
        self.g_loss = 0
        self.q_loss = None
        self.last_gp_loss = None
        self.last_cr_loss = None
        self.last_fid = None

        self.pl_length_ma = EMA(0.99)
        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.calculate_fid_every = calculate_fid_every
        self.calculate_fid_num_images = calculate_fid_num_images
        self.clear_fid_cache = clear_fid_cache

        self.top_k_training = top_k_training
        self.generator_top_k_gamma = generator_top_k_gamma
        self.generator_top_k_frac = generator_top_k_frac

        assert not (is_ddp and cl_reg), 'Contrastive loss regularization does not work well with multi GPUs yet'
        self.is_ddp = is_ddp
        self.is_main = rank == 0
        self.rank = rank
        self.world_size = world_size
        self.plotting = plotting
        if plotting:
            self.image_plotter = VisdomImagePlotter()
            self.line_plotter = VisdomLinePlotter()
        else:
            self.image_plotter = None
            self.line_plotter = None
                            help='Dimension of hidden layer',
                            type=int,
                            default=60)
    arg_parser.add_argument('--name',
                            '-n',
                            help='Name of raytune experiment',
                            type=str,
                            default='train_mlp')

    args = arg_parser.parse_args()

    print('=================== Loading dataset')
    dataset = ShapDatasetTop(args.normal_path, args.adversarial_path)

    global plotter
    plotter = VisdomLinePlotter(args.plot)


class Net(nn.Module):
    def __init__(self, hidden_dim=60):
        super(Net, self).__init__()

        self.l1 = torch.nn.Linear(100, hidden_dim)
        self.l2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)

        return F.sigmoid(x)
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

from visdom import Visdom

from utils import AverageMeter, VisdomLinePlotter

from ShapDataset import ShapDataset
from models import BinaryClassModel

import argparse

global plotter
plotter = VisdomLinePlotter('attack-classifier')

arg_parser = argparse.ArgumentParser(
    description='Use Captum to explain samples from RETAIN')

arg_parser.add_argument('--learning_rate',
                        '-lr',
                        help='Learning rate for optimiser',
                        default=0.0001,
                        type=float)
arg_parser.add_argument('--epochs',
                        '-e',
                        help='Number of epochs to train for',
                        default=10,
                        type=int)
arg_parser.add_argument('--batch_size',
Esempio n. 10
0
from sklearn.metrics import roc_auc_score, average_precision_score
import matplotlib.pyplot as plt

import sys
import os
import argparse
import pickle

from tqdm import trange

from utils import VisitSequenceWithLabelDataset, visit_collate_fn, VisdomLinePlotter

# Visdom plotting
global plotter
plotter = VisdomLinePlotter('lava-adversarial-retain')

""" Arguments """
parser = argparse.ArgumentParser()

parser.add_argument('data_path', metavar='DATA_PATH', help="Path to the dataset")

parser.add_argument('--num_features', type=int, default=4894, metavar='N', help='number of features (i.e., input dimension')
parser.add_argument('--dim-emb', default=128, type=int, help='embedding dimension (default: 128)')
parser.add_argument('--drop-emb', default=0.5, type=float, help='embedding layer dropout rate (default: 0.5)')
parser.add_argument('--dim-alpha', default=128, type=int, help='RNN-Alpha hidden size (default: 128)')
parser.add_argument('--dim-beta', default=128, type=int, help='RNN-Beta hidden size (default: 128)')
parser.add_argument('--drop-context', default=0.5, type=float, help='context layer dropout rate (default: 0.5)')

parser.add_argument('--lr', '--learning-rate', default=1e-2, type=float, help='learning rate (default: 1e-2)')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
Esempio n. 11
0
        criterion = criterion.cuda()

    print("======================= Running model on data...")
    inputs, labels = next(test_dataloader)
    inputs = inputs.to(device)
    labels = labels.to(device)

    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)

    correct = torch.sum(preds == labels).item()

    print('======================= Got {} sample correct'.format(correct))

    if args.normal:
        plotter = VisdomLinePlotter('cxr-explain-normal')
        type = 'Normal'
    else:
        plotter = VisdomLinePlotter('cxr-explain-adv')
        type = 'Adversarial'

    print("\n============================================== Explanations")

    # Let's try looking at the final layer of the model
    batch_size = args.batch_size
    #cond = LayerConductance(model, model.classifier)

    # Reset our dataloader
    test_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)

    # Get an input batch from our dataset
Esempio n. 12
0
from visdom import Visdom

from utils import AverageMeter, VisdomLinePlotter

from ShapDataset import ShapDataset, ShapDatasetChunked, ShapDatasetDict, ShapDatasetTop

import argparse

import pickle
import random

global plotter
plotter = VisdomLinePlotter('attack-svm')

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.svm import SVC
from sklearn import preprocessing
from sklearn.svm import libsvm

import tqdm

import os

arg_parser = argparse.ArgumentParser(description='Use Captum to explain samples from RETAIN')

arg_parser.add_argument('--all', '-a', help='Train all SVM models', action='store_true')
arg_parser.add_argument('--save', '-s', help='Location to save model to', type=str, default=None)
arg_parser.add_argument('--load', '-l', help='Test a pretrained model', type=str, default=None)
arg_parser.add_argument('--load_all_csv', help='Load a single CSV file containing all SHAP values', type=str,
                        default=None)
Esempio n. 13
0
class Trainer():
    def __init__(self,
                 device,
                 GAN,
                 dataloader,
                 model_dir='../models',
                 num_epochs=1,
                 criterion=nn.BCELoss(),
                 lr=0.0002,
                 beta1=0.5,
                 nz=100,
                 real_label=1.,
                 fake_label=0.,
                 plotting=False):
        self.device = device
        self.GAN = GAN
        self.dataloader = dataloader
        self.model_dir = model_dir
        self.num_epochs = num_epochs
        self.criterion = criterion
        self.nz = nz
        self.fixed_noise = torch.randn(64, self.nz, 1, 1, device=self.device)
        self.real_label = real_label
        self.fake_label = fake_label
        self.optimizerD = optim.Adam(self.GAN.D.parameters(),
                                     lr=lr,
                                     betas=(beta1, 0.999))
        self.optimizerG = optim.Adam(self.GAN.G.parameters(),
                                     lr=lr,
                                     betas=(beta1, 0.999))
        self.plotting = plotting
        if plotting:
            self.line_plotter = VisdomLinePlotter()
            self.image_plotter = VisdomImagePlotter()

    def train(self):
        # Training Loop

        # Lists to keep track of progress
        img_list = []
        G_losses = []
        D_losses = []
        iters = 0

        print("Starting Training Loop...")
        # For each epoch
        for epoch in range(self.num_epochs):
            # For each batch in the dataloader
            for i, data in enumerate(self.dataloader, 0):

                if (iters % 500 == 0) or ((epoch == self.num_epochs - 1) and
                                          (i == len(self.dataloader) - 1)):
                    with torch.no_grad():
                        fake = self.GAN.G(self.fixed_noise).detach().cpu()
                    self.image_plotter.plot(vutils.make_grid(fake,
                                                             padding=2,
                                                             normalize=True),
                                            name="generator-output")
                    # img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                ## Train with all-real batch
                self.GAN.D.zero_grad()
                # Format batch
                real_cpu = data[0].to(self.device)
                b_size = real_cpu.size(0)
                label = torch.full((b_size, ),
                                   self.real_label,
                                   dtype=torch.float,
                                   device=self.device)
                # Forward pass real batch through D
                output = self.GAN.D(real_cpu).view(-1)
                # Calculate loss on all-real batch
                errD_real = self.criterion(output, label)
                # Calculate gradients for D in backward pass
                errD_real.backward()
                D_x = output.mean().item()

                ## Train with all-fake batch
                # Generate batch of latent vectors
                noise = torch.randn(b_size, self.nz, 1, 1, device=self.device)
                # Generate fake image batch with G
                fake = self.GAN.G(noise)
                label.fill_(self.fake_label)
                # Classify all fake batch with D
                output = self.GAN.D(fake.detach()).view(-1)
                # Calculate D's loss on the all-fake batch
                errD_fake = self.criterion(output, label)
                # Calculate the gradients for this batch
                errD_fake.backward()
                D_G_z1 = output.mean().item()
                # Add the gradients from the all-real and all-fake batches
                errD = errD_real + errD_fake
                # Update D
                self.optimizerD.step()

                ############################
                # (2) Update G network: maximize log(D(G(z)))
                ###########################
                self.GAN.G.zero_grad()
                label.fill_(
                    self.real_label)  # fake labels are real for generator cost
                # Since we just updated D, perform another forward pass of all-fake batch through D
                output = self.GAN.D(fake).view(-1)
                # Calculate G's loss based on this output
                errG = self.criterion(output, label)
                # Calculate gradients for G
                errG.backward()
                D_G_z2 = output.mean().item()
                # Update G
                self.optimizerG.step()

                # Output training stats
                if i % 50 == 0:
                    print(
                        '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                        % (epoch, self.num_epochs, i, len(self.dataloader),
                           errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
                    if self.plotting:
                        self.line_plotter.plot(var_name="loss",
                                               split_name='Discriminator',
                                               title_name='Training Loss',
                                               x=epoch +
                                               i / len(self.dataloader),
                                               y=errD.item())

                        self.line_plotter.plot(var_name="loss",
                                               split_name='Generator',
                                               title_name='Training Loss',
                                               x=epoch +
                                               i / len(self.dataloader),
                                               y=errG.item())

                # # Save Losses for plotting later
                # G_losses.append(errG.item())
                # D_losses.append(errD.item())

                # Check how the generator is doing by saving G's output on fixed_noise
                if (iters % 500 == 0) or ((epoch == self.num_epochs - 1) and
                                          (i == len(self.dataloader) - 1)):
                    with torch.no_grad():
                        fake = self.GAN.G(self.fixed_noise).detach().cpu()
                    if self.plotting:
                        self.image_plotter.plot(vutils.make_grid(
                            fake, padding=2, normalize=True),
                                                name="generator-output")
                    # img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

                iters += 1

            with torch.no_grad():
                fake = self.GAN.G(self.fixed_noise).detach().cpu()
                if self.plotting:
                    self.image_plotter.plot(vutils.make_grid(fake,
                                                             padding=2,
                                                             normalize=True),
                                            name="generator-output")
                # img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': self.GAN.G.state_dict(),
                    'optimizer_state_dict': self.optimizerG.state_dict(),
                    'loss': errG.item(),
                }, os.path.join(self.model_dir, f"Generator-{epoch}.pth"))
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': self.GAN.D.state_dict(),
                    'optimizer_state_dict': self.optimizerD.state_dict(),
                    'loss': errD.item(),
                }, os.path.join(self.model_dir, f"Discriminator-{epoch}.pth"))
Esempio n. 14
0
def train(opt):

    global plotter
    plotter = VisdomLinePlotter(env_name='FreiCar Object Detection')

    params = Params(f'projects/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/{params.project_name}/'
    os.makedirs(opt.saved_path, exist_ok=True)

    # define paramteters for model training
    training_params = {
        'batch_size': opt.batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    # define paramteters for model evaluation
    val_params = {
        'batch_size': opt.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    # get training dataset
    training_set = FreiCarDataset(data_dir="./dataloader/data/",
                                  padding=(0, 0, 12, 12),
                                  split='training',
                                  load_real=True)

    # and make data generator from dataset
    training_generator = DataLoader(training_set, **training_params)

    # get validation dataset
    val_set = FreiCarDataset(data_dir="./dataloader/data/",
                             padding=(0, 0, 12, 12),
                             split='validation',
                             load_real=False)

    # and make data generator from dataset
    val_generator = DataLoader(val_set, **val_params)

    # Instantiation of the EfficientDet model
    model = EfficientDetBackbone(num_classes=len(params.obj_list),
                                 compound_coef=opt.compound_coef,
                                 ratios=eval(params.anchors_ratios),
                                 scales=eval(params.anchors_scales))

    # load last weights if training from checkpoint
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0

        try:
            ret = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
            print(
                '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with '
                'different number of classes. The rest of the weights should be loaded already.'
            )

        print(
            f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
        )
    else:
        last_step = 0
        print('[Info] initializing weights...')
        init_weights(model)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)

    optimizer = torch.optim.AdamW(model.parameters(), opt.lr)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)

    # Define training criterion
    criterion = FocalLoss()

    # Set model to train mode
    model.train()

    num_iter_per_epoch = len(training_generator)

    print('Started Training')

    # Train loop
    for epoch in range(opt.num_epochs):
        last_epoch = step // num_iter_per_epoch
        if epoch < last_epoch:
            continue

        epoch_loss = []  # here we append new total losses for each step

        progress_bar = tqdm(training_generator)
        for iter, data in enumerate(progress_bar):
            if iter < step - last_epoch * num_iter_per_epoch:
                progress_bar.update()
                continue

            ##########################################
            # TODO: implement me!
            # Made by DavideRezzoli
            ##########################################
            optimizer.zero_grad()
            _, reg, clas, anchor = model(data['img'].cuda())
            cls_loss, reg_loss = criterion(clas, reg, anchor,
                                           data['annot'].cuda())
            loss = cls_loss.mean() + reg_loss.mean()
            loss.backward()
            optimizer.step()

            epoch_loss.append(float(loss))

            progress_bar.set_description(
                'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                .format(step, epoch, opt.num_epochs, iter + 1,
                        num_iter_per_epoch, cls_loss.item(), reg_loss.item(),
                        loss.item()))

            plotter.plot('Total loss', 'train', 'Total loss', step,
                         loss.item())
            plotter.plot('Regression_loss', 'train', 'Regression_loss', step,
                         reg_loss.item())
            plotter.plot('Classfication_loss', 'train', 'Classfication_loss',
                         step, cls_loss.item())

            # log learning_rate
            current_lr = optimizer.param_groups[0]['lr']
            plotter.plot('learning rate', 'train', 'Classfication_loss', step,
                         current_lr)

            # increment step counter
            step += 1

            if step % opt.save_interval == 0 and step > 0:
                save_checkpoint(
                    model,
                    f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
                print('saved checkpoint...')

        # adjust learning rate via learning rate scheduler
        scheduler.step(np.mean(epoch_loss))

        if epoch % opt.val_interval == 0:

            print('Evaluating model')

            model.eval()
            loss_regression_ls = [
            ]  # here we append new regression losses for each step
            loss_classification_ls = [
            ]  # here we append new classification losses for each step

            for iter, data in enumerate(val_generator):

                with torch.no_grad():
                    ##########################################
                    # TODO: implement me!
                    # Made by Davide Rezzoli
                    #########################################
                    _, reg, clas, anchor = model(data['img'].cuda())
                    cls_loss, reg_loss = criterion(clas, reg, anchor,
                                                   data['annot'].cuda())

                    loss_classification_ls.append(cls_loss.item())
                    loss_regression_ls.append(reg_loss.item())

                    cls_loss = np.mean(loss_classification_ls)
                    reg_loss = np.mean(loss_regression_ls)
                    loss = cls_loss + reg_loss

            # LOGGING
            print(
                'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                .format(epoch, opt.num_epochs, cls_loss, reg_loss, loss))

            plotter.plot('Total loss', 'val', 'Total loss', step, loss.item())
            plotter.plot('Regression_loss', 'val', 'Regression_loss', step,
                         reg_loss.item())
            plotter.plot('Classfication_loss', 'val', 'Classfication_loss',
                         step, cls_loss.item())

            # Save model checkpoint if new best validation loss
            if loss + opt.es_min_delta < best_loss:
                best_loss = loss
                best_epoch = epoch

                save_checkpoint(
                    model,
                    f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')

            model.train()

            # Early stopping
            if epoch - best_epoch > opt.es_patience > 0:
                print(
                    '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                    .format(epoch, best_loss))
                break
Esempio n. 15
0
def main():
    global args, best_mIoU, NUM_CLASSES, COMB_DICTs
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    if args.visdom:
        global plotter
        plotter = VisdomLinePlotter(env_name=args.name + '_' + args.dataset)

    if args.dataset == 'HelenFace':
        COMB_DICT0 = {
            0: 0,
            1: 1,
            2: 1,
            3: 1,
            4: 1,
            5: 1,
            6: 1,
            7: 1,
            8: 1,
            9: 1,
            10: 2
        }
        COMB_DICT1 = {
            0: 0,
            1: 1,
            2: 2,
            3: 2,
            4: 2,
            5: 2,
            6: 3,
            7: 4,
            8: 4,
            9: 4,
            10: 5
        }
        NUM_CLASSES = [3, 6, 11]
    elif args.dataset == 'PASCALPersonParts':
        COMB_DICT0 = {0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 2, 6: 2}
        COMB_DICT1 = {0: 0, 1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 4}
        NUM_CLASSES = [3, 5, 7]
    elif args.dataset == 'ATR':
        COMB_DICT0 = {
            0: 0,
            1: 1,
            2: 1,
            3: 1,
            4: 2,
            5: 3,
            6: 3,
            7: 3,
            8: 3,
            9: 3,
            10: 3,
            11: 1,
            12: 3,
            13: 3,
            14: 2,
            15: 2,
            16: 4,
            17: 4
        }
        COMB_DICT1 = {
            0: 0,
            1: 1,
            2: 1,
            3: 2,
            4: 3,
            5: 5,
            6: 5,
            7: 5,
            8: 5,
            9: 6,
            10: 6,
            11: 2,
            12: 7,
            13: 7,
            14: 4,
            15: 4,
            16: 8,
            17: 8
        }
        NUM_CLASSES = [5, 9, 18]
    COMB_DICTs = [COMB_DICT0, COMB_DICT1]

    args.n_class = NUM_CLASSES[-1]
    print(args.name + '_' + args.dataset, 'n_class: ', args.n_class)

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(ImageLoader(
        '../Dataset/',
        args.dataset,
        'train.txt',
        ignore_label=args.ignore_label,
        n_imgs=args.num_trainimgs,
        crop_size=input_size,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(ImageLoader(
        '../Dataset/',
        args.dataset,
        'test.txt',
        ignore_label=args.ignore_label,
        n_imgs=10000,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                              batch_size=1,
                                              shuffle=True,
                                              **kwargs)

    net = _get_model_instance(args.name)(num_classes=NUM_CLASSES,
                                         pretrain=True,
                                         nIn=3)
    if args.cuda:
        net.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mIoU = checkpoint['best_mIoU']
            net.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    criterion = cross_entropy_loss
    parameters = filter(lambda p: p.requires_grad, net.parameters())
    optimizer = optim.Adam(parameters, lr=args.lr, weight_decay=1e-4)

    n_parameters = sum([p.data.nelement() for p in net.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    if args.test:
        print('Epoch: %d' % (args.start_epoch))
        test_acc, test_mIoU = test(test_loader,
                                   net,
                                   criterion,
                                   args.start_epoch,
                                   showall=True)
        sys.exit()

    for epoch in range(args.start_epoch, args.epochs + 1):
        # # update learning rate
        lr = adjust_learning_rate(args.lr, optimizer, epoch)
        if args.visdom:
            plotter.plot('lr',
                         'learning rate',
                         epoch,
                         lr,
                         exp_name=args.name + '_' + args.dataset)

        # train for one epoch
        cudnn.benchmark = True
        train(train_loader, net, criterion, optimizer, epoch)

        # evaluate on validation set
        cudnn.benchmark = False
        acc, mIoU = test(test_loader, net, criterion, epoch)

        # record best acc and save checkpoint
        is_best = mIoU > best_mIoU
        best_mIoU = max(mIoU, best_mIoU)

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'best_mIoU': best_mIoU,
                'acc': acc
            },
            is_best,
            exp_name=args.name + '_' + args.dataset,
            filename='checkpoint_%d.pth.tar' % (epoch))
test_auc = roc_auc_score(test_y_true.numpy(),
                         test_y_pred.numpy()[:, 1],
                         average="weighted")
test_aupr = average_precision_score(test_y_true.numpy(),
                                    test_y_pred.numpy()[:, 1],
                                    average="weighted")

print("======================= Finished!")
print('Test Loss: {}\n'.format(test_loss))
print('Test AUROC: {}\n'.format(test_auc))
print('Test AUPR: {}\n'.format(test_aupr))

# Run some simple explanations using Captum

if args.normal:
    plotter = VisdomLinePlotter('aa-normal')
    type = 'Normal'
else:
    plotter = VisdomLinePlotter('aa-adversarial')
    type = 'Adversarial'

print("\n============================================== Explanations")

# Let's try looking at the final layer of the model
cond = LayerConductance(model, model.output)

batch_size = 256

# Reset our dataloader
test_loader = DataLoader(dataset=test_set,
                         batch_size=batch_size,
Esempio n. 17
0
from data.bookcorpus import BookCorpus
from qt_model import QuickThoughts
from utils import checkpoint_training, restore_training, safe_pack_sequence, VisdomLinePlotter
from config import CONFIG
from pprint import pformat, pprint
from tqdm import tqdm
import gensim.downloader as api
import os
import json
from eval import test_performance

_LOGGER = logging.getLogger(__name__)

if __name__ == '__main__':

    plotter = VisdomLinePlotter()

    #setting up training
    os.mkdir(CONFIG['checkpoint_dir'])
    config_filepath = "{}/{}".format(CONFIG['checkpoint_dir'], 'config.json')
    with open(config_filepath, 'w') as fp:
        _LOGGER.info(pformat(CONFIG))
        json.dump(CONFIG, fp)
    _LOGGER.info("Wrote config to file: {}".format(config_filepath))

    #load in word vectors
    WV_MODEL = api.load(CONFIG['embedding'])

    # create dataset
    bookcorpus = BookCorpus(CONFIG['data_path'], WV_MODEL.vocab)
    train_iter = DataLoader(
from visdom import Visdom

from utils import AverageMeter, VisdomLinePlotter

from ShapDataset import ShapDataset, ShapDatasetChunked, ShapDatasetDict, ShapDatasetTop

import argparse

import pickle
import random

global plotter
plotter = VisdomLinePlotter('attack-svm')

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.svm import SVC
from sklearn import preprocessing
from sklearn.svm import libsvm

import tqdm

import os

arg_parser = argparse.ArgumentParser(
    description='Use Captum to explain samples from RETAIN')

arg_parser.add_argument('--all',
                        '-a',
                        help='Train all SVM models',
                        action='store_true')
Esempio n. 19
0
class Trainer():
    def __init__(
        self,
        name = 'default',
        results_dir = 'results',
        models_dir = 'models',
        transfer_from_checkpoint = None,
        plotting = False,
        base_dir = './',
        image_size = 128,
        network_capacity = 16,
        fmap_max = 512,
        transparent = False,
        batch_size = 16,
        mixed_prob = 0.9,
        gradient_accumulate_every=1,
        lr = 2e-4,
        lr_mlp = 0.1,
        ttur_mult = 2,
        rel_disc_loss = False,
        num_workers = None,
        save_every = 1000,
        evaluate_every = 1000,
        num_image_tiles = 8,
        trunc_psi = 0.6,
        fp16 = False,
        cl_reg = False,
        no_pl_reg = False,
        fq_layers = [],
        fq_dict_size = 256,
        attn_layers = [],
        no_const = False,
        aug_prob = 0.,
        aug_types = ['translation', 'cutout'],
        top_k_training = False,
        generator_top_k_gamma = 0.99,
        generator_top_k_frac = 0.5,
        dataset_aug_prob = 0.,
        calculate_fid_every = None,
        calculate_fid_num_images = 12800,
        clear_fid_cache = False,
        is_ddp = False,
        rank = 0,
        world_size = 1,
        log = False,
        *args,
        **kwargs
    ):
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir
        self.transfer_from_checkpoint = transfer_from_checkpoint
        self.fid_dir = base_dir / 'fid' / name
        self.config_path = self.models_dir / name / '.config.json'

        assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
        self.image_size = image_size
        self.network_capacity = network_capacity
        self.fmap_max = fmap_max
        self.transparent = transparent

        self.fq_layers = cast_list(fq_layers)
        self.fq_dict_size = fq_dict_size
        self.has_fq = len(self.fq_layers) > 0

        self.attn_layers = cast_list(attn_layers)
        self.no_const = no_const

        self.aug_prob = aug_prob
        self.aug_types = aug_types

        self.lr = lr
        self.lr_mlp = lr_mlp
        self.ttur_mult = ttur_mult
        self.rel_disc_loss = rel_disc_loss
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.mixed_prob = mixed_prob

        self.num_image_tiles = num_image_tiles
        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        self.av = None
        self.trunc_psi = trunc_psi

        self.no_pl_reg = no_pl_reg
        self.pl_mean = None

        self.gradient_accumulate_every = gradient_accumulate_every

        assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex is not available for you to use mixed precision training'
        self.fp16 = fp16

        self.cl_reg = cl_reg

        self.d_loss = 0
        self.g_loss = 0
        self.q_loss = None
        self.last_gp_loss = None
        self.last_cr_loss = None
        self.last_fid = None

        self.pl_length_ma = EMA(0.99)
        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.calculate_fid_every = calculate_fid_every
        self.calculate_fid_num_images = calculate_fid_num_images
        self.clear_fid_cache = clear_fid_cache

        self.top_k_training = top_k_training
        self.generator_top_k_gamma = generator_top_k_gamma
        self.generator_top_k_frac = generator_top_k_frac

        assert not (is_ddp and cl_reg), 'Contrastive loss regularization does not work well with multi GPUs yet'
        self.is_ddp = is_ddp
        self.is_main = rank == 0
        self.rank = rank
        self.world_size = world_size
        self.plotting = plotting
        if plotting:
            self.image_plotter = VisdomImagePlotter()
            self.line_plotter = VisdomLinePlotter()
        else:
            self.image_plotter = None
            self.line_plotter = None
    @property
    def image_extension(self):
        return 'jpg' if not self.transparent else 'png'

    @property
    def checkpoint_num(self):
        return floor(self.steps // self.save_every)

    @property
    def hparams(self):
        return {'image_size': self.image_size, 'network_capacity': self.network_capacity}
        
    def init_GAN(self):
        args, kwargs = self.GAN_params
        self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, fmap_max = self.fmap_max, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, **kwargs)

        if self.is_ddp:
            ddp_kwargs = {'device_ids': [self.rank]}
            self.S_ddp = DDP(self.GAN.S, **ddp_kwargs)
            self.G_ddp = DDP(self.GAN.G, **ddp_kwargs)
            self.D_ddp = DDP(self.GAN.D, **ddp_kwargs)
            self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs)

    def write_config(self):
        self.config_path.write_text(json.dumps(self.config()))

    def load_config(self):
        config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
        self.image_size = config['image_size']
        self.network_capacity = config['network_capacity']
        self.transparent = config['transparent']
        self.fq_layers = config['fq_layers']
        self.fq_dict_size = config['fq_dict_size']
        self.fmap_max = config.pop('fmap_max', 512)
        self.attn_layers = config.pop('attn_layers', [])
        self.no_const = config.pop('no_const', False)
        self.lr_mlp = config.pop('lr_mlp', 0.1)
        del self.GAN
        self.init_GAN()

    def config(self):
        return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'lr_mlp': self.lr_mlp, 'transparent': self.transparent, 'fq_layers': self.fq_layers, 'fq_dict_size': self.fq_dict_size, 'attn_layers': self.attn_layers, 'no_const': self.no_const}

    def set_data_src(self, folder):
        self.dataset = Dataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob)
        num_workers = num_workers = default(self.num_workers, NUM_CORES if not self.is_ddp else 0)
        sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None
        dataloader = data.DataLoader(self.dataset, num_workers = num_workers, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True)
        self.loader = cycle(dataloader)

        # auto set augmentation prob for user if dataset is detected to be low
        num_samples = len(self.dataset)
        if not exists(self.aug_prob) and num_samples < 1e5:
            self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
            print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%')

    def train(self):
        assert exists(self.loader), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'

        if not exists(self.GAN):
            self.init_GAN()

        self.GAN.train()
        total_disc_loss = torch.tensor(0.).cuda(self.rank)
        total_gen_loss = torch.tensor(0.).cuda(self.rank)

        batch_size = math.ceil(self.batch_size / self.world_size)

        image_size = self.GAN.G.image_size
        latent_dim = self.GAN.G.latent_dim
        num_layers = self.GAN.G.num_layers

        aug_prob   = self.aug_prob
        aug_types  = self.aug_types
        aug_kwargs = {'prob': aug_prob, 'types': aug_types}

        apply_gradient_penalty = self.steps % 4 == 0
        apply_path_penalty = not self.no_pl_reg and self.steps > 5000 and self.steps % 32 == 0
        apply_cl_reg_to_generated = self.steps > 20000

        S = self.GAN.S if not self.is_ddp else self.S_ddp
        G = self.GAN.G if not self.is_ddp else self.G_ddp
        D = self.GAN.D if not self.is_ddp else self.D_ddp
        D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp

        backwards = partial(loss_backwards, self.fp16)

        if exists(self.GAN.D_cl):
            self.GAN.D_opt.zero_grad()

            if apply_cl_reg_to_generated:
                for i in range(self.gradient_accumulate_every):
                    get_latents_fn = mixed_list if random() < self.mixed_prob else noise_list
                    style = get_latents_fn(batch_size, num_layers, latent_dim, device=self.rank)
                    noise = image_noise(batch_size, image_size, device=self.rank)

                    w_space = latent_to_w(self.GAN.S, style)
                    w_styles = styles_def_to_tensor(w_space)

                    generated_images = self.GAN.G(w_styles, noise)
                    self.GAN.D_cl(generated_images.clone().detach(), accumulate=True)

            for i in range(self.gradient_accumulate_every):
                image_batch = next(self.loader).cuda(self.rank)
                self.GAN.D_cl(image_batch, accumulate=True)

            loss = self.GAN.D_cl.calculate_loss()
            self.last_cr_loss = loss.clone().detach().item()
            backwards(loss, self.GAN.D_opt, loss_id = 0)

            self.GAN.D_opt.step()

        # train discriminator

        avg_pl_length = self.pl_mean
        self.GAN.D_opt.zero_grad()

        for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, S, G]):
            get_latents_fn = mixed_list if random() < self.mixed_prob else noise_list
            style = get_latents_fn(batch_size, num_layers, latent_dim, device=self.rank)
            noise = image_noise(batch_size, image_size, device=self.rank)

            w_space = latent_to_w(S, style)
            w_styles = styles_def_to_tensor(w_space)

            generated_images = G(w_styles, noise)
            fake_output, fake_q_loss = D_aug(generated_images.clone().detach(), detach = True, **aug_kwargs)

            image_batch = next(self.loader).cuda(self.rank)
            image_batch.requires_grad_()
            real_output, real_q_loss = D_aug(image_batch, **aug_kwargs)

            real_output_loss = real_output
            fake_output_loss = fake_output

            if self.rel_disc_loss:
                real_output_loss = real_output_loss - fake_output.mean()
                fake_output_loss = fake_output_loss - real_output.mean()

            divergence = (F.relu(1 + real_output_loss) + F.relu(1 - fake_output_loss)).mean()
            disc_loss = divergence

            if self.has_fq:
                quantize_loss = (fake_q_loss + real_q_loss).mean()
                self.q_loss = float(quantize_loss.detach().item())

                disc_loss = disc_loss + quantize_loss

            if apply_gradient_penalty:
                gp = gradient_penalty(image_batch, real_output)
                self.last_gp_loss = gp.clone().detach().item()
                self.track(y=self.last_gp_loss, var_name ='Penalty', name = 'GP',title ="Penalties", x=self.steps)
                disc_loss = disc_loss + gp

            disc_loss = disc_loss / self.gradient_accumulate_every
            disc_loss.register_hook(raise_if_nan)
            backwards(disc_loss, self.GAN.D_opt, loss_id = 1)

            total_disc_loss += divergence.detach().item() / self.gradient_accumulate_every

        self.d_loss = float(total_disc_loss)
        self.track(y=self.d_loss, var_name ='Loss', name='D',title = 'Training Loss',x=self.steps)

        self.GAN.D_opt.step()

        # train generator

        self.GAN.G_opt.zero_grad()

        for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[S, G, D_aug]):
            style = get_latents_fn(batch_size, num_layers, latent_dim, device=self.rank)
            noise = image_noise(batch_size, image_size, device=self.rank)

            w_space = latent_to_w(S, style)
            w_styles = styles_def_to_tensor(w_space)

            generated_images = G(w_styles, noise)
            fake_output, _ = D_aug(generated_images, **aug_kwargs)
            fake_output_loss = fake_output

            if self.top_k_training:
                epochs = (self.steps * batch_size * self.gradient_accumulate_every) / len(self.dataset)
                k_frac = max(self.generator_top_k_gamma ** epochs, self.generator_top_k_frac)
                k = math.ceil(batch_size * k_frac)

                if k != batch_size:
                    fake_output_loss, _ = fake_output_loss.topk(k=k, largest=False)

            loss = fake_output_loss.mean()
            gen_loss = loss

            if apply_path_penalty:
                pl_lengths = calc_pl_lengths(w_styles, generated_images)
                avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())

                if not is_empty(self.pl_mean):
                    pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
                    if not torch.isnan(pl_loss):
                        gen_loss = gen_loss + pl_loss

            gen_loss = gen_loss / self.gradient_accumulate_every
            gen_loss.register_hook(raise_if_nan)
            backwards(gen_loss, self.GAN.G_opt, loss_id = 2)

            total_gen_loss += loss.detach().item() / self.gradient_accumulate_every

        self.g_loss = float(total_gen_loss)
        self.track(y=self.g_loss, var_name ='Loss', name='G',title = 'Training Loss',x=self.steps)

        self.GAN.G_opt.step()

        # calculate moving averages

        if apply_path_penalty and not np.isnan(avg_pl_length):
            self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
            self.track(y=self.pl_mean,var_name ='Penalty', name='PL',title = 'Penalties',x=self.steps)

        if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
            self.GAN.EMA()

        if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2:
            self.GAN.reset_parameter_averaging()

        # save from NaN errors

        if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
            print(f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}')
            self.load(self.checkpoint_num)
            raise NanException

        # periodically save results

        if self.is_main:
            if self.steps % self.save_every == 0:
                self.save(self.checkpoint_num)

            if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 2500):
                self.evaluate(floor(self.steps / self.evaluate_every))

            if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
                num_batches = math.ceil(self.calculate_fid_num_images / self.batch_size)
                fid = self.calculate_fid(num_batches)
                self.last_fid = fid

                with open(str(self.results_dir / self.name / f'fid_scores.txt'), 'a') as f:
                    f.write(f'{self.steps},{fid}\n')

        self.steps += 1
        self.av = None

    @torch.no_grad()
    def evaluate(self, num = 0, trunc = 1.0):
        self.GAN.eval()
        ext = self.image_extension
        num_rows = self.num_image_tiles
    
        latent_dim = self.GAN.G.latent_dim
        image_size = self.GAN.G.image_size
        num_layers = self.GAN.G.num_layers

        # latents and noise

        latents = noise_list(num_rows ** 2, num_layers, latent_dim, device=self.rank)
        n = image_noise(num_rows ** 2, image_size, device=self.rank)

        # regular

        generated_images = self.generate_truncated(self.GAN.S, self.GAN.G, latents, n, trunc_psi = self.trunc_psi)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)
        if self.plotting:
            self.image_plotter.plot(vutils.make_grid(generated_images, padding=2, normalize=True),name="generator-S-output")
        # moving averages

        generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows)
        if self.plotting:
            self.image_plotter.plot(vutils.make_grid(generated_images, padding=2, normalize=True),name="generator-SE-output")

        # mixing regularities

        def tile(a, dim, n_tile):
            init_dim = a.size(dim)
            repeat_idx = [1] * a.dim()
            repeat_idx[dim] = n_tile
            a = a.repeat(*(repeat_idx))
            order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda(self.rank)
            return torch.index_select(a, dim, order_index)

        nn = noise(num_rows, latent_dim, device=self.rank)
        tmp1 = tile(nn, 0, num_rows)
        tmp2 = nn.repeat(num_rows, 1)

        tt = int(num_layers / 2)
        mixed_latents = [(tmp1, tt), (tmp2, num_layers - tt)]

        generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, mixed_latents, n, trunc_psi = self.trunc_psi)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-mr.{ext}'), nrow=num_rows)

    @torch.no_grad()
    def calculate_fid(self, num_batches):
        from pytorch_fid import fid_score
        torch.cuda.empty_cache()

        real_path = self.fid_dir / 'real'
        fake_path = self.fid_dir / 'fake'

        # remove any existing files used for fid calculation and recreate directories

        if not real_path.exists() or self.clear_fid_cache:
            rmtree(real_path, ignore_errors=True)
            os.makedirs(real_path)

            for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'):
                real_batch = next(self.loader)
                for k, image in enumerate(real_batch.unbind(0)):
                    filename = str(k + batch_num * self.batch_size)
                    torchvision.utils.save_image(image, str(real_path / f'{filename}.png'))

        # generate a bunch of fake images in results / name / fid_fake

        rmtree(fake_path, ignore_errors=True)
        os.makedirs(fake_path)

        self.GAN.eval()
        ext = self.image_extension

        latent_dim = self.GAN.G.latent_dim
        image_size = self.GAN.G.image_size
        num_layers = self.GAN.G.num_layers

        for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'):
            # latents and noise
            latents = noise_list(self.batch_size, num_layers, latent_dim, device=self.rank)
            noise = image_noise(self.batch_size, image_size, device=self.rank)

            # moving averages
            generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, noise, trunc_psi = self.trunc_psi)

            for j, image in enumerate(generated_images.unbind(0)):
                torchvision.utils.save_image(image, str(fake_path / f'{str(j + batch_num * self.batch_size)}-ema.{ext}'))

        return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, noise.device, 2048)

    @torch.no_grad()
    def truncate_style(self, tensor, trunc_psi = 0.75):
        S = self.GAN.S
        batch_size = self.batch_size
        latent_dim = self.GAN.G.latent_dim

        if not exists(self.av):
            z = noise(2000, latent_dim, device=self.rank)
            samples = evaluate_in_chunks(batch_size, S, z).cpu().numpy()
            self.av = np.mean(samples, axis = 0)
            self.av = np.expand_dims(self.av, axis = 0)

        av_torch = torch.from_numpy(self.av).cuda(self.rank)
        tensor = trunc_psi * (tensor - av_torch) + av_torch
        return tensor

    @torch.no_grad()
    def truncate_style_defs(self, w, trunc_psi = 0.75):
        w_space = []
        for tensor, num_layers in w:
            tensor = self.truncate_style(tensor, trunc_psi = trunc_psi)            
            w_space.append((tensor, num_layers))
        return w_space

    @torch.no_grad()
    def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8):
        w = map(lambda t: (S(t[0]), t[1]), style)
        w_truncated = self.truncate_style_defs(w, trunc_psi = trunc_psi)
        w_styles = styles_def_to_tensor(w_truncated)
        generated_images = evaluate_in_chunks(self.batch_size, G, w_styles, noi)
        return generated_images.clamp_(0., 1.)

    @torch.no_grad()
    def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, num_steps = 100, save_frames = False):
        self.GAN.eval()
        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.G.latent_dim
        image_size = self.GAN.G.image_size
        num_layers = self.GAN.G.num_layers

        # latents and noise

        latents_low = noise(num_rows ** 2, latent_dim, device=self.rank)
        latents_high = noise(num_rows ** 2, latent_dim, device=self.rank)
        n = image_noise(num_rows ** 2, image_size, device=self.rank)

        ratios = torch.linspace(0., 8., num_steps)

        frames = []
        for ratio in tqdm(ratios):
            interp_latents = slerp(ratio, latents_low, latents_high)
            latents = [(interp_latents, num_layers)]
            generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi)
            images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
            pil_image = transforms.ToPILImage()(images_grid.cpu())
            
            if self.transparent:
                background = Image.new("RGBA", pil_image.size, (255, 255, 255))
                pil_image = Image.alpha_composite(background, pil_image)
                
            frames.append(pil_image)

        frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True)

        if save_frames:
            folder_path = (self.results_dir / self.name / f'{str(num)}')
            folder_path.mkdir(parents=True, exist_ok=True)
            for ind, frame in enumerate(frames):
                frame.save(str(folder_path / f'{str(ind)}.{ext}'))

    def print_log(self):
        data = [
            ('G', self.g_loss),
            ('D', self.d_loss),
            ('GP', self.last_gp_loss),
            ('PL', self.pl_mean),
            ('CR', self.last_cr_loss),
            ('Q', self.q_loss),
            ('FID', self.last_fid)
        ]

        data = [d for d in data if exists(d[1])]
        log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
        print(log)

    def track(self, y, x, var_name, name,title):
        if not exists(self.line_plotter):
            return
        self.line_plotter.plot(var_name = var_name, split_name= name, 
                    title_name = title ,y=y,x = x)

    def model_name(self, num):
        return str(self.models_dir / self.name / f'model_{num}.pt')

    def init_folders(self):
        (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
        (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)

    def clear(self):
        rmtree(str(self.models_dir / self.name), True)
        rmtree(str(self.results_dir / self.name), True)
        rmtree(str(self.fid_dir), True)
        rmtree(str(self.config_path), True)
        self.init_folders()

    def save(self, num):
        save_data = {
            'GAN': self.GAN.state_dict(),
            'version': __version__
        }

        if self.GAN.fp16:
            save_data['amp'] = amp.state_dict()

        torch.save(save_data, self.model_name(num))
        self.write_config()

    def load(self, num = -1):
        load_data=None
        self.load_config()
        print(self.transfer_from_checkpoint)
        name = num
        if self.transfer_from_checkpoint:
            print("yeah boi")
            load_data = torch.load(self.transfer_from_checkpoint)
            name = 0
        
        elif num == -1:
            file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')]
            saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))
            if len(saved_nums) == 0:
                return
            name = saved_nums[-1]
            print(f'continuing from previous epoch - {name}')

        self.steps = name * self.save_every
        if not load_data:
            load_data = torch.load(self.model_name(name))

        if 'version' in load_data:
            print(f"loading from version {load_data['version']}")

        try:
            self.GAN.load_state_dict(load_data['GAN'])
        except Exception as e:
            print('unable to load save model. please try downgrading the package to the version specified by the saved model')
            raise e
        if self.GAN.fp16 and 'amp' in load_data:
            amp.load_state_dict(load_data['amp'])