示例#1
0
def test(**kwargs):
    """
    test model on specific cost and specific adversarial perturbation.
    """

    kwargs = set_default(**kwargs)

    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    assert FLAGS.nb_its >= 0
    assert FLAGS.attack_eps >= 0
    assert FLAGS.ps_num_divide >= 0
    if FLAGS.attack_eps > 0 and FLAGS.ps_num_divide > 0:
        raise ValueError(
            'Adversarial Attack and Patch Shuffle should not be used at same time'
        )
    if FLAGS.attack_eps > 0 and FLAGS.fn_eps > 0:
        raise ValueError(
            'Adversarial Attack and Fourier Noise should not be used at same time'
        )
    if FLAGS.ps_num_divide > 0 and FLAGS.fn_eps > 0:
        raise ValueError(
            'Patch Shuffle and Fourier Noise should not be used at same time')

    # optional transform
    optional_transform = []
    optional_transform.extend(
        [PatchSuffle(FLAGS.ps_num_divide)] if FLAGS.ps_num_divide else [])
    optional_transform.extend(
        [FourierNoise(FLAGS.fn_index_h, FLAGS.fn_index_w, FLAGS.fn_eps
                      )] if FLAGS.fn_eps else [])

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    test_dataset = dataset_builder(train=False,
                                   normalize=FLAGS.normalize,
                                   optional_transform=optional_transform)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=FLAGS.batch_size,
                                              shuffle=False,
                                              num_workers=FLAGS.num_workers,
                                              pin_memory=True)

    # model (load from checkpoint)
    num_classes = dataset_builder.num_classes
    model = ModelBuilder(num_classes=num_classes,
                         pretrained=False)[FLAGS.arch].cuda()
    load_model(model, FLAGS.weight)
    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # adversarial attack
    if FLAGS.attack and FLAGS.attack_eps > 0:
        # get step_size
        step_size = get_step_size(
            FLAGS.attack_eps,
            FLAGS.nb_its) if not FLAGS.step_size else FLAGS.step_size
        FLAGS._dict['step_size'] = step_size
        assert step_size >= 0

        # create attacker
        attacker = AttackerBuilder()(method=FLAGS.attack,
                                     norm=FLAGS.attack_norm,
                                     eps=FLAGS.attack_eps,
                                     **FLAGS._dict)

    # pre epoch misc
    test_metric_dict = MetricDict()

    # test
    for i, (x, t) in enumerate(test_loader):
        model.eval()
        x = x.to('cuda', non_blocking=True)
        t = t.to('cuda', non_blocking=True)
        loss_dict = OrderedDict()

        # adversarial samples
        if FLAGS.attack and FLAGS.attack_eps > 0:
            # create adversarial sampels
            model.zero_grad()
            x = attacker(model, x.detach(), t.detach())

        with torch.autograd.no_grad():
            # forward
            model.zero_grad()
            logit = model(x)

            # compute selective loss
            ce_loss = torch.nn.CrossEntropyLoss()(logit,
                                                  t).detach().cpu().item()
            loss_dict['loss'] = ce_loss

            # evaluation
            evaluator = Evaluator(logit.detach(),
                                  t.detach(),
                                  selection_out=None)
            loss_dict.update(evaluator())

        test_metric_dict.update(loss_dict)

    # post epoch
    print_metric_dict(None, None, test_metric_dict.avg, mode='test')

    return test_metric_dict.avg
示例#2
0
def train(**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()
    os.makedirs(FLAGS.log_dir, exist_ok=True)
    FLAGS.dump(
        path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix)))

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    train_dataset = dataset_builder(
        train=True,
        normalize=FLAGS.normalize,
        binary_classification_target=FLAGS.binary_target_class)
    val_dataset = dataset_builder(
        train=False,
        normalize=FLAGS.normalize,
        binary_classification_target=FLAGS.binary_target_class)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=FLAGS.batch_size,
                                               shuffle=True,
                                               num_workers=FLAGS.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=FLAGS.batch_size,
                                             shuffle=False,
                                             num_workers=FLAGS.num_workers,
                                             pin_memory=True)

    # model
    features = vgg16_variant(dataset_builder.input_size,
                             FLAGS.dropout_prob).cuda()
    model = DeepLinearSvmWithRejector(features,
                                      FLAGS.dim_features,
                                      num_classes=1).cuda()
    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # optimizer
    params = model.parameters()
    optimizer = torch.optim.SGD(params,
                                lr=FLAGS.lr,
                                momentum=FLAGS.momentum,
                                weight_decay=FLAGS.wd)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=25,
                                                gamma=0.5)

    # loss
    MHBRLoss = MaxHingeLossBinaryWithRejection(FLAGS.cost)

    # attacker
    if FLAGS.at and FLAGS.at_eps > 0:
        # get step_size
        if not FLAGS.step_size:
            FLAGS.step_size = get_step_size(FLAGS.at_eps, FLAGS.nb_its)
        assert FLAGS.step_size >= 0

        # create attacker
        if FLAGS.at == 'pgd':
            attacker = PGDAttackVariant(
                FLAGS.nb_its,
                FLAGS.at_eps,
                FLAGS.step_size,
                dataset=FLAGS.dataset,
                cost=FLAGS.cost,
                norm=FLAGS.at_norm,
                num_classes=dataset_builder.num_classes,
                is_binary_classification=True)
        else:
            raise NotImplementedError('invalid at method.')

    # logger
    train_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)),
                          mode='train',
                          use_wandb=False,
                          flags=FLAGS._dict)
    val_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)),
                        mode='val',
                        use_wandb=FLAGS.use_wandb,
                        flags=FLAGS._dict)

    for ep in range(FLAGS.num_epochs):
        # pre epoch
        train_metric_dict = MetricDict()
        val_metric_dict = MetricDict()

        # train
        for i, (x, t) in enumerate(train_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            # forward
            model.train()
            model.zero_grad()
            out_class, out_reject = model(x)

            # compute selective loss
            loss_dict = OrderedDict()
            # loss dict includes, 'A mean' / 'B mean'
            maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t)
            loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item()

            # regularization_loss = 0.5*WeightPenalty()(model.classifier)
            # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item()

            # total loss
            loss = maxhinge_loss  #+ regularization_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_metric_dict.update(loss_dict)

        # validation
        for i, (x, t) in enumerate(val_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            with torch.autograd.no_grad():
                # forward
                model.eval()
                model.zero_grad()
                out_class, out_reject = model(x)

                # compute selective loss
                loss_dict = OrderedDict()
                # loss dict includes, 'A mean' / 'B mean'
                maxhinge_loss, loss_dict = MHBRLoss(out_class, out_reject, t)
                loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item(
                )

                # regularization_loss = 0.5*WeightPenalty()(model.classifier)
                # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item()

                # total loss
                loss = maxhinge_loss  #+ regularization_loss
                loss_dict['loss'] = loss.detach().cpu().item()

                # evaluation
                evaluator = Evaluator(out_class.detach().view(-1),
                                      t.detach().view(-1),
                                      out_reject.detach().view(-1))
                loss_dict.update(evaluator())

                val_metric_dict.update(loss_dict)

        # post epoch
        # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train')
        print_metric_dict(ep,
                          FLAGS.num_epochs,
                          val_metric_dict.avg,
                          mode='val')

        train_logger.log(train_metric_dict.avg, step=(ep + 1))
        val_logger.log(val_metric_dict.avg, step=(ep + 1))

        scheduler.step()

    # post training
    save_model(model,
               path=os.path.join(FLAGS.log_dir,
                                 'weight_final{}.pth'.format(FLAGS.suffix)))
示例#3
0
文件: test.py 项目: MasaKat0/ATRO
def test(**kwargs):
    """
    test model on specific cost and specific adversarial perturbation.
    """
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    assert FLAGS.nb_its > 0
    assert FLAGS.attack_eps >= 0

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    test_dataset = dataset_builder(
        train=False,
        normalize=FLAGS.normalize,
        binary_classification_target=FLAGS.binary_target_class)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=FLAGS.batch_size,
                                              shuffle=False,
                                              num_workers=FLAGS.num_workers,
                                              pin_memory=True)

    # model
    features = vgg16_variant(dataset_builder.input_size,
                             FLAGS.dropout_prob).cuda()
    if FLAGS.binary_target_class is None:
        model = DeepLinearSvmWithRejector(features, FLAGS.dim_features,
                                          dataset_builder.num_classes).cuda()
    else:
        model = DeepLinearSvmWithRejector(features, FLAGS.dim_features,
                                          1).cuda()
    load_model(model, FLAGS.weight)

    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # loss
    if FLAGS.binary_target_class is None:
        criterion = MaxHingeLossWithRejection(FLAGS.cost)
    else:
        criterion = MaxHingeLossBinaryWithRejection(FLAGS.cost)

    # adversarial attack
    if FLAGS.attack:
        # get step_size
        if not FLAGS.step_size:
            FLAGS.step_size = get_step_size(FLAGS.attack_eps, FLAGS.nb_its)
        assert FLAGS.step_size >= 0

        # create attacker
        if FLAGS.attack == 'pgd':
            if FLAGS.binary_target_class is None:
                attacker = PGDAttackVariant(
                    FLAGS.nb_its,
                    FLAGS.attack_eps,
                    FLAGS.step_size,
                    dataset=FLAGS.dataset,
                    cost=FLAGS.cost,
                    norm=FLAGS.attack_norm,
                    num_classes=dataset_builder.num_classes,
                    is_binary_classification=False)
            else:
                attacker = PGDAttackVariant(
                    FLAGS.nb_its,
                    FLAGS.attack_eps,
                    FLAGS.step_size,
                    dataset=FLAGS.dataset,
                    cost=FLAGS.cost,
                    norm=FLAGS.attack_norm,
                    num_classes=dataset_builder.num_classes,
                    is_binary_classification=True)

        else:
            raise NotImplementedError('invalid attack method.')

    # pre epoch
    test_metric_dict = MetricDict()

    # test
    for i, (x, t) in enumerate(test_loader):
        model.eval()
        x = x.to('cuda', non_blocking=True)
        t = t.to('cuda', non_blocking=True)
        loss_dict = OrderedDict()

        # adversarial samples
        if FLAGS.attack and FLAGS.attack_eps > 0:
            # create adversarial sampels
            model.zero_grad()
            x = attacker(model, x.detach(), t.detach())

        with torch.autograd.no_grad():
            model.zero_grad()
            # forward
            out_class, out_reject = model(x)

            # compute selective loss
            maxhinge_loss, loss_dict = criterion(out_class, out_reject, t)
            loss_dict['maxhinge_loss'] = maxhinge_loss.detach().cpu().item()

            # compute standard cross entropy loss
            # regularization_loss = WeightPenalty()(model.classifier)
            # loss_dict['regularization_loss'] = regularization_loss.detach().cpu().item()

            # total loss
            loss = maxhinge_loss  #+ regularization_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # evaluation
            if FLAGS.binary_target_class is None:
                evaluator = Evaluator(out_class.detach(), t.detach(),
                                      out_reject.detach(), FLAGS.cost)
            else:
                evaluator = Evaluator(out_class.detach().view(-1),
                                      t.detach().view(-1),
                                      out_reject.detach().view(-1), FLAGS.cost)

            loss_dict.update(evaluator())

        test_metric_dict.update(loss_dict)

    # post epoch
    print_metric_dict(None, None, test_metric_dict.avg, mode='test')

    return test_metric_dict.avg
示例#4
0
def train(**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()
    FLAGS.dump(
        path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix)))

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    train_dataset = dataset_builder(train=True, normalize=FLAGS.normalize)
    val_dataset = dataset_builder(train=False, normalize=FLAGS.normalize)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=FLAGS.batch_size,
                                               shuffle=True,
                                               num_workers=FLAGS.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=FLAGS.batch_size,
                                             shuffle=False,
                                             num_workers=FLAGS.num_workers,
                                             pin_memory=True)

    # model
    features = vgg16_variant(dataset_builder.input_size,
                             FLAGS.dropout_prob).cuda()
    model = SelectiveNet(features, FLAGS.dim_features,
                         dataset_builder.num_classes).cuda()
    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # optimizer
    params = model.parameters()
    optimizer = torch.optim.SGD(params,
                                lr=FLAGS.lr,
                                momentum=FLAGS.momentum,
                                weight_decay=FLAGS.wd)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=25,
                                                gamma=0.5)

    # loss
    base_loss = torch.nn.CrossEntropyLoss(reduction='none')
    SelectiveCELoss = SelectiveLoss(base_loss, coverage=FLAGS.coverage)

    # logger
    train_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)),
                          mode='train')
    val_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)),
                        mode='val')

    for ep in range(FLAGS.num_epochs):
        # pre epoch
        train_metric_dict = MetricDict()
        val_metric_dict = MetricDict()

        # train
        for i, (x, t) in enumerate(train_loader):
            model.train()
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # forward
            out_class, out_select, out_aux = model(x)

            # compute selective loss
            loss_dict = OrderedDict()
            # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty'
            selective_loss, loss_dict = SelectiveCELoss(
                out_class, out_select, t)
            selective_loss *= FLAGS.alpha
            loss_dict['selective_loss'] = selective_loss.detach().cpu().item()
            # compute standard cross entropy loss
            ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t)
            ce_loss *= (1.0 - FLAGS.alpha)
            loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

            # total loss
            loss = selective_loss + ce_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_metric_dict.update(loss_dict)

        # validation
        with torch.autograd.no_grad():
            for i, (x, t) in enumerate(val_loader):
                model.eval()
                x = x.to('cuda', non_blocking=True)
                t = t.to('cuda', non_blocking=True)

                # forward
                out_class, out_select, out_aux = model(x)

                # compute selective loss
                loss_dict = OrderedDict()
                # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty'
                selective_loss, loss_dict = SelectiveCELoss(
                    out_class, out_select, t)
                selective_loss *= FLAGS.alpha
                loss_dict['selective_loss'] = selective_loss.detach().cpu(
                ).item()
                # compute standard cross entropy loss
                ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t)
                ce_loss *= (1.0 - FLAGS.alpha)
                loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

                # total loss
                loss = selective_loss + ce_loss
                loss_dict['loss'] = loss.detach().cpu().item()

                # evaluation
                evaluator = Evaluator(out_class.detach(), t.detach(),
                                      out_select.detach())
                loss_dict.update(evaluator())

                val_metric_dict.update(loss_dict)

        # post epoch
        # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train')
        print_metric_dict(ep,
                          FLAGS.num_epochs,
                          val_metric_dict.avg,
                          mode='val')

        train_logger.log(train_metric_dict.avg, step=(ep + 1))
        val_logger.log(val_metric_dict.avg, step=(ep + 1))

        scheduler.step()

    # post training
    save_model(model,
               path=os.path.join(FLAGS.log_dir,
                                 'weight_final{}.pth'.format(FLAGS.suffix)))
示例#5
0
def train(**kwargs):
    """
    this function executes standard training and adversarial training. 
    """
    # flags
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()
    os.makedirs(FLAGS.log_dir, exist_ok=True)
    FLAGS.dump(
        path=os.path.join(FLAGS.log_dir, 'flags{}.json'.format(FLAGS.suffix)))

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    train_dataset = dataset_builder(train=True, normalize=FLAGS.normalize)
    val_dataset = dataset_builder(train=False, normalize=FLAGS.normalize)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=FLAGS.batch_size,
                                               shuffle=True,
                                               num_workers=FLAGS.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=FLAGS.batch_size,
                                             shuffle=False,
                                             num_workers=FLAGS.num_workers,
                                             pin_memory=True)

    # model
    num_classes = dataset_builder.num_classes
    model = ModelBuilder(num_classes=num_classes,
                         pretrained=False)[FLAGS.arch].cuda()
    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # optimizer
    params = model.parameters()
    optimizer = torch.optim.SGD(params,
                                lr=FLAGS.lr,
                                momentum=FLAGS.momentum,
                                weight_decay=FLAGS.wd)

    # scheduler
    assert len(FLAGS.ms) == 0
    if len(FLAGS.ms) == 1:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=FLAGS.ms[0],
                                                    gamma=FLAGS.gamma)
    else:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones=sorted(
                                                             list(FLAGS.ms)),
                                                         gamma=FLAGS.gamma)

    # attacker
    if FLAGS.at and FLAGS.at_eps > 0:
        # get step_size
        step_size = get_step_size(
            FLAGS.at_eps,
            FLAGS.nb_its) if not FLAGS.step_size else FLAGS.step_size
        FLAGS._dict['step_size'] = step_size
        assert step_size >= 0

        # create attacker
        attacker = AttackerBuilder()(method=FLAGS.at,
                                     norm=FLAGS.at_norm,
                                     eps=FLAGS.at_eps,
                                     **FLAGS._dict)

    # logger
    train_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'train_log{}.csv'.format(FLAGS.suffix)),
                          mode='train',
                          use_wandb=False,
                          flags=FLAGS._dict)
    val_logger = Logger(path=os.path.join(
        FLAGS.log_dir, 'val_log{}.csv'.format(FLAGS.suffix)),
                        mode='val',
                        use_wandb=FLAGS.use_wandb,
                        flags=FLAGS._dict)

    for ep in range(FLAGS.num_epochs):
        # pre epoch
        train_metric_dict = MetricDict()
        val_metric_dict = MetricDict()

        # train
        for i, (x, t) in enumerate(train_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            # forward
            model.train()
            model.zero_grad()
            out = model(x)

            # compute selective loss
            loss_dict = OrderedDict()
            # cross entropy
            ce_loss = torch.nn.CrossEntropyLoss()(out, t)
            #loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

            # total loss
            loss = ce_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_metric_dict.update(loss_dict)

        # validation
        for i, (x, t) in enumerate(val_loader):
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # adversarial attack
            if FLAGS.at and FLAGS.at_eps > 0:
                model.eval()
                model.zero_grad()
                x = attacker(model, x.detach(), t.detach())

            with torch.autograd.no_grad():
                # forward
                model.eval()
                model.zero_grad()
                out = model(x)

                # compute selective loss
                loss_dict = OrderedDict()
                # cross entropy
                ce_loss = torch.nn.CrossEntropyLoss()(out, t)
                #loss_dict['ce_loss'] = ce_loss.detach().cpu().item()

                # total loss
                loss = ce_loss
                loss_dict['loss'] = loss.detach().cpu().item()

                # evaluation
                evaluator = Evaluator(out.detach(),
                                      t.detach(),
                                      selection_out=None)
                loss_dict.update(evaluator())

                val_metric_dict.update(loss_dict)

        # post epoch
        # print_metric_dict(ep, FLAGS.num_epochs, train_metric_dict.avg, mode='train')
        print_metric_dict(ep,
                          FLAGS.num_epochs,
                          val_metric_dict.avg,
                          mode='val')

        train_logger.log(train_metric_dict.avg, step=(ep + 1))
        val_logger.log(val_metric_dict.avg, step=(ep + 1))

        scheduler.step()

    # post training
    save_model(model,
               path=os.path.join(FLAGS.log_dir,
                                 'weight_final{}.pth'.format(FLAGS.suffix)))
示例#6
0
def test(**kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**kwargs)
    FLAGS.summary()

    # dataset
    dataset_builder = DatasetBuilder(name=FLAGS.dataset,
                                     root_path=FLAGS.dataroot)
    test_dataset = dataset_builder(train=False, normalize=FLAGS.normalize)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=FLAGS.batch_size,
                                              shuffle=False,
                                              num_workers=FLAGS.num_workers,
                                              pin_memory=True)

    # model
    features = vgg16_variant(dataset_builder.input_size,
                             FLAGS.dropout_prob).cuda()
    model = SelectiveNet(features, FLAGS.dim_features,
                         dataset_builder.num_classes).cuda()
    load_model(model, FLAGS.weight)

    if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model)

    # loss
    base_loss = torch.nn.CrossEntropyLoss(reduction='none')
    SelectiveCELoss = SelectiveLoss(base_loss, coverage=FLAGS.coverage)

    # pre epoch
    test_metric_dict = MetricDict()

    # test
    with torch.autograd.no_grad():
        for i, (x, t) in enumerate(test_loader):
            model.eval()
            x = x.to('cuda', non_blocking=True)
            t = t.to('cuda', non_blocking=True)

            # forward
            out_class, out_select, out_aux = model(x)

            # compute selective loss
            loss_dict = OrderedDict()
            # loss dict includes, 'empirical_risk' / 'emprical_coverage' / 'penulty'
            selective_loss, loss_dict = SelectiveCELoss(
                out_class, out_select, t)
            selective_loss *= FLAGS.alpha
            loss_dict['selective_loss'] = selective_loss.detach().cpu().item()
            # compute standard cross entropy loss
            ce_loss = torch.nn.CrossEntropyLoss()(out_aux, t)
            ce_loss *= (1.0 - FLAGS.alpha)
            loss_dict['ce_loss'] = ce_loss.detach().cpu().item()
            # total loss
            loss = selective_loss + ce_loss
            loss_dict['loss'] = loss.detach().cpu().item()

            # evaluation
            evaluator = Evaluator(out_class.detach(), t.detach(),
                                  out_select.detach())
            loss_dict.update(evaluator())

            test_metric_dict.update(loss_dict)

    # post epoch
    print_metric_dict(None, None, test_metric_dict.avg, mode='test')

    return test_metric_dict.avg