def build_dataset(): normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) if args.augment: train_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode='reflect').squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) else: train_transform = transforms.Compose([ transforms.ToTensor(), normalize, ]) test_transform = transforms.Compose([ transforms.ToTensor(), normalize ]) if args.dataset == 'cifar10': train_data_meta = CIFAR10( root='../data', train=True, meta=True, num_meta=args.num_meta, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True) train_data = CIFAR10( root='../data', train=True, meta=False, num_meta=args.num_meta, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) test_data = CIFAR10(root='../data', train=False, transform=test_transform, download=True) elif args.dataset == 'cifar100': train_data_meta = CIFAR100( root='../data', train=True, meta=True, num_meta=args.num_meta, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True) train_data = CIFAR100( root='../data', train=True, meta=False, num_meta=args.num_meta, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) test_data = CIFAR100(root='../data', train=False, transform=test_transform, download=True) train_loader = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.prefetch, pin_memory=True) train_meta_loader = torch.utils.data.DataLoader( train_data_meta, batch_size=args.batch_size, shuffle=True, num_workers=args.prefetch, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.prefetch, pin_memory=True) return train_loader, train_meta_loader, test_loader
args.gold_fraction, args.corruption_prob, args.corruption_type, transform=test_transform, download=True) test_data = CIFAR10(args.data_path, train=False, transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'cifar100': train_data_gold = CIFAR100(args.data_path, True, True, args.gold_fraction, args.corruption_prob, args.corruption_type, transform=train_transform, download=True) train_data_silver = CIFAR100(args.data_path, True, False, args.gold_fraction, args.corruption_prob, args.corruption_type, transform=train_transform, download=True) train_data_gold_deterministic = CIFAR100(args.data_path, True, True, args.gold_fraction,
def main(): if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) random.seed(args.seed) np.random.seed(args.data_seed) # cutout and load_corrupted_data use np.random torch.cuda.set_device(args.gpu) cudnn.benchmark = False torch.manual_seed(args.seed) cudnn.enabled = True cudnn.deterministic = True torch.cuda.manual_seed(args.seed) logging.info('gpu device = %d' % args.gpu) logging.info("args = %s", args) if args.loss_func == 'cce': criterion = nn.CrossEntropyLoss().cuda() elif args.loss_func == 'rll': criterion = utils.RobustLogLoss(alpha=args.alpha).cuda() else: assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(args.loss_func) if args.valid_loss_func == 'cce': valid_criterion = nn.CrossEntropyLoss().cuda() elif args.valid_loss_func == 'rll': valid_criterion = utils.RobustLogLoss(alpha=args.alpha).cuda() else: valid_criterion = None model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion, valid_criterion) model = model.cuda() logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) optimizer = torch.optim.SGD( model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) train_transform, valid_transform = utils._data_transforms_cifar10(args) # Load dataset if args.dataset == 'cifar10': noisy_train_data = CIFAR10( root=args.data, train=True, gold=False, gold_fraction=0.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.data_seed) gold_train_data = CIFAR10( root=args.data, train=True, gold=True, gold_fraction=1.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.data_seed) elif args.dataset == 'cifar100': noisy_train_data = CIFAR100( root=args.data, train=True, gold=False, gold_fraction=0.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.data_seed) gold_train_data = CIFAR100( root=args.data, train=True, gold=True, gold_fraction=1.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.data_seed) num_train = len(gold_train_data) indices = list(range(num_train)) split = int(np.floor(args.train_portion * num_train)) if args.gold_fraction == 1.0: train_data = gold_train_data else: train_data = noisy_train_data train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=0) if args.clean_valid: valid_data = gold_train_data else: valid_data = noisy_train_data valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=args.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]), pin_memory=True, num_workers=0) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs), eta_min=args.learning_rate_min) architect = Architect(model, args) for epoch in range(args.epochs): scheduler.step() lr = scheduler.get_lr()[0] logging.info('epoch %d lr %e', epoch, lr) genotype = model.genotype() logging.info('genotype = %s', genotype) print(F.softmax(model.alphas_normal, dim=-1)) print(F.softmax(model.alphas_reduce, dim=-1)) # training train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr) logging.info('train_acc %f', train_acc) # validation valid_acc, valid_obj = infer(valid_queue, model, criterion) logging.info('valid_acc %f', valid_acc) utils.save(model, os.path.join(args.save, 'weights.pt'))
def main(): if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) random.seed(args.seed) np.random.seed(args.seed) torch.cuda.set_device(args.gpu) cudnn.benchmark = False torch.manual_seed(args.seed) cudnn.enabled = True cudnn.deterministic = True torch.cuda.manual_seed(args.seed) logging.info('gpu device = %d' % args.gpu) logging.info("args = %s", args) if args.loss_func == 'cce': criterion = nn.CrossEntropyLoss().cuda() elif args.loss_func == 'rll': criterion = utils.RobustLogLoss(alpha=args.alpha).cuda() else: assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format( args.loss_func) model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion) model = model.cuda() model.train() model.apply(weights_init) nn.utils.clip_grad_norm(model.parameters(), args.grad_clip) logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) train_transform, valid_transform = utils._data_transforms_cifar10(args) # Load dataset if args.dataset == 'cifar10': train_data = CIFAR10(root=args.data, train=True, gold=False, gold_fraction=0.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) gold_train_data = CIFAR10(root=args.data, train=True, gold=True, gold_fraction=1.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) elif args.dataset == 'cifar100': train_data = CIFAR100(root=args.data, train=True, gold=False, gold_fraction=0.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) gold_train_data = CIFAR100(root=args.data, train=True, gold=True, gold_fraction=1.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) num_train = len(train_data) indices = list(range(num_train)) split = int(np.floor(args.train_portion * num_train)) clean_train_queue = torch.utils.data.DataLoader( gold_train_data, batch_size=args.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=0) noisy_train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=0) clean_valid_queue = torch.utils.data.DataLoader( gold_train_data, batch_size=args.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]), pin_memory=True, num_workers=0) noisy_valid_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]), pin_memory=True, num_workers=0) clean_train_list, clean_valid_list, noisy_train_list, noisy_valid_list = [], [], [], [] for dst_list, queue in [ (clean_train_list, clean_train_queue), (clean_valid_list, clean_valid_queue), (noisy_train_list, noisy_train_queue), (noisy_valid_list, noisy_valid_queue), ]: for input, target in queue: input = Variable(input, volatile=True).cuda() target = Variable(target, volatile=True).cuda(async=True) dst_list.append((input, target)) for epoch in range(args.epochs): logging.info('Epoch %d, random architecture with fix weights', epoch) genotype = model.genotype() logging.info('genotype = %s', genotype) logging.info(F.softmax(model.alphas_normal, dim=-1)) logging.info(F.softmax(model.alphas_reduce, dim=-1)) # training clean_train_acc, clean_train_obj = infer(clean_train_list, model, criterion, kind='clean_train') logging.info('clean_train_acc %f, clean_train_loss %f', clean_train_acc, clean_train_obj) noisy_train_acc, noisy_train_obj = infer(noisy_train_list, model, criterion, kind='noisy_train') logging.info('noisy_train_acc %f, noisy_train_loss %f', noisy_train_acc, noisy_train_obj) # validation clean_valid_acc, clean_valid_obj = infer(clean_valid_list, model, criterion, kind='clean_valid') logging.info('clean_valid_acc %f, clean_valid_loss %f', clean_valid_acc, clean_valid_obj) noisy_valid_acc, noisy_valid_obj = infer(noisy_valid_list, model, criterion, kind='noisy_valid') logging.info('noisy_valid_acc %f, noisy_valid_loss %f', noisy_valid_acc, noisy_valid_obj) utils.save(model, os.path.join(args.save, 'weights.pt')) # Randomly change the alphas k = sum(1 for i in range(model._steps) for n in range(2 + i)) num_ops = len(PRIMITIVES) model.alphas_normal.data.copy_(torch.randn(k, num_ops)) model.alphas_reduce.data.copy_(torch.randn(k, num_ops))
def main(): if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) random.seed(args.seed) np.random.seed( args.data_seed) # cutout and load_corrupted_data use np.random torch.cuda.set_device(args.gpu) cudnn.benchmark = False torch.manual_seed(args.seed) cudnn.enabled = True cudnn.deterministic = True torch.cuda.manual_seed(args.seed) logging.info('gpu device = %d' % args.gpu) logging.info("args = %s", args) if args.arch == 'resnet': model = ResNet18(CIFAR_CLASSES).cuda() args.auxiliary = False elif args.arch == 'resnet50': model = ResNet50(CIFAR_CLASSES).cuda() args.auxiliary = False elif args.arch == 'resnet34': model = ResNet34(CIFAR_CLASSES).cuda() args.auxiliary = False else: genotype = eval("genotypes.%s" % args.arch) model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype) model = model.cuda() logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) train_transform, test_transform = utils._data_transforms_cifar10(args) # Load dataset if args.dataset == 'cifar10': noisy_train_data = CIFAR10(root=args.data, train=True, gold=False, gold_fraction=0.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.data_seed) gold_train_data = CIFAR10(root=args.data, train=True, gold=True, gold_fraction=1.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.data_seed) test_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=test_transform) elif args.dataset == 'cifar100': noisy_train_data = CIFAR100(root=args.data, train=True, gold=False, gold_fraction=0.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.data_seed) gold_train_data = CIFAR100(root=args.data, train=True, gold=True, gold_fraction=1.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.data_seed) test_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=test_transform) num_train = len(gold_train_data) indices = list(range(num_train)) split = int(np.floor(args.train_portion * num_train)) if args.gold_fraction == 1.0: train_data = gold_train_data else: train_data = noisy_train_data train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=0) if args.clean_valid: valid_data = gold_train_data else: valid_data = noisy_train_data valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=args.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]), pin_memory=True, num_workers=0) test_queue = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs)) if args.loss_func == 'cce': criterion = nn.CrossEntropyLoss().cuda() elif args.loss_func == 'rll': criterion = utils.RobustLogLoss(alpha=args.alpha).cuda() elif args.loss_func == 'forward_gold': corruption_matrix = train_data.corruption_matrix criterion = utils.ForwardGoldLoss(corruption_matrix=corruption_matrix) else: assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format( args.loss_func) for epoch in range(args.epochs): scheduler.step() logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) model.drop_path_prob = args.drop_path_prob * epoch / args.epochs train_acc, train_obj = train(train_queue, model, criterion, optimizer) logging.info('train_acc %f', train_acc) valid_acc, valid_obj = infer_valid(valid_queue, model, criterion) logging.info('valid_acc %f', valid_acc) test_acc, test_obj = infer(test_queue, model, criterion) logging.info('test_acc %f', test_acc) utils.save(model, os.path.join(args.save, 'weights.pt'))
def main(): np.random.seed(args.seed) torch.cuda.set_device(device) cudnn.benchmark = True cudnn.enabled = True torch.manual_seed(args.seed) logging.info('gpu device = %d' % args.gpu) logging.info("args = %s", args) train_transform, valid_transform = utils._data_transforms_cifar10(args) # Load dataset if args.dataset == 'cifar10': if args.gold_fraction == 0: train_data = CIFAR10(root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) else: train_data = CIFAR10(root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform) num_classes = 10 elif args.dataset == 'cifar100': if args.gold_fraction == 0: train_data = CIFAR100(root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) else: train_data = CIFAR100(root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform) num_classes = 100 genotype = eval("genotypes.%s" % args.arch) model = Network(args.init_ch, num_classes, args.layers, args.auxiliary, genotype).cuda() logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) if args.loss_func == 'cce': criterion = nn.CrossEntropyLoss().to(device) elif args.loss_func == 'rll': criterion = utils.RobustLogLoss().to(device) else: assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format( args.loss_func) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) train_queue = torch.utils.data.DataLoader(train_data, batch_size=args.batchsz, shuffle=True, pin_memory=True, num_workers=2) valid_queue = torch.utils.data.DataLoader(valid_data, batch_size=args.batchsz, shuffle=False, pin_memory=True, num_workers=2) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs)) for epoch in range(args.epochs): scheduler.step() logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) model.drop_path_prob = args.drop_path_prob * epoch / args.epochs valid_acc, valid_obj = infer(valid_queue, model, criterion) logging.info('valid_acc: %f', valid_acc) train_acc, train_obj = train(train_queue, model, criterion, optimizer) logging.info('train_acc: %f', train_acc) utils.save(model, os.path.join(args.save, 'trained.pt')) print('saved to: trained.pt')
args.corruption_type, transform=test_transform, download=True, shuffle_indices=train_data_gold.shuffle_indices) test_data = CIFAR10(args.data_path, train=False, transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'cifar100': train_data_gold = CIFAR100(args.data_path, True, True, args.gold_fraction, args.corruption_prob, args.corruption_type, transform=train_transform, download=True, distinguish_gold=False) train_data_silver = CIFAR100( args.data_path, True, False, args.gold_fraction, args.corruption_prob, args.corruption_type, transform=train_transform, download=True, shuffle_indices=train_data_gold.shuffle_indices) train_data_gold_deterministic = CIFAR100(
def main(): if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) np.random.seed(args.seed) torch.cuda.set_device(args.gpu) cudnn.benchmark = True torch.manual_seed(args.seed) cudnn.enabled = True torch.cuda.manual_seed(args.seed) logging.info('gpu device = %d' % args.gpu) logging.info("args = %s", args) model = ResNet18() model = model.cuda() logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) if args.loss_func == 'cce': criterion = nn.CrossEntropyLoss().cuda() elif args.loss_func == 'rll': criterion = utils.RobustLogLoss().cuda() else: assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format( args.loss_func) optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) train_transform, valid_transform = utils._data_transforms_cifar10(args) if args.dataset == 'cifar10': if args.gold_fraction == 0: train_data = CIFAR10(root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) else: train_data = CIFAR10(root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform) elif args.dataset == 'cifar100': if args.gold_fraction == 0: train_data = CIFAR100(root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) else: train_data = CIFAR100(root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform) train_queue = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2) valid_queue = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs)) for epoch in range(args.epochs): scheduler.step() logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) train_acc, train_obj = train(train_queue, model, criterion, optimizer) logging.info('train_acc %f', train_acc) valid_acc, valid_obj = infer(valid_queue, model, criterion) logging.info('valid_acc %f', valid_acc) utils.save(model, os.path.join(args.save, 'weights.pt'))
corruption_type='unif', transform=test_transform) val_indices = np.random.choice(len(train_data), args.val_set_size, replace=False) train_indices = np.array( list(set(range(len(train_data))) - set(val_indices))) val_data = torch.utils.data.Subset(train_data_clean, val_indices) train_data = torch.utils.data.Subset(train_data, train_indices) test_data = CIFAR10(args.data_path, train=False, transform=test_transform) num_classes = 10 else: train_data = CIFAR100(args.data_path, True, False, 0, corruption_prob=args.corruption_prob, corruption_type='unif', transform=train_transform) if args.val_set_size > 0: train_data_clean = CIFAR100(args.data_path, True, False, 0, corruption_prob=0, corruption_type='unif', transform=test_transform) val_indices = np.random.choice(len(train_data), args.val_set_size, replace=False) train_indices = np.array(
def main(): np.random.seed(args.seed) cudnn.benchmark = True cudnn.enabled = True torch.manual_seed(args.seed) # ================================================ for id in device_ids: total, used = os.popen( 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader' ).read().split('\n')[id].split(',') print('GPU ({}) mem:'.format(id), total, 'used:', used) # try: # block_mem = 0.85 * (total - used) # print(block_mem) # x = torch.empty((256, 1024, int(block_mem))).cuda() # del x # except RuntimeError as err: # print(err) # block_mem = 0.8 * (total - used) # print(block_mem) # x = torch.empty((256, 1024, int(block_mem))).cuda() # del x # # # print('reuse mem now ...') # ================================================ args.unrolled = True logging.info('GPU device = %s' % args.gpu) logging.info("args = %s", args) train_transform, valid_transform = utils._data_transforms_cifar10(args) # Load dataset if args.dataset == 'cifar10': if args.gold_fraction == 0: train_data = CIFAR10( root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) if args.clean_valid: gold_train_data = CIFAR10( root=args.data, train=True, gold=True, gold_fraction=1.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) else: train_data = CIFAR10( root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) num_classes = 10 elif args.dataset == 'cifar100': if args.gold_fraction == 0: train_data = CIFAR100( root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) if args.clean_valid: gold_train_data = CIFAR100( root=args.data, train=True, gold=True, gold_fraction=1.0, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) else: train_data = CIFAR100( root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction, corruption_prob=args.corruption_prob, corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) num_classes = 100 # Split data to train and validation num_train = len(train_data) # 50000 indices = list(range(num_train)) split = int(np.floor(args.train_portion * num_train)) # 45000 train_queue = torch.utils.data.DataLoader( train_data, batch_size=batchsz, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=2) if args.clean_valid: valid_queue = torch.utils.data.DataLoader( gold_train_data, batch_size=batchsz, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]), pin_memory=True, num_workers=2) else: valid_queue = torch.utils.data.DataLoader( train_data, batch_size=batchsz, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]), pin_memory=True, num_workers=2) if args.loss_func == 'cce': criterion = nn.CrossEntropyLoss().cuda() elif args.loss_func == 'rll': criterion = utils.RobustLogLoss().cuda() else: assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(args.loss_func) model = Network(args.init_ch, num_classes, args.layers, criterion) if len(device_ids) > 1: model = MyDataParallel(model).cuda() else: model.cuda() # model = para_model.module.cuda() logging.info("Total param size = %f MB", utils.count_parameters_in_MB(model)) # this is the optimizer to optimize optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs), eta_min=args.lr_min) arch = Arch(model, args) global start start = time.time() for epoch in range(args.epochs): current_time = time.time() if current_time - start >= args.time_limit: break scheduler.step() lr = scheduler.get_lr()[0] logging.info('\nEpoch: %d lr: %e', epoch, lr) genotype = model.genotype() logging.info('Genotype: %s', genotype) # print(F.softmax(model.alphas_normal, dim=-1)) # print(F.softmax(model.alphas_reduce, dim=-1)) # training train_acc, train_obj = train(train_queue, valid_queue, model, arch, criterion, optimizer, lr) logging.info('train acc: %f', train_acc) # validation valid_acc, valid_obj = infer(valid_queue, model, criterion) logging.info('valid acc: %f', valid_acc) utils.save(model, os.path.join(args.exp_path, 'search_epoch1.pt'))