def test(**kwargs): """ test model on specific cost and specific adversarial perturbation. """ FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() assert FLAGS.nb_its > 0 assert FLAGS.attack_eps >= 0 # dataset dataset_builder = DatasetBuilder(name=FLAGS.dataset, root_path=FLAGS.dataroot) test_dataset = dataset_builder( train=False, normalize=FLAGS.normalize, binary_classification_target=FLAGS.binary_target_class) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers, pin_memory=True) # model features = vgg16_variant(dataset_builder.input_size, FLAGS.dropout_prob).cuda() if FLAGS.binary_target_class is None: model = DeepLinearSvmWithRejector(features, FLAGS.dim_features, dataset_builder.num_classes).cuda() else: model = DeepLinearSvmWithRejector(features, FLAGS.dim_features, 1).cuda() load_model(model, FLAGS.weight) if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # loss if FLAGS.binary_target_class is None: criterion = MaxHingeLossWithRejection(FLAGS.cost) else: criterion = MaxHingeLossBinaryWithRejection(FLAGS.cost) # adversarial attack if FLAGS.attack: # get step_size if not FLAGS.step_size: FLAGS.step_size = get_step_size(FLAGS.attack_eps, FLAGS.nb_its) assert FLAGS.step_size >= 0 # create attacker if FLAGS.attack == 'pgd': if FLAGS.binary_target_class is None: attacker = PGDAttackVariant( FLAGS.nb_its, FLAGS.attack_eps, FLAGS.step_size, dataset=FLAGS.dataset, cost=FLAGS.cost, norm=FLAGS.attack_norm, num_classes=dataset_builder.num_classes, is_binary_classification=False) else: attacker = PGDAttackVariant( FLAGS.nb_its, FLAGS.attack_eps, FLAGS.step_size, dataset=FLAGS.dataset, cost=FLAGS.cost, norm=FLAGS.attack_norm, num_classes=dataset_builder.num_classes, is_binary_classification=True) else: raise NotImplementedError('invalid attack method.') # pre epoch test_metric_dict = MetricDict() # test for i, (x, t) in enumerate(test_loader): model.eval() x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) loss_dict = OrderedDict() # adversarial samples if FLAGS.attack and FLAGS.attack_eps > 0: # create adversarial sampels model.zero_grad() x = attacker(model, x.detach(), t.detach()) with torch.autograd.no_grad(): model.zero_grad() # forward out_class, out_reject = model(x) # compute selective loss maxhinge_loss, loss_dict = criterion(out_class, out_reject, t) loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item() # compute standard cross entropy loss # regularization_loss = WeightPenalty()(model.classifier) # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item() # total loss loss = maxhinge_loss #+ regularization_loss loss_dict['loss'] = loss.detach().cpu().item() # evaluation if FLAGS.binary_target_class is None: evaluator = Evaluator(out_class.detach(), t.detach(), out_reject.detach(), FLAGS.cost) else: evaluator = Evaluator(out_class.detach().view(-1), t.detach().view(-1), out_reject.detach().view(-1), FLAGS.cost) loss_dict.update(evaluator()) test_metric_dict.update(loss_dict) # post epoch print_metric_dict(None, None, test_metric_dict.avg, mode='test') return test_metric_dict.avg
def train(**kwargs): FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() os.makedirs(FLAGS.log_dir, exist_ok=True) FLAGS.dump( path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix))) # dataset dataset_builder = DatasetBuilder(name=FLAGS.dataset, root_path=FLAGS.dataroot) train_dataset = dataset_builder( train=True, normalize=FLAGS.normalize, binary_classification_target=FLAGS.binary_target_class) val_dataset = dataset_builder( train=False, normalize=FLAGS.normalize, binary_classification_target=FLAGS.binary_target_class) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers, pin_memory=True) # model features = vgg16_variant(dataset_builder.input_size, FLAGS.dropout_prob).cuda() model = DeepLinearSvmWithRejector(features, FLAGS.dim_features, num_classes=1).cuda() if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # optimizer params = model.parameters() optimizer = torch.optim.SGD(params, lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=FLAGS.wd) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5) # loss MHBRLoss = MaxHingeLossBinaryWithRejection(FLAGS.cost) # attacker if FLAGS.at and FLAGS.at_eps > 0: # get step_size if not FLAGS.step_size: FLAGS.step_size = get_step_size(FLAGS.at_eps, FLAGS.nb_its) assert FLAGS.step_size >= 0 # create attacker if FLAGS.at == 'pgd': attacker = PGDAttackVariant( FLAGS.nb_its, FLAGS.at_eps, FLAGS.step_size, dataset=FLAGS.dataset, cost=FLAGS.cost, norm=FLAGS.at_norm, num_classes=dataset_builder.num_classes, is_binary_classification=True) else: raise NotImplementedError('invalid at method.') # logger train_logger = Logger(path=os.path.join( FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)), mode='train', use_wandb=False, flags=FLAGS._dict) val_logger = Logger(path=os.path.join( FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)), mode='val', use_wandb=FLAGS.use_wandb, flags=FLAGS._dict) for ep in range(FLAGS.num_epochs): # pre epoch train_metric_dict = MetricDict() val_metric_dict = MetricDict() # train for i, (x, t) in enumerate(train_loader): x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # adversarial attack if FLAGS.at and FLAGS.at_eps > 0: model.eval() model.zero_grad() x = attacker(model, x.detach(), t.detach()) # forward model.train() model.zero_grad() out_class, out_reject = model(x) # compute selective loss loss_dict = OrderedDict() # loss dict includes, 'A mean' / 'B mean' maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t) loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item() # regularization_loss = 0.5*WeightPenalty()(model.classifier) # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item() # total loss loss = maxhinge_loss #+ regularization_loss loss_dict['loss'] = loss.detach().cpu().item() # backward optimizer.zero_grad() loss.backward() optimizer.step() train_metric_dict.update(loss_dict) # validation for i, (x, t) in enumerate(val_loader): x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # adversarial attack if FLAGS.at and FLAGS.at_eps > 0: model.eval() model.zero_grad() x = attacker(model, x.detach(), t.detach()) with torch.autograd.no_grad(): # forward model.eval() model.zero_grad() out_class, out_reject = model(x) # compute selective loss loss_dict = OrderedDict() # loss dict includes, 'A mean' / 'B mean' maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t) loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item( ) # regularization_loss = 0.5*WeightPenalty()(model.classifier) # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item() # total loss loss = maxhinge_loss #+ regularization_loss loss_dict['loss'] = loss.detach().cpu().item() # evaluation evaluator = Evaluator(out_class.detach().view(-1), t.detach().view(-1), out_reject.detach().view(-1)) loss_dict.update(evaluator()) val_metric_dict.update(loss_dict) # post epoch # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train') print_metric_dict(ep, FLAGS.num_epochs, val_metric_dict.avg, mode='val') train_logger.log(train_metric_dict.avg, step=(ep + 1)) val_logger.log(val_metric_dict.avg, step=(ep + 1)) scheduler.step() # post training save_model(model, path=os.path.join(FLAGS.log_dir, 'weight_final{}.pth'.format(FLAGS.suffix)))
for m in module.modules(): if isinstance(m, torch.nn.Conv2d): torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: torch.nn.init.constant_(m.bias, 0) elif isinstance(m, torch.nn.BatchNorm1d): torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.bias, 0) elif isinstance(m, torch.nn.Linear): torch.nn.init.normal_(m.weight, 0, 0.01) torch.nn.init.constant_(m.bias, 0) if __name__ == '__main__': import os import sys base = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../') sys.path.append(base) from atro.vgg_variant import vgg16_variant features = vgg16_variant(32, 0.3).cuda() model = DeepLinearSvmWithRejector(features, 512, 10).cuda() for m in model.classifier.modules(): if isinstance(m, torch.nn.Linear): print(m.weight.shape) print(torch.matmul(m.weight.t(), m.weight).shape)