Beispiel #1
0
def get_criterion():
    weights = [float(args.weight), 1.0]
    class_weights = torch.FloatTensor(weights)

    class_weights = class_weights.cuda()
    if args.loss == 'Xent':
        criterion = PULoss(Probability_P=0.49, loss_fn="Xent")
    elif args.loss == 'nnPU':
        criterion = PULoss(Probability_P=0.49)
    elif args.loss == 'Focal':
        class_weights = torch.FloatTensor(weights).cuda()
        criterion = FocalLoss(gamma=0, weight=class_weights, one_hot=False)
    elif args.loss == 'uPU':
        criterion = PULoss(Probability_P=0.49, nnPU=False)
    elif args.loss == 'Xent_weighted':
        criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

    return criterion
Beispiel #2
0
def main():
    global args, switched, single_epoch_steps, step

    args = parser.parse_args()

    criterion = get_criterion()
    criterion_meta = PULoss(Probability_P=0.49, loss_fn="sigmoid_eps")

    torch.cuda.set_device(int(args.gpu))
    cudnn.benchmark = True

    if args.dataset == "mnist":
        (trainX, trainY), (testX, testY) = get_mnist()
        _trainY, _testY = binarize_mnist_class(trainY, testY)

        dataset_train_clean = MNIST_Dataset_FixSample(
            1000,
            60000,
            trainX,
            _trainY,
            testX,
            _testY,
            split='train',
            ids=[],
            increasing=args.increasing,
            replacement=args.replacement,
            mode=args.self_paced_type,
            top=args.top,
            type="clean",
            flex=args.flex,
            pickout=args.pickout)
        # clean dataset初始化为空
        dataset_train_noisy = MNIST_Dataset_FixSample(
            1000,
            60000,
            trainX,
            _trainY,
            testX,
            _testY,
            split='train',
            increasing=args.increasing,
            replacement=args.replacement,
            mode=args.self_paced_type,
            top=args.top,
            type="noisy",
            flex=args.flex,
            pickout=args.pickout)

        dataset_train_noisy.copy(
            dataset_train_clean)  # 和clean dataset使用相同的随机顺序
        dataset_train_noisy.reset_ids()  # 让初始化的noisy dataset使用全部数据

        dataset_test = MNIST_Dataset_FixSample(1000,
                                               60000,
                                               trainX,
                                               _trainY,
                                               testX,
                                               _testY,
                                               split='test',
                                               increasing=args.increasing,
                                               replacement=args.replacement,
                                               mode=args.self_paced_type,
                                               top=args.top,
                                               type="clean",
                                               flex=args.flex,
                                               pickout=args.pickout)
    elif args.dataset == 'cifar':
        data_transforms = {
            'train':
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225]),
            ]),
            'val':
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225]),
            ])
        }
        (trainX, trainY), (testX, testY) = get_cifar()
        _trainY, _testY = binarize_cifar_class(trainY, testY)
        dataset_train_clean = CIFAR_Dataset(1000,
                                            50000,
                                            trainX,
                                            _trainY,
                                            testX,
                                            _testY,
                                            split='train',
                                            ids=[],
                                            increasing=args.increasing,
                                            replacement=args.replacement,
                                            mode=args.self_paced_type,
                                            top=args.top,
                                            transform=data_transforms['train'],
                                            type="clean",
                                            flex=args.flex)
        # clean dataset初始化为空
        dataset_train_noisy = CIFAR_Dataset(1000,
                                            50000,
                                            trainX,
                                            _trainY,
                                            testX,
                                            _testY,
                                            split='train',
                                            increasing=args.increasing,
                                            replacement=args.replacement,
                                            mode=args.self_paced_type,
                                            top=args.top,
                                            transform=data_transforms['train'],
                                            type="noisy",
                                            flex=args.flex)

        dataset_train_noisy.copy(
            dataset_train_clean)  # 和clean dataset使用相同的随机顺序
        dataset_train_noisy.reset_ids()  # 让初始化的noisy dataset使用全部数据

        dataset_test = CIFAR_Dataset(1000,
                                     50000,
                                     trainX,
                                     _trainY,
                                     testX,
                                     _testY,
                                     split='test',
                                     increasing=args.increasing,
                                     replacement=args.replacement,
                                     mode=args.self_paced_type,
                                     top=args.top,
                                     transform=data_transforms['val'],
                                     type="clean",
                                     flex=args.flex)

        criterion.update_p(0.4)

    assert np.all(dataset_train_noisy.X == dataset_train_clean.X)
    assert np.all(dataset_train_noisy.Y == dataset_train_clean.Y)
    assert np.all(dataset_train_noisy.oids == dataset_train_clean.oids)
    assert np.all(dataset_train_noisy.T == dataset_train_clean.T)

    #step = args.ema_start * 2 + 1

    if len(dataset_train_clean) > 0:
        dataloader_train_clean = DataLoader(dataset_train_clean,
                                            batch_size=args.batch_size,
                                            num_workers=args.workers,
                                            shuffle=True,
                                            pin_memory=True)
    else:
        dataloader_train_clean = None

    if len(dataset_train_noisy) > 0:
        dataloader_train_noisy = DataLoader(dataset_train_noisy,
                                            batch_size=args.batch_size,
                                            num_workers=args.workers,
                                            shuffle=False,
                                            pin_memory=True)
    else:
        dataloader_train_noisy = None

    if len(dataset_test):
        dataloader_test = DataLoader(dataset_test,
                                     batch_size=args.batch_size,
                                     num_workers=0,
                                     shuffle=False,
                                     pin_memory=True)
    else:
        dataloader_test = None

    single_epoch_steps = len(dataloader_train_noisy) + 1
    print('Steps: {}'.format(single_epoch_steps))
    consistency_criterion = losses.softmax_mse_loss
    if args.dataset == 'mnist':
        model = create_model()
        ema_model = create_model(ema=True)
    elif args.dataset == 'cifar':
        model = create_cifar_model()
        ema_model = create_cifar_model(ema=True)

    if args.gpu is not None:
        model = model.cuda()
        ema_model = ema_model.cuda()
    else:
        model = model.cuda()
        ema_model = ema_model.cuda()

    params_list = [
        {
            'params': model.parameters(),
            'lr': args.lr
        },
    ]

    optimizer = torch.optim.Adam(params_list,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    stats_ = stats(args.modeldir, 0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           args.epochs,
                                                           eta_min=args.lr *
                                                           0.2)

    if args.evaluation:
        print("Evaluation mode!")
    best_acc = 0

    val = []
    for epoch in range(args.epochs):
        print("Self paced status: {}".format(check_self_paced(epoch)))
        print("Mean teacher status: {}".format(check_mean_teacher(epoch)))
        print("Noisy status: {}".format(check_noisy(epoch)))

        if check_mean_teacher(epoch) and (
                not check_mean_teacher(epoch - 1)) and not switched:
            ema_model.load_state_dict(model.state_dict())
            switched = True
            print("SWITCHED!")
        if epoch == 0:
            switched = False

        if (not check_mean_teacher(epoch)
            ) and check_mean_teacher(epoch - 1) and not switched:
            model.load_state_dict(ema_model.state_dict())
            switched = True
            print("SWITCHED!")

        if check_self_paced(epoch):
            trainPacc, trainNacc, trainPNacc = train_with_meta(
                dataloader_train_clean,
                dataloader_train_noisy,
                dataloader_test,
                model,
                ema_model,
                criterion_meta,
                consistency_criterion,
                optimizer,
                scheduler,
                epoch,
                self_paced_pick=len(dataset_train_clean))
        else:
            trainPacc, trainNacc, trainPNacc = train(
                dataloader_train_clean,
                dataloader_train_noisy,
                model,
                ema_model,
                criterion,
                consistency_criterion,
                optimizer,
                scheduler,
                epoch,
                self_paced_pick=len(dataset_train_clean))
        valPacc, valNacc, valPNacc = validate(dataloader_test, model,
                                              ema_model, criterion,
                                              consistency_criterion, epoch)
        val.append(valPNacc)
        stats_._update(trainPacc, trainNacc, trainPNacc, valPacc, valNacc,
                       valPNacc)

        is_best = valPNacc > best_acc
        best_acc = max(valPNacc, best_acc)
        filename = []
        filename.append(os.path.join(args.modeldir, 'checkpoint.pth.tar'))
        filename.append(os.path.join(args.modeldir, 'model_best.pth.tar'))

        if (check_self_paced(epoch)) and (epoch - args.self_paced_start
                                          ) % args.self_paced_frequency == 0:

            dataloader_train_clean, dataloader_train_noisy = update_dataset(
                model, ema_model, dataset_train_clean, dataset_train_noisy,
                epoch)

        plot_curve(stats_, args.modeldir, 'model', True)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, filename)
        dataset_train_noisy.shuffle()
    print(best_acc)
    print(val)
def main():

    global args, switched
    args = parser.parse_args()

    print(args)
    criterion = get_criterion()
    criterion_meta = PULoss(Probability_P=0.49, loss_fn = "sigmoid_eps")

    ids_train = np.load("rid.image_id.train.adni.npy")
    ids_val = np.load("rid.image_id.test.adni.npy")
    # load metadata from csv ######################################
    df = pd.read_csv("adni_dx_suvr_clean.csv")
    df = df.fillna('')
    tmp = []
    for i in range(len(ids_train)):
        id = ids_train[i]
        if '.' in id:
            id = id.split('.')
            dx = df[(df['RID'] == int(id[0])) & (df['MRI ImageID'] == int(id[1]))]['DX'].values[0]
        else:
            dx = df[(df['RID'] == int(id)) & (df['MRI ImageID'] == "")]['DX'].values[0]
        # train on AD/MCI/NL ([1,2,3]) or only AD/NL ([1,3])
        if dx in [1, 3]: tmp.append(ids_train[i])
    ids_train = np.array(tmp)
    tmp = []
    for i in range(len(ids_val)):
        id = ids_val[i]
        if '.' in id:
            id = id.split('.')
            dx = df[(df['RID'] == int(id[0])) & (df['MRI ImageID'] == int(id[1]))]['DX'].values[0]
        else:
            dx = df[(df['RID'] == int(id)) & (df['MRI ImageID'] == "")]['DX'].values[0]
        # train on AD/MCI/NL ([1,2,3]) or only AD/NL ([1,3])
        if dx in [1, 3]: tmp.append(ids_val[i])
    ids_val = np.array(tmp)
    print(len(ids_train), len(ids_val))


    dataset_train1_clean = ADNI("adni_dx_suvr_clean.csv", ids_train, [], '/ssd1/chenwy/adni', type="clean", transform = True)
    dataset_train2_clean = ADNI("adni_dx_suvr_clean.csv", ids_train, [], '/ssd1/chenwy/adni', type="clean", transform = True)
    dataset_train1_noisy = ADNI("adni_dx_suvr_clean.csv", ids_train, None, '/ssd1/chenwy/adni', type="noisy", transform = True)
    dataset_train2_noisy = ADNI("adni_dx_suvr_clean.csv", ids_train, None, '/ssd1/chenwy/adni', type="noisy", transform = True)
    dataset_test = ADNI("adni_dx_suvr_clean.csv", ids_val, None, '/ssd1/chenwy/adni', type="clean", transform = False)
    criterion.update_p(0.43)

        
    dataloader_train1_clean = None
    #dataloader_train1_clean = DataLoader(dataset_train1_clean, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True)
    dataloader_train1_noisy = DataLoader(dataset_train1_noisy, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)

    dataloader_train2_clean = None
    #dataloader_train2_clean = DataLoader(dataset_train2_clean, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True)
    dataloader_train2_noisy = DataLoader(dataset_train2_noisy, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
    dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size, num_workers=0, shuffle=False, pin_memory=True)
    consistency_criterion = losses.softmax_mse_loss

    model1 = create_lenet_model()
    model2 = create_lenet_model()
    ema_model1 = create_lenet_model(ema = True)
    ema_model2 = create_lenet_model(ema = True)
    
    if args.gpu is not None:
        model1 = model1.cuda(args.gpu)
        model2 = model2.cuda(args.gpu)
        ema_model1 = ema_model1.cuda(args.gpu)
        ema_model2 = ema_model2.cuda(args.gpu)
    else:
        model1 = model1.cuda(args.gpu)
        model2 = model2.cuda(args.gpu)
        ema_model1 = ema_model1.cuda(args.gpu)
        ema_model2 = ema_model2.cuda(args.gpu)

    optimizer1 = torch.optim.Adam(model1.parameters(), lr=args.lr,
        weight_decay=args.weight_decay
    )   
    optimizer2 = torch.optim.Adam(model2.parameters(), lr=args.lr,
        weight_decay=args.weight_decay
    )  

    stats_ = stats(args.modeldir, 0)
    #scheduler1 = torch.optim.lr_scheduler.MultiStepLR(optimizer1, milestones=[15, 60], gamma=0.7)
    #scheduler2 = torch.optim.lr_scheduler.MultiStepLR(optimizer2, milestones=[15, 60], gamma=0.7)
    scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer1, args.epochs)
    scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer2, args.epochs)

    best_acc1 = 0
    best_acc2 = 0
    best_acc3 = 0
    best_acc4 = 0
    for epoch in range(args.warmup):
        print("Warming up {}/{}".format(epoch + 1, args.warmup))

        trainPacc, trainNacc, trainPNacc = train(dataloader_train1_clean, dataloader_train1_noisy, dataloader_train2_clean, dataloader_train2_noisy, model1, model2, ema_model1, ema_model2, criterion, consistency_criterion, optimizer1, scheduler1, optimizer2, scheduler2, -1, warmup = True)

        valPacc, valNacc, valPNacc1, valPNacc2, valPNacc3, valPNacc4 = validate(dataloader_test, model1, model2, ema_model1, ema_model2, -1)

        dataset_train1_noisy.shuffle()
        dataset_train2_noisy.shuffle()
        dataloader_train1_noisy = DataLoader(dataset_train1_noisy, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
        dataloader_train2_noisy = DataLoader(dataset_train2_noisy, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)

    for epoch in range(args.epochs):
        print("Self paced status: {}".format(check_self_paced(epoch)))
        print("Mean Teacher status: {}".format(check_mean_teacher(epoch)))
        if check_mean_teacher(epoch) and not check_mean_teacher(epoch - 1) and not switched:
            ema_model1.load_state_dict(model1.state_dict())
            ema_model2.load_state_dict(model2.state_dict())
            switched = True
            print("SWITCHED!")

        if check_self_paced(epoch):
            trainPacc, trainNacc, trainPNacc = train_with_meta(dataloader_train1_clean, dataloader_train1_noisy, dataloader_train2_clean, dataloader_train2_noisy, dataloader_test, model1, model2, ema_model1, ema_model2, criterion_meta, consistency_criterion, optimizer1, scheduler1, optimizer2, scheduler2, epoch)
        else:
            trainPacc, trainNacc, trainPNacc = train(dataloader_train1_clean, dataloader_train1_noisy, dataloader_train2_clean, dataloader_train2_noisy, model1, model2, ema_model1, ema_model2, criterion, consistency_criterion, optimizer1, scheduler1, optimizer2, scheduler2, epoch)

        valPacc, valNacc, valPNacc1, valPNacc2, valPNacc3, valPNacc4 = validate(dataloader_test, model1, model2, ema_model1, ema_model2, epoch)
        #print(valPacc, valNacc, valPNacc1, valPNacc2, valPNacc3)
        stats_._update(trainPacc, trainNacc, trainPNacc, valPacc, valNacc, valPNacc1)

        is_best1 = valPNacc1 > best_acc1
        is_best2 = valPNacc2 > best_acc2
        is_best3 = valPNacc3 > best_acc3
        is_best4 = valPNacc4 > best_acc4
        best_acc1 = max(valPNacc1, best_acc1)
        best_acc2 = max(valPNacc2, best_acc2)
        best_acc3 = max(valPNacc3, best_acc3)
        best_acc4 = max(valPNacc4, best_acc4)
        filename = []
        filename.append(os.path.join(args.modeldir, 'checkpoint.pth.tar'))
        filename.append(os.path.join(args.modeldir, 'model_best.pth.tar'))

        if (check_self_paced(epoch)) and (epoch - args.self_paced_start) % args.self_paced_frequency == 0:

            dataloader_train1_clean, dataloader_train1_noisy, dataloader_train2_clean, dataloader_train2_noisy = update_dataset(model1, model2, ema_model1, ema_model2, dataset_train1_clean, dataset_train1_noisy, dataset_train2_clean, dataset_train2_noisy, epoch)

        plot_curve(stats_, args.modeldir, 'model', True)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model1.state_dict(),
            'best_prec1': best_acc1,
        }, is_best1, filename)

        dataset_train1_noisy.shuffle()
        dataset_train2_noisy.shuffle()
        dataloader_train1_noisy = DataLoader(dataset_train1_noisy, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
        dataloader_train2_noisy = DataLoader(dataset_train2_noisy, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)

    print(best_acc1)
    print(best_acc2)
    print(best_acc3)
    print(best_acc4)