Ejemplo n.º 1
0
def get_datasets(dataset_name, noise_type, noise_rate):
    if (dataset_name == "MNIST"):
        train_dataset = MNIST(root='./data/',
                              download=True,
                              train=True,
                              transform=transforms.ToTensor(),
                              noise_type=noise_type,
                              noise_rate=noise_rate)
        test_dataset = MNIST(root='./data/',
                             download=True,
                             train=False,
                             transform=transforms.ToTensor(),
                             noise_type=noise_type,
                             noise_rate=noise_rate)
    elif (dataset_name == "CIFAR10"):
        train_dataset = CIFAR10(root='./data/',
                                download=True,
                                train=True,
                                transform=transforms.ToTensor(),
                                noise_type=noise_type,
                                noise_rate=noise_rate)
        test_dataset = CIFAR10(root='./data/',
                               download=True,
                               train=False,
                               transform=transforms.ToTensor(),
                               noise_type=noise_type,
                               noise_rate=noise_rate)
    elif (dataset_name == "CIFAR100"):
        train_dataset = CIFAR100(root='./data/',
                                 download=True,
                                 train=True,
                                 transform=transforms.ToTensor(),
                                 noise_type=noise_type,
                                 noise_rate=noise_rate)
        test_dataset = CIFAR100(root='./data/',
                                download=True,
                                train=False,
                                transform=transforms.ToTensor(),
                                noise_type=noise_type,
                                noise_rate=noise_rate)
    # elif(dataset_name == "Clothes1M"):
    # 	train_dataset, test_dataset =
    return train_dataset, test_dataset
Ejemplo n.º 2
0
def get_datasets(dataset_name, noise_type=None, noise_rate=None):
    val_dataset = None
    if (dataset_name == "MNIST"):
        train_dataset = MNIST(root='./data/',
                              download=True,
                              train=True,
                              transform=transforms.ToTensor(),
                              noise_type=noise_type,
                              noise_rate=noise_rate)
        test_dataset = MNIST(root='./data/',
                             download=True,
                             train=False,
                             transform=transforms.ToTensor(),
                             noise_type=noise_type,
                             noise_rate=noise_rate)
    elif (dataset_name == "CIFAR10"):
        train_dataset = CIFAR10(root='./data/',
                                download=True,
                                train=True,
                                transform=get_transform(),
                                noise_type=noise_type,
                                noise_rate=noise_rate)
        test_dataset = CIFAR10(root='./data/',
                               download=True,
                               train=False,
                               transform=transforms.ToTensor(),
                               noise_type=noise_type,
                               noise_rate=noise_rate)
    elif (dataset_name == "CIFAR100"):
        train_dataset = CIFAR100(root='./data/',
                                 download=True,
                                 train=True,
                                 transform=get_transform(),
                                 noise_type=noise_type,
                                 noise_rate=noise_rate)
        test_dataset = CIFAR100(root='./data/',
                                download=True,
                                train=False,
                                transform=transforms.ToTensor(),
                                noise_type=noise_type,
                                noise_rate=noise_rate)
    elif (dataset_name == "Clothes1M"):
        img_sz = 224
        train_dataset = Clothing(data_dir='../clothing1M/',
                                 train=True,
                                 val=False,
                                 test=False,
                                 transform=clothing_transform(img_sz))
        val_dataset = Clothing(data_dir='../clothing1M/',
                               train=False,
                               val=True,
                               test=False,
                               transform=clothing_transform(img_sz))
        test_dataset = Clothing(data_dir='../clothing1M/',
                                train=False,
                                val=False,
                                test=True,
                                transform=clothing_transform(img_sz))
    elif (dataset_name == "Food101N"):
        img_sz = 224
        train_dataset = Food101N(data_dir='../Food-101N_release/',
                                 train=True,
                                 val=False,
                                 test=False,
                                 transform=food_transform(img_sz))
        val_dataset = Food101N(data_dir='../Food-101N_release/',
                               train=False,
                               val=True,
                               test=False,
                               transform=food_transform(img_sz))
        test_dataset = Food101N(data_dir='../Food-101N_release/',
                                train=False,
                                val=False,
                                test=True,
                                transform=food_transform(img_sz))

    return train_dataset, val_dataset, test_dataset
Ejemplo n.º 3
0
Archivo: main.py Proyecto: pizard/JoCoR
                           train=False,
                           transform=transforms.ToTensor(),
                           noise_type=args.noise_type,
                           noise_rate=args.noise_rate)

if args.dataset == 'cifar100':
    input_channel = 3
    num_classes = 100
    init_epoch = 5
    args.epoch_decay_start = 100
    # args.n_epoch = 200
    filter_outlier = False

    train_dataset = CIFAR100(root='../../data/raw/',
                             download=True,
                             train=True,
                             transform=transforms.ToTensor(),
                             noise_type=args.noise_type,
                             noise_rate=args.noise_rate)

    test_dataset = CIFAR100(root='../../data/raw/',
                            download=True,
                            train=False,
                            transform=transforms.ToTensor(),
                            noise_type=args.noise_type,
                            noise_rate=args.noise_rate)

if args.forget_rate is None:
    forget_rate = args.noise_rate
else:
    forget_rate = args.forget_rate
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')

    parser.add_argument('--result_dir',
                        type=str,
                        help='dir to save result txt files',
                        default='results/')
    parser.add_argument('--noise_rate',
                        type=float,
                        help='corruption rate, should be less than 1',
                        default=0.5)
    parser.add_argument('--forget_rate',
                        type=float,
                        help='forget rate',
                        default=None)
    parser.add_argument('--noise_type',
                        type=str,
                        help='[pairflip, symmetric]',
                        default='symmetric')
    parser.add_argument(
        '--num_gradual',
        type=int,
        default=10,
        help=
        'how many epochs for linear drop rate, can be 5, 10, 15. This parameter is equal to Tk for R(T) in Co-teaching paper.'
    )
    parser.add_argument(
        '--exponent',
        type=float,
        default=1,
        help=
        'exponent of the forget rate, can be 0.5, 1, 2. This parameter is equal to c in Tc for R(T) in Co-teaching paper.'
    )
    parser.add_argument('--top_bn', action='store_true')
    parser.add_argument('--dataset',
                        type=str,
                        help='mnist, cifar10, or cifar100',
                        default='mnist')
    parser.add_argument('--n_epoch', type=int, default=10)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--print_freq', type=int, default=50)
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='how many subprocesses to use for data loading')
    parser.add_argument('--num_iter_per_epoch', type=int, default=400)
    parser.add_argument('--epoch_decay_start', type=int, default=80)
    parser.add_argument('--eps', type=float, default=9.9)

    parser.add_argument('--batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for training (default: 256)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=4000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.005,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    batch_size = args.batch_size

    if args.dataset == 'mnist':
        input_channel = 1
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = MNIST(root='./data/',
                              download=True,
                              train=True,
                              transform=transforms.ToTensor(),
                              noise_type=args.noise_type,
                              noise_rate=args.noise_rate)

        test_dataset = MNIST(root='./data/',
                             download=True,
                             train=False,
                             transform=transforms.ToTensor(),
                             noise_type=args.noise_type,
                             noise_rate=args.noise_rate)

    if args.dataset == 'cifar10':
        input_channel = 3
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = CIFAR10(root='./data/',
                                download=True,
                                train=True,
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

        test_dataset = CIFAR10(root='./data/',
                               download=True,
                               train=False,
                               transform=transforms.ToTensor(),
                               noise_type=args.noise_type,
                               noise_rate=args.noise_rate)

    if args.dataset == 'cifar100':
        input_channel = 3
        num_classes = 100
        args.top_bn = False
        args.epoch_decay_start = 100
        args.n_epoch = 200
        train_dataset = CIFAR100(root='./data/',
                                 download=True,
                                 train=True,
                                 transform=transforms.ToTensor(),
                                 noise_type=args.noise_type,
                                 noise_rate=args.noise_rate)

        test_dataset = CIFAR100(root='./data/',
                                download=True,
                                train=False,
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

    if args.forget_rate is None:
        forget_rate = args.noise_rate
    else:
        forget_rate = args.forget_rate

    noise_or_not = train_dataset.noise_or_not
    # Data Loader (Input Pipeline)
    print('loading dataset...')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.num_workers,
                                               drop_last=True,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              num_workers=args.num_workers,
                                              drop_last=True,
                                              shuffle=False)
    # Define models
    #print('building model...')
    #cnn1 = CNN(input_channel=input_channel, n_outputs=num_classes+1)
    #cnn1.cuda()
    #print(cnn1.parameters)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    cnn1 = Net().to(device)
    #cnn1=nn.DataParallel(cnn1,device_ids=[0,1,2,3]).cuda()
    #print(model.parameters)
    #optimizer1 = torch.optim.SGD(cnn1.parameters(), lr=learning_rate)
    #optimizer = torch.optim.Adam(cnn1.parameters(), lr=args.lr)

    optimizer = torch.optim.SGD(cnn1.parameters(),
                                lr=args.lr,
                                momentum=args.momentum)
    #optimizer = nn.DataParallel(optimizer, device_ids=[0,1,2,3])

    acc = []
    loss = []
    loss_pure = []
    loss_corrupt = []
    out = []
    for epoch in range(1, args.n_epoch + 1):
        l1, out10 = train(args,
                          cnn1,
                          device,
                          train_loader,
                          optimizer,
                          epoch,
                          eps=args.eps,
                          nums=num_classes)
        loss.append(l1)
        out.append(out10)
        acc.append(test(args, cnn1, device, test_loader, num_classes))

    name = "raw " + str(args.dataset) + " " + str(args.noise_type) + " " + str(
        args.noise_rate) + " " + str(args.eps)
    np.save(name + " acc.npy", acc)
    np.save(name + " loss.npy", loss)
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch Gambler\'s Loss Runner')

    parser.add_argument('--result_dir',
                        type=str,
                        help='directory to save result txt files',
                        default='results')
    parser.add_argument('--noise_rate',
                        type=float,
                        help='corruption rate, should be less than 1',
                        default=0.5)
    parser.add_argument('--noise_type',
                        type=str,
                        help='[pairflip, symmetric]',
                        default='symmetric')
    parser.add_argument('--dataset',
                        type=str,
                        help='mnist, cifar10, or imdb',
                        default='mnist')
    parser.add_argument('--n_epoch', type=int, default=10)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='how many subprocesses to use for data loading')
    parser.add_argument('--epoch_decay_start', type=int, default=80)
    parser.add_argument('--load_model', type=str, default="")
    parser.add_argument('--model', type=str, default='default')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help=
        'how many batches to wait before logging training status (default: 100)'
    )

    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--eps',
                        type=float,
                        help='set lambda for lambda type \'gmblers\' only',
                        default=1000.0)
    parser.add_argument('--lambda_type',
                        type=str,
                        help='[nll, euc, mid, exp, gmblers]',
                        default="euc")
    parser.add_argument('--start_gamblers',
                        type=int,
                        help='number of epochs before starting gamblers',
                        default=0)

    # label smoothing args
    parser.add_argument('--smoothing',
                        type=float,
                        default=1.0,
                        help='smoothing parameter (default: 1)')

    args = parser.parse_args()
    args.use_scheduler = False

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    if args.dataset == 'mnist':
        input_channel = 1
        num_classes = 10
        train_dataset = MNIST(root='./data/',
                              download=True,
                              train=True,
                              transform=transforms.ToTensor(),
                              noise_type=args.noise_type,
                              noise_rate=args.noise_rate)
        test_dataset = MNIST(root='./data/',
                             download=True,
                             train=False,
                             transform=transforms.ToTensor(),
                             noise_type=args.noise_type,
                             noise_rate=args.noise_rate)
        print('loading dataset...')
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            drop_last=True,
            shuffle=True)
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                  batch_size=args.batch_size,
                                                  num_workers=args.num_workers,
                                                  drop_last=True,
                                                  shuffle=False)

    if args.dataset == 'cifar10':
        input_channel = 3
        num_classes = 10
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = CIFAR10(root='./data/',
                                download=True,
                                train=True,
                                transform=transform_train,
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)
        test_dataset = CIFAR10(root='./data/',
                               download=True,
                               train=False,
                               transform=transform_test,
                               noise_type=args.noise_type,
                               noise_rate=args.noise_rate)
        print('loading dataset...')
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            drop_last=True,
            shuffle=True)
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                  batch_size=args.batch_size,
                                                  num_workers=args.num_workers,
                                                  drop_last=True,
                                                  shuffle=False)

    if args.dataset == 'cifar100':
        input_channel = 3
        num_classes = 100
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = CIFAR100(root='./data/',
                                 download=True,
                                 train=True,
                                 transform=transform_train,
                                 noise_type=args.noise_type,
                                 noise_rate=args.noise_rate)
        test_dataset = CIFAR100(root='./data/',
                                download=True,
                                train=False,
                                transform=transform_test,
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)
        print('loading dataset...')
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            drop_last=True,
            shuffle=True)
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                  batch_size=args.batch_size,
                                                  num_workers=args.num_workers,
                                                  drop_last=True,
                                                  shuffle=False)

    if args.dataset == 'imdb':
        num_classes = 2
        embedding_length = 300
        hidden_size = 256
        print('loading dataset...')
        TEXT, vocab_size, word_embeddings, train_loader, valid_iter, test_loader = load_data.load_dataset(
            rate=args.noise_rate, batch_size=args.batch_size)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print("using {}".format(device))

    print('building model...')
    if args.dataset == 'mnist':
        model = CNN_basic(num_classes=num_classes).to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    if args.dataset == 'cifar10':
        if args.model == 'small':
            model = CNN_small(num_classes=num_classes).to(device)
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
        else:
            model = resnet.ResNet18(num_classes=num_classes).to(device)
            change_lr = lambda epoch: 0.1 if epoch >= 50 else 1.0
            optimizer = LaProp(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=4e-4)
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                          lr_lambda=change_lr)
            args.use_scheduler = True
    if args.dataset == 'cifar100':
        if args.model == 'small':
            model = CNN_small(num_classes=num_classes).to(device)
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
        else:
            model = resnet.ResNet18(num_classes=num_classes).to(device)
            change_lr = lambda epoch: 0.1 if epoch >= 50 else 1.0
            optimizer = LaProp(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=4e-4)
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                          lr_lambda=change_lr)
            args.use_scheduler = True
    if args.dataset == 'imdb':
        model = LSTMClassifier(args.batch_size, num_classes, hidden_size,
                               vocab_size, embedding_length,
                               word_embeddings).to(device)
        optimizer = LaProp(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr)

    test_accs = []
    train_losses = []
    test_losses = []
    out = []

    name = "{}_{}_{:.2f}_{:.2f}_{}_{}".format(args.dataset, args.noise_type,
                                              args.smoothing, args.noise_rate,
                                              args.eps, args.seed)

    if not os.path.exists(args.result_dir):
        os.system('mkdir -p %s' % args.result_dir)

    save_file = args.result_dir + "/" + name + ".json"
    if os.path.exists(save_file):
        print('case processed')
        exit()

    for epoch in range(1, args.n_epoch + 1):
        for param_group in optimizer.param_groups:
            print(epoch, param_group['lr'])
        print(name)
        train_loss = train(args,
                           model,
                           device,
                           train_loader,
                           optimizer,
                           epoch,
                           num_classes=num_classes,
                           use_gamblers=(epoch >= args.start_gamblers),
                           text=(args.dataset == 'imdb'))
        train_losses.append(train_loss)
        test_acc, test_loss = test(args,
                                   model,
                                   device,
                                   test_loader,
                                   num_classes,
                                   text=(args.dataset == 'imdb'))
        test_accs.append(test_acc)
        test_losses.append(test_loss)

        if (args.use_scheduler):
            scheduler.step()

        # torch.save({
        #     'model_state_dict': model.state_dict(),
        #     'optimizer_state_dict': optimizer.state_dict(),
        #     'loss': loss,
        #     'test_acc': acc
        # }, args.result_dir + "/" + name + "_model.npy")

    save_data = {
        "train_loss": train_losses,
        "test_loss": test_losses,
        "test_acc": test_accs
    }
    json.dump(save_data, open(save_file, 'w'))
Ejemplo n.º 6
0
def main():

    use_cuda = torch.cuda.is_available()
    global best_acc

    # load dataset

    if args.dataset == 'cifar10':
        print(f"Using {args.dataset}")
        num_classes = 10
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
        ])

        train_dataset = CIFAR10(root='/home/kanchanaranasinghe/data/cifar',
                                download=True,
                                train=True,
                                transform=transform_train,
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

        test_dataset = CIFAR10(root='/home/kanchanaranasinghe/data/cifar',
                               download=True,
                               train=False,
                               transform=transform_test,
                               noise_type=args.noise_type,
                               noise_rate=args.noise_rate)

    elif args.dataset == 'cifar100':
        print(f"Using {args.dataset}")
        num_classes = 100
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
        ])
        train_dataset = CIFAR100(root='/home/kanchanaranasinghe/data/cifar',
                                 download=True,
                                 train=True,
                                 transform=transform_train,
                                 noise_type=args.noise_type,
                                 noise_rate=args.noise_rate)

        test_dataset = CIFAR100(root='/home/kanchanaranasinghe/data/cifar',
                                download=True,
                                train=False,
                                transform=transform_test,
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

    else:
        raise NotImplementedError(f"invalid dataset: {args.dataset}")

    testloader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=2)

    trainloader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=2)

    # Model
    if args.resume is not None:
        # Load checkpoint.
        print(f'==> Resuming from checkpoint: {args.resume}')
        checkpoint = torch.load(f"{args.resume}")
        net = checkpoint['net']
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch'] + 1
        torch.set_rng_state(checkpoint['rng_state'])
    else:
        print(f'==> Building model.. (Default : {args.model})')
        start_epoch = 0
        if args.model == "resnet18":
            net = ResNet18(num_classes)
        elif args.model == "resnet34":
            net = ResNet34(num_classes)
        else:
            raise NotImplementedError(f"Invalid model: {args.model}")

    result_folder = f"checkpoint/{args.sess}"
    log_name = f"{result_folder}/{args.model}_{args.sess}.csv"

    if use_cuda:
        net.cuda()
        net = torch.nn.DataParallel(net)
        print('Using', torch.cuda.device_count(), 'GPUs.')
        cudnn.benchmark = True
        print('Using CUDA..')

    if args.ce:
        print(f"Using CE loss")
        criterion = CrossEntropyLoss()
    else:
        print(f"Using Truncated loss")
        criterion = TruncatedLoss(trainset_size=len(train_dataset)).cuda()

    if args.opl:
        print(f"Using OP loss")
        aux_criterion = OrthogonalProjectionLoss()
    else:
        aux_criterion = None
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=args.decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.schedule,
                                                     gamma=args.gamma)

    if not os.path.exists(log_name):
        with open(log_name, 'w') as logfile:
            log_writer = csv.writer(logfile, delimiter=',')
            log_writer.writerow(
                ['epoch', 'train loss', 'train acc', 'test loss', 'test acc'])

    for epoch in range(start_epoch, args.epochs):

        train_loss, train_acc = train(epoch, trainloader, net, criterion,
                                      optimizer, aux_criterion, args)
        test_loss, test_acc = test(epoch, testloader, net, criterion)

        with open(log_name, 'a') as logfile:
            log_writer = csv.writer(logfile, delimiter=',')
            log_writer.writerow(
                [epoch, train_loss, train_acc, test_loss, test_acc])
        scheduler.step()
Ejemplo n.º 7
0
def main():

    use_cuda = torch.cuda.is_available()
    global best_acc

    # load dataset

    if args.dataset == 'cifar10':
        num_classes = 10
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
        ])

        train_dataset = CIFAR10(root='../data/',
                                download=True,
                                train=True,
                                transform=transform_train,
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

        test_dataset = CIFAR10(root='../data/',
                               download=True,
                               train=False,
                               transform=transform_test,
                               noise_type=args.noise_type,
                               noise_rate=args.noise_rate)

    if args.dataset == 'cifar100':
        num_classes = 100
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
        ])
        train_dataset = CIFAR100(root='../data/',
                                 download=True,
                                 train=True,
                                 transform=transform_train,
                                 noise_type=args.noise_type,
                                 noise_rate=args.noise_rate)

        test_dataset = CIFAR100(root='../data/',
                                download=True,
                                train=False,
                                transform=transform_test,
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)
    testloader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=4)

    trainloader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4)
    # Model
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir(
            'checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/ckpt.t7.' + args.sess)
        net = checkpoint['net']
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch'] + 1
        torch.set_rng_state(checkpoint['rng_state'])
    else:
        print('==> Building model.. (Default : ResNet32)')
        start_epoch = 0
        if args.model_type == "resnet32":
            from models.resnet32 import ResNet32
            net = ResNet32(num_classes=num_classes)
        elif args.model_type == "resent18":
            from models.resnet_imselar import resnet18
            net = resnet18(num_classes=num_classes)
        else:
            net = ResNet34(num_classes)

    result_folder = './results/'
    if not os.path.exists(result_folder):
        os.makedirs(result_folder)

    logname = result_folder + net.__class__.__name__ + \
        '_' + args.sess + '.csv'

    if use_cuda:
        net.cuda()
        # net = torch.nn.DataParallel(net)
        # print('Using', torch.cuda.device_count(), 'GPUs.')
        cudnn.benchmark = True
        print('Using CUDA..')

    criterion = TruncatedLoss(trainset_size=len(train_dataset)).cuda()
    optimizer = optim.SGD(net.params(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=args.decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.schedule,
                                                     gamma=args.gamma)

    if not os.path.exists(logname):
        with open(logname, 'w') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow(
                ['epoch', 'train loss', 'train acc', 'test loss', 'test acc'])

    for epoch in range(start_epoch, args.epochs):

        train_loss, train_acc = train(epoch, trainloader, net, criterion,
                                      optimizer)
        test_loss, test_acc = test(epoch, testloader, net, criterion)

        with open(logname, 'a') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow(
                [epoch, train_loss, train_acc, test_loss, test_acc])
        scheduler.step()
Ejemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')

    parser.add_argument('--result_dir',
                        type=str,
                        help='dir to save result txt files',
                        default='results/')
    parser.add_argument('--noise_rate',
                        type=float,
                        help='corruption rate, should be less than 1',
                        default=0.5)
    parser.add_argument('--forget_rate',
                        type=float,
                        help='forget rate',
                        default=None)
    parser.add_argument('--noise_type',
                        type=str,
                        help='[pairflip, symmetric]',
                        default='symmetric')
    parser.add_argument(
        '--num_gradual',
        type=int,
        default=10,
        help=
        'how many epochs for linear drop rate, can be 5, 10, 15. This parameter is equal to Tk for R(T) in Co-teaching paper.'
    )
    parser.add_argument(
        '--exponent',
        type=float,
        default=1,
        help=
        'exponent of the forget rate, can be 0.5, 1, 2. This parameter is equal to c in Tc for R(T) in Co-teaching paper.'
    )
    parser.add_argument('--top_bn', action='store_true')
    parser.add_argument('--dataset',
                        type=str,
                        help='mnist, cifar10, or cifar100',
                        default='mnist')
    parser.add_argument('--n_epoch', type=int, default=200)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--print_freq', type=int, default=50)
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='how many subprocesses to use for data loading')
    parser.add_argument('--num_iter_per_epoch', type=int, default=400)
    parser.add_argument('--epoch_decay_start', type=int, default=80)
    parser.add_argument('--eps', type=float, default=9.3)

    parser.add_argument('--batch-size',
                        type=int,
                        default=600,
                        metavar='N',
                        help='input batch size for training (default: 256)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=4000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.005,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    batch_size = args.batch_size

    if args.dataset == 'mnist':
        input_channel = 1
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = MNIST(root='./data/',
                              download=True,
                              train=True,
                              transform=transforms.ToTensor(),
                              noise_type=args.noise_type,
                              noise_rate=args.noise_rate)

        test_dataset = MNIST(root='./data/',
                             download=True,
                             train=False,
                             transform=transforms.ToTensor(),
                             noise_type=args.noise_type,
                             noise_rate=args.noise_rate)

    if args.dataset == 'cifar10':
        input_channel = 3
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = CIFAR10(root='./data/',
                                download=True,
                                train=True,
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

        test_dataset = CIFAR10(root='./data/',
                               download=True,
                               train=False,
                               transform=transforms.ToTensor(),
                               noise_type=args.noise_type,
                               noise_rate=args.noise_rate)

    if args.dataset == 'cifar100':
        input_channel = 3
        num_classes = 100
        args.top_bn = False
        args.epoch_decay_start = 100
        args.n_epoch = 200
        train_dataset = CIFAR100(root='./data/',
                                 download=True,
                                 train=True,
                                 transform=transforms.ToTensor(),
                                 noise_type=args.noise_type,
                                 noise_rate=args.noise_rate)

        test_dataset = CIFAR100(root='./data/',
                                download=True,
                                train=False,
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

    if args.forget_rate is None:
        forget_rate = args.noise_rate
    else:
        forget_rate = args.forget_rate

    noise_or_not = train_dataset.noise_or_not
    # Data Loader (Input Pipeline)
    print('loading dataset...')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               num_workers=args.num_workers,
                                               drop_last=True,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              num_workers=args.num_workers,
                                              drop_last=True,
                                              shuffle=False)
    # Define models
    print('building model...')
    cnn1 = CNN(input_channel=input_channel, n_outputs=num_classes + 1)
    cnn1.cuda()
    print(cnn1.parameters)
    #optimizer1 = torch.optim.SGD(cnn1.parameters(), lr=learning_rate)
    optimizer = torch.optim.SGD(cnn1.parameters(),
                                lr=args.lr,
                                momentum=args.momentum)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    save_dir = args.result_dir + '/' + args.dataset + '/new_loss/'

    if not os.path.exists(save_dir):
        os.system('mkdir -p %s' % save_dir)

    model_str = args.dataset + '_new_loss_' + args.noise_type + '_' + str(
        args.noise_rate) + '_' + str(args.eps) + ' '
    nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    txtfile = save_dir + "/" + model_str + nowTime + ".txt"
    if os.path.exists(txtfile):
        os.system('mv %s %s' % (txtfile, txtfile + ".bak-%s" % nowTime))
    acc = []
    loss = []
    loss_pure = []
    loss_corrupt = []
    for epoch in range(1, args.n_epoch + 1):
        #         if epoch<20:
        #             loss.append(train(args, model, device, train_loader, optimizer, epoch, eps=9.5))
        #         else:
        #             loss.append(train(args, model, device, train_loader, optimizer, epoch, eps=5))
        #loss.append(train(args, model, device, train_loader, optimizer, epoch, eps=9.3))
        l1 = train(args,
                   cnn1,
                   device,
                   train_loader,
                   optimizer,
                   epoch,
                   eps=args.eps)
        loss.append(l1)
        acc.append(test(args, cnn1, device, test_loader))
        with open(txtfile, "a") as myfile:
            myfile.write(str(int(epoch)) + ': ' + str(acc[-1]) + "\n")

    name = str(args.dataset) + " " + str(args.noise_type) + " " + str(
        args.noise_rate)
    np.save(name + "acc.npy", acc)
    np.save(name + "loss.npy", loss)
Ejemplo n.º 9
0
                           noise_rate=args.noise_rate
                           )

if args.dataset == 'cifar100':
    input_channel = 3
    num_classes = 100
    init_epoch = 5
    args.epoch_decay_start = 100
    # args.n_epoch = 200
    filter_outlier = False
    args.model_type = "cnn"

    train_dataset = CIFAR100(root='./../Co-correcting_plus/data/cifar100/',
                             download=False,
                             train=True,
                             transform=transforms.ToTensor(),
                             noise_type=args.noise_type,
                             noise_rate=args.noise_rate
                             )

    test_dataset = CIFAR100(root='./../Co-correcting_plus/data/cifar100/',
                            download=False,
                            train=False,
                            transform=transforms.ToTensor(),
                            noise_type=args.noise_type,
                            noise_rate=args.noise_rate
                            )

if args.dataset == 'isic':
    input_channel = 3
    num_classes = 2
Ejemplo n.º 10
0
def get_dataset(args):

    ### color augmentation ###
    color_jitter = transforms.ColorJitter(0.8 * args.color_jitter_strength,
                                          0.8 * args.color_jitter_strength,
                                          0.8 * args.color_jitter_strength,
                                          0.2 * args.color_jitter_strength)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)

    learning_type = args.train_type

    if args.dataset == 'cifar-10':

        if learning_type == 'contrastive':
            transform_train = transforms.Compose([
                rnd_color_jitter,
                rnd_gray,
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(32),
                transforms.ToTensor(),
            ])

            transform_test = transform_train

        elif learning_type == 'linear_eval':
            transform_train = transforms.Compose([
                rnd_color_jitter,
                rnd_gray,
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(32),
                transforms.ToTensor(),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
            ])

        elif learning_type == 'test':
            transform_train = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(32),
                transforms.ToTensor(),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
            ])
        else:
            assert ('wrong learning type')

        train_dst = CIFAR10(root='./Data',
                            train=True,
                            download=True,
                            transform=transform_train,
                            contrastive_learning=learning_type)
        val_dst = CIFAR10(root='./Data',
                          train=False,
                          download=True,
                          transform=transform_test,
                          contrastive_learning=learning_type)

        if learning_type == 'contrastive':
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dst,
                num_replicas=args.ngpu,
                rank=args.local_rank,
            )
            train_loader = torch.utils.data.DataLoader(
                train_dst,
                batch_size=args.batch_size,
                num_workers=4,
                pin_memory=False,
                shuffle=(train_sampler is None),
                sampler=train_sampler,
            )

            val_loader = torch.utils.data.DataLoader(
                val_dst,
                batch_size=100,
                num_workers=4,
                pin_memory=False,
                shuffle=False,
            )

            return train_loader, train_dst, val_loader, val_dst, train_sampler
        else:
            train_loader = torch.utils.data.DataLoader(
                train_dst,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=4)
            val_batch = 100
            val_loader = torch.utils.data.DataLoader(val_dst,
                                                     batch_size=val_batch,
                                                     shuffle=False,
                                                     num_workers=4)

            return train_loader, train_dst, val_loader, val_dst

    if args.dataset == 'cifar-100':

        if learning_type == 'contrastive':
            transform_train = transforms.Compose([
                rnd_color_jitter, rnd_gray,
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(32),
                transforms.ToTensor()
            ])

            transform_test = transform_train

        elif learning_type == 'linear_eval':
            transform_train = transforms.Compose([
                rnd_color_jitter, rnd_gray,
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(32),
                transforms.ToTensor()
            ])

            transform_test = transforms.Compose([transforms.ToTensor()])

        elif learning_type == 'test':
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ])

            transform_test = transforms.Compose([transforms.ToTensor()])
        else:
            assert ('wrong learning type')

        train_dst = CIFAR100(root='./Data',
                             train=True,
                             download=True,
                             transform=transform_train,
                             contrastive_learning=learning_type)
        val_dst = CIFAR100(root='./Data',
                           train=False,
                           download=True,
                           transform=transform_test,
                           contrastive_learning=learning_type)

        if learning_type == 'contrastive':
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dst,
                num_replicas=args.ngpu,
                rank=args.local_rank,
            )
            train_loader = torch.utils.data.DataLoader(
                train_dst,
                batch_size=args.batch_size,
                num_workers=4,
                pin_memory=True,
                shuffle=(train_sampler is None),
                sampler=train_sampler,
            )

            val_loader = torch.utils.data.DataLoader(
                val_dst,
                batch_size=100,
                num_workers=4,
                pin_memory=True,
            )
            return train_loader, train_dst, val_loader, val_dst, train_sampler

        else:
            train_loader = torch.utils.data.DataLoader(
                train_dst,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=4)

            val_loader = torch.utils.data.DataLoader(val_dst,
                                                     batch_size=100,
                                                     shuffle=False,
                                                     num_workers=4)

            return train_loader, train_dst, val_loader, val_dst
Ejemplo n.º 11
0
                            train=True,
                            transform=transform,
                            noise_type=args.noise_type,
                            noise_rate=args.noise_rate)

    test_dataset = CIFAR10(
        root='./data/',
        download=True,
        train=False,
        transform=transform,
    )
    nclass = 10
elif args.dataset == 'cifar100':
    train_dataset = CIFAR100(root='./data/',
                             download=True,
                             train=True,
                             transform=transform,
                             noise_type=args.noise_type,
                             noise_rate=args.noise_rate)

    test_dataset = CIFAR100(root='./data/',
                            download=True,
                            train=False,
                            transform=transform)
    nclass = 100

logdir = None
if args.save_model:
    path = os.path.join(
        'results',
        args.dataset + '_' + args.noise_type + '_' + str(args.noise_rate))
    if not (os.path.exists(path)):
Ejemplo n.º 12
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')

    parser.add_argument('--result_dir',
                        type=str,
                        help='dir to save result txt files',
                        default='results/')
    parser.add_argument('--noise_rate',
                        type=float,
                        help='corruption rate, should be less than 1',
                        default=0.5)
    parser.add_argument('--forget_rate',
                        type=float,
                        help='forget rate',
                        default=None)
    parser.add_argument('--noise_type',
                        type=str,
                        help='[pairflip, symmetric]',
                        default='symmetric')
    parser.add_argument(
        '--num_gradual',
        type=int,
        default=10,
        help=
        'how many epochs for linear drop rate, can be 5, 10, 15. This parameter is equal to Tk for R(T) in Co-teaching paper.'
    )
    parser.add_argument(
        '--exponent',
        type=float,
        default=1,
        help=
        'exponent of the forget rate, can be 0.5, 1, 2. This parameter is equal to c in Tc for R(T) in Co-teaching paper.'
    )
    parser.add_argument('--top_bn', action='store_true')
    parser.add_argument('--dataset',
                        type=str,
                        help='mnist, cifar10, or cifar100',
                        default='mnist')
    parser.add_argument('--n_epoch', type=int, default=10)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--print_freq', type=int, default=50)
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='how many subprocesses to use for data loading')
    parser.add_argument('--num_iter_per_epoch', type=int, default=400)
    parser.add_argument('--epoch_decay_start', type=int, default=80)
    parser.add_argument('--eps', type=float, default=9.9)

    parser.add_argument('--batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for training (default: 256)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=4000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.005,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    batch_size = args.batch_size

    if args.dataset == 'mnist':
        input_channel = 1
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = MNIST(root='./data/',
                              download=True,
                              train=True,
                              transform=transforms.ToTensor(),
                              noise_type=args.noise_type,
                              noise_rate=args.noise_rate)
        train_size = int(0.9 * len(train_dataset))
        valid_size = len(train_dataset) - train_size
        train_dataset, valid_dataset = random_split(train_dataset,
                                                    [train_size, valid_size])

        test_dataset = MNIST(root='./data/',
                             download=True,
                             train=False,
                             transform=transforms.ToTensor(),
                             noise_type=args.noise_type,
                             noise_rate=args.noise_rate)

    if args.dataset == 'cifar10':
        input_channel = 3
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = CIFAR10(root='./data/',
                                download=True,
                                train=True,
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

        test_dataset = CIFAR10(root='./data/',
                               download=True,
                               train=False,
                               transform=transforms.ToTensor(),
                               noise_type=args.noise_type,
                               noise_rate=args.noise_rate)

    if args.dataset == 'cifar100':
        input_channel = 3
        num_classes = 100
        args.top_bn = False
        args.epoch_decay_start = 100
        args.n_epoch = 200
        train_dataset = CIFAR100(root='./data/',
                                 download=True,
                                 train=True,
                                 transform=transforms.ToTensor(),
                                 noise_type=args.noise_type,
                                 noise_rate=args.noise_rate)

        test_dataset = CIFAR100(root='./data/',
                                download=True,
                                train=False,
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

    if args.forget_rate is None:
        forget_rate = args.noise_rate
    else:
        forget_rate = args.forget_rate

    #noise_or_not = train_dataset.noise_or_not

    print('loading dataset...')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.num_workers,
                                               drop_last=True,
                                               shuffle=True)

    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.num_workers,
                                               drop_last=True,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              num_workers=args.num_workers,
                                              drop_last=True,
                                              shuffle=False)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    cnn1 = Net().to(device)
    optimizer = torch.optim.SGD(cnn1.parameters(),
                                lr=args.lr,
                                momentum=args.momentum)

    acc = []
    loss = []
    loss_pure = []
    loss_corrupt = []
    out = []
    eee = 1 - args.noise_rate
    criteria = (-1) * (eee * np.log(eee) + (1 - eee) * np.log(
        (1 - eee) / (args.eps - 1)))
    last = float("inf")
    count = 0

    for epoch in range(1, 101):
        l1, out10 = train(args,
                          cnn1,
                          device,
                          train_loader,
                          optimizer,
                          epoch,
                          eps=args.eps,
                          nums=num_classes)
        cur, out101 = train(args,
                            cnn1,
                            device,
                            valid_loader,
                            optimizer,
                            epoch,
                            eps=args.eps,
                            nums=num_classes)
        #if cur>last:
        #    count+=1
        #else:
        #    last=cur
        #    count=0
        #if count >= 4:
        #    break;
        loss.append(cur)
        out.append(out10)
        acc.append(test(args, cnn1, device, test_loader, num_classes))

    name = str(args.dataset) + " " + str(args.noise_type) + " " + str(
        args.noise_rate) + " " + str(args.eps) + " " + str(args.seed)
    np.save("vl_" + name + " acc.npy", acc)
    np.save("vl_" + name + " loss.npy", loss)
Ejemplo n.º 13
0
    test_dataset = CIFAR10(root=args.result_dir,
                           download=True,
                           train=False,
                           transform=transforms.ToTensor(),
                           noise_type=args.noise_type,
                           noise_rate=args.noise_rate)

if args.dataset == 'cifar100':
    input_channel = 3
    num_classes = 100
    args.top_bn = False
    args.epoch_decay_start = 100
    train_dataset = CIFAR100(root=args.result_dir,
                             download=True,
                             train=True,
                             transform=transformer,
                             noise_type=args.noise_type,
                             noise_rate=args.noise_rate)

    test_dataset = CIFAR100(root=args.result_dir,
                            download=True,
                            train=False,
                            transform=transforms.ToTensor(),
                            noise_type=args.noise_type,
                            noise_rate=args.noise_rate)
if args.forget_rate is None:
    forget_rate = args.noise_rate
else:
    forget_rate = args.forget_rate

noise_or_not = train_dataset.noise_or_not