Esempio n. 1
0
def main(args, dst_folder):
    # best_ac only record the best top1_ac for validation set.
    best_ac = 0.0
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    if args.cuda_dev == 1:
        torch.cuda.set_device(1)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.backends.cudnn.deterministic = True  # fix the GPU to deterministic mode
    torch.manual_seed(args.seed)  # CPU seed
    if device == "cuda":
        torch.cuda.manual_seed_all(args.seed)  # GPU seed

    random.seed(args.seed)  # python seed for image transformation
    np.random.seed(args.seed)

    if args.dataset == 'svhn':
        mean = [x/255 for x in[127.5,127.5,127.5]]
        std = [x/255 for x in[127.5,127.5,127.5]]
    elif args.dataset == 'cifar100':
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]

    if args.DA == "standard":
        transform_train = transforms.Compose([
            transforms.Pad(2, padding_mode='reflect'),
            transforms.RandomCrop(32),
            #transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

    elif args.DA == "jitter":
        transform_train = transforms.Compose([
            transforms.Pad(2, padding_mode='reflect'),
            transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
            transforms.RandomCrop(32),
            #SVHNPolicy(),
            #AutoAugment(),
            #transforms.RandomHorizontalFlip(),
            
            transforms.ToTensor(),
            #Cutout(n_holes=1,length=20),
            transforms.Normalize(mean, std),
        ])
    else:
        print("Wrong value for --DA argument.")


    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # data loader
    train_loader, test_loader, train_noisy_indexes = data_config(args, transform_train, transform_test,  dst_folder)


    if args.network == "MT_Net":
        print("Loading MT_Net...")
        model = MT_Net(num_classes = args.num_classes, dropRatio = args.dropout).to(device)

    elif args.network == "WRN28_2_wn":
        print("Loading WRN28_2...")
        model = WRN28_2_wn(num_classes = args.num_classes, dropout = args.dropout).to(device)

    elif args.network == "PreactResNet18_WNdrop":
        print("Loading preActResNet18_WNdrop...")
        model = PreactResNet18_WNdrop(drop_val = args.dropout, num_classes = args.num_classes).to(device)


    print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))

    milestones = args.M

    if args.swa == 'True':
        # to install it:
        # pip3 install torchcontrib
        # git clone https://github.com/pytorch/contrib.git
        # cd contrib
        # sudo python3 setup.py install
        from torchcontrib.optim import SWA
        #base_optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4)
        base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4)
        optimizer = SWA(base_optimizer, swa_lr=args.swa_lr)

    else:
        #optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4)
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4)

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)



    loss_train_epoch = []
    loss_val_epoch = []
    acc_train_per_epoch = []
    acc_val_per_epoch = []
    new_labels = []


    exp_path = os.path.join('./', 'noise_models_{0}'.format(args.experiment_name), str(args.labeled_samples))
    res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name), str(args.labeled_samples))

    if not os.path.isdir(res_path):
        os.makedirs(res_path)

    if not os.path.isdir(exp_path):
        os.makedirs(exp_path)

    cont = 0

    load = False
    save = True

    if args.initial_epoch != 0:
        initial_epoch = args.initial_epoch
        load = True
        save = False

    if args.dataset_type == 'sym_noise_warmUp':
        load = False
        save = True

    if load:
        if args.loss_term == 'Reg_ep':
            train_type = 'C'
        if args.loss_term == 'MixUp_ep':
            train_type = 'M'
        if args.dropout > 0.0:
            train_type = train_type + 'drop' + str(int(10*args.dropout))
        if args.beta == 0.0:
            train_type = train_type + 'noReg'
        path = './checkpoints/warmUp_{6}_{5}_{0}_{1}_{2}_{3}_S{4}.hdf5'.format(initial_epoch, \
                                                                                args.dataset, \
                                                                                args.labeled_samples, \
                                                                                args.network, \
                                                                                args.seed, \
                                                                                args.Mixup_Alpha, \
                                                                                train_type)

        checkpoint = torch.load(path)
        print("Load model in epoch " + str(checkpoint['epoch']))
        print("Path loaded: ", path)
        model.load_state_dict(checkpoint['state_dict'])
        print("Relabeling the unlabeled samples...")
        model.eval()
        initial_rand_relab = args.label_noise
        results = np.zeros((len(train_loader.dataset), 10), dtype=np.float32)

        for images, images_pslab, labels, soft_labels, index in train_loader:

            images = images.to(device)
            labels = labels.to(device)
            soft_labels = soft_labels.to(device)

            outputs = model(images)
            prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device, args)
            results[index.detach().numpy().tolist()] = prob.cpu().detach().numpy().tolist()

        train_loader.dataset.update_labels_randRelab(results, train_noisy_indexes, initial_rand_relab)
        print("Start training...")

    for epoch in range(1, args.epoch + 1):
        st = time.time()
        scheduler.step()
        # train for one epoch
        print(args.experiment_name, args.labeled_samples)

        loss_per_epoch, top_5_train_ac, top1_train_acc_original_labels, \
        top1_train_ac, train_time = train_CrossEntropy_partialRelab(\
                                                        args, model, device, \
                                                        train_loader, optimizer, \
                                                        epoch, train_noisy_indexes)


        loss_train_epoch += [loss_per_epoch]

        # test
        if args.validation_exp == "True":
            loss_per_epoch, acc_val_per_epoch_i = validating(args, model, device, test_loader)
        else:
            loss_per_epoch, acc_val_per_epoch_i = testing(args, model, device, test_loader)

        loss_val_epoch += loss_per_epoch
        acc_train_per_epoch += [top1_train_ac]
        acc_val_per_epoch += acc_val_per_epoch_i



        ####################################################################################################
        #############################               SAVING MODELS                ###########################
        ####################################################################################################

        if not os.path.exists('./checkpoints'):
            os.mkdir('./checkpoints')

        if epoch == 1:
            best_acc_val = acc_val_per_epoch_i[-1]
            snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % (
                epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
        else:
            if acc_val_per_epoch_i[-1] > best_acc_val:
                best_acc_val = acc_val_per_epoch_i[-1]

                if cont > 0:
                    try:
                        os.remove(os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
                        os.remove(os.path.join(exp_path, snapBest + '.pth'))
                    except OSError:
                        pass
                snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % (
                    epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
                torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
                torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))

        cont += 1

        if epoch == args.epoch:
            snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f' % (
                epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))


        #### Save models for ensembles:
        if (epoch >= 150) and (epoch%2 == 0) and (args.save_checkpoint == "True"):
            print("Saving model ...")
            out_path = './checkpoints/ENS_{0}_{1}'.format(args.experiment_name, args.labeled_samples)
            if not os.path.exists(out_path):
                os.makedirs(out_path)
            torch.save(model.state_dict(), out_path + "/epoch_{0}.pth".format(epoch))

        ### Saving model to load it again
        # cond = epoch%1 == 0
        if args.dataset_type == 'sym_noise_warmUp':
            if args.loss_term == 'Reg_ep':
                train_type = 'C'
            if args.loss_term == 'MixUp_ep':
                train_type = 'M'
            if args.dropout > 0.0:
                train_type = train_type + 'drop' + str(int(10*args.dropout))
            if args.beta == 0.0:
                train_type = train_type + 'noReg'


            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True
        else:
            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True


        if cond and save:
            print("Saving models...")
            path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, args.labeled_samples, args.network, args.seed)

            save_checkpoint({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                    'loss_train_epoch' : np.asarray(loss_train_epoch),
                    'loss_val_epoch' : np.asarray(loss_val_epoch),
                    'acc_train_per_epoch' : np.asarray(acc_train_per_epoch),
                    'acc_val_per_epoch' : np.asarray(acc_val_per_epoch),
                    'labels': np.asarray(train_loader.dataset.soft_labels)
                }, filename = path)



        ####################################################################################################
        ############################               SAVING METRICS                ###########################
        ####################################################################################################



        # Save losses:
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy', np.asarray(loss_val_epoch))

        # save accuracies:
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_train.npy',
                np.asarray(acc_train_per_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch))

        # save the new labels
        new_labels.append(train_loader.dataset.labels)
        np.save(res_path + '/' + str(args.labeled_samples) + '_new_labels.npy',
                np.asarray(new_labels))

        #logging.info('Epoch: [{}|{}], train_loss: {:.3f}, top1_train_ac: {:.3f}, top1_val_ac: {:.3f}, train_time: {:.3f}'.format(epoch, args.epoch, loss_per_epoch[-1], top1_train_ac, acc_val_per_epoch_i[-1], time.time() - st))

    # applying swa
    if args.swa == 'True':
        optimizer.swap_swa_sgd()
        optimizer.bn_update(train_loader, model, device)
        if args.validation_exp == "True":
            loss_swa, acc_val_swa = validating(args, model, device, test_loader)
        else:
            loss_swa, acc_val_swa = testing(args, model, device, test_loader)

        snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f_swaAcc_%.5f' % (
            epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val, acc_val_swa[0])
        torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))

    # save_fig(dst_folder)
    print('Best ac:%f' % best_acc_val)
    record_result(dst_folder, best_ac)
Esempio n. 2
0
def main(args):
    best_ac = 0.0

    #####################
    # Initializing seeds and preparing GPU
    if args.cuda_dev == 1:
        torch.cuda.set_device(1)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.deterministic = True  # fix the GPU to deterministic mode
    torch.manual_seed(args.seed)  # CPU seed
    if device == "cuda":
        torch.cuda.manual_seed_all(args.seed)  # GPU seed
    random.seed(args.seed)  # python seed for image transformation
    np.random.seed(args.seed)
    #####################

    if args.dataset == 'cifar10':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
    elif args.dataset == 'cifar100':
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]

    if args.DA == "standard":
        transform_train = transforms.Compose([
            transforms.Pad(2, padding_mode='reflect'),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

    elif args.DA == "jitter":
        transform_train = transforms.Compose([
            transforms.Pad(2, padding_mode='reflect'),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4,
                                   hue=0.1),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    else:
        print("Wrong value for --DA argument.")

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # data lodaer
    train_loader, test_loader, unlabeled_indexes = data_config(
        args, transform_train, transform_test)

    if args.network == "MT_Net":
        print("Loading MT_Net...")
        model = MT_Net(num_classes=args.num_classes,
                       dropRatio=args.dropout).to(device)

    elif args.network == "WRN28_2_wn":
        print("Loading WRN28_2...")
        model = WRN28_2_wn(num_classes=args.num_classes,
                           dropout=args.dropout).to(device)

    elif args.network == "PreactResNet18_WNdrop":
        print("Loading preActResNet18_WNdrop...")
        model = PreactResNet18_WNdrop(drop_val=args.dropout,
                                      num_classes=args.num_classes).to(device)

    print('Total params: %2.fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    milestones = args.M

    if args.swa == 'True':
        # to install it:
        # pip3 install torchcontrib
        # git clone https://github.com/pytorch/contrib.git
        # cd contrib
        # sudo python3 setup.py install
        from torchcontrib.optim import SWA
        base_optimizer = optim.SGD(model.parameters(),
                                   lr=args.lr,
                                   momentum=args.momentum,
                                   weight_decay=args.wd)
        optimizer = SWA(base_optimizer, swa_lr=args.swa_lr)

    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.wd)

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=milestones,
                                               gamma=0.1)

    loss_train_epoch = []
    loss_val_epoch = []
    acc_train_per_epoch = []
    acc_val_per_epoch = []

    exp_path = os.path.join('./',
                            'ssl_models_{0}'.format(args.experiment_name),
                            str(args.labeled_samples))
    res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name),
                            str(args.labeled_samples))

    if not os.path.isdir(res_path):
        os.makedirs(res_path)

    if not os.path.isdir(exp_path):
        os.makedirs(exp_path)

    cont = 0
    load = False
    save = True

    if args.load_epoch != 0:
        load_epoch = args.load_epoch
        load = True
        save = False

    if args.dataset_type == 'ssl_warmUp':
        load = False
        save = True

    if load:
        if args.loss_term == 'Reg_ep':
            train_type = 'C'
        if args.loss_term == 'MixUp_ep':
            train_type = 'M'
        if args.dropout > 0.0:
            train_type = train_type + 'drop' + str(int(10 * args.dropout))
        path = './checkpoints/warmUp_{0}_{1}_{2}_{3}_{4}_{5}_S{6}.hdf5'.format(train_type, \
                                                                                args.Mixup_Alpha, \
                                                                                load_epoch, \
                                                                                args.dataset, \
                                                                                args.labeled_samples, \
                                                                                args.network, \
                                                                                args.seed)

        checkpoint = torch.load(path)
        print("Load model in epoch " + str(checkpoint['epoch']))
        print("Path loaded: ", path)
        model.load_state_dict(checkpoint['state_dict'])
        print("Relabeling the unlabeled samples...")
        model.eval()
        results = np.zeros((len(train_loader.dataset), args.num_classes),
                           dtype=np.float32)
        for images, images_pslab, labels, soft_labels, index in train_loader:

            images = images.to(device)
            labels = labels.to(device)
            soft_labels = soft_labels.to(device)

            outputs = model(images)
            prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device,
                                          args)
            results[index.detach().numpy().tolist()] = prob.cpu().detach(
            ).numpy().tolist()

        train_loader.dataset.update_labels(results, unlabeled_indexes)
        print("Start training...")

    ####################################################################################################
    ###############################               TRAINING                ##############################
    ####################################################################################################

    for epoch in range(1, args.epoch + 1):
        st = time.time()
        scheduler.step()
        # train for one epoch
        print(args.experiment_name, args.labeled_samples)

        loss_per_epoch_train, \
        top_5_train_ac, \
        top1_train_ac, \
        train_time = train_CrossEntropy(args, model, device, \
                                        train_loader, optimizer, \
                                        epoch, unlabeled_indexes)
        loss_train_epoch += [loss_per_epoch_train]

        # test
        if args.validation_exp == "True":
            loss_per_epoch_test, acc_val_per_epoch_i = validating(
                args, model, device, test_loader)
        else:
            loss_per_epoch_test, acc_val_per_epoch_i = testing(
                args, model, device, test_loader)

        loss_val_epoch += loss_per_epoch_test
        acc_train_per_epoch += [top1_train_ac]
        acc_val_per_epoch += acc_val_per_epoch_i

        ####################################################################################################
        #############################               SAVING MODELS                ###########################
        ####################################################################################################
        if not os.path.exists('./checkpoints'):
            os.mkdir('./checkpoints')

        if epoch == 1:
            best_acc_val = acc_val_per_epoch_i[-1]
            snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestAccVal_%.5f' % (
                epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1],
                args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(),
                       os.path.join(exp_path, snapBest + '.pth'))
            torch.save(optimizer.state_dict(),
                       os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
        else:
            if acc_val_per_epoch_i[-1] > best_acc_val:
                best_acc_val = acc_val_per_epoch_i[-1]

                if cont > 0:
                    try:
                        os.remove(
                            os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
                        os.remove(os.path.join(exp_path, snapBest + '.pth'))
                    except OSError:
                        pass
                snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestAccVal_%.5f' % (
                    epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1],
                    args.labeled_samples, best_acc_val)
                torch.save(model.state_dict(),
                           os.path.join(exp_path, snapBest + '.pth'))
                torch.save(optimizer.state_dict(),
                           os.path.join(exp_path, 'opt_' + snapBest + '.pth'))

        cont += 1

        if epoch == args.epoch:
            snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestValLoss_%.5f' % (
                epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1],
                args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(),
                       os.path.join(exp_path, snapLast + '.pth'))
            torch.save(optimizer.state_dict(),
                       os.path.join(exp_path, 'opt_' + snapLast + '.pth'))

        ### Saving model to load it again
        # cond = epoch%1 == 0
        if args.dataset_type == 'ssl_warmUp':
            if args.loss_term == 'Reg_ep':
                train_type = 'C'
            if args.loss_term == 'MixUp_ep':
                train_type = 'M'
            if args.dropout > 0.0:
                train_type = train_type + 'drop' + str(int(10 * args.dropout))

            cond = (epoch == args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True
        else:
            cond = False

        if cond and save:
            print("Saving models...")
            path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, \
                                                                        args.labeled_samples, \
                                                                        args.network, \
                                                                        args.seed)
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'loss_train_epoch': np.asarray(loss_train_epoch),
                    'loss_val_epoch': np.asarray(loss_val_epoch),
                    'acc_train_per_epoch': np.asarray(acc_train_per_epoch),
                    'acc_val_per_epoch': np.asarray(acc_val_per_epoch),
                    'labels': np.asarray(train_loader.dataset.soft_labels)
                },
                filename=path)

        ####################################################################################################
        ############################               SAVING METRICS                ###########################
        ####################################################################################################

        # Save losses:
        np.save(
            res_path + '/' + str(args.labeled_samples) +
            '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch))
        np.save(
            res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy',
            np.asarray(loss_val_epoch))

        # save accuracies:
        np.save(
            res_path + '/' + str(args.labeled_samples) +
            '_accuracy_per_epoch_train.npy', np.asarray(acc_train_per_epoch))
        np.save(
            res_path + '/' + str(args.labeled_samples) +
            '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch))

    # applying swa
    if args.swa == 'True':
        optimizer.swap_swa_sgd()
        optimizer.bn_update(train_loader, model, device)
        if args.validation_exp == "True":
            loss_swa, acc_val_swa = validating(args, model, device,
                                               test_loader)
        else:
            loss_swa, acc_val_swa = testing(args, model, device, test_loader)

        snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestValLoss_%.5f_swaAcc_%.5f' % (
            epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1],
            args.labeled_samples, best_acc_val, acc_val_swa[0])
        torch.save(model.state_dict(), os.path.join(exp_path,
                                                    snapLast + '.pth'))
        torch.save(optimizer.state_dict(),
                   os.path.join(exp_path, 'opt_' + snapLast + '.pth'))

    print('Best ac:%f' % best_acc_val)