def test_logger(): log_path_root = '/home/gatheluck/Scratch/selectivenet/logs' log_basename = 'log_test_'+get_time_stamp('short') log_path = os.path.join(log_path_root, log_basename) logger = Logger(log_path) log_dict = {'loss01':1.0, 'loss02':2.0} log_dict_ = {'loss01':1.0, 'loss03':3.0} logger.log(log_dict, 1) logger.log(log_dict, 2) logger.log(log_dict, 3) logger.log(log_dict_, 4)
def test_multi_adv(**kwargs): """ this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test. if there is exactly same file, the result becomes the mean of them. the results are saved as csv file. 'target_dir' should be like follow (.pth file name should be "weight_final_cost_{}") ~/target_dir/XXXX/weight_final_cost_0.10_pgd-linf_eps-0.pth ... /weight_final_cost_0.10_pgd-linf_eps-8.pth /weight_final_cost_0.10_pgd-linf_eps-16.pth ... /YYYY/weight_final_cost_0.10_pgd-linf_eps-0.pth ... /weight_final_cost_0.10_pgd-linf_eps-8.pth /weight_final_cost_0.10_pgd-linf_eps-16.pth ... """ # flags FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() # specify target weight path run_dir = '../scripts' target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth') weight_paths = sorted(glob.glob(target_path, recursive=True), key=lambda x: os.path.basename(x)) if FLAGS.cost is not None: weight_paths = [ wpath for wpath in weight_paths if 'cost-{cost:0.2f}'.format(cost=FLAGS.cost) in wpath ] if FLAGS.at is not None: weight_paths = [ wpath for wpath in weight_paths if '{at}-{at_norm}'.format( at=FLAGS.at, at_norm=FLAGS.at_norm) in wpath ] log_path = os.path.join(FLAGS.target_dir, 'test{}.csv'.format(FLAGS.suffix)) # logging logger = Logger(path=log_path, mode='test', use_wandb=False, flags=FLAGS) # get epses key = FLAGS.attack + '_' + FLAGS.attack_norm attack_epses = EPS[key] for weight_path in weight_paths: for attack_eps in attack_epses: # parse basename basename = os.path.basename(weight_path) ret_dict = parse_weight_basename(basename) # keyword args for test function # variable args kw_args = {} kw_args['weight'] = weight_path kw_args['dataset'] = FLAGS.dataset kw_args['dataroot'] = FLAGS.dataroot kw_args['binary_target_class'] = FLAGS.binary_target_class kw_args['cost'] = ret_dict['cost'] kw_args['attack'] = FLAGS.attack kw_args['nb_its'] = FLAGS.nb_its kw_args['step_size'] = None kw_args['attack_eps'] = attack_eps kw_args['attack_norm'] = FLAGS.attack_norm # default args kw_args['dim_features'] = 512 kw_args['dropout_prob'] = 0.3 kw_args['num_workers'] = 8 kw_args['batch_size'] = 128 kw_args['normalize'] = True kw_args['alpha'] = 0.5 # run test out_dict = test(**kw_args) metric_dict = OrderedDict() metric_dict['cost'] = ret_dict['cost'] metric_dict['binary_target_class'] = FLAGS.binary_target_class # at metric_dict['at'] = ret_dict['at'] metric_dict['at_norm'] = ret_dict['at_norm'] metric_dict['at_eps'] = ret_dict['at_eps'] # attack metric_dict['attack'] = FLAGS.attack metric_dict['attack_norm'] = FLAGS.attack_norm metric_dict['attack_eps'] = attack_eps # path metric_dict['path'] = weight_path metric_dict.update(out_dict) # log logger.log(metric_dict)
def train(**kwargs): FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() os.makedirs(FLAGS.log_dir, exist_ok=True) FLAGS.dump( path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix))) # dataset dataset_builder = DatasetBuilder(name=FLAGS.dataset, root_path=FLAGS.dataroot) train_dataset = dataset_builder( train=True, normalize=FLAGS.normalize, binary_classification_target=FLAGS.binary_target_class) val_dataset = dataset_builder( train=False, normalize=FLAGS.normalize, binary_classification_target=FLAGS.binary_target_class) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers, pin_memory=True) # model features = vgg16_variant(dataset_builder.input_size, FLAGS.dropout_prob).cuda() model = DeepLinearSvmWithRejector(features, FLAGS.dim_features, num_classes=1).cuda() if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # optimizer params = model.parameters() optimizer = torch.optim.SGD(params, lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=FLAGS.wd) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5) # loss MHBRLoss = MaxHingeLossBinaryWithRejection(FLAGS.cost) # attacker if FLAGS.at and FLAGS.at_eps > 0: # get step_size if not FLAGS.step_size: FLAGS.step_size = get_step_size(FLAGS.at_eps, FLAGS.nb_its) assert FLAGS.step_size >= 0 # create attacker if FLAGS.at == 'pgd': attacker = PGDAttackVariant( FLAGS.nb_its, FLAGS.at_eps, FLAGS.step_size, dataset=FLAGS.dataset, cost=FLAGS.cost, norm=FLAGS.at_norm, num_classes=dataset_builder.num_classes, is_binary_classification=True) else: raise NotImplementedError('invalid at method.') # logger train_logger = Logger(path=os.path.join( FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)), mode='train', use_wandb=False, flags=FLAGS._dict) val_logger = Logger(path=os.path.join( FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)), mode='val', use_wandb=FLAGS.use_wandb, flags=FLAGS._dict) for ep in range(FLAGS.num_epochs): # pre epoch train_metric_dict = MetricDict() val_metric_dict = MetricDict() # train for i, (x, t) in enumerate(train_loader): x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # adversarial attack if FLAGS.at and FLAGS.at_eps > 0: model.eval() model.zero_grad() x = attacker(model, x.detach(), t.detach()) # forward model.train() model.zero_grad() out_class, out_reject = model(x) # compute selective loss loss_dict = OrderedDict() # loss dict includes, 'A mean' / 'B mean' maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t) loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item() # regularization_loss = 0.5*WeightPenalty()(model.classifier) # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item() # total loss loss = maxhinge_loss #+ regularization_loss loss_dict['loss'] = loss.detach().cpu().item() # backward optimizer.zero_grad() loss.backward() optimizer.step() train_metric_dict.update(loss_dict) # validation for i, (x, t) in enumerate(val_loader): x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # adversarial attack if FLAGS.at and FLAGS.at_eps > 0: model.eval() model.zero_grad() x = attacker(model, x.detach(), t.detach()) with torch.autograd.no_grad(): # forward model.eval() model.zero_grad() out_class, out_reject = model(x) # compute selective loss loss_dict = OrderedDict() # loss dict includes, 'A mean' / 'B mean' maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t) loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item( ) # regularization_loss = 0.5*WeightPenalty()(model.classifier) # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item() # total loss loss = maxhinge_loss #+ regularization_loss loss_dict['loss'] = loss.detach().cpu().item() # evaluation evaluator = Evaluator(out_class.detach().view(-1), t.detach().view(-1), out_reject.detach().view(-1)) loss_dict.update(evaluator()) val_metric_dict.update(loss_dict) # post epoch # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train') print_metric_dict(ep, FLAGS.num_epochs, val_metric_dict.avg, mode='val') train_logger.log(train_metric_dict.avg, step=(ep + 1)) val_logger.log(val_metric_dict.avg, step=(ep + 1)) scheduler.step() # post training save_model(model, path=os.path.join(FLAGS.log_dir, 'weight_final{}.pth'.format(FLAGS.suffix)))
def train(**kwargs): FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() FLAGS.dump( path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix))) # dataset dataset_builder = DatasetBuilder(name=FLAGS.dataset, root_path=FLAGS.dataroot) train_dataset = dataset_builder(train=True, normalize=FLAGS.normalize) val_dataset = dataset_builder(train=False, normalize=FLAGS.normalize) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers, pin_memory=True) # model features = vgg16_variant(dataset_builder.input_size, FLAGS.dropout_prob).cuda() model = SelectiveNet(features, FLAGS.dim_features, dataset_builder.num_classes).cuda() if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # optimizer params = model.parameters() optimizer = torch.optim.SGD(params, lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=FLAGS.wd) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5) # loss base_loss = torch.nn.CrossEntropyLoss(reduction='none') SelectiveCELoss = SelectiveLoss(base_loss, coverage=FLAGS.coverage) # logger train_logger = Logger(path=os.path.join( FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)), mode='train') val_logger = Logger(path=os.path.join( FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)), mode='val') for ep in range(FLAGS.num_epochs): # pre epoch train_metric_dict = MetricDict() val_metric_dict = MetricDict() # train for i, (x, t) in enumerate(train_loader): model.train() x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # forward out_class, out_select, out_aux = model(x) # compute selective loss loss_dict = OrderedDict() # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty' selective_loss, loss_dict = SelectiveCELoss( out_class, out_select, t) selective_loss *= FLAGS.alpha loss_dict['selective_loss'] = selective_loss.detach().cpu().item() # compute standard cross entropy loss ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t) ce_loss *= (1.0 - FLAGS.alpha) loss_dict['ce_loss'] = ce_loss.detach().cpu().item() # total loss loss = selective_loss + ce_loss loss_dict['loss'] = loss.detach().cpu().item() # backward optimizer.zero_grad() loss.backward() optimizer.step() train_metric_dict.update(loss_dict) # validation with torch.autograd.no_grad(): for i, (x, t) in enumerate(val_loader): model.eval() x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # forward out_class, out_select, out_aux = model(x) # compute selective loss loss_dict = OrderedDict() # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty' selective_loss, loss_dict = SelectiveCELoss( out_class, out_select, t) selective_loss *= FLAGS.alpha loss_dict['selective_loss'] = selective_loss.detach().cpu( ).item() # compute standard cross entropy loss ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t) ce_loss *= (1.0 - FLAGS.alpha) loss_dict['ce_loss'] = ce_loss.detach().cpu().item() # total loss loss = selective_loss + ce_loss loss_dict['loss'] = loss.detach().cpu().item() # evaluation evaluator = Evaluator(out_class.detach(), t.detach(), out_select.detach()) loss_dict.update(evaluator()) val_metric_dict.update(loss_dict) # post epoch # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train') print_metric_dict(ep, FLAGS.num_epochs, val_metric_dict.avg, mode='val') train_logger.log(train_metric_dict.avg, step=(ep + 1)) val_logger.log(val_metric_dict.avg, step=(ep + 1)) scheduler.step() # post training save_model(model, path=os.path.join(FLAGS.log_dir, 'weight_final{}.pth'.format(FLAGS.suffix)))
def test_fourier(**kwargs): """ this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test. if there is exactly same file, the result becomes the mean of them. the results are saved as csv file. 'target_dir' should be like follow (.pth file name should be "weight_final_coverage_{}") ~/target_dir/XXXX/weight_final_pgd-linf_eps-0.pth ... /weight_final_pgd-linf_eps-8.pth /weight_final_pgd-linf_eps-16.pth ... /YYYY/weight_final_pgd-linf_eps-0.pth ... /weight_final_pgd-linf_eps-8.pth /weight_final_pgd-linf_eps-16.pth ... """ # flags FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() # paths run_dir = '../scripts' if os.path.splitext(FLAGS.target_dir)[-1] != '.pth': target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth') weight_paths = sorted(glob.glob(target_path, recursive=True), key=lambda x: os.path.basename(x)) log_path = os.path.join(FLAGS.target_dir, 'test{}.csv'.format(FLAGS.suffix)) else: weight_paths = [FLAGS.target_dir] log_path = os.path.join(os.path.dirname(FLAGS.target_dir), 'test{}.csv'.format(FLAGS.suffix)) # logging logger = Logger(path=log_path, mode='test', use_wandb=False, flags=FLAGS) for weight_path in weight_paths: for index_h in range(-FLAGS.fn_max_index_h, FLAGS.fn_max_index_h + 1): for index_w in range(-FLAGS.fn_max_index_w, FLAGS.fn_max_index_w + 1): # continue when indices are 0 if index_h == 0 or index_w == 0: continue # parse basename basename = os.path.basename(weight_path) ret_dict = parse_weight_basename(basename) # keyword args for test function # variable args kw_args = {} kw_args['arch'] = FLAGS.arch kw_args['weight'] = weight_path kw_args['dataset'] = FLAGS.dataset kw_args['dataroot'] = FLAGS.dataroot kw_args['batch_size'] = FLAGS.batch_size kw_args['fn_eps'] = FLAGS.fn_eps kw_args['fn_index_h'] = index_h kw_args['fn_index_w'] = index_w # default args kw_args['num_workers'] = 8 kw_args['normalize'] = True # run test out_dict = test(**kw_args) metric_dict = OrderedDict() # model metric_dict['arch'] = FLAGS.arch # Fourier noise metric_dict['fn_eps'] = FLAGS.fn_eps metric_dict['fn_index_h'] = index_h metric_dict['fn_index_w'] = index_w # at metric_dict['at'] = ret_dict['at'] metric_dict['at_norm'] = ret_dict['at_norm'] metric_dict['at_eps'] = ret_dict['at_eps'] # path metric_dict['path'] = weight_path metric_dict.update(out_dict) # log logger.log(metric_dict)
def test_adv(**kwargs): """ this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test. if there is exactly same file, the result becomes the mean of them. the results are saved as csv file. 'target_dir' should be like follow (.pth file name should be "weight_final_coverage_{}") ~/target_dir/XXXX/weight_final_pgd-linf_eps-0.pth ... /weight_final_pgd-linf_eps-8.pth /weight_final_pgd-linf_eps-16.pth ... /YYYY/weight_final_pgd-linf_eps-0.pth ... /weight_final_pgd-linf_eps-8.pth /weight_final_pgd-linf_eps-16.pth ... """ # flags FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() # paths run_dir = '../scripts' if os.path.splitext(FLAGS.target_dir)[-1] != '.pth': target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth') weight_paths = sorted(glob.glob(target_path, recursive=True), key=lambda x: os.path.basename(x)) log_path = os.path.join(FLAGS.target_dir, 'test{}.csv'.format(FLAGS.suffix)) else: weight_paths = list(FLAGS.target_dir) log_path = os.path.join(os.path.dirname(FLAGS.target_dir), 'test{}.csv'.format(FLAGS.suffix)) # logging logger = Logger(path=log_path, mode='test', use_wandb=False, flags=FLAGS) num_divides = [0, 2, 4, 8, 16] if not FLAGS.num_divide else list( FLAGS.num_divide) for weight_path in weight_paths: for num_divide in num_divides: # parse basename basename = os.path.basename(weight_path) ret_dict = parse_weight_basename(basename) # keyword args for test function # variable args kw_args = {} kw_args['arch'] = FLAGS.arch kw_args['weight'] = weight_path kw_args['dataset'] = FLAGS.dataset kw_args['dataroot'] = FLAGS.dataroot kw_args['batch_size'] = FLAGS.batch_size kw_args['attack'] = None kw_args['attack_eps'] = 0 kw_args['attack_norm'] = None kw_args['nb_its'] = 0 kw_args['step_size'] = None kw_args['num_divide'] = num_divide # default args kw_args['num_workers'] = 8 kw_args['normalize'] = True # run test out_dict = test(**kw_args) metric_dict = OrderedDict() # model metric_dict['arch'] = FLAGS.arch # at metric_dict['at'] = ret_dict['at'] metric_dict['at_norm'] = ret_dict['at_norm'] metric_dict['at_eps'] = ret_dict['at_eps'] # transform metric_dict['num_divide'] = num_divide # path metric_dict['path'] = weight_path metric_dict.update(out_dict) # log logger.log(metric_dict)
def test_multi(**kwargs): """ this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test. if there is exactly same file, the result becomes the mean of them. the results are saved as csv file. 'target_dir' should be like follow ~/target_dir/XXXX/weight_final_coverage_0.10.pth /weight_final_coverage_0.95.pth /weight_final_coverage_0.90.pth ... /YYYY/weight_final_coverage_0.10.pth /weight_final_coverage_0.95.pth /weight_final_coverage_0.90.pth ... """ # flags FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() # paths run_dir = '../scripts' target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth') weight_paths = sorted(glob.glob(target_path, recursive=True), key=lambda x: os.path.basename(x)) log_path = os.path.join(FLAGS.target_dir, 'test.csv') # logging logger = Logger(path=log_path, mode='test') for weight_path in weight_paths: # get coverage # name should be like, '~_coverage_{}.pth' basename = os.path.basename(weight_path) basename, ext = os.path.splitext(basename) coverage = float(basename.split('_')[-1]) # keyword args for test function # variable args kw_args = {} kw_args['weight'] = weight_path kw_args['dataset'] = FLAGS.dataset kw_args['dataroot'] = FLAGS.dataroot kw_args['coverage'] = coverage # default args kw_args['dim_features'] = 512 kw_args['dropout_prob'] = 0.3 kw_args['num_workers'] = 8 kw_args['batch_size'] = 128 kw_args['normalize'] = True kw_args['alpha'] = 0.5 # run test out_dict = test(**kw_args) metric_dict = OrderedDict() metric_dict['coverage'] = coverage metric_dict['path'] = weight_path metric_dict.update(out_dict) # log logger.log(metric_dict)
def train(**kwargs): """ this function executes standard training and adversarial training. """ # flags FLAGS = FlagHolder() FLAGS.initialize(**kwargs) FLAGS.summary() os.makedirs(FLAGS.log_dir, exist_ok=True) FLAGS.dump( path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix))) # dataset dataset_builder = DatasetBuilder(name=FLAGS.dataset, root_path=FLAGS.dataroot) train_dataset = dataset_builder(train=True, normalize=FLAGS.normalize) val_dataset = dataset_builder(train=False, normalize=FLAGS.normalize) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers, pin_memory=True) # model num_classes = dataset_builder.num_classes model = ModelBuilder(num_classes=num_classes, pretrained=False)[FLAGS.arch].cuda() if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # optimizer params = model.parameters() optimizer = torch.optim.SGD(params, lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=FLAGS.wd) # scheduler assert len(FLAGS.ms) == 0 if len(FLAGS.ms) == 1: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=FLAGS.ms[0], gamma=FLAGS.gamma) else: scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=sorted( list(FLAGS.ms)), gamma=FLAGS.gamma) # attacker if FLAGS.at and FLAGS.at_eps > 0: # get step_size step_size = get_step_size( FLAGS.at_eps, FLAGS.nb_its) if not FLAGS.step_size else FLAGS.step_size FLAGS._dict['step_size'] = step_size assert step_size >= 0 # create attacker attacker = AttackerBuilder()(method=FLAGS.at, norm=FLAGS.at_norm, eps=FLAGS.at_eps, **FLAGS._dict) # logger train_logger = Logger(path=os.path.join( FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)), mode='train', use_wandb=False, flags=FLAGS._dict) val_logger = Logger(path=os.path.join( FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)), mode='val', use_wandb=FLAGS.use_wandb, flags=FLAGS._dict) for ep in range(FLAGS.num_epochs): # pre epoch train_metric_dict = MetricDict() val_metric_dict = MetricDict() # train for i, (x, t) in enumerate(train_loader): x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # adversarial attack if FLAGS.at and FLAGS.at_eps > 0: model.eval() model.zero_grad() x = attacker(model, x.detach(), t.detach()) # forward model.train() model.zero_grad() out = model(x) # compute selective loss loss_dict = OrderedDict() # cross entropy ce_loss = torch.nn.CrossEntropyLoss()(out, t) #loss_dict['ce_loss'] = ce_loss.detach().cpu().item() # total loss loss = ce_loss loss_dict['loss'] = loss.detach().cpu().item() # backward optimizer.zero_grad() loss.backward() optimizer.step() train_metric_dict.update(loss_dict) # validation for i, (x, t) in enumerate(val_loader): x = x.to('cuda', non_blocking=True) t = t.to('cuda', non_blocking=True) # adversarial attack if FLAGS.at and FLAGS.at_eps > 0: model.eval() model.zero_grad() x = attacker(model, x.detach(), t.detach()) with torch.autograd.no_grad(): # forward model.eval() model.zero_grad() out = model(x) # compute selective loss loss_dict = OrderedDict() # cross entropy ce_loss = torch.nn.CrossEntropyLoss()(out, t) #loss_dict['ce_loss'] = ce_loss.detach().cpu().item() # total loss loss = ce_loss loss_dict['loss'] = loss.detach().cpu().item() # evaluation evaluator = Evaluator(out.detach(), t.detach(), selection_out=None) loss_dict.update(evaluator()) val_metric_dict.update(loss_dict) # post epoch # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train') print_metric_dict(ep, FLAGS.num_epochs, val_metric_dict.avg, mode='val') train_logger.log(train_metric_dict.avg, step=(ep + 1)) val_logger.log(val_metric_dict.avg, step=(ep + 1)) scheduler.step() # post training save_model(model, path=os.path.join(FLAGS.log_dir, 'weight_final{}.pth'.format(FLAGS.suffix)))