def __init__(self, disc_classifier, rep_size=64, n_classes=10, mi_units=64, margin=5, alpha=0.6, beta=0.2, gamma=0.2): super().__init__() self.disc_classifier = disc_classifier #.half() # Use half-precision for saving memory and time. self.disc_classifier.requires_grad_( requires_grad=False) # shut down grad on pre-trained classifier. # self.disc_classifier.eval() # set to eval mode. self.rep_size = rep_size self.n_classes = n_classes self.mi_units = mi_units self.margin = margin self.alpha = alpha self.beta = beta self.gamma = gamma self.feature_transformer = MLP(self.n_classes, self.rep_size) # 1x1 conv performed on only channel dimension self.local_MInet = MI1x1ConvNet(self.n_classes, self.mi_units) self.global_MInet = MI1x1ConvNet(self.rep_size, self.mi_units) self.class_conditional = ClassConditionalGaussianMixture( self.n_classes, self.rep_size) n_feat_transformer = cal_parameters(self.feature_transformer) n_local = cal_parameters(self.local_MInet) n_global = cal_parameters(self.global_MInet) n_class_conditional = cal_parameters(self.class_conditional) n_additional = n_feat_transformer + n_local + n_global + n_class_conditional self.cross_entropy = nn.CrossEntropyLoss() print('==> # Model parameters.') print('==> # discriminative classifier parameters: {}.'.format( cal_parameters(self.disc_classifier))) print('==> # additional parameters: {}.'.format(n_additional)) print('==> # FeatureTransformer parameters: {}.'.format( n_feat_transformer)) print('==> # T parameters: {}.'.format(n_local + n_global)) print('==> # class conditional parameters: {}.'.format( n_class_conditional))
def run(args: DictConfig) -> None: assert torch.cuda.is_available() torch.manual_seed(args.seed) n_classes = args.get(args.dataset).n_classes classifier = resnet18(n_classes=n_classes).to(args.device) logger.info('Base classifier resnet18: # parameters {}'.format( cal_parameters(classifier))) data_dir = hydra.utils.to_absolute_path(args.data_dir) train_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=True, crop_flip=True) test_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=False, crop_flip=False) train_loader = DataLoader(dataset=train_data, batch_size=args.n_batch_train, shuffle=True) test_loader = DataLoader(dataset=test_data, batch_size=args.n_batch_test, shuffle=False) if args.inference: save_name = 'resnet18_wd{}.pth'.format(args.weight_decay) classifier.load_state_dict( torch.load(save_name, map_location=lambda storage, loc: storage)) loss, acc = run_epoch(classifier, test_loader, args) logger.info('Inference, test loss: {:.4f}, Acc: {:.4f}'.format( loss, acc)) else: train(classifier, train_loader, test_loader, args)
def get_model(model_name='resnext50_32x4d'): if model_name == 'resnext101_32x8d': m = models.resnext101_32x8d(pretrained=True) elif model_name == 'resnext50_32x4d': m = models.resnext50_32x4d(pretrained=True) print('Model name: {}, # parameters: {}'.format(model_name, cal_parameters(m))) return m
def desc(self): """ Description of this model. :return: tuple of descriptions of SDIM components. """ n_fixed = cal_parameters(self.disc_classifier, filter_func=lambda x: not x.requires_grad) n_trainable = cal_parameters(self.disc_classifier, filter_func=lambda x: x.requires_grad) n_T = cal_parameters(self.local_MInet) + cal_parameters( self.global_MInet) n_C = cal_parameters(self.class_conditional) base_desc = 'Base classifier, # fixed parameters: {}, # trainable parameters: {}'.format( n_fixed, n_trainable) T_desc = 'MI evaluation network, #parameters: {}.'.format(n_T) class_con_desc = 'Class conditional embedding layer, #parameters: {}.'.format( n_C) return base_desc, T_desc, class_con_desc
def run(args: DictConfig) -> None: torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) data_dir = hydra.utils.to_absolute_path(args.data_dir) clean_test_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=False, crop_flip=False) #advset_at = TensorDataset(torch.load(os.path.join(data_dir, 'advset_{}_at_fast.pt'.format(args.classifier_name)))) advset_clean = torch.load(os.path.join(data_dir, 'advset_{}_clean.pt'.format(args.classifier_name))) clean_loader = DataLoader(dataset=clean_test_data, batch_size=args.n_batch_test, shuffle=False) advset_loader = DataLoader(dataset=advset_clean, batch_size=args.n_batch_test, shuffle=False) results_dict = dict() for width in args.width_list: classifier_list = [] for split_id in range(args.n_split): classifier = eval(args.classifier_name)(width, args.n_classes).to(args.device) logger.info('Classifier: {}, width: {}, # parameters: {}' .format(args.classifier_name, width, cal_parameters(classifier))) checkpoint = '{}_w{}_split{}.pth'.format(args.classifier_name, width, split_id) classifier.load_state_dict(torch.load(checkpoint)) classifier_list.append(classifier) results_dict['clean_on_clean_w{}'.format(width)] = eval_risk_bias_variance(classifier_list, clean_loader, args) results_dict['clean_on_adv_w{}'.format(width)] = eval_risk_bias_variance(classifier_list, advset_loader, args) del classifier_list classifier_list = [] for split_id in range(args.n_split): classifier = eval(args.classifier_name)(width, args.n_classes).to(args.device) logger.info('Classifier: {}, width: {}, # parameters: {}' .format(args.classifier_name, width, cal_parameters(classifier))) checkpoint = '{}_w{}_split{}_at_fast.pth'.format(args.classifier_name, width, split_id) classifier.load_state_dict(torch.load(checkpoint)) classifier_list.append(classifier) results_dict['adv_on_clean_w{}'.format(width)] = eval_risk_bias_variance(classifier_list, clean_loader, args) results_dict['adv_on_adv_w{}'.format(width)] = eval_risk_bias_variance(classifier_list, advset_loader, args) torch.save(results_dict, 'adv_eval_width_results.pt')
def run(args: DictConfig) -> None: cuda_available = torch.cuda.is_available() torch.manual_seed(args.seed) device = "cuda" if cuda_available and args.device == 'cuda' else "cpu" n_classes = args.get(args.dataset).n_classes if args.dataset == 'tiny_imagenet': args.epochs = 20 args.learning_rate = 0.001 classifier = get_model_for_tiny_imagenet(args.classifier_name, n_classes).to(device) args.data_dir = 'tiny_imagenet' else: classifier = get_model(name=args.classifier_name, n_classes=n_classes).to(device) # if device == 'cuda' and args.n_gpu > 1: # classifier = torch.nn.DataParallel(classifier, device_ids=list(range(args.n_gpu))) logger.info('Base classifier name: {}, # parameters: {}'.format( args.classifier_name, cal_parameters(classifier))) data_dir = hydra.utils.to_absolute_path(args.data_dir) train_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=True, crop_flip=True) test_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=False, crop_flip=False) train_loader = DataLoader(dataset=train_data, batch_size=args.n_batch_train, shuffle=True) test_loader = DataLoader(dataset=test_data, batch_size=args.n_batch_test, shuffle=False) if args.inference: save_name = '{}.pth'.format(args.classifier_name) classifier.load_state_dict( torch.load(save_name, map_location=lambda storage, loc: storage)) loss, acc = run_epoch(classifier, test_loader, args) logger.info('Inference loss: {:.4f}, acc: {:.4f}'.format(loss, acc)) else: train(classifier, train_loader, test_loader, args)
def inference(hps: DictConfig) -> None: # This enables a ctr-C without triggering errors import signal signal.signal(signal.SIGINT, lambda x, y: sys.exit(0)) logger = logging.getLogger(__name__) cuda_available = torch.cuda.is_available() torch.manual_seed(hps.seed) device = "cuda" if cuda_available and hps.device == 'cuda' else "cpu" # Models local_channel = hps.get(hps.base_classifier).last_conv_channel classifier = get_model(model_name=hps.base_classifier, in_size=local_channel, out_size=hps.rep_size).to(hps.device) logger.info('Base classifier name: {}, # parameters: {}'.format( hps.base_classifier, cal_parameters(classifier))) sdim = SDIM(disc_classifier=classifier, mi_units=hps.mi_units, n_classes=hps.n_classes, margin=hps.margin, rep_size=hps.rep_size, local_channel=local_channel).to(hps.device) model_path = 'SDIM_{}.pth'.format(hps.base_classifier) base_dir = '/userhome/cs/u3003679/generative-classification-with-rejection' path = os.path.join(base_dir, model_path) sdim.load_state_dict(torch.load(path)['model_state']) # logging the SDIM desc. for desc in sdim.desc(): logger.info(desc) eval_loader = Loader('eval', batch_size=hps.n_batch_test, device=device) if cuda_available and hps.n_gpu > 1: sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu))) torch.manual_seed(hps.seed) np.random.seed(hps.seed) n_iters = 0 top1 = AverageMeter('Acc@1') top5 = AverageMeter('Acc@5') sdim.eval() for x, y in eval_loader: n_iters += 1 if n_iters == len(eval_loader): break with torch.no_grad(): log_lik = sdim.infer(x) acc1, acc5 = accuracy(log_lik, y, topk=(1, 5)) top1.update(acc1, x.size(0)) top5.update(acc5, x.size(0)) logger.info('Test Acc@1: {:.3f}, Acc@5: {:.3f}'.format(top1.avg, top5.avg))
# prepare test set pair_filename = os.path.join( args.test_set, 'm50_{}_{}_0.txt'.format(args.n_samples_test, args.n_samples_test)) pairs, labels = ReadPairs(pair_filename) test_set = CustomDataset(pairs, labels, args.test_set, transform=None) test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) model = ComposedModel(residual=args.res_feature_net).cuda() print('Model parameters: {}'.format(cal_parameters(model))) if args.eval: if args.res_feature_net: state_dict = torch.load('res_model_{}.pt'.format(args.train_set)) else: state_dict = torch.load('model_{}.pt'.format(args.train_set)) model.load_state_dict(state_dict) model.eval() score_list = [] label_list = [] for idx, (left, right, label) in enumerate(test_loader): left = preprocess(left).cuda() right = preprocess(right).cuda() label = label.cuda()
def run(args: DictConfig) -> None: # Load datasets train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4) ]) preprocess = transforms.Compose( [transforms.ToTensor(), transforms.Normalize([mean_] * 3, [std_] * 3)]) test_transform = preprocess data_dir = hydra.utils.to_absolute_path(args.data_dir) if args.dataset == 'cifar10': train_data = datasets.CIFAR10(data_dir, train=True, transform=train_transform, download=True) test_data = datasets.CIFAR10(data_dir, train=False, transform=test_transform, download=True) base_c_path = os.path.join(data_dir, 'CIFAR-10-C/') args.n_classes = 10 else: train_data = datasets.CIFAR100(data_dir, train=True, transform=train_transform, download=True) test_data = datasets.CIFAR100(data_dir, train=False, transform=test_transform, download=True) base_c_path = os.path.join(data_dir, 'CIFAR-100-C/') args.n_classes = 100 train_data = AugMixDataset(train_data, preprocess, args, args.no_jsd) train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) test_loader = DataLoader(test_data, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) # Create model if args.model == 'densenet': classifier = densenet(num_classes=args.n_classes) elif args.model == 'wide_resnet': n_layers = 40 widen_factor = 2 droprate = 0. classifier = WideResNet(n_layers, args.n_classes, widen_factor, droprate) elif args.model == 'resnext': classifier = resnext29(num_classes=args.n_classes) classifier = classifier.to(args.device) logger.info('Model: {}, # parameters: {}'.format( args.model, cal_parameters(classifier))) cudnn.benchmark = True classifier = torch.nn.DataParallel(classifier).to(args.device) if args.inference: classifier.load_state_dict( torch.load('{}_{}.pth'.format(args.model, args.augmentation_type))) test_loss, test_acc = eval_epoch(classifier, test_loader, args, adversarial=False) logger.info('Clean Test CE:{:.4f}, acc:{:.4f}'.format( test_loss, test_acc)) else: optimizer = torch.optim.SGD(classifier.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) best_acc = 0 pre_adv_acc = 0 scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: get_lr( # pylint: disable=g-long-lambda step, args.n_epochs * len(train_loader), 1, # lr_lambda computes multiplicative factor 1e-6 / args.learning_rate)) for epoch in range(args.n_epochs): if args.augmentation_type: loss, ce_loss, js_loss, acc = train_epoch_advmix( classifier, train_loader, args, optimizer, scheduler) else: loss, ce_loss, js_loss, acc = train_epoch( classifier, train_loader, args, optimizer, scheduler) lr = scheduler.get_lr()[0] logger.info( 'Epoch {}, lr:{:.4f}, loss:{:.4f}, CE:{:.4f}, JS:{:.4f}, Acc:{:.4f}' .format(epoch + 1, lr, loss, ce_loss, js_loss, acc)) test_loss, test_acc = eval_epoch(classifier, test_loader, args, adversarial=False) logger.info('Test CE:{:.4f}, acc:{:.4f}'.format( test_loss, test_acc)) adv_loss, adv_acc = eval_epoch(classifier, test_loader, args, adversarial=True) logger.info('Adversarial evaluation, CE:{:.4f}, acc:{:.4f}'.format( adv_loss, adv_acc)) if test_acc > best_acc: best_acc = test_acc if adv_acc + 0.1 < pre_adv_acc: pre_adv_acc = adv_acc logger.info( "Catastrophic overfitting happens, early stopping") break logging.info('===> New optimal, save checkpoint ...') torch.save( classifier.state_dict(), '{}_{}.pth'.format(args.model, args.augmentation_type)) test_c_acc = eval_c(classifier, base_c_path, args) logger.info('Mean Corruption Error:{:.4f}'.format(test_c_acc))
def run(args: DictConfig) -> None: assert torch.cuda.is_available() torch.manual_seed(args.seed) n_classes = args.get(args.dataset).n_classes classifier = resnet18(n_classes=n_classes).to(args.device) logger.info('Base classifier resnet18: # parameters {}'.format( cal_parameters(classifier))) data_dir = hydra.utils.to_absolute_path(args.data_dir) train_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=True, crop_flip=True) test_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=False, crop_flip=False) train_loader = DataLoader(dataset=train_data, batch_size=args.n_batch_train, shuffle=True) test_loader = DataLoader(dataset=test_data, batch_size=args.n_batch_test, shuffle=False) if args.inference is True: classifier.load_state_dict( torch.load('resnet18_wd{}.pth'.format(args.weight_decay))) logger.info('Load classifier from checkpoint') else: optimizer = torch.optim.SGD(classifier.parameters(), args.lr_max, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_steps = args.epochs * len(train_loader) if args.lr_schedule == 'cyclic': scheduler = torch.optim.lr_scheduler.CyclicLR( optimizer, base_lr=args.lr_min, max_lr=args.lr_max, step_size_up=lr_steps / 2, step_size_down=lr_steps / 2) elif args.lr_schedule == 'multistep': scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[lr_steps / 2, lr_steps * 3 / 4], gamma=0.1) else: raise Exception("scheduler not implemented.") optimal_loss = 1e5 for epoch in range(1, args.epochs + 1): loss, acc = train_epoch(classifier, train_loader, args, optimizer, scheduler) lr = scheduler.get_lr()[0] logger.info('Epoch {}, lr:{:.4f}, loss:{:.4f}, Acc:{:.4f}'.format( epoch, lr, loss, acc)) if loss < optimal_loss: optimal_loss = loss torch.save(classifier.state_dict(), 'resnet18_at.pth') clean_loss, clean_acc = eval_epoch(classifier, test_loader, args, adversarial=False) logger.info('Clean loss: {:.4f}, acc: {:.4f}'.format( clean_loss, clean_acc)) adv_loss, adv_acc = eval_epoch(classifier, test_loader, args, adversarial=True) logger.info('Adversarial loss: {:.4f}, acc: {:.4f}'.format( adv_loss, adv_acc))
else: train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True) test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True) args.n_classes = 100 train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.prefetch, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=args.prefetch, pin_memory=True) # Init checkpoints if not os.path.isdir(args.save): os.makedirs(args.save) # Init model, criterion, and optimizer classifier = CifarResNeXt(args.cardinality, args.depth, args.n_classes, args.base_width, args.widen_factor).to(args.device) print('# Classifier parameters: ', cal_parameters(classifier)) save_name = 'ResNeXt{}_{}x{}d.pth'.format(args.depth, args.cardinality, args.base_width) check_point = torch.load(os.path.join(args.save, save_name)) classifier.load_state_dict(check_point['model_state']) train_acc = check_point['train_acc'] test_acc = check_point['test_acc'] print('Original Discriminative Classifier, train acc: {:.4f}, test acc: {:.4f}'.format(train_acc, test_acc)) sdim = SDIM(disc_classifier=classifier, rep_size=args.rep_size, mi_units=args.mi_units, n_classes=args.n_classes).to(args.device) optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad is True, sdim.parameters()), lr=args.learning_rate) if use_cuda and args.n_gpu > 1:
def run(args: DictConfig) -> None: # cuda_available = torch.cuda.is_available() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # device = "cuda" if cuda_available and args.device == 'cuda' else "cpu" classifier = eval(args.classifier_name)(args.width, args.n_classes).to(args.device) logger.info('Classifier: {}, width: {}, # parameters: {}'.format( args.classifier_name, args.width, cal_parameters(classifier))) data_dir = hydra.utils.to_absolute_path(args.data_dir) train_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=True, crop_flip=True) test_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=False, crop_flip=False) test_loader = DataLoader(dataset=test_data, batch_size=args.n_batch_test, shuffle=False) optimizer = SGD(classifier.parameters(), lr=args.lr_max, momentum=args.momentum, weight_decay=args.weight_decay) def run_forward(scheduler): optimal_loss = 1e5 for epoch in range(1, args.n_epochs + 1): loss, acc = train_epoch(classifier, train_loader, args, optimizer, scheduler=scheduler) if loss < optimal_loss: optimal_loss = loss torch.save(classifier.state_dict(), checkpoint) logger.info( 'Epoch {}, lr: {:.4f}, loss: {:.4f}, acc: {:.4f}'.format( epoch, scheduler.get_lr()[0], loss, acc)) if args.adv_generation: checkpoint = '{}_w{}_at_fast.pth'.format(args.classifier_name, args.width) train_loader = DataLoader(dataset=train_data, batch_size=args.n_batch_train, shuffle=True) lr_steps = args.n_epochs * len(train_loader) scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=args.lr_min, max_lr=args.lr_max, step_size_up=lr_steps / 2, step_size_down=lr_steps / 2) run_forward(scheduler) clean_loss, clean_acc = eval_epoch(classifier, test_loader, args, adversarial=False) adv_loss, adv_acc = eval_epoch(classifier, test_loader, args, adversarial=True, save=True) logger.info('Clean loss: {:.4f}, acc: {:.4f}'.format( clean_loss, clean_acc)) logger.info('Adversarial loss: {:.4f}, acc: {:.4f}'.format( adv_loss, adv_acc)) else: n = len(train_data) split_size = n // args.n_split lengths = [split_size] * (args.n_split - 1) + [ n % split_size + split_size ] datasets_list = random_split(train_data, lengths=lengths) for split_id, dataset in enumerate(datasets_list): checkpoint = '{}_w{}_split{}_at_fast.pth'.format( args.classifier_name, args.width, split_id) logger.info('Running on subset {}, size: {}'.format( split_id + 1, len(dataset))) train_loader = DataLoader(dataset=dataset, batch_size=args.n_batch_train, shuffle=True) lr_steps = args.n_epochs * len(train_loader) scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=args.lr_min, max_lr=args.lr_max, step_size_up=lr_steps / 2, step_size_down=lr_steps / 2) run_forward(scheduler) clean_loss, clean_acc = eval_epoch(classifier, test_loader, args, adversarial=False) adv_loss, adv_acc = eval_epoch(classifier, test_loader, args, adversarial=True) logger.info('Clean loss: {:.4f}, acc: {:.4f}'.format( clean_loss, clean_acc)) logger.info('Adversarial loss: {:.4f}, acc: {:.4f}'.format( adv_loss, adv_acc))
def run(args: DictConfig) -> None: cuda_available = torch.cuda.is_available() torch.manual_seed(args.seed) device = "cuda" if cuda_available and args.device == 'cuda' else "cpu" # n_classes = args.n_classes # classifier = get_model(name=args.classifier_name, n_classes=n_classes).to(device) classifier = PreActResNet18().to(device) logger.info('Classifier: {}, # parameters: {}'.format( args.classifier_name, cal_parameters(classifier))) data_dir = hydra.utils.to_absolute_path(args.data_dir) train_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=True, crop_flip=True) test_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=False, crop_flip=False) train_loader = DataLoader(dataset=train_data, batch_size=args.n_batch_train, shuffle=True) test_loader = DataLoader(dataset=test_data, batch_size=args.n_batch_test, shuffle=False) if args.inference is True: classifier.load_state_dict( torch.load('{}_at.pth'.format(args.classifier_name))) logger.info('Load classifier from checkpoint') else: # optimizer = SGD(classifier.parameters(), lr=args.lr_max, momentum=args.momentum, weight_decay=args.weight_decay) # lr_steps = args.n_epochs * len(train_loader) # scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=args.lr_min, max_lr=args.lr_max, # step_size_up=lr_steps/2, step_size_down=lr_steps/2) optimizer = torch.optim.SGD(classifier.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: get_lr( # pylint: disable=g-long-lambda step, args.n_epochs * len(train_loader), 1, # lr_lambda computes multiplicative factor 1e-6 / args.learning_rate)) optimal_loss = 1e5 for epoch in range(1, args.n_epochs + 1): loss, acc = train_epoch(classifier, train_loader, args, optimizer, scheduler=scheduler) lr = scheduler.get_lr()[0] logger.info('Epoch {}, lr:{:.4f}, loss:{:.4f}, Acc:{:.4f}'.format( epoch, lr, loss, acc)) if loss < optimal_loss: optimal_loss = loss torch.save(classifier.state_dict(), '{}_at.pth'.format(args.classifier_name)) clean_loss, clean_acc = eval_epoch(classifier, test_loader, args, adversarial=False) adv_loss, adv_acc = eval_epoch(classifier, test_loader, args, adversarial=True) logger.info('Clean loss: {:.4f}, acc: {:.4f}'.format( clean_loss, clean_acc)) logger.info('Adversarial loss: {:.4f}, acc: {:.4f}'.format( adv_loss, adv_acc))
map_location=lambda storage, loc: storage)) else: n_encoder_layers = int(hps.encoder_name.strip('resnet')) model = build_resnet_32x32(n=n_encoder_layers, fc_size=hps.n_classes, image_channel=hps.image_channel).to( hps.device) checkpoint_path = os.path.join( hps.log_dir, '{}_{}.pth'.format(hps.encoder_name, hps.problem)) model.load_state_dict( torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) print('Model name: {}'.format(hps.encoder_name)) print('==> # Model parameters: {}.'.format(cal_parameters(model))) if not os.path.exists(hps.log_dir): os.mkdir(hps.log_dir) if not os.path.exists(hps.attack_dir): os.mkdir(hps.attack_dir) if hps.attack == 'pgdinf': linfPGD_attack(model, hps) elif hps.attack == 'jsma': jsma_attack(model, hps) elif hps.attack == 'cw': cw_l2_attack(model, hps) elif hps.attack == 'fgsm': fgsm_attack(model, hps)
def train(hps: DictConfig) -> None: # This enables a ctr-C without triggering errors import signal signal.signal(signal.SIGINT, lambda x, y: sys.exit(0)) logger = logging.getLogger(__name__) cuda_available = torch.cuda.is_available() torch.manual_seed(hps.seed) device = "cuda" if cuda_available and hps.device == 'cuda' else "cpu" # Models local_channel = hps.get(hps.base_classifier).last_conv_channel classifier = get_model(model_name=hps.base_classifier, in_size=local_channel, out_size=hps.rep_size).to(hps.device) logger.info('Base classifier name: {}, # parameters: {}'.format( hps.base_classifier, cal_parameters(classifier))) sdim = SDIM(disc_classifier=classifier, mi_units=hps.mi_units, n_classes=hps.n_classes, margin=hps.margin, rep_size=hps.rep_size, local_channel=local_channel).to(hps.device) # logging the SDIM desc. for desc in sdim.desc(): logger.info(desc) train_loader = Loader('train', batch_size=hps.n_batch_train, device=device) if cuda_available and hps.n_gpu > 1: sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu))) optimizer = Adam(filter(lambda param: param.requires_grad is True, sdim.parameters()), lr=hps.lr) torch.manual_seed(hps.seed) np.random.seed(hps.seed) # Create log dir logdir = os.path.abspath(hps.log_dir) + "/" if not os.path.exists(logdir): os.mkdir(logdir) loss_optimal = 1e5 n_iters = 0 losses = AverageMeter('Loss') MIs = AverageMeter('MI') nlls = AverageMeter('NLL') margins = AverageMeter('Margin') top1 = AverageMeter('Acc@1') top5 = AverageMeter('Acc@5') for x, y in train_loader: n_iters += 1 if n_iters == hps.training_iters: break # backward optimizer.zero_grad() loss, mi_loss, nll_loss, ll_margin, log_lik = sdim(x, y) loss.mean().backward() optimizer.step() acc1, acc5 = accuracy(log_lik, y, topk=(1, 5)) losses.update(loss.item(), x.size(0)) top1.update(acc1, x.size(0)) top5.update(acc5, x.size(0)) MIs.update(mi_loss.item(), x.size(0)) nlls.update(nll_loss.item(), x.size(0)) margins.update(ll_margin.item(), x.size(0)) if n_iters % hps.log_interval == hps.log_interval - 1: logger.info( 'Train loss: {:.4f}, mi: {:.4f}, nll: {:.4f}, ll_margin: {:.4f}' .format(losses.avg, MIs.avg, nlls.avg, margins.avg)) logger.info('Train Acc@1: {:.3f}, Acc@5: {:.3f}'.format( top1.avg, top5.avg)) if losses.avg < loss_optimal: loss_optimal = losses.avg model_path = 'SDIM_{}.pth'.format(hps.base_classifier) if cuda_available and hps.n_gpu > 1: state = sdim.module.state_dict() else: state = sdim.state_dict() check_point = { 'model_state': state, 'train_acc_top1': top1.avg, 'train_acc_top5': top5.avg } torch.save(check_point, os.path.join(hps.log_dir, model_path)) losses.reset() MIs.reset() nlls.reset() margins.reset() top1.reset() top5.reset()
num_workers=args.prefetch, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=args.prefetch, pin_memory=True) # Init checkpoints if not os.path.isdir(args.save): os.makedirs(args.save) # Init model, criterion, and optimizer net = CifarResNeXt(args.cardinality, args.depth, n_classes, args.base_width, args.widen_factor).to(args.device) print('# Classifier parameters: ', cal_parameters(net)) if use_cuda and args.n_gpu > 1: net = torch.nn.DataParallel(net, device_ids=list(range(args.n_gpu))) optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.decay, nesterov=True) best_train_loss = np.inf best_accuracy = 0. # train function (forward, backward, update) def train():
def run(args: DictConfig) -> None: # Load datasets train_transform = transforms.Compose( [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4)]) preprocess = transforms.ToTensor() test_transform = preprocess data_dir = hydra.utils.to_absolute_path(args.data_dir) if args.dataset == 'cifar10': train_data = datasets.CIFAR10( data_dir, train=True, transform=train_transform, download=True) test_data = datasets.CIFAR10( data_dir, train=False, transform=test_transform, download=True) base_c_path = os.path.join(data_dir, 'CIFAR-10-C/') # args.n_classes = 10 else: train_data = datasets.CIFAR100( data_dir, train=True, transform=train_transform, download=True) test_data = datasets.CIFAR100( data_dir, train=False, transform=test_transform, download=True) base_c_path = os.path.join(data_dir, 'CIFAR-100-C/') # args.n_classes = 100 train_data = AugMixDataset(train_data, preprocess, args, args.no_jsd) train_loader = DataLoader( train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) test_loader = DataLoader( test_data, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) n_classes = args.get(args.dataset).n_classes classifier = resnet18(n_classes=n_classes).to(args.device) logger.info('Model resnet18, # parameters: {}'.format(cal_parameters(classifier))) cudnn.benchmark = True if args.inference: classifier.load_state_dict(torch.load('resnet18_c.pth')) test_loss, test_acc = eval_epoch(classifier, test_loader, args) logger.info('Clean Test CE:{:.4f}, acc:{:.4f}'.format(test_loss, test_acc)) else: optimizer = torch.optim.SGD( classifier.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) best_loss = 1e5 scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: get_lr( # pylint: disable=g-long-lambda step, args.epochs * len(train_loader), 1, # lr_lambda computes multiplicative factor 1e-6 / args.learning_rate)) for epoch in range(args.epochs): loss, ce_loss, js_loss, acc = train_epoch(classifier, train_loader, args, optimizer, scheduler) lr = scheduler.get_lr()[0] logger.info('Epoch {}, lr:{:.4f}, loss:{:.4f}, CE:{:.4f}, JS:{:.4f}, Acc:{:.4f}' .format(epoch + 1, lr, loss, ce_loss, js_loss, acc)) test_loss, test_acc = eval_epoch(classifier, test_loader, args) logger.info('Clean test CE:{:.4f}, acc:{:.4f}'.format(test_loss, test_acc)) if loss < best_loss: best_loss = loss logging.info('===> New optimal, save checkpoint ...') torch.save(classifier.state_dict(), 'resnet18_c.pth') test_c_acc = eval_c(classifier, base_c_path, args) logger.info('Mean Corruption Error:{:.4f}'.format(1 - test_c_acc))
) def forward(self, x): return self.fc(x) class Projection(nn.Module): def __init__(self, in_dim=4096, hidden_size=1024): super(MetricNet, self).__init__() self.fc = nn.Sequential( nn.Linear(in_dim, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size) ) def forward(self, x): return self.fc(x) if __name__ == "__main__": x = torch.randn(1, 1, 64, 64) # m = FeatureNet() m = ResFeatureNet() o = m(x) print(o.size()) from utils import cal_parameters print(cal_parameters(m))