コード例 #1
0
def train():
    if os.name == 'nt':
        data_dir = 'C:/Users/marky/Documents/Courses/saliency/datasets/DUTS/'
    else:
        data_dir = os.getenv(
            "HOME") + '/Documents/Courses/EE298-CV/finalproj/datasets/DUTS/'
    tra_image_dir = 'DUTS-TR/DUTS-TR-Image/'
    tra_label_dir = 'DUTS-TR/DUTS-TR-Mask/'
    test_image_dir = 'DUTS-TE/DUTS-TE-Image/'
    test_label_dir = 'DUTS-TE/DUTS-TE-Mask/'

    image_ext = '.jpg'
    label_ext = '.png'

    model_dir = "./saved_models/basnet_bsi_aug/"
    resume_train = False
    resume_model_path = model_dir + "basnet_bsi_epoch_81_itr_106839_train_1.511335_tar_0.098392.pth"
    last_epoch = 1
    epoch_num = 100000
    batch_size_train = 8
    batch_size_val = 1
    train_num = 0
    val_num = 0
    enableInpaintAug = False
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # ------- 5. training process --------
    print("---start training...")
    test_increments = 6250
    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 1
    next_test = ite_num + 0
    visdom_tab_title = "StructArchWithoutStructImgs(WithHFlip)"
    ############
    ############
    ############
    ############

    tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)
    print("data_dir + tra_image_dir + '*' + image_ext: ",
          data_dir + tra_image_dir + '*' + image_ext)
    test_img_name_list = glob.glob(data_dir + test_image_dir + '*' + image_ext)
    print("data_dir + test_image_dir + '*' + image_ext: ",
          data_dir + test_image_dir + '*' + image_ext)

    tra_lbl_name_list = []
    for img_path in tra_img_name_list:
        img_name = img_path.split("/")[-1]
        aaa = img_name.split(".")
        bbb = aaa[0:-1]
        imidx = bbb[0]
        for i in range(1, len(bbb)):
            imidx = imidx + "." + bbb[i]
        tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
    test_lbl_name_list = []
    for img_path in test_img_name_list:
        img_name = img_path.split("/")[-1]
        aaa = img_name.split(".")
        bbb = aaa[0:-1]
        imidx = bbb[0]
        for i in range(1, len(bbb)):
            imidx = imidx + "." + bbb[i]
        test_lbl_name_list.append(data_dir + test_label_dir + imidx +
                                  label_ext)

    print("---")
    print("train images: ", len(tra_img_name_list))
    print("train labels: ", len(tra_lbl_name_list))
    print("---")

    print("---")
    print("test images: ", len(test_img_name_list))
    print("test labels: ", len(test_lbl_name_list))
    print("---")

    train_num = len(tra_img_name_list)
    test_num = len(test_img_name_list)
    salobj_dataset = SalObjDataset(img_name_list=tra_img_name_list,
                                   lbl_name_list=tra_lbl_name_list,
                                   transform=transforms.Compose([
                                       RescaleT(256),
                                       RandomCrop(224),
                                       ToTensorLab(flag=0)
                                   ]),
                                   category="train",
                                   enableInpaintAug=enableInpaintAug)
    salobj_dataset_test = SalObjDataset(img_name_list=test_img_name_list,
                                        lbl_name_list=test_lbl_name_list,
                                        transform=transforms.Compose([
                                            RescaleT(256),
                                            RandomCrop(224),
                                            ToTensorLab(flag=0)
                                        ]),
                                        category="test",
                                        enableInpaintAug=enableInpaintAug)
    salobj_dataloader = DataLoader(salobj_dataset,
                                   batch_size=batch_size_train,
                                   shuffle=True,
                                   num_workers=1)
    salobj_dataloader_test = DataLoader(salobj_dataset_test,
                                        batch_size=batch_size_val,
                                        shuffle=True,
                                        num_workers=1)

    # ------- 3. define model --------
    # define the net
    net = BASNet(3, 1)
    if resume_train:
        # print("resume_model_path:", resume_model_path)
        checkpoint = torch.load(resume_model_path)
        net.load_state_dict(checkpoint)
    if torch.cuda.is_available():
        net.to(device)

    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(),
                           lr=0.001,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)

    plotter = VisdomLinePlotter(env_name=visdom_tab_title)

    best_ave_mae = 100000
    best_max_fmeasure = 0
    best_relaxed_fmeasure = 0
    best_ave_maxf = 0
    best_own_RelaxedFmeasure = 0
    for epoch in range(last_epoch - 1, epoch_num):
        ### Train network
        train_loss0 = AverageMeter()
        train_loss1 = AverageMeter()
        train_loss2 = AverageMeter()
        train_loss3 = AverageMeter()
        train_loss4 = AverageMeter()
        train_loss5 = AverageMeter()
        train_loss6 = AverageMeter()
        train_loss7 = AverageMeter()
        train_struct_loss1 = AverageMeter()
        train_struct_loss2 = AverageMeter()
        train_struct_loss3 = AverageMeter()
        train_struct_loss4 = AverageMeter()
        train_struct_loss5 = AverageMeter()
        train_struct_loss6 = AverageMeter()
        train_struct_loss7 = AverageMeter()

        test_loss0 = AverageMeter()
        test_loss1 = AverageMeter()
        test_loss2 = AverageMeter()
        test_loss3 = AverageMeter()
        test_loss4 = AverageMeter()
        test_loss5 = AverageMeter()
        test_loss6 = AverageMeter()
        test_loss7 = AverageMeter()
        test_struct_loss1 = AverageMeter()
        test_struct_loss2 = AverageMeter()
        test_struct_loss3 = AverageMeter()
        test_struct_loss4 = AverageMeter()
        test_struct_loss5 = AverageMeter()
        test_struct_loss6 = AverageMeter()
        test_struct_loss7 = AverageMeter()

        average_mae = AverageMeter()
        average_maxf = AverageMeter()
        average_relaxedf = AverageMeter()
        average_own_RelaxedFMeasure = AverageMeter()
        net.train()
        for i, data in enumerate(salobj_dataloader):
            ite_num = ite_num + 1
            ite_num4val = ite_num4val + 1
            inputs, labels, labels_struct = data['image'], data['label'], data[
                'label2']

            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)
            labels_struct = labels_struct.type(torch.FloatTensor)

            # wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v, labels_struct_v = Variable(
                    inputs.to(device), requires_grad=False), Variable(
                        labels.to(device), requires_grad=False), Variable(
                            labels_struct.to(device), requires_grad=False)
            else:
                inputs_v, labels_v, labels_struct_v = Variable(
                    inputs, requires_grad=False), Variable(
                        labels,
                        requires_grad=False), Variable(labels_struct,
                                                       requires_grad=False)

            # y zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct = net(
                inputs_v)
            loss2, loss = muti_bce_loss_fusion(
                d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct,
                d3_struct, d4_struct, d5_struct, d6_struct, d7_struct,
                labels_v, train_loss0, train_loss1, train_loss2, train_loss3,
                train_loss4, train_loss5, train_loss6, train_loss7,
                train_struct_loss1, train_struct_loss2, train_struct_loss3,
                train_struct_loss4, train_struct_loss5, train_struct_loss6,
                train_struct_loss7)
            loss.backward()
            optimizer.step()

            # # print statistics
            running_loss += loss.data
            running_tar_loss += loss2.data

            # del temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct, loss2, loss

            print(
                "[train epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f "
                % (epoch + 1, epoch_num,
                   (i + 1) * batch_size_train, train_num, ite_num,
                   running_loss / ite_num4val, running_tar_loss / ite_num4val))
        plotter.plot('loss0', 'train', 'Main Loss 0', epoch + 1,
                     float(train_loss0.avg))
        plotter.plot('loss1', 'train', 'Main Loss 1', epoch + 1,
                     float(train_loss1.avg))
        plotter.plot('loss2', 'train', 'Main Loss 2', epoch + 1,
                     float(train_loss2.avg))
        plotter.plot('loss3', 'train', 'Main Loss 3', epoch + 1,
                     float(train_loss3.avg))
        plotter.plot('loss4', 'train', 'Main Loss 4', epoch + 1,
                     float(train_loss4.avg))
        plotter.plot('loss5', 'train', 'Main Loss 5', epoch + 1,
                     float(train_loss5.avg))
        plotter.plot('loss6', 'train', 'Main Loss 6', epoch + 1,
                     float(train_loss6.avg))
        plotter.plot('loss7', 'train', 'Main Loss 7', epoch + 1,
                     float(train_loss7.avg))
        plotter.plot('structloss1', 'train', 'Struct Loss 1', epoch + 1,
                     float(train_struct_loss1.avg))
        plotter.plot('structloss2', 'train', 'Struct Loss 2', epoch + 1,
                     float(train_struct_loss2.avg))
        plotter.plot('structloss3', 'train', 'Struct Loss 3', epoch + 1,
                     float(train_struct_loss3.avg))
        plotter.plot('structloss4', 'train', 'Struct Loss 4', epoch + 1,
                     float(train_struct_loss4.avg))
        plotter.plot('structloss5', 'train', 'Struct Loss 5', epoch + 1,
                     float(train_struct_loss5.avg))
        plotter.plot('structloss6', 'train', 'Struct Loss 6', epoch + 1,
                     float(train_struct_loss6.avg))
        plotter.plot('structloss7', 'train', 'Struct Loss 7', epoch + 1,
                     float(train_struct_loss7.avg))

        ### Validate model
        print("---Evaluate model---")
        if ite_num >= next_test:  # test and save model 10000 iterations, due to very large DUTS-TE dataset
            next_test = ite_num + test_increments
            net.eval()
            max_epoch_fmeasure = 0
            for i, data in enumerate(salobj_dataloader_test):
                inputs, labels = data['image'], data['label']
                inputs = inputs.type(torch.FloatTensor)
                labels = labels.type(torch.FloatTensor)
                if torch.cuda.is_available():
                    inputs_v, labels_v = Variable(
                        inputs.to(device),
                        requires_grad=False), Variable(labels.to(device),
                                                       requires_grad=False)
                else:
                    inputs_v, labels_v = Variable(
                        inputs,
                        requires_grad=False), Variable(labels,
                                                       requires_grad=False)
                d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct = net(
                    inputs_v)

                pred = d0[:, 0, :, :]
                pred = normPRED(pred)
                pred = pred.squeeze()
                predict_np = pred.cpu().data.numpy()
                im = Image.fromarray(predict_np * 255).convert('RGB')
                img_name = test_img_name_list[i]
                image = cv2.imread(img_name)
                imo = im.resize((image.shape[1], image.shape[0]),
                                resample=Image.BILINEAR)
                imo = imo.convert("L")  ###  Convert to grayscale 1-channel
                resizedImg_np = np.array(
                    imo)  ### Result is 2D numpy array predicted salient map
                img__lbl_name = test_lbl_name_list[i]
                gt_img = np.array(Image.open(img__lbl_name).convert(
                    "L"))  ### Ground truth salient map

                ### Compute metrics
                result_mae = getMAE(gt_img, resizedImg_np)
                average_mae.update(result_mae, 1)
                precision, recall = getPRCurve(gt_img, resizedImg_np)
                result_maxfmeasure = getMaxFMeasure(precision, recall)
                result_maxfmeasure = result_maxfmeasure.mean()
                average_maxf.update(result_maxfmeasure, 1)
                if (result_maxfmeasure > max_epoch_fmeasure):
                    max_epoch_fmeasure = result_maxfmeasure
                result_relaxedfmeasure = getRelaxedFMeasure(
                    gt_img, resizedImg_np)
                result_ownrelaxedfmeasure = own_RelaxedFMeasure(
                    gt_img, resizedImg_np)
                average_relaxedf.update(result_relaxedfmeasure, 1)
                average_own_RelaxedFMeasure.update(result_ownrelaxedfmeasure,
                                                   1)
                loss2, loss = muti_bce_loss_fusion(
                    d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct,
                    d3_struct, d4_struct, d5_struct, d6_struct, d7_struct,
                    labels_v, test_loss0, test_loss1, test_loss2, test_loss3,
                    test_loss4, test_loss5, test_loss6, test_loss7,
                    test_struct_loss1, test_struct_loss2, test_struct_loss3,
                    test_struct_loss4, test_struct_loss5, test_struct_loss6,
                    test_struct_loss7)
                del d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct, loss2, loss
                print(
                    "[test epoch: %3d/%3d, batch: %5d/%5d, ite: %d] test loss: %3f, tar: %3f "
                    % (epoch + 1, epoch_num, (i + 1) * batch_size_val,
                       test_num, ite_num, running_loss / ite_num4val,
                       running_tar_loss / ite_num4val))
            model_name = model_dir + "basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f.pth" % (
                epoch + 1, ite_num, running_loss / ite_num4val,
                running_tar_loss / ite_num4val)
            torch.save(net.state_dict(), model_name)
            running_loss = 0.0
            running_tar_loss = 0.0
            net.train()  # resume train
            ite_num4val = 1
            if (average_mae.avg < best_ave_mae):
                best_ave_mae = average_mae.avg
                newname = model_dir + "bestMAE/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_mae_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_ave_mae)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (max_epoch_fmeasure > best_max_fmeasure):
                best_max_fmeasure = max_epoch_fmeasure
                newname = model_dir + "bestEpochMaxF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_maxfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_max_fmeasure)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (average_maxf.avg > best_ave_maxf):
                best_ave_maxf = average_maxf.avg
                newname = model_dir + "bestAveMaxF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_avemfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_ave_maxf)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (average_relaxedf.avg > best_relaxed_fmeasure):
                best_relaxed_fmeasure = average_relaxedf.avg
                newname = model_dir + "bestAveRelaxF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_averelaxfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_relaxed_fmeasure)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (average_own_RelaxedFMeasure.avg > best_own_RelaxedFmeasure):
                best_own_RelaxedFmeasure = average_own_RelaxedFMeasure.avg
                newname = model_dir + "bestOwnRelaxedF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_averelaxfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_own_RelaxedFmeasure)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            plotter.plot('loss0', 'test', 'Main Loss 0', epoch + 1,
                         float(test_loss0.avg))
            plotter.plot('loss1', 'test', 'Main Loss 1', epoch + 1,
                         float(test_loss1.avg))
            plotter.plot('loss2', 'test', 'Main Loss 2', epoch + 1,
                         float(test_loss2.avg))
            plotter.plot('loss3', 'test', 'Main Loss 3', epoch + 1,
                         float(test_loss3.avg))
            plotter.plot('loss4', 'test', 'Main Loss 4', epoch + 1,
                         float(test_loss4.avg))
            plotter.plot('loss5', 'test', 'Main Loss 5', epoch + 1,
                         float(test_loss5.avg))
            plotter.plot('loss6', 'test', 'Main Loss 6', epoch + 1,
                         float(test_loss6.avg))
            plotter.plot('loss7', 'test', 'Main Loss 7', epoch + 1,
                         float(test_loss7.avg))
            plotter.plot('structloss1', 'test', 'Struct Loss 1', epoch + 1,
                         float(test_struct_loss1.avg))
            plotter.plot('structloss2', 'test', 'Struct Loss 2', epoch + 1,
                         float(test_struct_loss2.avg))
            plotter.plot('structloss3', 'test', 'Struct Loss 3', epoch + 1,
                         float(test_struct_loss3.avg))
            plotter.plot('structloss4', 'test', 'Struct Loss 4', epoch + 1,
                         float(test_struct_loss4.avg))
            plotter.plot('structloss5', 'test', 'Struct Loss 5', epoch + 1,
                         float(test_struct_loss5.avg))
            plotter.plot('structloss6', 'test', 'Struct Loss 6', epoch + 1,
                         float(test_struct_loss6.avg))
            plotter.plot('structloss7', 'test', 'Struct Loss 7', epoch + 1,
                         float(test_struct_loss7.avg))
            plotter.plot('mae', 'test', 'Average Epoch MAE', epoch + 1,
                         float(average_mae.avg))
            plotter.plot('max_maxf', 'test', 'Max Max Epoch F-Measure',
                         epoch + 1, float(max_epoch_fmeasure))
            plotter.plot('ave_maxf', 'test', 'Average Max F-Measure',
                         epoch + 1, float(average_maxf.avg))
            plotter.plot('ave_relaxedf', 'test', 'Average Relaxed F-Measure',
                         epoch + 1, float(average_relaxedf.avg))
            plotter.plot('own_RelaxedFMeasure', 'test',
                         'Own Average Relaxed F-Measure', epoch + 1,
                         float(average_own_RelaxedFMeasure.avg))
    print('-------------Congratulations! Training Done!!!-------------')
コード例 #2
0
def train(opt):

    global plotter
    plotter = VisdomLinePlotter(env_name='FreiCar Object Detection')

    params = Params(f'projects/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/{params.project_name}/'
    os.makedirs(opt.saved_path, exist_ok=True)

    # define paramteters for model training
    training_params = {
        'batch_size': opt.batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    # define paramteters for model evaluation
    val_params = {
        'batch_size': opt.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    # get training dataset
    training_set = FreiCarDataset(data_dir="./dataloader/data/",
                                  padding=(0, 0, 12, 12),
                                  split='training',
                                  load_real=True)

    # and make data generator from dataset
    training_generator = DataLoader(training_set, **training_params)

    # get validation dataset
    val_set = FreiCarDataset(data_dir="./dataloader/data/",
                             padding=(0, 0, 12, 12),
                             split='validation',
                             load_real=False)

    # and make data generator from dataset
    val_generator = DataLoader(val_set, **val_params)

    # Instantiation of the EfficientDet model
    model = EfficientDetBackbone(num_classes=len(params.obj_list),
                                 compound_coef=opt.compound_coef,
                                 ratios=eval(params.anchors_ratios),
                                 scales=eval(params.anchors_scales))

    # load last weights if training from checkpoint
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0

        try:
            ret = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
            print(
                '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with '
                'different number of classes. The rest of the weights should be loaded already.'
            )

        print(
            f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
        )
    else:
        last_step = 0
        print('[Info] initializing weights...')
        init_weights(model)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)

    optimizer = torch.optim.AdamW(model.parameters(), opt.lr)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)

    # Define training criterion
    criterion = FocalLoss()

    # Set model to train mode
    model.train()

    num_iter_per_epoch = len(training_generator)

    print('Started Training')

    # Train loop
    for epoch in range(opt.num_epochs):
        last_epoch = step // num_iter_per_epoch
        if epoch < last_epoch:
            continue

        epoch_loss = []  # here we append new total losses for each step

        progress_bar = tqdm(training_generator)
        for iter, data in enumerate(progress_bar):
            if iter < step - last_epoch * num_iter_per_epoch:
                progress_bar.update()
                continue

            ##########################################
            # TODO: implement me!
            # Made by DavideRezzoli
            ##########################################
            optimizer.zero_grad()
            _, reg, clas, anchor = model(data['img'].cuda())
            cls_loss, reg_loss = criterion(clas, reg, anchor,
                                           data['annot'].cuda())
            loss = cls_loss.mean() + reg_loss.mean()
            loss.backward()
            optimizer.step()

            epoch_loss.append(float(loss))

            progress_bar.set_description(
                'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                .format(step, epoch, opt.num_epochs, iter + 1,
                        num_iter_per_epoch, cls_loss.item(), reg_loss.item(),
                        loss.item()))

            plotter.plot('Total loss', 'train', 'Total loss', step,
                         loss.item())
            plotter.plot('Regression_loss', 'train', 'Regression_loss', step,
                         reg_loss.item())
            plotter.plot('Classfication_loss', 'train', 'Classfication_loss',
                         step, cls_loss.item())

            # log learning_rate
            current_lr = optimizer.param_groups[0]['lr']
            plotter.plot('learning rate', 'train', 'Classfication_loss', step,
                         current_lr)

            # increment step counter
            step += 1

            if step % opt.save_interval == 0 and step > 0:
                save_checkpoint(
                    model,
                    f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
                print('saved checkpoint...')

        # adjust learning rate via learning rate scheduler
        scheduler.step(np.mean(epoch_loss))

        if epoch % opt.val_interval == 0:

            print('Evaluating model')

            model.eval()
            loss_regression_ls = [
            ]  # here we append new regression losses for each step
            loss_classification_ls = [
            ]  # here we append new classification losses for each step

            for iter, data in enumerate(val_generator):

                with torch.no_grad():
                    ##########################################
                    # TODO: implement me!
                    # Made by Davide Rezzoli
                    #########################################
                    _, reg, clas, anchor = model(data['img'].cuda())
                    cls_loss, reg_loss = criterion(clas, reg, anchor,
                                                   data['annot'].cuda())

                    loss_classification_ls.append(cls_loss.item())
                    loss_regression_ls.append(reg_loss.item())

                    cls_loss = np.mean(loss_classification_ls)
                    reg_loss = np.mean(loss_regression_ls)
                    loss = cls_loss + reg_loss

            # LOGGING
            print(
                'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                .format(epoch, opt.num_epochs, cls_loss, reg_loss, loss))

            plotter.plot('Total loss', 'val', 'Total loss', step, loss.item())
            plotter.plot('Regression_loss', 'val', 'Regression_loss', step,
                         reg_loss.item())
            plotter.plot('Classfication_loss', 'val', 'Classfication_loss',
                         step, cls_loss.item())

            # Save model checkpoint if new best validation loss
            if loss + opt.es_min_delta < best_loss:
                best_loss = loss
                best_epoch = epoch

                save_checkpoint(
                    model,
                    f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')

            model.train()

            # Early stopping
            if epoch - best_epoch > opt.es_patience > 0:
                print(
                    '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                    .format(epoch, best_loss))
                break
コード例 #3
0
        for data, _ in train_loader:
            data = data.requires_grad_()
            data = data.to(device)

            opt.zero_grad()

            output, mu, logvar = model(data)

            loss_obj = loss_function(output, data, mu, logvar)

            loss_obj.backward()
            opt.step()

            losses.update(loss_obj.item(), len(data))

        plotter.plot('loss', 'train', 'Class Loss', epoch, losses.avg)
        print('Epoch {} loss: {}'.format(epoch, losses.avg))

    if args.save is not None:
        with open(args.save, 'wb') as f:
            torch.save(model, f)
else:
    with open(args.load, 'rb') as f:
        model = torch.load(f)

model = model.eval()

# Go through our adversarial samples and get error
total_err = 0
for i in range(len(adv_samples)):
    data = adv_samples.iloc[i]
            labels = labels.to(device)

            opt.zero_grad()

            output = model(data)

            #print('output shape:', output.shape)
            #print('data shape:', data.shape)

            loss_obj = loss(output, data)
            loss_obj.backward()
            opt.step()

            losses.update(loss_obj.data, len(data))

        plotter.plot('loss-ae', 'train', 'Class Loss', epoch, losses.avg.cpu())
        print('Epoch {} loss: {}'.format(epoch, losses.avg.cpu()))

    if args.save is not None:
        with open(args.save, 'wb') as f:
            torch.save(model, f)
else:
    with open(args.load, 'rb') as f:
        model = torch.load(f)

model = model.eval()

print(
    '====================== Calculating final reconstruction loss for all train data'
)
# Get the normal and adv. samples from our train data
コード例 #5
0
ファイル: joint_solver.py プロジェクト: paul028/PoolNet
    def train(self):
        plotter = VisdomLinePlotter(env_name=visdom_tab_title)
        iter_num = 30000  # each batch only train 30000 iters.(This number is just a random choice...)
        aveGrad = 0
        for epoch in range(self.config.epoch):
            r_edge_loss, r_sal_loss, r_sum_loss = 0, 0, 0
            self.net.zero_grad()
            for i, data_batch in enumerate(self.train_loader):
                if (i + 1) == iter_num: break
                edge_image, edge_label, sal_image, sal_label = data_batch[
                    'edge_image'], data_batch['edge_label'], data_batch[
                        'sal_image'], data_batch['sal_label']
                if (sal_image.size(2) != sal_label.size(2)) or (
                        sal_image.size(3) != sal_label.size(3)):
                    print('IMAGE ERROR, PASSING```')
                    continue
                edge_image, edge_label, sal_image, sal_label = Variable(
                    edge_image), Variable(edge_label), Variable(
                        sal_image), Variable(sal_label)
                if self.config.cuda:
                    edge_image, edge_label, sal_image, sal_label = edge_image.cuda(
                    ), edge_label.cuda(), sal_image.cuda(), sal_label.cuda()

                # edge part
                edge_pred = self.net(edge_image, mode=0)
                #edge_loss_fuse = bce2d(edge_pred[0], edge_label)
                edge_loss_fuse = bce2d(edge_pred[0], edge_label)
                print(edge_loss_fuse)
                edge_loss_part = []
                for ix in edge_pred[1]:
                    edge_loss_part.append(
                        bce2d(ix, edge_label, reduction='sum'))
                #edge_loss = (edge_loss_fuse + sum(edge_loss_part)) / (self.iter_size * self.config.batch_size)
                #edge_loss = (sum(edge_loss_part)) / (self.iter_size * self.config.batch_size)
                edge_loss = sum(edge_loss_part) / len(
                    edge_loss_part
                )  #(sum(edge_loss_part)) / (self.iter_size * self.config.batch_size)
                r_edge_loss = edge_loss
                # sal part
                sal_pred = self.net(sal_image, mode=1)
                sal_loss_fuse = F.binary_cross_entropy_with_logits(
                    sal_pred, sal_label)
                sal_loss = sal_loss_fuse  #/ (self.iter_size * self.config.batch_size)
                r_sal_loss = sal_loss

                loss = sal_loss + edge_loss
                r_sum_loss = loss
                loss.backward()

                aveGrad += 1

                # accumulate gradients as done in DSS
                if aveGrad % self.iter_size == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    aveGrad = 0

                if i % (self.show_every // self.config.batch_size) == 0:
                    if i == 0:
                        x_showEvery = 1
                    print(
                        'epoch: [%2d/%2d], iter: [%5d/%5d]  ||  Edge : %10.4f  ||  Sal : %10.4f  ||  Sum : %10.4f'
                        % (epoch, self.config.epoch, i, iter_num,
                           r_edge_loss / x_showEvery, r_sal_loss / x_showEvery,
                           r_sum_loss / x_showEvery))
                    print('Learning rate: ' + str(self.lr))
                    r_edge_loss, r_sal_loss, r_sum_loss = 0, 0, 0

            plotter.plot('edge_loss', 'train',
                         'Balanced Binary Cross Entropy Loss', epoch + 1,
                         float(edge_loss))
            plotter.plot('sal_loss', 'train', 'Binary Cross Entropy Loss',
                         epoch + 1, float(sal_loss))
            plotter.plot('loss', 'train', '', epoch + 1, float(loss))
            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(
                    self.net.state_dict(), '%s/models/epoch_%d.pth' %
                    (self.config.save_folder, epoch + 1))

            if epoch in self.lr_decay_epoch:
                self.lr = self.lr * 0.1
                self.optimizer = Adam(filter(lambda p: p.requires_grad,
                                             self.net.parameters()),
                                      lr=self.lr,
                                      weight_decay=self.wd)

        torch.save(self.net.state_dict(),
                   '%s/models/final.pth' % self.config.save_folder)
コード例 #6
0
                loss = kl_loss(block_log_scores, targets)
                loss.backward()

                #grad clipping
                nn.utils.clip_grad_norm_(
                    filter(lambda p: p.requires_grad, qt.parameters()),
                    CONFIG['norm_threshold'])
                optimizer.step()

                temp.set_description(
                    "loss {:.4f} | failed/skipped {:3d}".format(
                        loss, failed_or_skipped_batches))

                if i % 100 == 0:
                    plotter.plot(
                        'loss', 'train', 'Run: {} Loss'.format(
                            CONFIG['checkpoint_dir'].split('/')[-1]), i,
                        loss.item())

                if i % 5000 == 0:
                    checkpoint_training(CONFIG['checkpoint_dir'], i, qt,
                                        optimizer)
                    qt.eval()
                    for dataset in ['MR', 'CR', 'MPQA', 'SUBJ']:
                        acc = test_performance(qt,
                                               WV_MODEL.vocab,
                                               dataset,
                                               '../data',
                                               seed=int(time.time()))
                        plotter.plot('acc',
                                     dataset,
                                     'Downstream Accuracy',
コード例 #7
0
def main():
    global args, best_mIoU, NUM_CLASSES, COMB_DICTs
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    if args.visdom:
        global plotter
        plotter = VisdomLinePlotter(env_name=args.name + '_' + args.dataset)

    if args.dataset == 'HelenFace':
        COMB_DICT0 = {
            0: 0,
            1: 1,
            2: 1,
            3: 1,
            4: 1,
            5: 1,
            6: 1,
            7: 1,
            8: 1,
            9: 1,
            10: 2
        }
        COMB_DICT1 = {
            0: 0,
            1: 1,
            2: 2,
            3: 2,
            4: 2,
            5: 2,
            6: 3,
            7: 4,
            8: 4,
            9: 4,
            10: 5
        }
        NUM_CLASSES = [3, 6, 11]
    elif args.dataset == 'PASCALPersonParts':
        COMB_DICT0 = {0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 2, 6: 2}
        COMB_DICT1 = {0: 0, 1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 4}
        NUM_CLASSES = [3, 5, 7]
    elif args.dataset == 'ATR':
        COMB_DICT0 = {
            0: 0,
            1: 1,
            2: 1,
            3: 1,
            4: 2,
            5: 3,
            6: 3,
            7: 3,
            8: 3,
            9: 3,
            10: 3,
            11: 1,
            12: 3,
            13: 3,
            14: 2,
            15: 2,
            16: 4,
            17: 4
        }
        COMB_DICT1 = {
            0: 0,
            1: 1,
            2: 1,
            3: 2,
            4: 3,
            5: 5,
            6: 5,
            7: 5,
            8: 5,
            9: 6,
            10: 6,
            11: 2,
            12: 7,
            13: 7,
            14: 4,
            15: 4,
            16: 8,
            17: 8
        }
        NUM_CLASSES = [5, 9, 18]
    COMB_DICTs = [COMB_DICT0, COMB_DICT1]

    args.n_class = NUM_CLASSES[-1]
    print(args.name + '_' + args.dataset, 'n_class: ', args.n_class)

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(ImageLoader(
        '../Dataset/',
        args.dataset,
        'train.txt',
        ignore_label=args.ignore_label,
        n_imgs=args.num_trainimgs,
        crop_size=input_size,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(ImageLoader(
        '../Dataset/',
        args.dataset,
        'test.txt',
        ignore_label=args.ignore_label,
        n_imgs=10000,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                              batch_size=1,
                                              shuffle=True,
                                              **kwargs)

    net = _get_model_instance(args.name)(num_classes=NUM_CLASSES,
                                         pretrain=True,
                                         nIn=3)
    if args.cuda:
        net.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mIoU = checkpoint['best_mIoU']
            net.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    criterion = cross_entropy_loss
    parameters = filter(lambda p: p.requires_grad, net.parameters())
    optimizer = optim.Adam(parameters, lr=args.lr, weight_decay=1e-4)

    n_parameters = sum([p.data.nelement() for p in net.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    if args.test:
        print('Epoch: %d' % (args.start_epoch))
        test_acc, test_mIoU = test(test_loader,
                                   net,
                                   criterion,
                                   args.start_epoch,
                                   showall=True)
        sys.exit()

    for epoch in range(args.start_epoch, args.epochs + 1):
        # # update learning rate
        lr = adjust_learning_rate(args.lr, optimizer, epoch)
        if args.visdom:
            plotter.plot('lr',
                         'learning rate',
                         epoch,
                         lr,
                         exp_name=args.name + '_' + args.dataset)

        # train for one epoch
        cudnn.benchmark = True
        train(train_loader, net, criterion, optimizer, epoch)

        # evaluate on validation set
        cudnn.benchmark = False
        acc, mIoU = test(test_loader, net, criterion, epoch)

        # record best acc and save checkpoint
        is_best = mIoU > best_mIoU
        best_mIoU = max(mIoU, best_mIoU)

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'best_mIoU': best_mIoU,
                'acc': acc
            },
            is_best,
            exp_name=args.name + '_' + args.dataset,
            filename='checkpoint_%d.pth.tar' % (epoch))
コード例 #8
0
ファイル: train.py プロジェクト: larsh0103/scene_synth
class Trainer():
    def __init__(self,
                 device,
                 GAN,
                 dataloader,
                 model_dir='../models',
                 num_epochs=1,
                 criterion=nn.BCELoss(),
                 lr=0.0002,
                 beta1=0.5,
                 nz=100,
                 real_label=1.,
                 fake_label=0.,
                 plotting=False):
        self.device = device
        self.GAN = GAN
        self.dataloader = dataloader
        self.model_dir = model_dir
        self.num_epochs = num_epochs
        self.criterion = criterion
        self.nz = nz
        self.fixed_noise = torch.randn(64, self.nz, 1, 1, device=self.device)
        self.real_label = real_label
        self.fake_label = fake_label
        self.optimizerD = optim.Adam(self.GAN.D.parameters(),
                                     lr=lr,
                                     betas=(beta1, 0.999))
        self.optimizerG = optim.Adam(self.GAN.G.parameters(),
                                     lr=lr,
                                     betas=(beta1, 0.999))
        self.plotting = plotting
        if plotting:
            self.line_plotter = VisdomLinePlotter()
            self.image_plotter = VisdomImagePlotter()

    def train(self):
        # Training Loop

        # Lists to keep track of progress
        img_list = []
        G_losses = []
        D_losses = []
        iters = 0

        print("Starting Training Loop...")
        # For each epoch
        for epoch in range(self.num_epochs):
            # For each batch in the dataloader
            for i, data in enumerate(self.dataloader, 0):

                if (iters % 500 == 0) or ((epoch == self.num_epochs - 1) and
                                          (i == len(self.dataloader) - 1)):
                    with torch.no_grad():
                        fake = self.GAN.G(self.fixed_noise).detach().cpu()
                    self.image_plotter.plot(vutils.make_grid(fake,
                                                             padding=2,
                                                             normalize=True),
                                            name="generator-output")
                    # img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                ## Train with all-real batch
                self.GAN.D.zero_grad()
                # Format batch
                real_cpu = data[0].to(self.device)
                b_size = real_cpu.size(0)
                label = torch.full((b_size, ),
                                   self.real_label,
                                   dtype=torch.float,
                                   device=self.device)
                # Forward pass real batch through D
                output = self.GAN.D(real_cpu).view(-1)
                # Calculate loss on all-real batch
                errD_real = self.criterion(output, label)
                # Calculate gradients for D in backward pass
                errD_real.backward()
                D_x = output.mean().item()

                ## Train with all-fake batch
                # Generate batch of latent vectors
                noise = torch.randn(b_size, self.nz, 1, 1, device=self.device)
                # Generate fake image batch with G
                fake = self.GAN.G(noise)
                label.fill_(self.fake_label)
                # Classify all fake batch with D
                output = self.GAN.D(fake.detach()).view(-1)
                # Calculate D's loss on the all-fake batch
                errD_fake = self.criterion(output, label)
                # Calculate the gradients for this batch
                errD_fake.backward()
                D_G_z1 = output.mean().item()
                # Add the gradients from the all-real and all-fake batches
                errD = errD_real + errD_fake
                # Update D
                self.optimizerD.step()

                ############################
                # (2) Update G network: maximize log(D(G(z)))
                ###########################
                self.GAN.G.zero_grad()
                label.fill_(
                    self.real_label)  # fake labels are real for generator cost
                # Since we just updated D, perform another forward pass of all-fake batch through D
                output = self.GAN.D(fake).view(-1)
                # Calculate G's loss based on this output
                errG = self.criterion(output, label)
                # Calculate gradients for G
                errG.backward()
                D_G_z2 = output.mean().item()
                # Update G
                self.optimizerG.step()

                # Output training stats
                if i % 50 == 0:
                    print(
                        '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                        % (epoch, self.num_epochs, i, len(self.dataloader),
                           errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
                    if self.plotting:
                        self.line_plotter.plot(var_name="loss",
                                               split_name='Discriminator',
                                               title_name='Training Loss',
                                               x=epoch +
                                               i / len(self.dataloader),
                                               y=errD.item())

                        self.line_plotter.plot(var_name="loss",
                                               split_name='Generator',
                                               title_name='Training Loss',
                                               x=epoch +
                                               i / len(self.dataloader),
                                               y=errG.item())

                # # Save Losses for plotting later
                # G_losses.append(errG.item())
                # D_losses.append(errD.item())

                # Check how the generator is doing by saving G's output on fixed_noise
                if (iters % 500 == 0) or ((epoch == self.num_epochs - 1) and
                                          (i == len(self.dataloader) - 1)):
                    with torch.no_grad():
                        fake = self.GAN.G(self.fixed_noise).detach().cpu()
                    if self.plotting:
                        self.image_plotter.plot(vutils.make_grid(
                            fake, padding=2, normalize=True),
                                                name="generator-output")
                    # img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

                iters += 1

            with torch.no_grad():
                fake = self.GAN.G(self.fixed_noise).detach().cpu()
                if self.plotting:
                    self.image_plotter.plot(vutils.make_grid(fake,
                                                             padding=2,
                                                             normalize=True),
                                            name="generator-output")
                # img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': self.GAN.G.state_dict(),
                    'optimizer_state_dict': self.optimizerG.state_dict(),
                    'loss': errG.item(),
                }, os.path.join(self.model_dir, f"Generator-{epoch}.pth"))
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': self.GAN.D.state_dict(),
                    'optimizer_state_dict': self.optimizerD.state_dict(),
                    'loss': errD.item(),
                }, os.path.join(self.model_dir, f"Discriminator-{epoch}.pth"))
コード例 #9
0
ファイル: stylegan2.py プロジェクト: larsh0103/scene_synth
class Trainer():
    def __init__(
        self,
        name = 'default',
        results_dir = 'results',
        models_dir = 'models',
        transfer_from_checkpoint = None,
        plotting = False,
        base_dir = './',
        image_size = 128,
        network_capacity = 16,
        fmap_max = 512,
        transparent = False,
        batch_size = 16,
        mixed_prob = 0.9,
        gradient_accumulate_every=1,
        lr = 2e-4,
        lr_mlp = 0.1,
        ttur_mult = 2,
        rel_disc_loss = False,
        num_workers = None,
        save_every = 1000,
        evaluate_every = 1000,
        num_image_tiles = 8,
        trunc_psi = 0.6,
        fp16 = False,
        cl_reg = False,
        no_pl_reg = False,
        fq_layers = [],
        fq_dict_size = 256,
        attn_layers = [],
        no_const = False,
        aug_prob = 0.,
        aug_types = ['translation', 'cutout'],
        top_k_training = False,
        generator_top_k_gamma = 0.99,
        generator_top_k_frac = 0.5,
        dataset_aug_prob = 0.,
        calculate_fid_every = None,
        calculate_fid_num_images = 12800,
        clear_fid_cache = False,
        is_ddp = False,
        rank = 0,
        world_size = 1,
        log = False,
        *args,
        **kwargs
    ):
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir
        self.transfer_from_checkpoint = transfer_from_checkpoint
        self.fid_dir = base_dir / 'fid' / name
        self.config_path = self.models_dir / name / '.config.json'

        assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
        self.image_size = image_size
        self.network_capacity = network_capacity
        self.fmap_max = fmap_max
        self.transparent = transparent

        self.fq_layers = cast_list(fq_layers)
        self.fq_dict_size = fq_dict_size
        self.has_fq = len(self.fq_layers) > 0

        self.attn_layers = cast_list(attn_layers)
        self.no_const = no_const

        self.aug_prob = aug_prob
        self.aug_types = aug_types

        self.lr = lr
        self.lr_mlp = lr_mlp
        self.ttur_mult = ttur_mult
        self.rel_disc_loss = rel_disc_loss
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.mixed_prob = mixed_prob

        self.num_image_tiles = num_image_tiles
        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        self.av = None
        self.trunc_psi = trunc_psi

        self.no_pl_reg = no_pl_reg
        self.pl_mean = None

        self.gradient_accumulate_every = gradient_accumulate_every

        assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex is not available for you to use mixed precision training'
        self.fp16 = fp16

        self.cl_reg = cl_reg

        self.d_loss = 0
        self.g_loss = 0
        self.q_loss = None
        self.last_gp_loss = None
        self.last_cr_loss = None
        self.last_fid = None

        self.pl_length_ma = EMA(0.99)
        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.calculate_fid_every = calculate_fid_every
        self.calculate_fid_num_images = calculate_fid_num_images
        self.clear_fid_cache = clear_fid_cache

        self.top_k_training = top_k_training
        self.generator_top_k_gamma = generator_top_k_gamma
        self.generator_top_k_frac = generator_top_k_frac

        assert not (is_ddp and cl_reg), 'Contrastive loss regularization does not work well with multi GPUs yet'
        self.is_ddp = is_ddp
        self.is_main = rank == 0
        self.rank = rank
        self.world_size = world_size
        self.plotting = plotting
        if plotting:
            self.image_plotter = VisdomImagePlotter()
            self.line_plotter = VisdomLinePlotter()
        else:
            self.image_plotter = None
            self.line_plotter = None
    @property
    def image_extension(self):
        return 'jpg' if not self.transparent else 'png'

    @property
    def checkpoint_num(self):
        return floor(self.steps // self.save_every)

    @property
    def hparams(self):
        return {'image_size': self.image_size, 'network_capacity': self.network_capacity}
        
    def init_GAN(self):
        args, kwargs = self.GAN_params
        self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, fmap_max = self.fmap_max, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, **kwargs)

        if self.is_ddp:
            ddp_kwargs = {'device_ids': [self.rank]}
            self.S_ddp = DDP(self.GAN.S, **ddp_kwargs)
            self.G_ddp = DDP(self.GAN.G, **ddp_kwargs)
            self.D_ddp = DDP(self.GAN.D, **ddp_kwargs)
            self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs)

    def write_config(self):
        self.config_path.write_text(json.dumps(self.config()))

    def load_config(self):
        config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
        self.image_size = config['image_size']
        self.network_capacity = config['network_capacity']
        self.transparent = config['transparent']
        self.fq_layers = config['fq_layers']
        self.fq_dict_size = config['fq_dict_size']
        self.fmap_max = config.pop('fmap_max', 512)
        self.attn_layers = config.pop('attn_layers', [])
        self.no_const = config.pop('no_const', False)
        self.lr_mlp = config.pop('lr_mlp', 0.1)
        del self.GAN
        self.init_GAN()

    def config(self):
        return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'lr_mlp': self.lr_mlp, 'transparent': self.transparent, 'fq_layers': self.fq_layers, 'fq_dict_size': self.fq_dict_size, 'attn_layers': self.attn_layers, 'no_const': self.no_const}

    def set_data_src(self, folder):
        self.dataset = Dataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob)
        num_workers = num_workers = default(self.num_workers, NUM_CORES if not self.is_ddp else 0)
        sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None
        dataloader = data.DataLoader(self.dataset, num_workers = num_workers, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True)
        self.loader = cycle(dataloader)

        # auto set augmentation prob for user if dataset is detected to be low
        num_samples = len(self.dataset)
        if not exists(self.aug_prob) and num_samples < 1e5:
            self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
            print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%')

    def train(self):
        assert exists(self.loader), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'

        if not exists(self.GAN):
            self.init_GAN()

        self.GAN.train()
        total_disc_loss = torch.tensor(0.).cuda(self.rank)
        total_gen_loss = torch.tensor(0.).cuda(self.rank)

        batch_size = math.ceil(self.batch_size / self.world_size)

        image_size = self.GAN.G.image_size
        latent_dim = self.GAN.G.latent_dim
        num_layers = self.GAN.G.num_layers

        aug_prob   = self.aug_prob
        aug_types  = self.aug_types
        aug_kwargs = {'prob': aug_prob, 'types': aug_types}

        apply_gradient_penalty = self.steps % 4 == 0
        apply_path_penalty = not self.no_pl_reg and self.steps > 5000 and self.steps % 32 == 0
        apply_cl_reg_to_generated = self.steps > 20000

        S = self.GAN.S if not self.is_ddp else self.S_ddp
        G = self.GAN.G if not self.is_ddp else self.G_ddp
        D = self.GAN.D if not self.is_ddp else self.D_ddp
        D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp

        backwards = partial(loss_backwards, self.fp16)

        if exists(self.GAN.D_cl):
            self.GAN.D_opt.zero_grad()

            if apply_cl_reg_to_generated:
                for i in range(self.gradient_accumulate_every):
                    get_latents_fn = mixed_list if random() < self.mixed_prob else noise_list
                    style = get_latents_fn(batch_size, num_layers, latent_dim, device=self.rank)
                    noise = image_noise(batch_size, image_size, device=self.rank)

                    w_space = latent_to_w(self.GAN.S, style)
                    w_styles = styles_def_to_tensor(w_space)

                    generated_images = self.GAN.G(w_styles, noise)
                    self.GAN.D_cl(generated_images.clone().detach(), accumulate=True)

            for i in range(self.gradient_accumulate_every):
                image_batch = next(self.loader).cuda(self.rank)
                self.GAN.D_cl(image_batch, accumulate=True)

            loss = self.GAN.D_cl.calculate_loss()
            self.last_cr_loss = loss.clone().detach().item()
            backwards(loss, self.GAN.D_opt, loss_id = 0)

            self.GAN.D_opt.step()

        # train discriminator

        avg_pl_length = self.pl_mean
        self.GAN.D_opt.zero_grad()

        for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, S, G]):
            get_latents_fn = mixed_list if random() < self.mixed_prob else noise_list
            style = get_latents_fn(batch_size, num_layers, latent_dim, device=self.rank)
            noise = image_noise(batch_size, image_size, device=self.rank)

            w_space = latent_to_w(S, style)
            w_styles = styles_def_to_tensor(w_space)

            generated_images = G(w_styles, noise)
            fake_output, fake_q_loss = D_aug(generated_images.clone().detach(), detach = True, **aug_kwargs)

            image_batch = next(self.loader).cuda(self.rank)
            image_batch.requires_grad_()
            real_output, real_q_loss = D_aug(image_batch, **aug_kwargs)

            real_output_loss = real_output
            fake_output_loss = fake_output

            if self.rel_disc_loss:
                real_output_loss = real_output_loss - fake_output.mean()
                fake_output_loss = fake_output_loss - real_output.mean()

            divergence = (F.relu(1 + real_output_loss) + F.relu(1 - fake_output_loss)).mean()
            disc_loss = divergence

            if self.has_fq:
                quantize_loss = (fake_q_loss + real_q_loss).mean()
                self.q_loss = float(quantize_loss.detach().item())

                disc_loss = disc_loss + quantize_loss

            if apply_gradient_penalty:
                gp = gradient_penalty(image_batch, real_output)
                self.last_gp_loss = gp.clone().detach().item()
                self.track(y=self.last_gp_loss, var_name ='Penalty', name = 'GP',title ="Penalties", x=self.steps)
                disc_loss = disc_loss + gp

            disc_loss = disc_loss / self.gradient_accumulate_every
            disc_loss.register_hook(raise_if_nan)
            backwards(disc_loss, self.GAN.D_opt, loss_id = 1)

            total_disc_loss += divergence.detach().item() / self.gradient_accumulate_every

        self.d_loss = float(total_disc_loss)
        self.track(y=self.d_loss, var_name ='Loss', name='D',title = 'Training Loss',x=self.steps)

        self.GAN.D_opt.step()

        # train generator

        self.GAN.G_opt.zero_grad()

        for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[S, G, D_aug]):
            style = get_latents_fn(batch_size, num_layers, latent_dim, device=self.rank)
            noise = image_noise(batch_size, image_size, device=self.rank)

            w_space = latent_to_w(S, style)
            w_styles = styles_def_to_tensor(w_space)

            generated_images = G(w_styles, noise)
            fake_output, _ = D_aug(generated_images, **aug_kwargs)
            fake_output_loss = fake_output

            if self.top_k_training:
                epochs = (self.steps * batch_size * self.gradient_accumulate_every) / len(self.dataset)
                k_frac = max(self.generator_top_k_gamma ** epochs, self.generator_top_k_frac)
                k = math.ceil(batch_size * k_frac)

                if k != batch_size:
                    fake_output_loss, _ = fake_output_loss.topk(k=k, largest=False)

            loss = fake_output_loss.mean()
            gen_loss = loss

            if apply_path_penalty:
                pl_lengths = calc_pl_lengths(w_styles, generated_images)
                avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())

                if not is_empty(self.pl_mean):
                    pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
                    if not torch.isnan(pl_loss):
                        gen_loss = gen_loss + pl_loss

            gen_loss = gen_loss / self.gradient_accumulate_every
            gen_loss.register_hook(raise_if_nan)
            backwards(gen_loss, self.GAN.G_opt, loss_id = 2)

            total_gen_loss += loss.detach().item() / self.gradient_accumulate_every

        self.g_loss = float(total_gen_loss)
        self.track(y=self.g_loss, var_name ='Loss', name='G',title = 'Training Loss',x=self.steps)

        self.GAN.G_opt.step()

        # calculate moving averages

        if apply_path_penalty and not np.isnan(avg_pl_length):
            self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
            self.track(y=self.pl_mean,var_name ='Penalty', name='PL',title = 'Penalties',x=self.steps)

        if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
            self.GAN.EMA()

        if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2:
            self.GAN.reset_parameter_averaging()

        # save from NaN errors

        if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
            print(f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}')
            self.load(self.checkpoint_num)
            raise NanException

        # periodically save results

        if self.is_main:
            if self.steps % self.save_every == 0:
                self.save(self.checkpoint_num)

            if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 2500):
                self.evaluate(floor(self.steps / self.evaluate_every))

            if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
                num_batches = math.ceil(self.calculate_fid_num_images / self.batch_size)
                fid = self.calculate_fid(num_batches)
                self.last_fid = fid

                with open(str(self.results_dir / self.name / f'fid_scores.txt'), 'a') as f:
                    f.write(f'{self.steps},{fid}\n')

        self.steps += 1
        self.av = None

    @torch.no_grad()
    def evaluate(self, num = 0, trunc = 1.0):
        self.GAN.eval()
        ext = self.image_extension
        num_rows = self.num_image_tiles
    
        latent_dim = self.GAN.G.latent_dim
        image_size = self.GAN.G.image_size
        num_layers = self.GAN.G.num_layers

        # latents and noise

        latents = noise_list(num_rows ** 2, num_layers, latent_dim, device=self.rank)
        n = image_noise(num_rows ** 2, image_size, device=self.rank)

        # regular

        generated_images = self.generate_truncated(self.GAN.S, self.GAN.G, latents, n, trunc_psi = self.trunc_psi)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)
        if self.plotting:
            self.image_plotter.plot(vutils.make_grid(generated_images, padding=2, normalize=True),name="generator-S-output")
        # moving averages

        generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows)
        if self.plotting:
            self.image_plotter.plot(vutils.make_grid(generated_images, padding=2, normalize=True),name="generator-SE-output")

        # mixing regularities

        def tile(a, dim, n_tile):
            init_dim = a.size(dim)
            repeat_idx = [1] * a.dim()
            repeat_idx[dim] = n_tile
            a = a.repeat(*(repeat_idx))
            order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda(self.rank)
            return torch.index_select(a, dim, order_index)

        nn = noise(num_rows, latent_dim, device=self.rank)
        tmp1 = tile(nn, 0, num_rows)
        tmp2 = nn.repeat(num_rows, 1)

        tt = int(num_layers / 2)
        mixed_latents = [(tmp1, tt), (tmp2, num_layers - tt)]

        generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, mixed_latents, n, trunc_psi = self.trunc_psi)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-mr.{ext}'), nrow=num_rows)

    @torch.no_grad()
    def calculate_fid(self, num_batches):
        from pytorch_fid import fid_score
        torch.cuda.empty_cache()

        real_path = self.fid_dir / 'real'
        fake_path = self.fid_dir / 'fake'

        # remove any existing files used for fid calculation and recreate directories

        if not real_path.exists() or self.clear_fid_cache:
            rmtree(real_path, ignore_errors=True)
            os.makedirs(real_path)

            for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'):
                real_batch = next(self.loader)
                for k, image in enumerate(real_batch.unbind(0)):
                    filename = str(k + batch_num * self.batch_size)
                    torchvision.utils.save_image(image, str(real_path / f'{filename}.png'))

        # generate a bunch of fake images in results / name / fid_fake

        rmtree(fake_path, ignore_errors=True)
        os.makedirs(fake_path)

        self.GAN.eval()
        ext = self.image_extension

        latent_dim = self.GAN.G.latent_dim
        image_size = self.GAN.G.image_size
        num_layers = self.GAN.G.num_layers

        for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'):
            # latents and noise
            latents = noise_list(self.batch_size, num_layers, latent_dim, device=self.rank)
            noise = image_noise(self.batch_size, image_size, device=self.rank)

            # moving averages
            generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, noise, trunc_psi = self.trunc_psi)

            for j, image in enumerate(generated_images.unbind(0)):
                torchvision.utils.save_image(image, str(fake_path / f'{str(j + batch_num * self.batch_size)}-ema.{ext}'))

        return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, noise.device, 2048)

    @torch.no_grad()
    def truncate_style(self, tensor, trunc_psi = 0.75):
        S = self.GAN.S
        batch_size = self.batch_size
        latent_dim = self.GAN.G.latent_dim

        if not exists(self.av):
            z = noise(2000, latent_dim, device=self.rank)
            samples = evaluate_in_chunks(batch_size, S, z).cpu().numpy()
            self.av = np.mean(samples, axis = 0)
            self.av = np.expand_dims(self.av, axis = 0)

        av_torch = torch.from_numpy(self.av).cuda(self.rank)
        tensor = trunc_psi * (tensor - av_torch) + av_torch
        return tensor

    @torch.no_grad()
    def truncate_style_defs(self, w, trunc_psi = 0.75):
        w_space = []
        for tensor, num_layers in w:
            tensor = self.truncate_style(tensor, trunc_psi = trunc_psi)            
            w_space.append((tensor, num_layers))
        return w_space

    @torch.no_grad()
    def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8):
        w = map(lambda t: (S(t[0]), t[1]), style)
        w_truncated = self.truncate_style_defs(w, trunc_psi = trunc_psi)
        w_styles = styles_def_to_tensor(w_truncated)
        generated_images = evaluate_in_chunks(self.batch_size, G, w_styles, noi)
        return generated_images.clamp_(0., 1.)

    @torch.no_grad()
    def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, num_steps = 100, save_frames = False):
        self.GAN.eval()
        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.G.latent_dim
        image_size = self.GAN.G.image_size
        num_layers = self.GAN.G.num_layers

        # latents and noise

        latents_low = noise(num_rows ** 2, latent_dim, device=self.rank)
        latents_high = noise(num_rows ** 2, latent_dim, device=self.rank)
        n = image_noise(num_rows ** 2, image_size, device=self.rank)

        ratios = torch.linspace(0., 8., num_steps)

        frames = []
        for ratio in tqdm(ratios):
            interp_latents = slerp(ratio, latents_low, latents_high)
            latents = [(interp_latents, num_layers)]
            generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi)
            images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
            pil_image = transforms.ToPILImage()(images_grid.cpu())
            
            if self.transparent:
                background = Image.new("RGBA", pil_image.size, (255, 255, 255))
                pil_image = Image.alpha_composite(background, pil_image)
                
            frames.append(pil_image)

        frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True)

        if save_frames:
            folder_path = (self.results_dir / self.name / f'{str(num)}')
            folder_path.mkdir(parents=True, exist_ok=True)
            for ind, frame in enumerate(frames):
                frame.save(str(folder_path / f'{str(ind)}.{ext}'))

    def print_log(self):
        data = [
            ('G', self.g_loss),
            ('D', self.d_loss),
            ('GP', self.last_gp_loss),
            ('PL', self.pl_mean),
            ('CR', self.last_cr_loss),
            ('Q', self.q_loss),
            ('FID', self.last_fid)
        ]

        data = [d for d in data if exists(d[1])]
        log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
        print(log)

    def track(self, y, x, var_name, name,title):
        if not exists(self.line_plotter):
            return
        self.line_plotter.plot(var_name = var_name, split_name= name, 
                    title_name = title ,y=y,x = x)

    def model_name(self, num):
        return str(self.models_dir / self.name / f'model_{num}.pt')

    def init_folders(self):
        (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
        (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)

    def clear(self):
        rmtree(str(self.models_dir / self.name), True)
        rmtree(str(self.results_dir / self.name), True)
        rmtree(str(self.fid_dir), True)
        rmtree(str(self.config_path), True)
        self.init_folders()

    def save(self, num):
        save_data = {
            'GAN': self.GAN.state_dict(),
            'version': __version__
        }

        if self.GAN.fp16:
            save_data['amp'] = amp.state_dict()

        torch.save(save_data, self.model_name(num))
        self.write_config()

    def load(self, num = -1):
        load_data=None
        self.load_config()
        print(self.transfer_from_checkpoint)
        name = num
        if self.transfer_from_checkpoint:
            print("yeah boi")
            load_data = torch.load(self.transfer_from_checkpoint)
            name = 0
        
        elif num == -1:
            file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')]
            saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))
            if len(saved_nums) == 0:
                return
            name = saved_nums[-1]
            print(f'continuing from previous epoch - {name}')

        self.steps = name * self.save_every
        if not load_data:
            load_data = torch.load(self.model_name(name))

        if 'version' in load_data:
            print(f"loading from version {load_data['version']}")

        try:
            self.GAN.load_state_dict(load_data['GAN'])
        except Exception as e:
            print('unable to load save model. please try downgrading the package to the version specified by the saved model')
            raise e
        if self.GAN.fp16 and 'amp' in load_data:
            amp.load_state_dict(load_data['amp'])