def train(**kwargs): FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() FLAGS.dump( path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix))) # dataset dataset_builder = DatasetBuilder(name=FLAGS.dataset, root_path=FLAGS.dataroot) train_dataset = dataset_builder(train=True, normalize=FLAGS.normalize) val_dataset = dataset_builder(train=False, normalize=FLAGS.normalize) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers, pin_memory=True) # model features = vgg16_variant(dataset_builder.input_size, FLAGS.dropout_prob).cuda() model = SelectiveNet(features, FLAGS.dim_features, dataset_builder.num_classes).cuda() if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # optimizer params = model.parameters() optimizer = torch.optim.SGD(params, lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=FLAGS.wd) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5) # loss base_loss = torch.nn.CrossEntropyLoss(reduction='none') SelectiveCELoss = SelectiveLoss(base_loss, coverage=FLAGS.coverage) # logger train_logger = Logger(path=os.path.join( FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)), mode='train') val_logger = Logger(path=os.path.join( FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)), mode='val') for ep in range(FLAGS.num_epochs): # pre epoch train_metric_dict = MetricDict() val_metric_dict = MetricDict() # train for i, (x, t) in enumerate(train_loader): model.train() x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # forward out_class, out_select, out_aux = model(x) # compute selective loss loss_dict = OrderedDict() # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty' selective_loss, loss_dict = SelectiveCELoss( out_class, out_select, t) selective_loss *= FLAGS.alpha loss_dict['selective_loss'] = selective_loss.detach().cpu().item() # compute standard cross entropy loss ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t) ce_loss *= (1.0 - FLAGS.alpha) loss_dict['ce_loss'] = ce_loss.detach().cpu().item() # total loss loss = selective_loss + ce_loss loss_dict['loss'] = loss.detach().cpu().item() # backward optimizer.zero_grad() loss.backward() optimizer.step() train_metric_dict.update(loss_dict) # validation with torch.autograd.no_grad(): for i, (x, t) in enumerate(val_loader): model.eval() x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # forward out_class, out_select, out_aux = model(x) # compute selective loss loss_dict = OrderedDict() # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty' selective_loss, loss_dict = SelectiveCELoss( out_class, out_select, t) selective_loss *= FLAGS.alpha loss_dict['selective_loss'] = selective_loss.detach().cpu( ).item() # compute standard cross entropy loss ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t) ce_loss *= (1.0 - FLAGS.alpha) loss_dict['ce_loss'] = ce_loss.detach().cpu().item() # total loss loss = selective_loss + ce_loss loss_dict['loss'] = loss.detach().cpu().item() # evaluation evaluator = Evaluator(out_class.detach(), t.detach(), out_select.detach()) loss_dict.update(evaluator()) val_metric_dict.update(loss_dict) # post epoch # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train') print_metric_dict(ep, FLAGS.num_epochs, val_metric_dict.avg, mode='val') train_logger.log(train_metric_dict.avg, step=(ep + 1)) val_logger.log(val_metric_dict.avg, step=(ep + 1)) scheduler.step() # post training save_model(model, path=os.path.join(FLAGS.log_dir, 'weight_final{}.pth'.format(FLAGS.suffix)))
def train(**kwargs): FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() os.makedirs(FLAGS.log_dir, exist_ok=True) FLAGS.dump( path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix))) # dataset dataset_builder = DatasetBuilder(name=FLAGS.dataset, root_path=FLAGS.dataroot) train_dataset = dataset_builder( train=True, normalize=FLAGS.normalize, binary_classification_target=FLAGS.binary_target_class) val_dataset = dataset_builder( train=False, normalize=FLAGS.normalize, binary_classification_target=FLAGS.binary_target_class) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers, pin_memory=True) # model features = vgg16_variant(dataset_builder.input_size, FLAGS.dropout_prob).cuda() model = DeepLinearSvmWithRejector(features, FLAGS.dim_features, num_classes=1).cuda() if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # optimizer params = model.parameters() optimizer = torch.optim.SGD(params, lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=FLAGS.wd) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5) # loss MHBRLoss = MaxHingeLossBinaryWithRejection(FLAGS.cost) # attacker if FLAGS.at and FLAGS.at_eps > 0: # get step_size if not FLAGS.step_size: FLAGS.step_size = get_step_size(FLAGS.at_eps, FLAGS.nb_its) assert FLAGS.step_size >= 0 # create attacker if FLAGS.at == 'pgd': attacker = PGDAttackVariant( FLAGS.nb_its, FLAGS.at_eps, FLAGS.step_size, dataset=FLAGS.dataset, cost=FLAGS.cost, norm=FLAGS.at_norm, num_classes=dataset_builder.num_classes, is_binary_classification=True) else: raise NotImplementedError('invalid at method.') # logger train_logger = Logger(path=os.path.join( FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)), mode='train', use_wandb=False, flags=FLAGS._dict) val_logger = Logger(path=os.path.join( FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)), mode='val', use_wandb=FLAGS.use_wandb, flags=FLAGS._dict) for ep in range(FLAGS.num_epochs): # pre epoch train_metric_dict = MetricDict() val_metric_dict = MetricDict() # train for i, (x, t) in enumerate(train_loader): x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # adversarial attack if FLAGS.at and FLAGS.at_eps > 0: model.eval() model.zero_grad() x = attacker(model, x.detach(), t.detach()) # forward model.train() model.zero_grad() out_class, out_reject = model(x) # compute selective loss loss_dict = OrderedDict() # loss dict includes, 'A mean' / 'B mean' maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t) loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item() # regularization_loss = 0.5*WeightPenalty()(model.classifier) # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item() # total loss loss = maxhinge_loss #+ regularization_loss loss_dict['loss'] = loss.detach().cpu().item() # backward optimizer.zero_grad() loss.backward() optimizer.step() train_metric_dict.update(loss_dict) # validation for i, (x, t) in enumerate(val_loader): x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # adversarial attack if FLAGS.at and FLAGS.at_eps > 0: model.eval() model.zero_grad() x = attacker(model, x.detach(), t.detach()) with torch.autograd.no_grad(): # forward model.eval() model.zero_grad() out_class, out_reject = model(x) # compute selective loss loss_dict = OrderedDict() # loss dict includes, 'A mean' / 'B mean' maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t) loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item( ) # regularization_loss = 0.5*WeightPenalty()(model.classifier) # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item() # total loss loss = maxhinge_loss #+ regularization_loss loss_dict['loss'] = loss.detach().cpu().item() # evaluation evaluator = Evaluator(out_class.detach().view(-1), t.detach().view(-1), out_reject.detach().view(-1)) loss_dict.update(evaluator()) val_metric_dict.update(loss_dict) # post epoch # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train') print_metric_dict(ep, FLAGS.num_epochs, val_metric_dict.avg, mode='val') train_logger.log(train_metric_dict.avg, step=(ep + 1)) val_logger.log(val_metric_dict.avg, step=(ep + 1)) scheduler.step() # post training save_model(model, path=os.path.join(FLAGS.log_dir, 'weight_final{}.pth'.format(FLAGS.suffix)))
def train(**kwargs): """ this function executes standard training and adversarial training. """ # flags FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() os.makedirs(FLAGS.log_dir, exist_ok=True) FLAGS.dump( path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix))) # dataset dataset_builder = DatasetBuilder(name=FLAGS.dataset, root_path=FLAGS.dataroot) train_dataset = dataset_builder(train=True, normalize=FLAGS.normalize) val_dataset = dataset_builder(train=False, normalize=FLAGS.normalize) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers, pin_memory=True) # model num_classes = dataset_builder.num_classes model = ModelBuilder(num_classes=num_classes, pretrained=False)[FLAGS.arch].cuda() if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # optimizer params = model.parameters() optimizer = torch.optim.SGD(params, lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=FLAGS.wd) # scheduler assert len(FLAGS.ms) == 0 if len(FLAGS.ms) == 1: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=FLAGS.ms[0], gamma=FLAGS.gamma) else: scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=sorted( list(FLAGS.ms)), gamma=FLAGS.gamma) # attacker if FLAGS.at and FLAGS.at_eps > 0: # get step_size step_size = get_step_size( FLAGS.at_eps, FLAGS.nb_its) if not FLAGS.step_size else FLAGS.step_size FLAGS._dict['step_size'] = step_size assert step_size >= 0 # create attacker attacker = AttackerBuilder()(method=FLAGS.at, norm=FLAGS.at_norm, eps=FLAGS.at_eps, **FLAGS._dict) # logger train_logger = Logger(path=os.path.join( FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)), mode='train', use_wandb=False, flags=FLAGS._dict) val_logger = Logger(path=os.path.join( FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)), mode='val', use_wandb=FLAGS.use_wandb, flags=FLAGS._dict) for ep in range(FLAGS.num_epochs): # pre epoch train_metric_dict = MetricDict() val_metric_dict = MetricDict() # train for i, (x, t) in enumerate(train_loader): x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # adversarial attack if FLAGS.at and FLAGS.at_eps > 0: model.eval() model.zero_grad() x = attacker(model, x.detach(), t.detach()) # forward model.train() model.zero_grad() out = model(x) # compute selective loss loss_dict = OrderedDict() # cross entropy ce_loss = torch.nn.CrossEntropyLoss()(out, t) #loss_dict['ce_loss'] = ce_loss.detach().cpu().item() # total loss loss = ce_loss loss_dict['loss'] = loss.detach().cpu().item() # backward optimizer.zero_grad() loss.backward() optimizer.step() train_metric_dict.update(loss_dict) # validation for i, (x, t) in enumerate(val_loader): x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # adversarial attack if FLAGS.at and FLAGS.at_eps > 0: model.eval() model.zero_grad() x = attacker(model, x.detach(), t.detach()) with torch.autograd.no_grad(): # forward model.eval() model.zero_grad() out = model(x) # compute selective loss loss_dict = OrderedDict() # cross entropy ce_loss = torch.nn.CrossEntropyLoss()(out, t) #loss_dict['ce_loss'] = ce_loss.detach().cpu().item() # total loss loss = ce_loss loss_dict['loss'] = loss.detach().cpu().item() # evaluation evaluator = Evaluator(out.detach(), t.detach(), selection_out=None) loss_dict.update(evaluator()) val_metric_dict.update(loss_dict) # post epoch # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train') print_metric_dict(ep, FLAGS.num_epochs, val_metric_dict.avg, mode='val') train_logger.log(train_metric_dict.avg, step=(ep + 1)) val_logger.log(val_metric_dict.avg, step=(ep + 1)) scheduler.step() # post training save_model(model, path=os.path.join(FLAGS.log_dir, 'weight_final{}.pth'.format(FLAGS.suffix)))