Example #1
0
def train_multi(**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # create script output dir
    script_dir = os.path.join(FLAGS.script_root, FLAGS.ex_id)
    os.makedirs(script_dir, exist_ok=True)

    costs = [1.00, 0.95, 0.90, 0.85, 0.80, 0.75, 0.70] if not FLAGS.cost else [FLAGS.cost]
    ats = ['pgd']
    at_norms = ['linf', 'l2']

    EPS = {
        'pgd-linf': [0, 1, 2, 4, 8, 16],
        'pgd-l2':   [0, 40, 80, 160, 320, 640],
    }

    for cost in sorted(costs):
        for at in ats:
            for at_norm in at_norms:
                key = at+'-'+at_norm
                for at_eps in EPS[key]:

                    suffix = '_cost-{cost:0.2f}_{at}-{at_norm}_eps-{at_eps:d}'.format(
                        cost=cost, at=at, at_norm=at_norm, at_eps=at_eps) 

                    log_dir = os.path.join(FLAGS.log_dir, FLAGS.ex_id)
                    os.makedirs(log_dir, exist_ok=True)

                    cmd = 'python train.py \
                          -d {dataset} \
                          --dataroot {dataroot} \
                          --num_epochs {num_epochs} \
                          --batch_size {batch_size} \
                          --cost {cost} \
                          --at {at} \
                          --nb_its {nb_its} \
                          --at_eps {at_eps} \
                          --at_norm {at_norm} \
                          -s {suffix} \
                          -l {log_dir}'.format(
                            dataset=FLAGS.dataset,
                            dataroot=FLAGS.dataroot,
                            num_epochs=FLAGS.num_epochs,
                            batch_size=FLAGS.batch_size,
                            cost=cost,
                            at=at,
                            nb_its=FLAGS.nb_its,
                            at_eps=at_eps,
                            at_norm=at_norm,
                            suffix=suffix,
                            log_dir=log_dir)

                    script_basename = suffix.lstrip('_')+'.sh'
                    script_path = os.path.join(script_dir, script_basename)
                    generate_script(cmd, script_path, FLAGS.run_dir, FLAGS.abci_log_dir, FLAGS.ex_id, FLAGS.user, FLAGS.env)
Example #2
0
def stats(**kwargs):
    """
    compute statistics of spesific model and spesific attack
    """
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    #FLAGS.summary()

    # load csv file and plot
    df = pd.read_csv(FLAGS.target_path)

    # conditioning data frame
    df = df[df['at'] == FLAGS.at]  # pgd
    df = df[df['at_norm'] == FLAGS.at_norm]  # linf /l2
    df = df[df['attack'] == FLAGS.attack]  # pgd
    df = df[df['attack_norm'] == FLAGS.attack_norm]  # linf /l2
    df = df[df['at_eps'] == FLAGS.at_eps]
    df = df[df['attack_eps'] == FLAGS.attack_eps]
    df = df[df['cost'] == FLAGS.cost]

    df_dict = df[["error", "rejection rate",
                  "rejection precision"]].describe().to_dict()

    return df_dict
Example #3
0
def test_abci**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # create script output dir
    script_dir = os.path.join(FLAGS.script_root, FLAGS.ex_id)
    os.makedirs(script_dir, exist_ok=True)


    # loop for attack method
    for attack_method in AttackerBuilder.ATTACK_CONFIG.keys():
        for attack_norm in AttackerBuilder.ATTACK_CONFIG[attack_method].norms:

            attack_method_with_norm = attack_method + '-' + attack_norm    
            for attack_eps in AT_PARAMS[attack_method_with_norm][FLAGS.dataset]:

                suffix = '_{attack_method}-{attack_norm}_eps-{attack_eps:d}'.format(
                            attack_method=attack_method, attack_norm=attack_norm, attack_eps=attack_eps) 

                log_dir = os.path.join(FLAGS.log_dir, FLAGS.ex_id)
                    os.makedirs(log_dir, exist_ok=True)

                    cmd = 'python experiments/test_abci.py \
                          -d {dataset} \
                          --dataroot {dataroot} \
                          --num_epochs {num_epochs} \
                          --batch_size {batch_size} \
                          --cost {cost} \
                          --at {at} \
                          --nb_its {nb_its} \
                          --at_eps {at_eps} \
                          --at_norm {at_norm} \
                          -s {suffix} \
                          -l {log_dir}'.format(
                            dataset=FLAGS.dataset,
                            dataroot=FLAGS.dataroot,
                            num_epochs=FLAGS.num_epochs,
                            batch_size=FLAGS.batch_size,
                            cost=cost,
                            at=at,
                            nb_its=FLAGS.nb_its,
                            at_eps=at_eps,
                            at_norm=at_norm,
                            suffix=suffix,
                            log_dir=log_dir)

                    script_basename = suffix.lstrip('_')+'.sh'
                    script_path = os.path.join(script_dir, script_basename)
                    generate_script(cmd, script_path, FLAGS.run_dir, FLAGS.abci_log_dir, FLAGS.ex_id, FLAGS.user, FLAGS.env)
Example #4
0
def stats_multi(**kwargs):
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    at = 'pgd'
    at_norms = ['linf']
    at_epses = {'linf': [0, 4, 8]}

    attack = FLAGS.attack
    attack_norm = FLAGS.attack_norm
    attack_epses = {'linf': [0, 4, 8, 16], 'l2': [0, 80, 160]}

    for at_norm in at_norms:
        for at_eps in at_epses[at_norm]:
            for attack_eps in attack_epses[attack_norm]:

                kw_args = {}
                kw_args['target_path'] = FLAGS.target_path
                kw_args['cost'] = FLAGS.cost
                kw_args['at'] = at
                kw_args['attack'] = attack
                kw_args['at_norm'] = at_norm
                kw_args['attack_norm'] = attack_norm
                kw_args['at_eps'] = at_eps
                kw_args['attack_eps'] = attack_eps

                df_dict = stats(**kw_args)

                template = ' | '.join(
                            ['AT:{at}-{at_norm:<5s}: {at_eps:>3d}',
                             'Att: {attack}-{attack_norm:<5s}: {attack_eps:>3d}',
                             'Err: {err_mean:>4.1f}+-{err_std:>4.2f}',
                             'Rej: {rjc_mean:>4.1f}+-{rjc_std:>4.2f}',
                             'PR:  {pr_mean:>4.1f} +-{pr_std:>4.2f}']).format(
                                at       =at,     at_norm     =at_norm,     at_eps     =at_eps, \
                                attack   =attack, attack_norm =attack_norm, attack_eps =attack_eps, \
                                err_mean =df_dict['error']['mean']*100,               err_std =df_dict['error']['std']*100, \
                                rjc_mean =df_dict['rejection rate']['mean']*100,      rjc_std =df_dict['rejection rate']['std']*100, \
                                pr_mean  =df_dict['rejection precision']['mean']*100, pr_std  =df_dict['rejection precision']['std']*100)

                print(template)
Example #5
0
def plot_multi(**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    run_dir = '../scripts'
    target_path = os.path.join(FLAGS.target_dir, '**/*.csv')
    weight_paths = sorted(glob.glob(target_path, recursive=True),
                          key=lambda x: os.path.basename(x))

    for weight_path in weight_paths:
        # skip 'test*.csv'
        if os.path.basename(weight_path) == 'test*.csv': continue

        log_dir = os.path.join(os.path.dirname(weight_path), 'plot')
        os.makedirs(log_dir, exist_ok=True)

        basename = os.path.basename(weight_path)
        basename, _ = os.path.splitext(basename)
        log_path = os.path.join(log_dir, basename) + '.png'

        cmd = 'python plot.py \
            -t {target_dir} \
            -x {x} \
            -s \
            -l {log_path}'.format(target_dir=weight_path,
                                  x=FLAGS.x,
                                  log_path=log_path)

        # add y
        if FLAGS.y != '':
            cmd += ' -y {y}'.format(y=FLAGS.y)

        # add flag command
        if FLAGS.plot_all:
            cmd += ' --plot_all'

        subprocess.run(cmd.split(), cwd=run_dir)
Example #6
0
def plot(**kwargs):
    """
    reference
    - https://qiita.com/ryo111/items/bf24c8cf508ad90cfe2e (how to make make block)
    - https://heavywatal.github.io/python/matplotlib.html
    """
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    if (FLAGS.plot_all is True) and (FLAGS.plot_test is True):
        raise ValueError(
            'invalid option. either "plot_all" or "plot_test" should be True.')

    # load csv file and plot
    df = pd.read_csv(FLAGS.target_path)

    # plot all variable. this is basically used for visualize training log.
    if FLAGS.plot_all:
        # ignore some columns
        ignore_columns = ['Unnamed: 0', 'time stamp', 'step', FLAGS.x]
        column_names = [
            column for column in df.columns if column not in ignore_columns
        ]

        # create figure
        fig = plt.figure(figsize=(4 * len(column_names), 3))

        for i, column_name in enumerate(column_names):
            ax = fig.add_subplot(1, len(column_names), i + 1)
            sns.lineplot(x=FLAGS.x, y=column_name, ci="sd", data=df)

        plt.tight_layout()

    # plot test.csv file.
    elif FLAGS.plot_test:
        # ignore some columns
        ignore_columns = [
            'Unnamed: 0', 'time stamp', 'arch', 'path', 'loss', FLAGS.x
        ]
        column_names = [
            column for column in df.columns if column not in ignore_columns
        ]

        # create figure
        fig = plt.figure(figsize=(4 * len(column_names), 3))

        for i, column_name in enumerate(column_names):
            ax = fig.add_subplot(1, len(column_names), i + 1)
            sns.lineplot(x=FLAGS.x, y=column_name, ci="sd", data=df)

        plt.tight_layout()

    # plot test.csv file.
    elif FLAGS.plot_test_trans:
        # conditioning data frame
        df = df[df['at'] == FLAGS.at]  # pgd
        df = df[df['at_norm'] == FLAGS.at_norm]  # linf /l2
        # ignore some columns
        ignore_columns = [
            'Unnamed: 0', 'time stamp', 'arch', 'path', 'loss', FLAGS.x
        ]
        column_names = [
            column for column in df.columns if column not in ignore_columns
        ]

        # create figure
        fig = plt.figure(figsize=(4 * len(column_names), 3))

        for i, column_name in enumerate(column_names):
            ax = fig.add_subplot(1, len(column_names), i + 1)
            sns.lineplot(x=FLAGS.x, y=column_name, ci="sd", data=df)

        plt.tight_layout()

    # plot test_adv.csv file
    # when x = attack_eps, at_eps is fixed vise virsa
    elif FLAGS.plot_test_adv:
        # conditioning data frame
        df = df[df['at'] == FLAGS.at]  # pgd
        df = df[df['at_norm'] == FLAGS.at_norm]  # linf /l2
        df = df[df['attack'] == FLAGS.attack]  # pgd
        df = df[df['attack_norm'] == FLAGS.attack_norm]  # linf /l2

        if not FLAGS.cost and (FLAGS.at_eps and FLAGS.attack_eps):
            df = df[df['at_eps'] == FLAGS.at_eps]
            df = df[df['attack_eps'] == FLAGS.attack_eps]

        elif not FLAGS.at_eps and (FLAGS.attack_eps and FLAGS.cost):
            df = df[df['attack_eps'] == FLAGS.attack_eps]
            df = df[df['cost'] == FLAGS.cost]

        elif not FLAGS.attack_eps and (FLAGS.cost and FLAGS.at_eps):
            df = df[df['cost'] == FLAGS.cost]
            df = df[df['at_eps'] == FLAGS.at_eps]
        else:
            raise ValueError

        # ignore some columns
        ignore_columns = [
            'Unnamed: 0', 'time stamp', 'path', 'loss', 'maxhinge_loss', 'at',
            'at_norm', 'at_eps', 'attack_trg_loss', 'attack', 'attack_norm',
            'attack_eps', 'cost'
        ]
        column_names = [
            column for column in df.columns if column not in ignore_columns
        ]

        # create figure
        fig = plt.figure(figsize=(4 * len(column_names), 3))

        for i, column_name in enumerate(column_names):
            print(column_name)
            ax = fig.add_subplot(1, len(column_names), i + 1)
            sns.lineplot(x=FLAGS.x, y=column_name, ci="sd", data=df)

        plt.tight_layout()

    # plot specified variable
    else:
        if FLAGS.y == '':
            raise ValueError('please specify "y"')
        fig = plt.figure()
        ax = fig.subplots()
        sns.lineplot(x=FLAGS.x, y=FLAGS.y, ci="sd", data=df)

    # show and save
    if FLAGS.save:
        plt.close()
        if FLAGS.log_path == '':
            raise ValueError('please specify "log_path"')
        os.makedirs(os.path.dirname(FLAGS.log_path), exist_ok=True)
        fig.savefig(FLAGS.log_path)
    else:
        plt.show()
Example #7
0
def train_multi(**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # create script output dir
    script_dir = os.path.join(FLAGS.script_root, FLAGS.ex_id)
    os.makedirs(script_dir, exist_ok=True)

    costs = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50] if not FLAGS.cost else [FLAGS.cost]
    ats = ['pgd']
    at_norms = ['linf', 'l2'] if not FLAGS.at_norm else [FLAGS.at_norm]

    EPS = {
        'pgd-linf': [0., 1., 2., 4., 8., 16.],
        'pgd-l2':   [0., 40., 80., 160., 320., 640.],
    }

    for cost in sorted(costs):
        for at in ats:
            for at_norm in at_norms:
                key = at+'-'+at_norm
                epses = EPS[key] if FLAGS.at_eps is None else [FLAGS.at_eps]

                for at_eps in epses:

                    suffix = '_cost-{cost:0.2f}_{at}-{at_norm}_eps-{at_eps:0.1f}'.format(
                        cost=cost, at=at, at_norm=at_norm, at_eps=at_eps) 

                    log_dir = os.path.join(FLAGS.log_dir, FLAGS.ex_id)
                    os.makedirs(log_dir, exist_ok=True)

                    cmd = 'python train_binary.py \
                          -d {dataset} \
                          --dataroot {dataroot} \
                          --num_epochs {num_epochs} \
                          --batch_size {batch_size} \
                          --cost {cost} \
                          --at {at} \
                          --nb_its {nb_its} \
                          --at_eps {at_eps} \
                          --at_norm {at_norm} \
                          -s {suffix} \
                          -l {log_dir} \
                          {use_wandb} \
                          --wandb_project {wandb_project} \
                          --wandb_name {wandb_name}'.format(
                            dataset=FLAGS.dataset,
                            dataroot=FLAGS.dataroot,
                            num_epochs=FLAGS.num_epochs,
                            batch_size=FLAGS.batch_size,
                            cost=cost,
                            at=at,
                            nb_its=FLAGS.nb_its,
                            at_eps=at_eps,
                            at_norm=at_norm,
                            suffix=suffix,
                            log_dir=log_dir,
                            use_wandb='--use_wandb' if FLAGS.use_wandb else '',
                            wandb_project=FLAGS.wandb_project,
                            wandb_name=suffix.lstrip('_'))

                    script_basename = suffix.lstrip('_')+'.sh'
                    script_path = os.path.join(script_dir, script_basename)
                    generate_script(cmd, script_path, FLAGS.run_dir, FLAGS.abci_log_dir, FLAGS.ex_id, FLAGS.user, FLAGS.env, hour=FLAGS.hour, wandb_api_key=FLAGS.wandb_api_key)
Example #8
0
def test(**kwargs):
    """
    test model on specific cost and specific adversarial perturbation.
    """
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    assert FLAGS.nb_its > 0
    assert FLAGS.attack_eps >= 0

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    test_dataset = dataset_builder(
        train=False,
        normalize=FLAGS.normalize,
        binary_classification_target=FLAGS.binary_target_class)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=FLAGS.batch_size,
                                              shuffle=False,
                                              num_workers=FLAGS.num_workers,
                                              pin_memory=True)

    # model
    features = vgg16_variant(dataset_builder.input_size,
                             FLAGS.dropout_prob).cuda()
    if FLAGS.binary_target_class is None:
        model = DeepLinearSvmWithRejector(features, FLAGS.dim_features,
                                          dataset_builder.num_classes).cuda()
    else:
        model = DeepLinearSvmWithRejector(features, FLAGS.dim_features,
                                          1).cuda()
    load_model(model, FLAGS.weight)

    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # loss
    if FLAGS.binary_target_class is None:
        criterion = MaxHingeLossWithRejection(FLAGS.cost)
    else:
        criterion = MaxHingeLossBinaryWithRejection(FLAGS.cost)

    # adversarial attack
    if FLAGS.attack:
        # get step_size
        if not FLAGS.step_size:
            FLAGS.step_size = get_step_size(FLAGS.attack_eps, FLAGS.nb_its)
        assert FLAGS.step_size >= 0

        # create attacker
        if FLAGS.attack == 'pgd':
            if FLAGS.binary_target_class is None:
                attacker = PGDAttackVariant(
                    FLAGS.nb_its,
                    FLAGS.attack_eps,
                    FLAGS.step_size,
                    dataset=FLAGS.dataset,
                    cost=FLAGS.cost,
                    norm=FLAGS.attack_norm,
                    num_classes=dataset_builder.num_classes,
                    is_binary_classification=False)
            else:
                attacker = PGDAttackVariant(
                    FLAGS.nb_its,
                    FLAGS.attack_eps,
                    FLAGS.step_size,
                    dataset=FLAGS.dataset,
                    cost=FLAGS.cost,
                    norm=FLAGS.attack_norm,
                    num_classes=dataset_builder.num_classes,
                    is_binary_classification=True)

        else:
            raise NotImplementedError('invalid attack method.')

    # pre epoch
    test_metric_dict = MetricDict()

    # test
    for i, (x, t) in enumerate(test_loader):
        model.eval()
        x = x.to('cuda', non_blocking=True)
        t = t.to('cuda', non_blocking=True)
        loss_dict = OrderedDict()

        # adversarial samples
        if FLAGS.attack and FLAGS.attack_eps > 0:
            # create adversarial sampels
            model.zero_grad()
            x = attacker(model, x.detach(), t.detach())

        with torch.autograd.no_grad():
            model.zero_grad()
            # forward
            out_class, out_reject = model(x)

            # compute selective loss
            maxhinge_loss, loss_dict = criterion(out_class, out_reject, t)
            loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item()

            # compute standard cross entropy loss
            # regularization_loss = WeightPenalty()(model.classifier)
            # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item()

            # total loss
            loss = maxhinge_loss  #+ regularization_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # evaluation
            if FLAGS.binary_target_class is None:
                evaluator = Evaluator(out_class.detach(), t.detach(),
                                      out_reject.detach(), FLAGS.cost)
            else:
                evaluator = Evaluator(out_class.detach().view(-1),
                                      t.detach().view(-1),
                                      out_reject.detach().view(-1), FLAGS.cost)

            loss_dict.update(evaluator())

        test_metric_dict.update(loss_dict)

    # post epoch
    print_metric_dict(None, None, test_metric_dict.avg, mode='test')

    return test_metric_dict.avg
Example #9
0
def test_multi_adv(**kwargs):
    """
    this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test.
    if there is exactly same file, the result becomes the mean of them.
    the results are saved as csv file.

    'target_dir' should be like follow
    (.pth file name should be "weight_final_cost_{}") 

    ~/target_dir/XXXX/weight_final_cost_0.10_pgd-linf_eps-0.pth
                     ...
                     /weight_final_cost_0.10_pgd-linf_eps-8.pth
                     /weight_final_cost_0.10_pgd-linf_eps-16.pth
                     ...
                /YYYY/weight_final_cost_0.10_pgd-linf_eps-0.pth
                     ...
                     /weight_final_cost_0.10_pgd-linf_eps-8.pth
                     /weight_final_cost_0.10_pgd-linf_eps-16.pth
                     ...
    """
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # specify target weight path
    run_dir = '../scripts'
    target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth')
    weight_paths = sorted(glob.glob(target_path, recursive=True),
                          key=lambda x: os.path.basename(x))

    if FLAGS.cost is not None:
        weight_paths = [
            wpath for wpath in weight_paths
            if 'cost-{cost:0.2f}'.format(cost=FLAGS.cost) in wpath
        ]
    if FLAGS.at is not None:
        weight_paths = [
            wpath for wpath in weight_paths if '{at}-{at_norm}'.format(
                at=FLAGS.at, at_norm=FLAGS.at_norm) in wpath
        ]

    log_path = os.path.join(FLAGS.target_dir,
                            'test{}.csv'.format(FLAGS.suffix))

    # logging
    logger = Logger(path=log_path, mode='test', use_wandb=False, flags=FLAGS)

    # get epses
    key = FLAGS.attack + '_' + FLAGS.attack_norm
    attack_epses = EPS[key]

    for weight_path in weight_paths:
        for attack_eps in attack_epses:

            # parse basename
            basename = os.path.basename(weight_path)
            ret_dict = parse_weight_basename(basename)

            # keyword args for test function
            # variable args
            kw_args = {}
            kw_args['weight'] = weight_path
            kw_args['dataset'] = FLAGS.dataset
            kw_args['dataroot'] = FLAGS.dataroot
            kw_args['binary_target_class'] = FLAGS.binary_target_class
            kw_args['cost'] = ret_dict['cost']
            kw_args['attack'] = FLAGS.attack
            kw_args['nb_its'] = FLAGS.nb_its
            kw_args['step_size'] = None
            kw_args['attack_eps'] = attack_eps
            kw_args['attack_norm'] = FLAGS.attack_norm

            # default args
            kw_args['dim_features'] = 512
            kw_args['dropout_prob'] = 0.3
            kw_args['num_workers'] = 8
            kw_args['batch_size'] = 128
            kw_args['normalize'] = True
            kw_args['alpha'] = 0.5

            # run test
            out_dict = test(**kw_args)

            metric_dict = OrderedDict()
            metric_dict['cost'] = ret_dict['cost']
            metric_dict['binary_target_class'] = FLAGS.binary_target_class
            # at
            metric_dict['at'] = ret_dict['at']
            metric_dict['at_norm'] = ret_dict['at_norm']
            metric_dict['at_eps'] = ret_dict['at_eps']
            # attack
            metric_dict['attack'] = FLAGS.attack
            metric_dict['attack_norm'] = FLAGS.attack_norm
            metric_dict['attack_eps'] = attack_eps
            # path
            metric_dict['path'] = weight_path
            metric_dict.update(out_dict)

            # log
            logger.log(metric_dict)
Example #10
0
def train(**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()
    FLAGS.dump(
        path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix)))

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    train_dataset = dataset_builder(train=True, normalize=FLAGS.normalize)
    val_dataset = dataset_builder(train=False, normalize=FLAGS.normalize)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=FLAGS.batch_size,
                                               shuffle=True,
                                               num_workers=FLAGS.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=FLAGS.batch_size,
                                             shuffle=False,
                                             num_workers=FLAGS.num_workers,
                                             pin_memory=True)

    # model
    features = vgg16_variant(dataset_builder.input_size,
                             FLAGS.dropout_prob).cuda()
    model = SelectiveNet(features, FLAGS.dim_features,
                         dataset_builder.num_classes).cuda()
    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # optimizer
    params = model.parameters()
    optimizer = torch.optim.SGD(params,
                                lr=FLAGS.lr,
                                momentum=FLAGS.momentum,
                                weight_decay=FLAGS.wd)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=25,
                                                gamma=0.5)

    # loss
    base_loss = torch.nn.CrossEntropyLoss(reduction='none')
    SelectiveCELoss = SelectiveLoss(base_loss, coverage=FLAGS.coverage)

    # logger
    train_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)),
                          mode='train')
    val_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)),
                        mode='val')

    for ep in range(FLAGS.num_epochs):
        # pre epoch
        train_metric_dict = MetricDict()
        val_metric_dict = MetricDict()

        # train
        for i, (x, t) in enumerate(train_loader):
            model.train()
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # forward
            out_class, out_select, out_aux = model(x)

            # compute selective loss
            loss_dict = OrderedDict()
            # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty'
            selective_loss, loss_dict = SelectiveCELoss(
                out_class, out_select, t)
            selective_loss *= FLAGS.alpha
            loss_dict['selective_loss'] = selective_loss.detach().cpu().item()
            # compute standard cross entropy loss
            ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t)
            ce_loss *= (1.0 - FLAGS.alpha)
            loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

            # total loss
            loss = selective_loss + ce_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_metric_dict.update(loss_dict)

        # validation
        with torch.autograd.no_grad():
            for i, (x, t) in enumerate(val_loader):
                model.eval()
                x = x.to('cuda', non_blocking=True)
                t = t.to('cuda', non_blocking=True)

                # forward
                out_class, out_select, out_aux = model(x)

                # compute selective loss
                loss_dict = OrderedDict()
                # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty'
                selective_loss, loss_dict = SelectiveCELoss(
                    out_class, out_select, t)
                selective_loss *= FLAGS.alpha
                loss_dict['selective_loss'] = selective_loss.detach().cpu(
                ).item()
                # compute standard cross entropy loss
                ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t)
                ce_loss *= (1.0 - FLAGS.alpha)
                loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

                # total loss
                loss = selective_loss + ce_loss
                loss_dict['loss'] = loss.detach().cpu().item()

                # evaluation
                evaluator = Evaluator(out_class.detach(), t.detach(),
                                      out_select.detach())
                loss_dict.update(evaluator())

                val_metric_dict.update(loss_dict)

        # post epoch
        # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train')
        print_metric_dict(ep,
                          FLAGS.num_epochs,
                          val_metric_dict.avg,
                          mode='val')

        train_logger.log(train_metric_dict.avg, step=(ep + 1))
        val_logger.log(val_metric_dict.avg, step=(ep + 1))

        scheduler.step()

    # post training
    save_model(model,
               path=os.path.join(FLAGS.log_dir,
                                 'weight_final{}.pth'.format(FLAGS.suffix)))
Example #11
0
def test_adv(**kwargs):
    """
    this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test.
    if there is exactly same file, the result becomes the mean of them.
    the results are saved as csv file.

    'target_dir' should be like follow
    (.pth file name should be "weight_final_coverage_{}") 

    ~/target_dir/XXXX/weight_final_pgd-linf_eps-0.pth
                     ...
                     /weight_final_pgd-linf_eps-8.pth
                     /weight_final_pgd-linf_eps-16.pth
                     ...
                /YYYY/weight_final_pgd-linf_eps-0.pth
                     ...
                     /weight_final_pgd-linf_eps-8.pth
                     /weight_final_pgd-linf_eps-16.pth
                     ...
    """
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # paths
    run_dir = '../scripts'
    if os.path.splitext(FLAGS.target_dir)[-1] != '.pth':
        target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth')
        weight_paths = sorted(glob.glob(target_path, recursive=True),
                              key=lambda x: os.path.basename(x))
        log_path = os.path.join(FLAGS.target_dir,
                                'test{}.csv'.format(FLAGS.suffix))
    else:
        weight_paths = list(FLAGS.target_dir)
        log_path = os.path.join(os.path.dirname(FLAGS.target_dir),
                                'test{}.csv'.format(FLAGS.suffix))

    # logging
    logger = Logger(path=log_path, mode='test', use_wandb=False, flags=FLAGS)

    num_divides = [0, 2, 4, 8, 16] if not FLAGS.num_divide else list(
        FLAGS.num_divide)

    for weight_path in weight_paths:
        for num_divide in num_divides:

            # parse basename
            basename = os.path.basename(weight_path)
            ret_dict = parse_weight_basename(basename)

            # keyword args for test function
            # variable args
            kw_args = {}
            kw_args['arch'] = FLAGS.arch
            kw_args['weight'] = weight_path
            kw_args['dataset'] = FLAGS.dataset
            kw_args['dataroot'] = FLAGS.dataroot
            kw_args['batch_size'] = FLAGS.batch_size
            kw_args['attack'] = None
            kw_args['attack_eps'] = 0
            kw_args['attack_norm'] = None
            kw_args['nb_its'] = 0
            kw_args['step_size'] = None
            kw_args['num_divide'] = num_divide

            # default args
            kw_args['num_workers'] = 8
            kw_args['normalize'] = True

            # run test
            out_dict = test(**kw_args)

            metric_dict = OrderedDict()
            # model
            metric_dict['arch'] = FLAGS.arch
            # at
            metric_dict['at'] = ret_dict['at']
            metric_dict['at_norm'] = ret_dict['at_norm']
            metric_dict['at_eps'] = ret_dict['at_eps']
            # transform
            metric_dict['num_divide'] = num_divide
            # path
            metric_dict['path'] = weight_path
            metric_dict.update(out_dict)

            # log
            logger.log(metric_dict)
Example #12
0
def test_multi(**kwargs):
    """
    this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test.
    if there is exactly same file, the result becomes the mean of them.
    the results are saved as csv file.

    'target_dir' should be like follow

    ~/target_dir/XXXX/weight_final_coverage_0.10.pth
                     /weight_final_coverage_0.95.pth
                     /weight_final_coverage_0.90.pth
                     ...
                /YYYY/weight_final_coverage_0.10.pth
                     /weight_final_coverage_0.95.pth
                     /weight_final_coverage_0.90.pth
                     ...
    """
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # paths
    run_dir = '../scripts'
    target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth')
    weight_paths = sorted(glob.glob(target_path, recursive=True),
                          key=lambda x: os.path.basename(x))
    log_path = os.path.join(FLAGS.target_dir, 'test.csv')

    # logging
    logger = Logger(path=log_path, mode='test')

    for weight_path in weight_paths:
        # get coverage
        # name should be like, '~_coverage_{}.pth'
        basename = os.path.basename(weight_path)
        basename, ext = os.path.splitext(basename)
        coverage = float(basename.split('_')[-1])

        # keyword args for test function
        # variable args
        kw_args = {}
        kw_args['weight'] = weight_path
        kw_args['dataset'] = FLAGS.dataset
        kw_args['dataroot'] = FLAGS.dataroot
        kw_args['coverage'] = coverage
        # default args
        kw_args['dim_features'] = 512
        kw_args['dropout_prob'] = 0.3
        kw_args['num_workers'] = 8
        kw_args['batch_size'] = 128
        kw_args['normalize'] = True
        kw_args['alpha'] = 0.5

        # run test
        out_dict = test(**kw_args)

        metric_dict = OrderedDict()
        metric_dict['coverage'] = coverage
        metric_dict['path'] = weight_path
        metric_dict.update(out_dict)

        # log
        logger.log(metric_dict)
Example #13
0
def train(**kwargs):
    """
    this function executes standard training and adversarial training. 
    """
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()
    os.makedirs(FLAGS.log_dir, exist_ok=True)
    FLAGS.dump(
        path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix)))

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    train_dataset = dataset_builder(train=True, normalize=FLAGS.normalize)
    val_dataset = dataset_builder(train=False, normalize=FLAGS.normalize)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=FLAGS.batch_size,
                                               shuffle=True,
                                               num_workers=FLAGS.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=FLAGS.batch_size,
                                             shuffle=False,
                                             num_workers=FLAGS.num_workers,
                                             pin_memory=True)

    # model
    num_classes = dataset_builder.num_classes
    model = ModelBuilder(num_classes=num_classes,
                         pretrained=False)[FLAGS.arch].cuda()
    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # optimizer
    params = model.parameters()
    optimizer = torch.optim.SGD(params,
                                lr=FLAGS.lr,
                                momentum=FLAGS.momentum,
                                weight_decay=FLAGS.wd)

    # scheduler
    assert len(FLAGS.ms) == 0
    if len(FLAGS.ms) == 1:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=FLAGS.ms[0],
                                                    gamma=FLAGS.gamma)
    else:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones=sorted(
                                                             list(FLAGS.ms)),
                                                         gamma=FLAGS.gamma)

    # attacker
    if FLAGS.at and FLAGS.at_eps > 0:
        # get step_size
        step_size = get_step_size(
            FLAGS.at_eps,
            FLAGS.nb_its) if not FLAGS.step_size else FLAGS.step_size
        FLAGS._dict['step_size'] = step_size
        assert step_size >= 0

        # create attacker
        attacker = AttackerBuilder()(method=FLAGS.at,
                                     norm=FLAGS.at_norm,
                                     eps=FLAGS.at_eps,
                                     **FLAGS._dict)

    # logger
    train_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)),
                          mode='train',
                          use_wandb=False,
                          flags=FLAGS._dict)
    val_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)),
                        mode='val',
                        use_wandb=FLAGS.use_wandb,
                        flags=FLAGS._dict)

    for ep in range(FLAGS.num_epochs):
        # pre epoch
        train_metric_dict = MetricDict()
        val_metric_dict = MetricDict()

        # train
        for i, (x, t) in enumerate(train_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            # forward
            model.train()
            model.zero_grad()
            out = model(x)

            # compute selective loss
            loss_dict = OrderedDict()
            # cross entropy
            ce_loss = torch.nn.CrossEntropyLoss()(out, t)
            #loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

            # total loss
            loss = ce_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_metric_dict.update(loss_dict)

        # validation
        for i, (x, t) in enumerate(val_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            with torch.autograd.no_grad():
                # forward
                model.eval()
                model.zero_grad()
                out = model(x)

                # compute selective loss
                loss_dict = OrderedDict()
                # cross entropy
                ce_loss = torch.nn.CrossEntropyLoss()(out, t)
                #loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

                # total loss
                loss = ce_loss
                loss_dict['loss'] = loss.detach().cpu().item()

                # evaluation
                evaluator = Evaluator(out.detach(),
                                      t.detach(),
                                      selection_out=None)
                loss_dict.update(evaluator())

                val_metric_dict.update(loss_dict)

        # post epoch
        # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train')
        print_metric_dict(ep,
                          FLAGS.num_epochs,
                          val_metric_dict.avg,
                          mode='val')

        train_logger.log(train_metric_dict.avg, step=(ep + 1))
        val_logger.log(val_metric_dict.avg, step=(ep + 1))

        scheduler.step()

    # post training
    save_model(model,
               path=os.path.join(FLAGS.log_dir,
                                 'weight_final{}.pth'.format(FLAGS.suffix)))
Example #14
0
def train(**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()
    os.makedirs(FLAGS.log_dir, exist_ok=True)
    FLAGS.dump(
        path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix)))

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    train_dataset = dataset_builder(
        train=True,
        normalize=FLAGS.normalize,
        binary_classification_target=FLAGS.binary_target_class)
    val_dataset = dataset_builder(
        train=False,
        normalize=FLAGS.normalize,
        binary_classification_target=FLAGS.binary_target_class)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=FLAGS.batch_size,
                                               shuffle=True,
                                               num_workers=FLAGS.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=FLAGS.batch_size,
                                             shuffle=False,
                                             num_workers=FLAGS.num_workers,
                                             pin_memory=True)

    # model
    features = vgg16_variant(dataset_builder.input_size,
                             FLAGS.dropout_prob).cuda()
    model = DeepLinearSvmWithRejector(features,
                                      FLAGS.dim_features,
                                      num_classes=1).cuda()
    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # optimizer
    params = model.parameters()
    optimizer = torch.optim.SGD(params,
                                lr=FLAGS.lr,
                                momentum=FLAGS.momentum,
                                weight_decay=FLAGS.wd)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=25,
                                                gamma=0.5)

    # loss
    MHBRLoss = MaxHingeLossBinaryWithRejection(FLAGS.cost)

    # attacker
    if FLAGS.at and FLAGS.at_eps > 0:
        # get step_size
        if not FLAGS.step_size:
            FLAGS.step_size = get_step_size(FLAGS.at_eps, FLAGS.nb_its)
        assert FLAGS.step_size >= 0

        # create attacker
        if FLAGS.at == 'pgd':
            attacker = PGDAttackVariant(
                FLAGS.nb_its,
                FLAGS.at_eps,
                FLAGS.step_size,
                dataset=FLAGS.dataset,
                cost=FLAGS.cost,
                norm=FLAGS.at_norm,
                num_classes=dataset_builder.num_classes,
                is_binary_classification=True)
        else:
            raise NotImplementedError('invalid at method.')

    # logger
    train_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)),
                          mode='train',
                          use_wandb=False,
                          flags=FLAGS._dict)
    val_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)),
                        mode='val',
                        use_wandb=FLAGS.use_wandb,
                        flags=FLAGS._dict)

    for ep in range(FLAGS.num_epochs):
        # pre epoch
        train_metric_dict = MetricDict()
        val_metric_dict = MetricDict()

        # train
        for i, (x, t) in enumerate(train_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            # forward
            model.train()
            model.zero_grad()
            out_class, out_reject = model(x)

            # compute selective loss
            loss_dict = OrderedDict()
            # loss dict includes, 'A mean' / 'B mean'
            maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t)
            loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item()

            # regularization_loss = 0.5*WeightPenalty()(model.classifier)
            # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item()

            # total loss
            loss = maxhinge_loss  #+ regularization_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_metric_dict.update(loss_dict)

        # validation
        for i, (x, t) in enumerate(val_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            with torch.autograd.no_grad():
                # forward
                model.eval()
                model.zero_grad()
                out_class, out_reject = model(x)

                # compute selective loss
                loss_dict = OrderedDict()
                # loss dict includes, 'A mean' / 'B mean'
                maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t)
                loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item(
                )

                # regularization_loss = 0.5*WeightPenalty()(model.classifier)
                # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item()

                # total loss
                loss = maxhinge_loss  #+ regularization_loss
                loss_dict['loss'] = loss.detach().cpu().item()

                # evaluation
                evaluator = Evaluator(out_class.detach().view(-1),
                                      t.detach().view(-1),
                                      out_reject.detach().view(-1))
                loss_dict.update(evaluator())

                val_metric_dict.update(loss_dict)

        # post epoch
        # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train')
        print_metric_dict(ep,
                          FLAGS.num_epochs,
                          val_metric_dict.avg,
                          mode='val')

        train_logger.log(train_metric_dict.avg, step=(ep + 1))
        val_logger.log(val_metric_dict.avg, step=(ep + 1))

        scheduler.step()

    # post training
    save_model(model,
               path=os.path.join(FLAGS.log_dir,
                                 'weight_final{}.pth'.format(FLAGS.suffix)))
Example #15
0
def test_fourier(**kwargs):
    """
    this script loads all 'weight_final_{something}.pth' files which exisits under 'kwargs.target_dir' and execute test.
    if there is exactly same file, the result becomes the mean of them.
    the results are saved as csv file.

    'target_dir' should be like follow
    (.pth file name should be "weight_final_coverage_{}") 

    ~/target_dir/XXXX/weight_final_pgd-linf_eps-0.pth
                     ...
                     /weight_final_pgd-linf_eps-8.pth
                     /weight_final_pgd-linf_eps-16.pth
                     ...
                /YYYY/weight_final_pgd-linf_eps-0.pth
                     ...
                     /weight_final_pgd-linf_eps-8.pth
                     /weight_final_pgd-linf_eps-16.pth
                     ...
    """
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # paths
    run_dir = '../scripts'
    if os.path.splitext(FLAGS.target_dir)[-1] != '.pth':
        target_path = os.path.join(FLAGS.target_dir, '**/weight_final*.pth')
        weight_paths = sorted(glob.glob(target_path, recursive=True),
                              key=lambda x: os.path.basename(x))
        log_path = os.path.join(FLAGS.target_dir,
                                'test{}.csv'.format(FLAGS.suffix))
    else:
        weight_paths = [FLAGS.target_dir]
        log_path = os.path.join(os.path.dirname(FLAGS.target_dir),
                                'test{}.csv'.format(FLAGS.suffix))

    # logging
    logger = Logger(path=log_path, mode='test', use_wandb=False, flags=FLAGS)

    for weight_path in weight_paths:
        for index_h in range(-FLAGS.fn_max_index_h, FLAGS.fn_max_index_h + 1):
            for index_w in range(-FLAGS.fn_max_index_w,
                                 FLAGS.fn_max_index_w + 1):
                # continue when indices are 0
                if index_h == 0 or index_w == 0: continue

                # parse basename
                basename = os.path.basename(weight_path)
                ret_dict = parse_weight_basename(basename)

                # keyword args for test function
                # variable args
                kw_args = {}
                kw_args['arch'] = FLAGS.arch
                kw_args['weight'] = weight_path
                kw_args['dataset'] = FLAGS.dataset
                kw_args['dataroot'] = FLAGS.dataroot
                kw_args['batch_size'] = FLAGS.batch_size
                kw_args['fn_eps'] = FLAGS.fn_eps
                kw_args['fn_index_h'] = index_h
                kw_args['fn_index_w'] = index_w

                # default args
                kw_args['num_workers'] = 8
                kw_args['normalize'] = True

                # run test
                out_dict = test(**kw_args)

                metric_dict = OrderedDict()
                # model
                metric_dict['arch'] = FLAGS.arch
                # Fourier noise
                metric_dict['fn_eps'] = FLAGS.fn_eps
                metric_dict['fn_index_h'] = index_h
                metric_dict['fn_index_w'] = index_w
                # at
                metric_dict['at'] = ret_dict['at']
                metric_dict['at_norm'] = ret_dict['at_norm']
                metric_dict['at_eps'] = ret_dict['at_eps']
                # path
                metric_dict['path'] = weight_path
                metric_dict.update(out_dict)

                # log
                logger.log(metric_dict)
Example #16
0
def test(**kwargs):
    """
    test model on specific cost and specific adversarial perturbation.
    """

    kwargs = set_default(**kwargs)

    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    assert FLAGS.nb_its >= 0
    assert FLAGS.attack_eps >= 0
    assert FLAGS.ps_num_divide >= 0
    if FLAGS.attack_eps > 0 and FLAGS.ps_num_divide > 0:
        raise ValueError(
            'Adversarial Attack and Patch Shuffle should not be used at same time'
        )
    if FLAGS.attack_eps > 0 and FLAGS.fn_eps > 0:
        raise ValueError(
            'Adversarial Attack and Fourier Noise should not be used at same time'
        )
    if FLAGS.ps_num_divide > 0 and FLAGS.fn_eps > 0:
        raise ValueError(
            'Patch Shuffle and Fourier Noise should not be used at same time')

    # optional transform
    optional_transform = []
    optional_transform.extend(
        [PatchSuffle(FLAGS.ps_num_divide)] if FLAGS.ps_num_divide else [])
    optional_transform.extend(
        [FourierNoise(FLAGS.fn_index_h, FLAGS.fn_index_w, FLAGS.fn_eps
                      )] if FLAGS.fn_eps else [])

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    test_dataset = dataset_builder(train=False,
                                   normalize=FLAGS.normalize,
                                   optional_transform=optional_transform)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=FLAGS.batch_size,
                                              shuffle=False,
                                              num_workers=FLAGS.num_workers,
                                              pin_memory=True)

    # model (load from checkpoint)
    num_classes = dataset_builder.num_classes
    model = ModelBuilder(num_classes=num_classes,
                         pretrained=False)[FLAGS.arch].cuda()
    load_model(model, FLAGS.weight)
    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # adversarial attack
    if FLAGS.attack and FLAGS.attack_eps > 0:
        # get step_size
        step_size = get_step_size(
            FLAGS.attack_eps,
            FLAGS.nb_its) if not FLAGS.step_size else FLAGS.step_size
        FLAGS._dict['step_size'] = step_size
        assert step_size >= 0

        # create attacker
        attacker = AttackerBuilder()(method=FLAGS.attack,
                                     norm=FLAGS.attack_norm,
                                     eps=FLAGS.attack_eps,
                                     **FLAGS._dict)

    # pre epoch misc
    test_metric_dict = MetricDict()

    # test
    for i, (x, t) in enumerate(test_loader):
        model.eval()
        x = x.to('cuda', non_blocking=True)
        t = t.to('cuda', non_blocking=True)
        loss_dict = OrderedDict()

        # adversarial samples
        if FLAGS.attack and FLAGS.attack_eps > 0:
            # create adversarial sampels
            model.zero_grad()
            x = attacker(model, x.detach(), t.detach())

        with torch.autograd.no_grad():
            # forward
            model.zero_grad()
            logit = model(x)

            # compute selective loss
            ce_loss = torch.nn.CrossEntropyLoss()(logit,
                                                  t).detach().cpu().item()
            loss_dict['loss'] = ce_loss

            # evaluation
            evaluator = Evaluator(logit.detach(),
                                  t.detach(),
                                  selection_out=None)
            loss_dict.update(evaluator())

        test_metric_dict.update(loss_dict)

    # post epoch
    print_metric_dict(None, None, test_metric_dict.avg, mode='test')

    return test_metric_dict.avg
Example #17
0
def test(**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    test_dataset = dataset_builder(train=False, normalize=FLAGS.normalize)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=FLAGS.batch_size,
                                              shuffle=False,
                                              num_workers=FLAGS.num_workers,
                                              pin_memory=True)

    # model
    features = vgg16_variant(dataset_builder.input_size,
                             FLAGS.dropout_prob).cuda()
    model = SelectiveNet(features, FLAGS.dim_features,
                         dataset_builder.num_classes).cuda()
    load_model(model, FLAGS.weight)

    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # loss
    base_loss = torch.nn.CrossEntropyLoss(reduction='none')
    SelectiveCELoss = SelectiveLoss(base_loss, coverage=FLAGS.coverage)

    # pre epoch
    test_metric_dict = MetricDict()

    # test
    with torch.autograd.no_grad():
        for i, (x, t) in enumerate(test_loader):
            model.eval()
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # forward
            out_class, out_select, out_aux = model(x)

            # compute selective loss
            loss_dict = OrderedDict()
            # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty'
            selective_loss, loss_dict = SelectiveCELoss(
                out_class, out_select, t)
            selective_loss *= FLAGS.alpha
            loss_dict['selective_loss'] = selective_loss.detach().cpu().item()
            # compute standard cross entropy loss
            ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t)
            ce_loss *= (1.0 - FLAGS.alpha)
            loss_dict['ce_loss'] = ce_loss.detach().cpu().item()
            # total loss
            loss = selective_loss + ce_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # evaluation
            evaluator = Evaluator(out_class.detach(), t.detach(),
                                  out_select.detach())
            loss_dict.update(evaluator())

            test_metric_dict.update(loss_dict)

    # post epoch
    print_metric_dict(None, None, test_metric_dict.avg, mode='test')

    return test_metric_dict.avg