Esempio n. 1
0
def test_logger():
    log_path_root = '/home/gatheluck/Scratch/selectivenet/logs'
    log_basename = 'log_test_'+get_time_stamp('short')
    log_path = os.path.join(log_path_root, log_basename)

    logger = Logger(log_path)

    log_dict  = {'loss01':1.0, 'loss02':2.0}
    log_dict_ = {'loss01':1.0, 'loss03':3.0}
    logger.log(log_dict, 1)
    logger.log(log_dict, 2)
    logger.log(log_dict, 3)
    logger.log(log_dict_, 4)
Esempio n. 2
0
def test_multi_adv(**kwargs):
    """
    this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test.
    if there is exactly same file, the result becomes the mean of them.
    the results are saved as csv file.

    'target_dir' should be like follow
    (.pth file name should be "weight_final_cost_{}") 

    ~/target_dir/XXXX/weight_final_cost_0.10_pgd-linf_eps-0.pth
                     ...
                     /weight_final_cost_0.10_pgd-linf_eps-8.pth
                     /weight_final_cost_0.10_pgd-linf_eps-16.pth
                     ...
                /YYYY/weight_final_cost_0.10_pgd-linf_eps-0.pth
                     ...
                     /weight_final_cost_0.10_pgd-linf_eps-8.pth
                     /weight_final_cost_0.10_pgd-linf_eps-16.pth
                     ...
    """
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # specify target weight path
    run_dir = '../scripts'
    target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth')
    weight_paths = sorted(glob.glob(target_path, recursive=True),
                          key=lambda x: os.path.basename(x))

    if FLAGS.cost is not None:
        weight_paths = [
            wpath for wpath in weight_paths
            if 'cost-{cost:0.2f}'.format(cost=FLAGS.cost) in wpath
        ]
    if FLAGS.at is not None:
        weight_paths = [
            wpath for wpath in weight_paths if '{at}-{at_norm}'.format(
                at=FLAGS.at, at_norm=FLAGS.at_norm) in wpath
        ]

    log_path = os.path.join(FLAGS.target_dir,
                            'test{}.csv'.format(FLAGS.suffix))

    # logging
    logger = Logger(path=log_path, mode='test', use_wandb=False, flags=FLAGS)

    # get epses
    key = FLAGS.attack + '_' + FLAGS.attack_norm
    attack_epses = EPS[key]

    for weight_path in weight_paths:
        for attack_eps in attack_epses:

            # parse basename
            basename = os.path.basename(weight_path)
            ret_dict = parse_weight_basename(basename)

            # keyword args for test function
            # variable args
            kw_args = {}
            kw_args['weight'] = weight_path
            kw_args['dataset'] = FLAGS.dataset
            kw_args['dataroot'] = FLAGS.dataroot
            kw_args['binary_target_class'] = FLAGS.binary_target_class
            kw_args['cost'] = ret_dict['cost']
            kw_args['attack'] = FLAGS.attack
            kw_args['nb_its'] = FLAGS.nb_its
            kw_args['step_size'] = None
            kw_args['attack_eps'] = attack_eps
            kw_args['attack_norm'] = FLAGS.attack_norm

            # default args
            kw_args['dim_features'] = 512
            kw_args['dropout_prob'] = 0.3
            kw_args['num_workers'] = 8
            kw_args['batch_size'] = 128
            kw_args['normalize'] = True
            kw_args['alpha'] = 0.5

            # run test
            out_dict = test(**kw_args)

            metric_dict = OrderedDict()
            metric_dict['cost'] = ret_dict['cost']
            metric_dict['binary_target_class'] = FLAGS.binary_target_class
            # at
            metric_dict['at'] = ret_dict['at']
            metric_dict['at_norm'] = ret_dict['at_norm']
            metric_dict['at_eps'] = ret_dict['at_eps']
            # attack
            metric_dict['attack'] = FLAGS.attack
            metric_dict['attack_norm'] = FLAGS.attack_norm
            metric_dict['attack_eps'] = attack_eps
            # path
            metric_dict['path'] = weight_path
            metric_dict.update(out_dict)

            # log
            logger.log(metric_dict)
Esempio n. 3
0
def train(**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()
    os.makedirs(FLAGS.log_dir, exist_ok=True)
    FLAGS.dump(
        path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix)))

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    train_dataset = dataset_builder(
        train=True,
        normalize=FLAGS.normalize,
        binary_classification_target=FLAGS.binary_target_class)
    val_dataset = dataset_builder(
        train=False,
        normalize=FLAGS.normalize,
        binary_classification_target=FLAGS.binary_target_class)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=FLAGS.batch_size,
                                               shuffle=True,
                                               num_workers=FLAGS.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=FLAGS.batch_size,
                                             shuffle=False,
                                             num_workers=FLAGS.num_workers,
                                             pin_memory=True)

    # model
    features = vgg16_variant(dataset_builder.input_size,
                             FLAGS.dropout_prob).cuda()
    model = DeepLinearSvmWithRejector(features,
                                      FLAGS.dim_features,
                                      num_classes=1).cuda()
    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # optimizer
    params = model.parameters()
    optimizer = torch.optim.SGD(params,
                                lr=FLAGS.lr,
                                momentum=FLAGS.momentum,
                                weight_decay=FLAGS.wd)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=25,
                                                gamma=0.5)

    # loss
    MHBRLoss = MaxHingeLossBinaryWithRejection(FLAGS.cost)

    # attacker
    if FLAGS.at and FLAGS.at_eps > 0:
        # get step_size
        if not FLAGS.step_size:
            FLAGS.step_size = get_step_size(FLAGS.at_eps, FLAGS.nb_its)
        assert FLAGS.step_size >= 0

        # create attacker
        if FLAGS.at == 'pgd':
            attacker = PGDAttackVariant(
                FLAGS.nb_its,
                FLAGS.at_eps,
                FLAGS.step_size,
                dataset=FLAGS.dataset,
                cost=FLAGS.cost,
                norm=FLAGS.at_norm,
                num_classes=dataset_builder.num_classes,
                is_binary_classification=True)
        else:
            raise NotImplementedError('invalid at method.')

    # logger
    train_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)),
                          mode='train',
                          use_wandb=False,
                          flags=FLAGS._dict)
    val_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)),
                        mode='val',
                        use_wandb=FLAGS.use_wandb,
                        flags=FLAGS._dict)

    for ep in range(FLAGS.num_epochs):
        # pre epoch
        train_metric_dict = MetricDict()
        val_metric_dict = MetricDict()

        # train
        for i, (x, t) in enumerate(train_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            # forward
            model.train()
            model.zero_grad()
            out_class, out_reject = model(x)

            # compute selective loss
            loss_dict = OrderedDict()
            # loss dict includes, 'A mean' / 'B mean'
            maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t)
            loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item()

            # regularization_loss = 0.5*WeightPenalty()(model.classifier)
            # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item()

            # total loss
            loss = maxhinge_loss  #+ regularization_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_metric_dict.update(loss_dict)

        # validation
        for i, (x, t) in enumerate(val_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            with torch.autograd.no_grad():
                # forward
                model.eval()
                model.zero_grad()
                out_class, out_reject = model(x)

                # compute selective loss
                loss_dict = OrderedDict()
                # loss dict includes, 'A mean' / 'B mean'
                maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t)
                loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item(
                )

                # regularization_loss = 0.5*WeightPenalty()(model.classifier)
                # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item()

                # total loss
                loss = maxhinge_loss  #+ regularization_loss
                loss_dict['loss'] = loss.detach().cpu().item()

                # evaluation
                evaluator = Evaluator(out_class.detach().view(-1),
                                      t.detach().view(-1),
                                      out_reject.detach().view(-1))
                loss_dict.update(evaluator())

                val_metric_dict.update(loss_dict)

        # post epoch
        # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train')
        print_metric_dict(ep,
                          FLAGS.num_epochs,
                          val_metric_dict.avg,
                          mode='val')

        train_logger.log(train_metric_dict.avg, step=(ep + 1))
        val_logger.log(val_metric_dict.avg, step=(ep + 1))

        scheduler.step()

    # post training
    save_model(model,
               path=os.path.join(FLAGS.log_dir,
                                 'weight_final{}.pth'.format(FLAGS.suffix)))
Esempio n. 4
0
def train(**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()
    FLAGS.dump(
        path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix)))

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    train_dataset = dataset_builder(train=True, normalize=FLAGS.normalize)
    val_dataset = dataset_builder(train=False, normalize=FLAGS.normalize)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=FLAGS.batch_size,
                                               shuffle=True,
                                               num_workers=FLAGS.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=FLAGS.batch_size,
                                             shuffle=False,
                                             num_workers=FLAGS.num_workers,
                                             pin_memory=True)

    # model
    features = vgg16_variant(dataset_builder.input_size,
                             FLAGS.dropout_prob).cuda()
    model = SelectiveNet(features, FLAGS.dim_features,
                         dataset_builder.num_classes).cuda()
    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # optimizer
    params = model.parameters()
    optimizer = torch.optim.SGD(params,
                                lr=FLAGS.lr,
                                momentum=FLAGS.momentum,
                                weight_decay=FLAGS.wd)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=25,
                                                gamma=0.5)

    # loss
    base_loss = torch.nn.CrossEntropyLoss(reduction='none')
    SelectiveCELoss = SelectiveLoss(base_loss, coverage=FLAGS.coverage)

    # logger
    train_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)),
                          mode='train')
    val_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)),
                        mode='val')

    for ep in range(FLAGS.num_epochs):
        # pre epoch
        train_metric_dict = MetricDict()
        val_metric_dict = MetricDict()

        # train
        for i, (x, t) in enumerate(train_loader):
            model.train()
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # forward
            out_class, out_select, out_aux = model(x)

            # compute selective loss
            loss_dict = OrderedDict()
            # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty'
            selective_loss, loss_dict = SelectiveCELoss(
                out_class, out_select, t)
            selective_loss *= FLAGS.alpha
            loss_dict['selective_loss'] = selective_loss.detach().cpu().item()
            # compute standard cross entropy loss
            ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t)
            ce_loss *= (1.0 - FLAGS.alpha)
            loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

            # total loss
            loss = selective_loss + ce_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_metric_dict.update(loss_dict)

        # validation
        with torch.autograd.no_grad():
            for i, (x, t) in enumerate(val_loader):
                model.eval()
                x = x.to('cuda', non_blocking=True)
                t = t.to('cuda', non_blocking=True)

                # forward
                out_class, out_select, out_aux = model(x)

                # compute selective loss
                loss_dict = OrderedDict()
                # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty'
                selective_loss, loss_dict = SelectiveCELoss(
                    out_class, out_select, t)
                selective_loss *= FLAGS.alpha
                loss_dict['selective_loss'] = selective_loss.detach().cpu(
                ).item()
                # compute standard cross entropy loss
                ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t)
                ce_loss *= (1.0 - FLAGS.alpha)
                loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

                # total loss
                loss = selective_loss + ce_loss
                loss_dict['loss'] = loss.detach().cpu().item()

                # evaluation
                evaluator = Evaluator(out_class.detach(), t.detach(),
                                      out_select.detach())
                loss_dict.update(evaluator())

                val_metric_dict.update(loss_dict)

        # post epoch
        # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train')
        print_metric_dict(ep,
                          FLAGS.num_epochs,
                          val_metric_dict.avg,
                          mode='val')

        train_logger.log(train_metric_dict.avg, step=(ep + 1))
        val_logger.log(val_metric_dict.avg, step=(ep + 1))

        scheduler.step()

    # post training
    save_model(model,
               path=os.path.join(FLAGS.log_dir,
                                 'weight_final{}.pth'.format(FLAGS.suffix)))
Esempio n. 5
0
def test_fourier(**kwargs):
    """
    this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test.
    if there is exactly same file, the result becomes the mean of them.
    the results are saved as csv file.

    'target_dir' should be like follow
    (.pth file name should be "weight_final_coverage_{}") 

    ~/target_dir/XXXX/weight_final_pgd-linf_eps-0.pth
                     ...
                     /weight_final_pgd-linf_eps-8.pth
                     /weight_final_pgd-linf_eps-16.pth
                     ...
                /YYYY/weight_final_pgd-linf_eps-0.pth
                     ...
                     /weight_final_pgd-linf_eps-8.pth
                     /weight_final_pgd-linf_eps-16.pth
                     ...
    """
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # paths
    run_dir = '../scripts'
    if os.path.splitext(FLAGS.target_dir)[-1] != '.pth':
        target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth')
        weight_paths = sorted(glob.glob(target_path, recursive=True),
                              key=lambda x: os.path.basename(x))
        log_path = os.path.join(FLAGS.target_dir,
                                'test{}.csv'.format(FLAGS.suffix))
    else:
        weight_paths = [FLAGS.target_dir]
        log_path = os.path.join(os.path.dirname(FLAGS.target_dir),
                                'test{}.csv'.format(FLAGS.suffix))

    # logging
    logger = Logger(path=log_path, mode='test', use_wandb=False, flags=FLAGS)

    for weight_path in weight_paths:
        for index_h in range(-FLAGS.fn_max_index_h, FLAGS.fn_max_index_h + 1):
            for index_w in range(-FLAGS.fn_max_index_w,
                                 FLAGS.fn_max_index_w + 1):
                # continue when indices are 0
                if index_h == 0 or index_w == 0: continue

                # parse basename
                basename = os.path.basename(weight_path)
                ret_dict = parse_weight_basename(basename)

                # keyword args for test function
                # variable args
                kw_args = {}
                kw_args['arch'] = FLAGS.arch
                kw_args['weight'] = weight_path
                kw_args['dataset'] = FLAGS.dataset
                kw_args['dataroot'] = FLAGS.dataroot
                kw_args['batch_size'] = FLAGS.batch_size
                kw_args['fn_eps'] = FLAGS.fn_eps
                kw_args['fn_index_h'] = index_h
                kw_args['fn_index_w'] = index_w

                # default args
                kw_args['num_workers'] = 8
                kw_args['normalize'] = True

                # run test
                out_dict = test(**kw_args)

                metric_dict = OrderedDict()
                # model
                metric_dict['arch'] = FLAGS.arch
                # Fourier noise
                metric_dict['fn_eps'] = FLAGS.fn_eps
                metric_dict['fn_index_h'] = index_h
                metric_dict['fn_index_w'] = index_w
                # at
                metric_dict['at'] = ret_dict['at']
                metric_dict['at_norm'] = ret_dict['at_norm']
                metric_dict['at_eps'] = ret_dict['at_eps']
                # path
                metric_dict['path'] = weight_path
                metric_dict.update(out_dict)

                # log
                logger.log(metric_dict)
Esempio n. 6
0
def test_adv(**kwargs):
    """
    this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test.
    if there is exactly same file, the result becomes the mean of them.
    the results are saved as csv file.

    'target_dir' should be like follow
    (.pth file name should be "weight_final_coverage_{}") 

    ~/target_dir/XXXX/weight_final_pgd-linf_eps-0.pth
                     ...
                     /weight_final_pgd-linf_eps-8.pth
                     /weight_final_pgd-linf_eps-16.pth
                     ...
                /YYYY/weight_final_pgd-linf_eps-0.pth
                     ...
                     /weight_final_pgd-linf_eps-8.pth
                     /weight_final_pgd-linf_eps-16.pth
                     ...
    """
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # paths
    run_dir = '../scripts'
    if os.path.splitext(FLAGS.target_dir)[-1] != '.pth':
        target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth')
        weight_paths = sorted(glob.glob(target_path, recursive=True),
                              key=lambda x: os.path.basename(x))
        log_path = os.path.join(FLAGS.target_dir,
                                'test{}.csv'.format(FLAGS.suffix))
    else:
        weight_paths = list(FLAGS.target_dir)
        log_path = os.path.join(os.path.dirname(FLAGS.target_dir),
                                'test{}.csv'.format(FLAGS.suffix))

    # logging
    logger = Logger(path=log_path, mode='test', use_wandb=False, flags=FLAGS)

    num_divides = [0, 2, 4, 8, 16] if not FLAGS.num_divide else list(
        FLAGS.num_divide)

    for weight_path in weight_paths:
        for num_divide in num_divides:

            # parse basename
            basename = os.path.basename(weight_path)
            ret_dict = parse_weight_basename(basename)

            # keyword args for test function
            # variable args
            kw_args = {}
            kw_args['arch'] = FLAGS.arch
            kw_args['weight'] = weight_path
            kw_args['dataset'] = FLAGS.dataset
            kw_args['dataroot'] = FLAGS.dataroot
            kw_args['batch_size'] = FLAGS.batch_size
            kw_args['attack'] = None
            kw_args['attack_eps'] = 0
            kw_args['attack_norm'] = None
            kw_args['nb_its'] = 0
            kw_args['step_size'] = None
            kw_args['num_divide'] = num_divide

            # default args
            kw_args['num_workers'] = 8
            kw_args['normalize'] = True

            # run test
            out_dict = test(**kw_args)

            metric_dict = OrderedDict()
            # model
            metric_dict['arch'] = FLAGS.arch
            # at
            metric_dict['at'] = ret_dict['at']
            metric_dict['at_norm'] = ret_dict['at_norm']
            metric_dict['at_eps'] = ret_dict['at_eps']
            # transform
            metric_dict['num_divide'] = num_divide
            # path
            metric_dict['path'] = weight_path
            metric_dict.update(out_dict)

            # log
            logger.log(metric_dict)
Esempio n. 7
0
def test_multi(**kwargs):
    """
    this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test.
    if there is exactly same file, the result becomes the mean of them.
    the results are saved as csv file.

    'target_dir' should be like follow

    ~/target_dir/XXXX/weight_final_coverage_0.10.pth
                     /weight_final_coverage_0.95.pth
                     /weight_final_coverage_0.90.pth
                     ...
                /YYYY/weight_final_coverage_0.10.pth
                     /weight_final_coverage_0.95.pth
                     /weight_final_coverage_0.90.pth
                     ...
    """
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # paths
    run_dir = '../scripts'
    target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth')
    weight_paths = sorted(glob.glob(target_path, recursive=True),
                          key=lambda x: os.path.basename(x))
    log_path = os.path.join(FLAGS.target_dir, 'test.csv')

    # logging
    logger = Logger(path=log_path, mode='test')

    for weight_path in weight_paths:
        # get coverage
        # name should be like, '~_coverage_{}.pth'
        basename = os.path.basename(weight_path)
        basename, ext = os.path.splitext(basename)
        coverage = float(basename.split('_')[-1])

        # keyword args for test function
        # variable args
        kw_args = {}
        kw_args['weight'] = weight_path
        kw_args['dataset'] = FLAGS.dataset
        kw_args['dataroot'] = FLAGS.dataroot
        kw_args['coverage'] = coverage
        # default args
        kw_args['dim_features'] = 512
        kw_args['dropout_prob'] = 0.3
        kw_args['num_workers'] = 8
        kw_args['batch_size'] = 128
        kw_args['normalize'] = True
        kw_args['alpha'] = 0.5

        # run test
        out_dict = test(**kw_args)

        metric_dict = OrderedDict()
        metric_dict['coverage'] = coverage
        metric_dict['path'] = weight_path
        metric_dict.update(out_dict)

        # log
        logger.log(metric_dict)
Esempio n. 8
0
def train(**kwargs):
    """
    this function executes standard training and adversarial training. 
    """
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()
    os.makedirs(FLAGS.log_dir, exist_ok=True)
    FLAGS.dump(
        path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix)))

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    train_dataset = dataset_builder(train=True, normalize=FLAGS.normalize)
    val_dataset = dataset_builder(train=False, normalize=FLAGS.normalize)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=FLAGS.batch_size,
                                               shuffle=True,
                                               num_workers=FLAGS.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=FLAGS.batch_size,
                                             shuffle=False,
                                             num_workers=FLAGS.num_workers,
                                             pin_memory=True)

    # model
    num_classes = dataset_builder.num_classes
    model = ModelBuilder(num_classes=num_classes,
                         pretrained=False)[FLAGS.arch].cuda()
    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # optimizer
    params = model.parameters()
    optimizer = torch.optim.SGD(params,
                                lr=FLAGS.lr,
                                momentum=FLAGS.momentum,
                                weight_decay=FLAGS.wd)

    # scheduler
    assert len(FLAGS.ms) == 0
    if len(FLAGS.ms) == 1:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=FLAGS.ms[0],
                                                    gamma=FLAGS.gamma)
    else:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones=sorted(
                                                             list(FLAGS.ms)),
                                                         gamma=FLAGS.gamma)

    # attacker
    if FLAGS.at and FLAGS.at_eps > 0:
        # get step_size
        step_size = get_step_size(
            FLAGS.at_eps,
            FLAGS.nb_its) if not FLAGS.step_size else FLAGS.step_size
        FLAGS._dict['step_size'] = step_size
        assert step_size >= 0

        # create attacker
        attacker = AttackerBuilder()(method=FLAGS.at,
                                     norm=FLAGS.at_norm,
                                     eps=FLAGS.at_eps,
                                     **FLAGS._dict)

    # logger
    train_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)),
                          mode='train',
                          use_wandb=False,
                          flags=FLAGS._dict)
    val_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)),
                        mode='val',
                        use_wandb=FLAGS.use_wandb,
                        flags=FLAGS._dict)

    for ep in range(FLAGS.num_epochs):
        # pre epoch
        train_metric_dict = MetricDict()
        val_metric_dict = MetricDict()

        # train
        for i, (x, t) in enumerate(train_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            # forward
            model.train()
            model.zero_grad()
            out = model(x)

            # compute selective loss
            loss_dict = OrderedDict()
            # cross entropy
            ce_loss = torch.nn.CrossEntropyLoss()(out, t)
            #loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

            # total loss
            loss = ce_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_metric_dict.update(loss_dict)

        # validation
        for i, (x, t) in enumerate(val_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            with torch.autograd.no_grad():
                # forward
                model.eval()
                model.zero_grad()
                out = model(x)

                # compute selective loss
                loss_dict = OrderedDict()
                # cross entropy
                ce_loss = torch.nn.CrossEntropyLoss()(out, t)
                #loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

                # total loss
                loss = ce_loss
                loss_dict['loss'] = loss.detach().cpu().item()

                # evaluation
                evaluator = Evaluator(out.detach(),
                                      t.detach(),
                                      selection_out=None)
                loss_dict.update(evaluator())

                val_metric_dict.update(loss_dict)

        # post epoch
        # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train')
        print_metric_dict(ep,
                          FLAGS.num_epochs,
                          val_metric_dict.avg,
                          mode='val')

        train_logger.log(train_metric_dict.avg, step=(ep + 1))
        val_logger.log(val_metric_dict.avg, step=(ep + 1))

        scheduler.step()

    # post training
    save_model(model,
               path=os.path.join(FLAGS.log_dir,
                                 'weight_final{}.pth'.format(FLAGS.suffix)))