Ejemplo n.º 1
0
def main():

    use_cuda = torch.cuda.is_available()
    global best_acc

    # load dataset

    if args.dataset == 'cifar10':
        num_classes = 10
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
        ])

        train_dataset = CIFAR10(root='../data/',
                                download=True,
                                train=True,
                                transform=transform_train,
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

        test_dataset = CIFAR10(root='../data/',
                               download=True,
                               train=False,
                               transform=transform_test,
                               noise_type=args.noise_type,
                               noise_rate=args.noise_rate)

    if args.dataset == 'cifar100':
        num_classes = 100
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
        ])
        train_dataset = CIFAR100(root='../data/',
                                 download=True,
                                 train=True,
                                 transform=transform_train,
                                 noise_type=args.noise_type,
                                 noise_rate=args.noise_rate)

        test_dataset = CIFAR100(root='../data/',
                                download=True,
                                train=False,
                                transform=transform_test,
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)
    testloader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=4)

    trainloader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4)
    # Model
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir(
            'checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/ckpt.t7.' + args.sess)
        net = checkpoint['net']
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch'] + 1
        torch.set_rng_state(checkpoint['rng_state'])
    else:
        print('==> Building model.. (Default : ResNet32)')
        start_epoch = 0
        if args.model_type == "resnet32":
            from models.resnet32 import ResNet32
            net = ResNet32(num_classes=num_classes)
        elif args.model_type == "resent18":
            from models.resnet_imselar import resnet18
            net = resnet18(num_classes=num_classes)
        else:
            net = ResNet34(num_classes)

    result_folder = './results/'
    if not os.path.exists(result_folder):
        os.makedirs(result_folder)

    logname = result_folder + net.__class__.__name__ + \
        '_' + args.sess + '.csv'

    if use_cuda:
        net.cuda()
        # net = torch.nn.DataParallel(net)
        # print('Using', torch.cuda.device_count(), 'GPUs.')
        cudnn.benchmark = True
        print('Using CUDA..')

    criterion = TruncatedLoss(trainset_size=len(train_dataset)).cuda()
    optimizer = optim.SGD(net.params(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=args.decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.schedule,
                                                     gamma=args.gamma)

    if not os.path.exists(logname):
        with open(logname, 'w') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow(
                ['epoch', 'train loss', 'train acc', 'test loss', 'test acc'])

    for epoch in range(start_epoch, args.epochs):

        train_loss, train_acc = train(epoch, trainloader, net, criterion,
                                      optimizer)
        test_loss, test_acc = test(epoch, testloader, net, criterion)

        with open(logname, 'a') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow(
                [epoch, train_loss, train_acc, test_loss, test_acc])
        scheduler.step()
Ejemplo n.º 2
0
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = CIFAR10(root='./data',
                   train=True,
                   download=True,
                   transform=transform_train,
                   noise_type="clean")
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=args.bs,
                                          shuffle=False,
                                          num_workers=16,
                                          drop_last=True)

testset = CIFAR10(root='./data',
                  train=False,
                  download=True,
                  transform=transform_test,
                  noise_type="clean")
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=100,
Ejemplo n.º 3
0
                               download=True,  
                               train=False, 
                               transform=transforms.ToTensor(),
                               noise_type=args.noise_type,
                               noise_rate=args.noise_rate
                                )
    
if args.dataset=='cifar10':
    input_channel=3
    init_epoch = 20
    num_classes = 10
    args.n_epoch = 200
    train_dataset = CIFAR10(root='./data/',
                                download=True,  
                                train=True, 
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate
                                )
    
    test_dataset = CIFAR10(root='./data/',
                                download=True,  
                                train=False, 
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate
                                )

if args.dataset=='cifar100':
    input_channel=3
    init_epoch = 5
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')

    parser.add_argument('--result_dir',
                        type=str,
                        help='dir to save result txt files',
                        default='results/')
    parser.add_argument('--noise_rate',
                        type=float,
                        help='corruption rate, should be less than 1',
                        default=0.5)
    parser.add_argument('--forget_rate',
                        type=float,
                        help='forget rate',
                        default=None)
    parser.add_argument('--noise_type',
                        type=str,
                        help='[pairflip, symmetric]',
                        default='symmetric')
    parser.add_argument(
        '--num_gradual',
        type=int,
        default=10,
        help=
        'how many epochs for linear drop rate, can be 5, 10, 15. This parameter is equal to Tk for R(T) in Co-teaching paper.'
    )
    parser.add_argument(
        '--exponent',
        type=float,
        default=1,
        help=
        'exponent of the forget rate, can be 0.5, 1, 2. This parameter is equal to c in Tc for R(T) in Co-teaching paper.'
    )
    parser.add_argument('--top_bn', action='store_true')
    parser.add_argument('--dataset',
                        type=str,
                        help='mnist, cifar10, or cifar100',
                        default='mnist')
    parser.add_argument('--n_epoch', type=int, default=200)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--print_freq', type=int, default=50)
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='how many subprocesses to use for data loading')
    parser.add_argument('--num_iter_per_epoch', type=int, default=400)
    parser.add_argument('--epoch_decay_start', type=int, default=80)
    parser.add_argument('--eps', type=float, default=9.9)

    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for training (default: 256)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=4000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    batch_size = args.batch_size

    if args.dataset == 'mnist':
        input_channel = 1
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = MNIST(root='./data/',
                              download=True,
                              train=True,
                              transform=transforms.ToTensor(),
                              noise_type=args.noise_type,
                              noise_rate=args.noise_rate)

        test_dataset = MNIST(root='./data/',
                             download=True,
                             train=False,
                             transform=transforms.ToTensor(),
                             noise_type=args.noise_type,
                             noise_rate=args.noise_rate)

    if args.dataset == 'cifar10':
        input_channel = 3
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = CIFAR10(root='./data/',
                                download=True,
                                train=True,
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

        test_dataset = CIFAR10(root='./data/',
                               download=True,
                               train=False,
                               transform=transforms.ToTensor(),
                               noise_type=args.noise_type,
                               noise_rate=args.noise_rate)

    if args.dataset == 'cifar100':
        input_channel = 3
        num_classes = 100
        args.top_bn = False
        args.epoch_decay_start = 100
        args.n_epoch = 200
        train_dataset = CIFAR100(root='./data/',
                                 download=True,
                                 train=True,
                                 transform=transforms.ToTensor(),
                                 noise_type=args.noise_type,
                                 noise_rate=args.noise_rate)

        test_dataset = CIFAR100(root='./data/',
                                download=True,
                                train=False,
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

    if args.forget_rate is None:
        forget_rate = args.noise_rate
    else:
        forget_rate = args.forget_rate

    noise_or_not = train_dataset.noise_or_not
    # Data Loader (Input Pipeline)
    print('loading dataset...')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               num_workers=args.num_workers,
                                               drop_last=True,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              num_workers=args.num_workers,
                                              drop_last=True,
                                              shuffle=False)
    # Define models
    #print('building model...')
    #cnn = CNN(input_channel=input_channel, n_outputs=num_classes*2)
    #cnn = Net().to(device)
    #cnn.cuda()
    #print(cnn1.parameters)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    cnn = PreResNet.ResNet18(num_classes=10).to(device)
    #cnn = Net().to(device)
    cnn.cuda()
    #print(model.parameters)
    #optimizer1 = torch.optim.SGD(cnn1.parameters(), lr=learning_rate)
    optimizer = torch.optim.SGD(cnn.parameters(),
                                lr=args.lr,
                                momentum=args.momentum)
    #optimizer = torch.optim.Adam(cnn.parameters(), lr=args.lr)

    acc = []
    loss = []
    loss_pure = []
    loss_corrupt = []
    out = []
    for epoch in range(1, args.n_epoch + 1):
        if epoch < 20:
            l1 = train(args,
                       cnn,
                       device,
                       train_loader,
                       optimizer,
                       epoch,
                       eps=args.eps)
            loss.append(l1)
            #out.append(out10)
            acc.append(test(args, cnn, device, test_loader))
        else:
            l1 = train(args,
                       cnn,
                       device,
                       train_loader,
                       optimizer,
                       epoch,
                       eps=args.eps)
            loss.append(l1)
            #out.append(out10)
            acc.append(test(args, cnn, device, test_loader))

    name = str(args.dataset) + " " + str(args.noise_type) + " " + str(
        args.noise_rate)
Ejemplo n.º 5
0
    input_channel = 3
    num_classes = 10
    init_epoch = 20
    args.epoch_decay_start = 80
    filter_outlier = True
    args.model_type = "cnn"
    # args.n_epoch = 200
    transform1 = torchvision.transforms.Compose([
        # torchvision.transforms.RandomHorizontalFlip(),
        # torchvision.transforms.RandomCrop(32, 4),
        torchvision.transforms.ToTensor(),
    ])
    train_dataset = CIFAR10(root='./../Co-correcting_plus/data/cifar10/',
                            download=False,
                            train=True,
                            transform=transform1,
                            noise_type=args.noise_type,
                            noise_rate=args.noise_rate
                            )

    test_dataset = CIFAR10(root='./../Co-correcting_plus/data/cifar10/',
                           download=False,
                           train=False,
                           transform=transforms.ToTensor(),
                           noise_type=args.noise_type,
                           noise_rate=args.noise_rate
                           )

if args.dataset == 'cifar100':
    input_channel = 3
    num_classes = 100
Ejemplo n.º 6
0
def get_dataset(args):

    ### color augmentation ###
    color_jitter = transforms.ColorJitter(0.8 * args.color_jitter_strength,
                                          0.8 * args.color_jitter_strength,
                                          0.8 * args.color_jitter_strength,
                                          0.2 * args.color_jitter_strength)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)

    learning_type = args.train_type

    if args.dataset == 'cifar-10':

        if learning_type == 'contrastive':
            transform_train = transforms.Compose([
                rnd_color_jitter,
                rnd_gray,
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(32),
                transforms.ToTensor(),
            ])

            transform_test = transform_train

        elif learning_type == 'linear_eval':
            transform_train = transforms.Compose([
                rnd_color_jitter,
                rnd_gray,
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(32),
                transforms.ToTensor(),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
            ])

        elif learning_type == 'test':
            transform_train = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(32),
                transforms.ToTensor(),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
            ])
        else:
            assert ('wrong learning type')

        train_dst = CIFAR10(root='./Data',
                            train=True,
                            download=True,
                            transform=transform_train,
                            contrastive_learning=learning_type)
        val_dst = CIFAR10(root='./Data',
                          train=False,
                          download=True,
                          transform=transform_test,
                          contrastive_learning=learning_type)

        if learning_type == 'contrastive':
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dst,
                num_replicas=args.ngpu,
                rank=args.local_rank,
            )
            train_loader = torch.utils.data.DataLoader(
                train_dst,
                batch_size=args.batch_size,
                num_workers=4,
                pin_memory=False,
                shuffle=(train_sampler is None),
                sampler=train_sampler,
            )

            val_loader = torch.utils.data.DataLoader(
                val_dst,
                batch_size=100,
                num_workers=4,
                pin_memory=False,
                shuffle=False,
            )

            return train_loader, train_dst, val_loader, val_dst, train_sampler
        else:
            train_loader = torch.utils.data.DataLoader(
                train_dst,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=4)
            val_batch = 100
            val_loader = torch.utils.data.DataLoader(val_dst,
                                                     batch_size=val_batch,
                                                     shuffle=False,
                                                     num_workers=4)

            return train_loader, train_dst, val_loader, val_dst

    if args.dataset == 'cifar-100':

        if learning_type == 'contrastive':
            transform_train = transforms.Compose([
                rnd_color_jitter, rnd_gray,
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(32),
                transforms.ToTensor()
            ])

            transform_test = transform_train

        elif learning_type == 'linear_eval':
            transform_train = transforms.Compose([
                rnd_color_jitter, rnd_gray,
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(32),
                transforms.ToTensor()
            ])

            transform_test = transforms.Compose([transforms.ToTensor()])

        elif learning_type == 'test':
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ])

            transform_test = transforms.Compose([transforms.ToTensor()])
        else:
            assert ('wrong learning type')

        train_dst = CIFAR100(root='./Data',
                             train=True,
                             download=True,
                             transform=transform_train,
                             contrastive_learning=learning_type)
        val_dst = CIFAR100(root='./Data',
                           train=False,
                           download=True,
                           transform=transform_test,
                           contrastive_learning=learning_type)

        if learning_type == 'contrastive':
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dst,
                num_replicas=args.ngpu,
                rank=args.local_rank,
            )
            train_loader = torch.utils.data.DataLoader(
                train_dst,
                batch_size=args.batch_size,
                num_workers=4,
                pin_memory=True,
                shuffle=(train_sampler is None),
                sampler=train_sampler,
            )

            val_loader = torch.utils.data.DataLoader(
                val_dst,
                batch_size=100,
                num_workers=4,
                pin_memory=True,
            )
            return train_loader, train_dst, val_loader, val_dst, train_sampler

        else:
            train_loader = torch.utils.data.DataLoader(
                train_dst,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=4)

            val_loader = torch.utils.data.DataLoader(val_dst,
                                                     batch_size=100,
                                                     shuffle=False,
                                                     num_workers=4)

            return train_loader, train_dst, val_loader, val_dst
Ejemplo n.º 7
0
def main():
    print("===Setup running===")
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./config/poison_train.yaml")
    parser.add_argument("--gpu", default="0", type=str)
    args = parser.parse_args()
    config, _, _ = load_config(args.config)

    print("===Prepare data===")
    bd_config = config["backdoor"]
    print("Load backdoor config:\n{}".format(bd_config))
    bd_transform = CLBD(bd_config["clbd"]["trigger_path"])
    target_label = bd_config["target_label"]
    poison_ratio = bd_config["poison_ratio"]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010]),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010]),
    ])
    print("Load dataset from: {}".format(config["dataset_dir"]))
    clean_train_data = CIFAR10(config["dataset_dir"],
                               train_transform,
                               train=True)
    poison_train_idx = gen_poison_idx(clean_train_data,
                                      target_label,
                                      poison_ratio=poison_ratio)
    print("Load the adversarially perturbed dataset from: {}".format(
        config["adv_dataset_path"]))
    poison_train_data = CleanLabelDataset(
        clean_train_data,
        config["adv_dataset_path"],
        bd_transform,
        poison_train_idx,
        target_label,
    )
    poison_train_loader = DataLoader(poison_train_data,
                                     **config["loader"],
                                     shuffle=True)
    clean_test_data = CIFAR10(config["dataset_dir"],
                              test_transform,
                              train=False)
    poison_test_idx = gen_poison_idx(clean_test_data, target_label)
    poison_test_data = CleanLabelDataset(
        clean_test_data,
        config["adv_dataset_path"],
        bd_transform,
        poison_test_idx,
        target_label,
    )
    clean_test_loader = DataLoader(clean_test_data, **config["loader"])
    poison_test_loader = DataLoader(poison_test_data, **config["loader"])

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    gpu = torch.cuda.current_device()
    print("Set gpu to: {}".format(args.gpu))

    model = resnet18()
    model = model.cuda(gpu)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda(gpu)
    optimizer = torch.optim.SGD(model.parameters(),
                                **config["optimizer"]["SGD"])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, **config["lr_scheduler"]["multi_step"])

    for epoch in range(config["num_epochs"]):
        print("===Epoch: {}/{}===".format(epoch + 1, config["num_epochs"]))
        print("Poison training...")
        poison_train(model, poison_train_loader, criterion, optimizer)
        print("Test model on clean data...")
        test(model, clean_test_loader, criterion)
        print("Test model on poison data...")
        test(model, poison_test_loader, criterion)

        scheduler.step()
        print("Adjust learning rate to {}".format(
            optimizer.param_groups[0]["lr"]))
Ejemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')

    parser.add_argument('--result_dir',
                        type=str,
                        help='dir to save result txt files',
                        default='results/')
    parser.add_argument('--noise_rate',
                        type=float,
                        help='corruption rate, should be less than 1',
                        default=0.5)
    parser.add_argument('--forget_rate',
                        type=float,
                        help='forget rate',
                        default=None)
    parser.add_argument('--noise_type',
                        type=str,
                        help='[pairflip, symmetric]',
                        default='symmetric')
    parser.add_argument(
        '--num_gradual',
        type=int,
        default=10,
        help=
        'how many epochs for linear drop rate, can be 5, 10, 15. This parameter is equal to Tk for R(T) in Co-teaching paper.'
    )
    parser.add_argument(
        '--exponent',
        type=float,
        default=1,
        help=
        'exponent of the forget rate, can be 0.5, 1, 2. This parameter is equal to c in Tc for R(T) in Co-teaching paper.'
    )
    parser.add_argument('--top_bn', action='store_true')
    parser.add_argument('--dataset',
                        type=str,
                        help='mnist, cifar10, or cifar100',
                        default='mnist')
    parser.add_argument('--n_epoch', type=int, default=10)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--print_freq', type=int, default=50)
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='how many subprocesses to use for data loading')
    parser.add_argument('--num_iter_per_epoch', type=int, default=400)
    parser.add_argument('--epoch_decay_start', type=int, default=80)
    parser.add_argument('--eps', type=float, default=9.9)

    parser.add_argument('--batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for training (default: 256)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=4000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.005,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    batch_size = args.batch_size

    if args.dataset == 'mnist':
        input_channel = 1
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = MNIST(root='./data/',
                              download=True,
                              train=True,
                              transform=transforms.ToTensor(),
                              noise_type=args.noise_type,
                              noise_rate=args.noise_rate)
        train_size = int(0.9 * len(train_dataset))
        valid_size = len(train_dataset) - train_size
        train_dataset, valid_dataset = random_split(train_dataset,
                                                    [train_size, valid_size])

        test_dataset = MNIST(root='./data/',
                             download=True,
                             train=False,
                             transform=transforms.ToTensor(),
                             noise_type=args.noise_type,
                             noise_rate=args.noise_rate)

    if args.dataset == 'cifar10':
        input_channel = 3
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = CIFAR10(root='./data/',
                                download=True,
                                train=True,
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

        test_dataset = CIFAR10(root='./data/',
                               download=True,
                               train=False,
                               transform=transforms.ToTensor(),
                               noise_type=args.noise_type,
                               noise_rate=args.noise_rate)

    if args.dataset == 'cifar100':
        input_channel = 3
        num_classes = 100
        args.top_bn = False
        args.epoch_decay_start = 100
        args.n_epoch = 200
        train_dataset = CIFAR100(root='./data/',
                                 download=True,
                                 train=True,
                                 transform=transforms.ToTensor(),
                                 noise_type=args.noise_type,
                                 noise_rate=args.noise_rate)

        test_dataset = CIFAR100(root='./data/',
                                download=True,
                                train=False,
                                transform=transforms.ToTensor(),
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate)

    if args.forget_rate is None:
        forget_rate = args.noise_rate
    else:
        forget_rate = args.forget_rate

    #noise_or_not = train_dataset.noise_or_not

    print('loading dataset...')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.num_workers,
                                               drop_last=True,
                                               shuffle=True)

    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.num_workers,
                                               drop_last=True,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              num_workers=args.num_workers,
                                              drop_last=True,
                                              shuffle=False)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    cnn1 = Net().to(device)
    optimizer = torch.optim.SGD(cnn1.parameters(),
                                lr=args.lr,
                                momentum=args.momentum)

    acc = []
    loss = []
    loss_pure = []
    loss_corrupt = []
    out = []
    eee = 1 - args.noise_rate
    criteria = (-1) * (eee * np.log(eee) + (1 - eee) * np.log(
        (1 - eee) / (args.eps - 1)))
    last = float("inf")
    count = 0

    for epoch in range(1, 101):
        l1, out10 = train(args,
                          cnn1,
                          device,
                          train_loader,
                          optimizer,
                          epoch,
                          eps=args.eps,
                          nums=num_classes)
        cur, out101 = train(args,
                            cnn1,
                            device,
                            valid_loader,
                            optimizer,
                            epoch,
                            eps=args.eps,
                            nums=num_classes)
        #if cur>last:
        #    count+=1
        #else:
        #    last=cur
        #    count=0
        #if count >= 4:
        #    break;
        loss.append(cur)
        out.append(out10)
        acc.append(test(args, cnn1, device, test_loader, num_classes))

    name = str(args.dataset) + " " + str(args.noise_type) + " " + str(
        args.noise_rate) + " " + str(args.eps) + " " + str(args.seed)
    np.save("vl_" + name + " acc.npy", acc)
    np.save("vl_" + name + " loss.npy", loss)
Ejemplo n.º 9
0
                         noise_rate=args.noise_rate)

if args.dataset == 'cifar10':
    input_channel = 3
    num_classes = 10
    args.top_bn = False
    args.epoch_decay_start = 80
    args.n_epoch = 250
    train_dataset = CIFAR10(
        root='/home/cgn/data/',
        download=True,
        train=True,
        #         transform=transforms.ToTensor(),
        transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            #             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]),
        noise_type=args.noise_type,
        noise_rate=args.noise_rate,
        fn=args.fn,  # path to noisy labels
        train_mask_dir=args.train_mask_dir,  # path to train mask
    )

    test_dataset = CIFAR10(
        root='/home/cgn/data/',
        download=True,
        train=False,
        transform=transforms.ToTensor(),
        noise_type=args.noise_type,
        noise_rate=args.noise_rate,