示例#1
0
    def __init__(self,
                 disc_classifier,
                 rep_size=64,
                 n_classes=10,
                 mi_units=64,
                 margin=5,
                 alpha=0.6,
                 beta=0.2,
                 gamma=0.2):
        super().__init__()
        self.disc_classifier = disc_classifier  #.half()  # Use half-precision for saving memory and time.
        self.disc_classifier.requires_grad_(
            requires_grad=False)  # shut down grad on pre-trained classifier.
        # self.disc_classifier.eval()  # set to eval mode.

        self.rep_size = rep_size
        self.n_classes = n_classes
        self.mi_units = mi_units
        self.margin = margin
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        self.feature_transformer = MLP(self.n_classes, self.rep_size)

        # 1x1 conv performed on only channel dimension
        self.local_MInet = MI1x1ConvNet(self.n_classes, self.mi_units)
        self.global_MInet = MI1x1ConvNet(self.rep_size, self.mi_units)

        self.class_conditional = ClassConditionalGaussianMixture(
            self.n_classes, self.rep_size)

        n_feat_transformer = cal_parameters(self.feature_transformer)
        n_local = cal_parameters(self.local_MInet)
        n_global = cal_parameters(self.global_MInet)
        n_class_conditional = cal_parameters(self.class_conditional)

        n_additional = n_feat_transformer + n_local + n_global + n_class_conditional

        self.cross_entropy = nn.CrossEntropyLoss()

        print('==>  # Model parameters.')
        print('==>  # discriminative classifier parameters: {}.'.format(
            cal_parameters(self.disc_classifier)))
        print('==>  # additional parameters: {}.'.format(n_additional))

        print('==>  # FeatureTransformer parameters: {}.'.format(
            n_feat_transformer))
        print('==>  # T parameters: {}.'.format(n_local + n_global))
        print('==>  # class conditional parameters: {}.'.format(
            n_class_conditional))
def run(args: DictConfig) -> None:
    assert torch.cuda.is_available()
    torch.manual_seed(args.seed)

    n_classes = args.get(args.dataset).n_classes
    classifier = resnet18(n_classes=n_classes).to(args.device)
    logger.info('Base classifier resnet18: # parameters {}'.format(
        cal_parameters(classifier)))

    data_dir = hydra.utils.to_absolute_path(args.data_dir)
    train_data = get_dataset(data_name=args.dataset,
                             data_dir=data_dir,
                             train=True,
                             crop_flip=True)
    test_data = get_dataset(data_name=args.dataset,
                            data_dir=data_dir,
                            train=False,
                            crop_flip=False)

    train_loader = DataLoader(dataset=train_data,
                              batch_size=args.n_batch_train,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_data,
                             batch_size=args.n_batch_test,
                             shuffle=False)

    if args.inference:
        save_name = 'resnet18_wd{}.pth'.format(args.weight_decay)
        classifier.load_state_dict(
            torch.load(save_name, map_location=lambda storage, loc: storage))
        loss, acc = run_epoch(classifier, test_loader, args)
        logger.info('Inference, test loss: {:.4f}, Acc: {:.4f}'.format(
            loss, acc))
    else:
        train(classifier, train_loader, test_loader, args)
示例#3
0
def get_model(model_name='resnext50_32x4d'):
    if model_name == 'resnext101_32x8d':
        m = models.resnext101_32x8d(pretrained=True)
    elif model_name == 'resnext50_32x4d':
        m = models.resnext50_32x4d(pretrained=True)
    print('Model name: {}, # parameters: {}'.format(model_name, cal_parameters(m)))
    return m
    def desc(self):
        """
        Description of this model.
        :return: tuple of descriptions of SDIM components.
        """
        n_fixed = cal_parameters(self.disc_classifier,
                                 filter_func=lambda x: not x.requires_grad)
        n_trainable = cal_parameters(self.disc_classifier,
                                     filter_func=lambda x: x.requires_grad)
        n_T = cal_parameters(self.local_MInet) + cal_parameters(
            self.global_MInet)
        n_C = cal_parameters(self.class_conditional)

        base_desc = 'Base classifier, # fixed parameters: {}, # trainable parameters: {}'.format(
            n_fixed, n_trainable)
        T_desc = 'MI evaluation network, #parameters: {}.'.format(n_T)
        class_con_desc = 'Class conditional embedding layer, #parameters: {}.'.format(
            n_C)
        return base_desc, T_desc, class_con_desc
def run(args: DictConfig) -> None:
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    data_dir = hydra.utils.to_absolute_path(args.data_dir)

    clean_test_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=False, crop_flip=False)
    #advset_at = TensorDataset(torch.load(os.path.join(data_dir, 'advset_{}_at_fast.pt'.format(args.classifier_name))))
    advset_clean = torch.load(os.path.join(data_dir, 'advset_{}_clean.pt'.format(args.classifier_name)))

    clean_loader = DataLoader(dataset=clean_test_data, batch_size=args.n_batch_test, shuffle=False)
    advset_loader = DataLoader(dataset=advset_clean, batch_size=args.n_batch_test, shuffle=False)

    results_dict = dict()
    for width in args.width_list:
        classifier_list = []
        for split_id in range(args.n_split):
            classifier = eval(args.classifier_name)(width, args.n_classes).to(args.device)
            logger.info('Classifier: {}, width: {}, # parameters: {}'
                        .format(args.classifier_name, width, cal_parameters(classifier)))
            checkpoint = '{}_w{}_split{}.pth'.format(args.classifier_name, width, split_id)
            classifier.load_state_dict(torch.load(checkpoint))
            classifier_list.append(classifier)

        results_dict['clean_on_clean_w{}'.format(width)] = eval_risk_bias_variance(classifier_list, clean_loader, args)
        results_dict['clean_on_adv_w{}'.format(width)] = eval_risk_bias_variance(classifier_list, advset_loader, args)

        del classifier_list

        classifier_list = []
        for split_id in range(args.n_split):
            classifier = eval(args.classifier_name)(width, args.n_classes).to(args.device)
            logger.info('Classifier: {}, width: {}, # parameters: {}'
                        .format(args.classifier_name, width, cal_parameters(classifier)))
            checkpoint = '{}_w{}_split{}_at_fast.pth'.format(args.classifier_name, width, split_id)
            classifier.load_state_dict(torch.load(checkpoint))
            classifier_list.append(classifier)

        results_dict['adv_on_clean_w{}'.format(width)] = eval_risk_bias_variance(classifier_list, clean_loader, args)
        results_dict['adv_on_adv_w{}'.format(width)] = eval_risk_bias_variance(classifier_list, advset_loader, args)

    torch.save(results_dict, 'adv_eval_width_results.pt')
def run(args: DictConfig) -> None:
    cuda_available = torch.cuda.is_available()
    torch.manual_seed(args.seed)
    device = "cuda" if cuda_available and args.device == 'cuda' else "cpu"

    n_classes = args.get(args.dataset).n_classes
    if args.dataset == 'tiny_imagenet':
        args.epochs = 20
        args.learning_rate = 0.001
        classifier = get_model_for_tiny_imagenet(args.classifier_name,
                                                 n_classes).to(device)
        args.data_dir = 'tiny_imagenet'

    else:
        classifier = get_model(name=args.classifier_name,
                               n_classes=n_classes).to(device)

    # if device == 'cuda' and args.n_gpu > 1:
    #     classifier = torch.nn.DataParallel(classifier, device_ids=list(range(args.n_gpu)))

    logger.info('Base classifier name: {}, # parameters: {}'.format(
        args.classifier_name, cal_parameters(classifier)))

    data_dir = hydra.utils.to_absolute_path(args.data_dir)
    train_data = get_dataset(data_name=args.dataset,
                             data_dir=data_dir,
                             train=True,
                             crop_flip=True)
    test_data = get_dataset(data_name=args.dataset,
                            data_dir=data_dir,
                            train=False,
                            crop_flip=False)

    train_loader = DataLoader(dataset=train_data,
                              batch_size=args.n_batch_train,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_data,
                             batch_size=args.n_batch_test,
                             shuffle=False)

    if args.inference:
        save_name = '{}.pth'.format(args.classifier_name)
        classifier.load_state_dict(
            torch.load(save_name, map_location=lambda storage, loc: storage))
        loss, acc = run_epoch(classifier, test_loader, args)
        logger.info('Inference loss: {:.4f}, acc: {:.4f}'.format(loss, acc))
    else:
        train(classifier, train_loader, test_loader, args)
示例#7
0
def inference(hps: DictConfig) -> None:
    # This enables a ctr-C without triggering errors
    import signal

    signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))

    logger = logging.getLogger(__name__)

    cuda_available = torch.cuda.is_available()

    torch.manual_seed(hps.seed)

    device = "cuda" if cuda_available and hps.device == 'cuda' else "cpu"

    # Models
    local_channel = hps.get(hps.base_classifier).last_conv_channel
    classifier = get_model(model_name=hps.base_classifier,
                           in_size=local_channel,
                           out_size=hps.rep_size).to(hps.device)
    logger.info('Base classifier name: {}, # parameters: {}'.format(
        hps.base_classifier, cal_parameters(classifier)))

    sdim = SDIM(disc_classifier=classifier,
                mi_units=hps.mi_units,
                n_classes=hps.n_classes,
                margin=hps.margin,
                rep_size=hps.rep_size,
                local_channel=local_channel).to(hps.device)

    model_path = 'SDIM_{}.pth'.format(hps.base_classifier)
    base_dir = '/userhome/cs/u3003679/generative-classification-with-rejection'
    path = os.path.join(base_dir, model_path)
    sdim.load_state_dict(torch.load(path)['model_state'])

    # logging the SDIM desc.
    for desc in sdim.desc():
        logger.info(desc)

    eval_loader = Loader('eval', batch_size=hps.n_batch_test, device=device)

    if cuda_available and hps.n_gpu > 1:
        sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu)))

    torch.manual_seed(hps.seed)
    np.random.seed(hps.seed)

    n_iters = 0

    top1 = AverageMeter('Acc@1')
    top5 = AverageMeter('Acc@5')

    sdim.eval()
    for x, y in eval_loader:
        n_iters += 1
        if n_iters == len(eval_loader):
            break

        with torch.no_grad():
            log_lik = sdim.infer(x)

        acc1, acc5 = accuracy(log_lik, y, topk=(1, 5))

        top1.update(acc1, x.size(0))
        top5.update(acc5, x.size(0))

    logger.info('Test Acc@1: {:.3f}, Acc@5: {:.3f}'.format(top1.avg, top5.avg))
示例#8
0
    # prepare test set
    pair_filename = os.path.join(
        args.test_set, 'm50_{}_{}_0.txt'.format(args.n_samples_test,
                                                args.n_samples_test))
    pairs, labels = ReadPairs(pair_filename)
    test_set = CustomDataset(pairs, labels, args.test_set, transform=None)
    test_loader = DataLoader(test_set,
                             batch_size=100,
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True,
                             drop_last=True)

    model = ComposedModel(residual=args.res_feature_net).cuda()

    print('Model parameters: {}'.format(cal_parameters(model)))
    if args.eval:
        if args.res_feature_net:
            state_dict = torch.load('res_model_{}.pt'.format(args.train_set))
        else:
            state_dict = torch.load('model_{}.pt'.format(args.train_set))
        model.load_state_dict(state_dict)
        model.eval()

        score_list = []
        label_list = []
        for idx, (left, right, label) in enumerate(test_loader):
            left = preprocess(left).cuda()
            right = preprocess(right).cuda()
            label = label.cuda()
示例#9
0
def run(args: DictConfig) -> None:
    # Load datasets
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4)
    ])
    preprocess = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([mean_] * 3, [std_] * 3)])
    test_transform = preprocess

    data_dir = hydra.utils.to_absolute_path(args.data_dir)
    if args.dataset == 'cifar10':
        train_data = datasets.CIFAR10(data_dir,
                                      train=True,
                                      transform=train_transform,
                                      download=True)
        test_data = datasets.CIFAR10(data_dir,
                                     train=False,
                                     transform=test_transform,
                                     download=True)
        base_c_path = os.path.join(data_dir, 'CIFAR-10-C/')
        args.n_classes = 10
    else:
        train_data = datasets.CIFAR100(data_dir,
                                       train=True,
                                       transform=train_transform,
                                       download=True)
        test_data = datasets.CIFAR100(data_dir,
                                      train=False,
                                      transform=test_transform,
                                      download=True)

        base_c_path = os.path.join(data_dir, 'CIFAR-100-C/')
        args.n_classes = 100

    train_data = AugMixDataset(train_data, preprocess, args, args.no_jsd)
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)

    test_loader = DataLoader(test_data,
                             batch_size=args.eval_batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             pin_memory=True)

    # Create model
    if args.model == 'densenet':
        classifier = densenet(num_classes=args.n_classes)
    elif args.model == 'wide_resnet':
        n_layers = 40
        widen_factor = 2
        droprate = 0.
        classifier = WideResNet(n_layers, args.n_classes, widen_factor,
                                droprate)
    elif args.model == 'resnext':
        classifier = resnext29(num_classes=args.n_classes)

    classifier = classifier.to(args.device)
    logger.info('Model: {}, # parameters: {}'.format(
        args.model, cal_parameters(classifier)))

    cudnn.benchmark = True
    classifier = torch.nn.DataParallel(classifier).to(args.device)

    if args.inference:
        classifier.load_state_dict(
            torch.load('{}_{}.pth'.format(args.model, args.augmentation_type)))
        test_loss, test_acc = eval_epoch(classifier,
                                         test_loader,
                                         args,
                                         adversarial=False)
        logger.info('Clean Test CE:{:.4f}, acc:{:.4f}'.format(
            test_loss, test_acc))
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)

        best_acc = 0
        pre_adv_acc = 0
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
                step,
                args.n_epochs * len(train_loader),
                1,  # lr_lambda computes multiplicative factor
                1e-6 / args.learning_rate))

        for epoch in range(args.n_epochs):
            if args.augmentation_type:
                loss, ce_loss, js_loss, acc = train_epoch_advmix(
                    classifier, train_loader, args, optimizer, scheduler)
            else:
                loss, ce_loss, js_loss, acc = train_epoch(
                    classifier, train_loader, args, optimizer, scheduler)

            lr = scheduler.get_lr()[0]
            logger.info(
                'Epoch {}, lr:{:.4f}, loss:{:.4f}, CE:{:.4f}, JS:{:.4f}, Acc:{:.4f}'
                .format(epoch + 1, lr, loss, ce_loss, js_loss, acc))

            test_loss, test_acc = eval_epoch(classifier,
                                             test_loader,
                                             args,
                                             adversarial=False)
            logger.info('Test CE:{:.4f}, acc:{:.4f}'.format(
                test_loss, test_acc))

            adv_loss, adv_acc = eval_epoch(classifier,
                                           test_loader,
                                           args,
                                           adversarial=True)
            logger.info('Adversarial evaluation, CE:{:.4f}, acc:{:.4f}'.format(
                adv_loss, adv_acc))

            if test_acc > best_acc:
                best_acc = test_acc
                if adv_acc + 0.1 < pre_adv_acc:
                    pre_adv_acc = adv_acc
                    logger.info(
                        "Catastrophic overfitting happens, early stopping")
                    break
                logging.info('===> New optimal, save checkpoint ...')
                torch.save(
                    classifier.state_dict(),
                    '{}_{}.pth'.format(args.model, args.augmentation_type))

    test_c_acc = eval_c(classifier, base_c_path, args)
    logger.info('Mean Corruption Error:{:.4f}'.format(test_c_acc))
示例#10
0
def run(args: DictConfig) -> None:
    assert torch.cuda.is_available()
    torch.manual_seed(args.seed)

    n_classes = args.get(args.dataset).n_classes
    classifier = resnet18(n_classes=n_classes).to(args.device)
    logger.info('Base classifier resnet18: # parameters {}'.format(
        cal_parameters(classifier)))

    data_dir = hydra.utils.to_absolute_path(args.data_dir)
    train_data = get_dataset(data_name=args.dataset,
                             data_dir=data_dir,
                             train=True,
                             crop_flip=True)
    test_data = get_dataset(data_name=args.dataset,
                            data_dir=data_dir,
                            train=False,
                            crop_flip=False)

    train_loader = DataLoader(dataset=train_data,
                              batch_size=args.n_batch_train,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_data,
                             batch_size=args.n_batch_test,
                             shuffle=False)

    if args.inference is True:
        classifier.load_state_dict(
            torch.load('resnet18_wd{}.pth'.format(args.weight_decay)))
        logger.info('Load classifier from checkpoint')
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    args.lr_max,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)

        lr_steps = args.epochs * len(train_loader)
        if args.lr_schedule == 'cyclic':
            scheduler = torch.optim.lr_scheduler.CyclicLR(
                optimizer,
                base_lr=args.lr_min,
                max_lr=args.lr_max,
                step_size_up=lr_steps / 2,
                step_size_down=lr_steps / 2)
        elif args.lr_schedule == 'multistep':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=[lr_steps / 2, lr_steps * 3 / 4],
                gamma=0.1)
        else:
            raise Exception("scheduler not implemented.")

        optimal_loss = 1e5
        for epoch in range(1, args.epochs + 1):
            loss, acc = train_epoch(classifier, train_loader, args, optimizer,
                                    scheduler)
            lr = scheduler.get_lr()[0]
            logger.info('Epoch {}, lr:{:.4f}, loss:{:.4f}, Acc:{:.4f}'.format(
                epoch, lr, loss, acc))

            if loss < optimal_loss:
                optimal_loss = loss

                torch.save(classifier.state_dict(), 'resnet18_at.pth')

    clean_loss, clean_acc = eval_epoch(classifier,
                                       test_loader,
                                       args,
                                       adversarial=False)
    logger.info('Clean loss: {:.4f}, acc: {:.4f}'.format(
        clean_loss, clean_acc))
    adv_loss, adv_acc = eval_epoch(classifier,
                                   test_loader,
                                   args,
                                   adversarial=True)
    logger.info('Adversarial loss: {:.4f}, acc: {:.4f}'.format(
        adv_loss, adv_acc))
    else:
        train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
        test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
        args.n_classes = 100
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.prefetch, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False,
                                              num_workers=args.prefetch, pin_memory=True)

    # Init checkpoints
    if not os.path.isdir(args.save):
        os.makedirs(args.save)

    # Init model, criterion, and optimizer
    classifier = CifarResNeXt(args.cardinality, args.depth, args.n_classes, args.base_width, args.widen_factor).to(args.device)
    print('# Classifier parameters: ', cal_parameters(classifier))

    save_name = 'ResNeXt{}_{}x{}d.pth'.format(args.depth, args.cardinality, args.base_width)
    check_point = torch.load(os.path.join(args.save, save_name))
    classifier.load_state_dict(check_point['model_state'])

    train_acc = check_point['train_acc']
    test_acc = check_point['test_acc']
    print('Original Discriminative Classifier, train acc: {:.4f}, test acc: {:.4f}'.format(train_acc, test_acc))

    sdim = SDIM(disc_classifier=classifier, rep_size=args.rep_size, mi_units=args.mi_units, n_classes=args.n_classes).to(args.device)

    optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad is True, sdim.parameters()),
                                 lr=args.learning_rate)

    if use_cuda and args.n_gpu > 1:
def run(args: DictConfig) -> None:
    # cuda_available = torch.cuda.is_available()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    # device = "cuda" if cuda_available and args.device == 'cuda' else "cpu"

    classifier = eval(args.classifier_name)(args.width,
                                            args.n_classes).to(args.device)
    logger.info('Classifier: {}, width: {}, # parameters: {}'.format(
        args.classifier_name, args.width, cal_parameters(classifier)))

    data_dir = hydra.utils.to_absolute_path(args.data_dir)
    train_data = get_dataset(data_name=args.dataset,
                             data_dir=data_dir,
                             train=True,
                             crop_flip=True)
    test_data = get_dataset(data_name=args.dataset,
                            data_dir=data_dir,
                            train=False,
                            crop_flip=False)

    test_loader = DataLoader(dataset=test_data,
                             batch_size=args.n_batch_test,
                             shuffle=False)

    optimizer = SGD(classifier.parameters(),
                    lr=args.lr_max,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)

    def run_forward(scheduler):
        optimal_loss = 1e5
        for epoch in range(1, args.n_epochs + 1):
            loss, acc = train_epoch(classifier,
                                    train_loader,
                                    args,
                                    optimizer,
                                    scheduler=scheduler)
            if loss < optimal_loss:
                optimal_loss = loss
                torch.save(classifier.state_dict(), checkpoint)
            logger.info(
                'Epoch {}, lr: {:.4f}, loss: {:.4f}, acc: {:.4f}'.format(
                    epoch,
                    scheduler.get_lr()[0], loss, acc))

    if args.adv_generation:
        checkpoint = '{}_w{}_at_fast.pth'.format(args.classifier_name,
                                                 args.width)
        train_loader = DataLoader(dataset=train_data,
                                  batch_size=args.n_batch_train,
                                  shuffle=True)
        lr_steps = args.n_epochs * len(train_loader)
        scheduler = lr_scheduler.CyclicLR(optimizer,
                                          base_lr=args.lr_min,
                                          max_lr=args.lr_max,
                                          step_size_up=lr_steps / 2,
                                          step_size_down=lr_steps / 2)

        run_forward(scheduler)

        clean_loss, clean_acc = eval_epoch(classifier,
                                           test_loader,
                                           args,
                                           adversarial=False)
        adv_loss, adv_acc = eval_epoch(classifier,
                                       test_loader,
                                       args,
                                       adversarial=True,
                                       save=True)
        logger.info('Clean loss: {:.4f}, acc: {:.4f}'.format(
            clean_loss, clean_acc))
        logger.info('Adversarial loss: {:.4f}, acc: {:.4f}'.format(
            adv_loss, adv_acc))

    else:
        n = len(train_data)
        split_size = n // args.n_split
        lengths = [split_size] * (args.n_split - 1) + [
            n % split_size + split_size
        ]
        datasets_list = random_split(train_data, lengths=lengths)

        for split_id, dataset in enumerate(datasets_list):
            checkpoint = '{}_w{}_split{}_at_fast.pth'.format(
                args.classifier_name, args.width, split_id)
            logger.info('Running on subset {}, size: {}'.format(
                split_id + 1, len(dataset)))
            train_loader = DataLoader(dataset=dataset,
                                      batch_size=args.n_batch_train,
                                      shuffle=True)

            lr_steps = args.n_epochs * len(train_loader)
            scheduler = lr_scheduler.CyclicLR(optimizer,
                                              base_lr=args.lr_min,
                                              max_lr=args.lr_max,
                                              step_size_up=lr_steps / 2,
                                              step_size_down=lr_steps / 2)

            run_forward(scheduler)

            clean_loss, clean_acc = eval_epoch(classifier,
                                               test_loader,
                                               args,
                                               adversarial=False)
            adv_loss, adv_acc = eval_epoch(classifier,
                                           test_loader,
                                           args,
                                           adversarial=True)
            logger.info('Clean loss: {:.4f}, acc: {:.4f}'.format(
                clean_loss, clean_acc))
            logger.info('Adversarial loss: {:.4f}, acc: {:.4f}'.format(
                adv_loss, adv_acc))
def run(args: DictConfig) -> None:
    cuda_available = torch.cuda.is_available()
    torch.manual_seed(args.seed)
    device = "cuda" if cuda_available and args.device == 'cuda' else "cpu"

    # n_classes = args.n_classes
    # classifier = get_model(name=args.classifier_name, n_classes=n_classes).to(device)
    classifier = PreActResNet18().to(device)
    logger.info('Classifier: {}, # parameters: {}'.format(
        args.classifier_name, cal_parameters(classifier)))

    data_dir = hydra.utils.to_absolute_path(args.data_dir)
    train_data = get_dataset(data_name=args.dataset,
                             data_dir=data_dir,
                             train=True,
                             crop_flip=True)
    test_data = get_dataset(data_name=args.dataset,
                            data_dir=data_dir,
                            train=False,
                            crop_flip=False)

    train_loader = DataLoader(dataset=train_data,
                              batch_size=args.n_batch_train,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_data,
                             batch_size=args.n_batch_test,
                             shuffle=False)

    if args.inference is True:
        classifier.load_state_dict(
            torch.load('{}_at.pth'.format(args.classifier_name)))
        logger.info('Load classifier from checkpoint')
    else:
        # optimizer = SGD(classifier.parameters(), lr=args.lr_max, momentum=args.momentum, weight_decay=args.weight_decay)
        # lr_steps = args.n_epochs * len(train_loader)
        # scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=args.lr_min, max_lr=args.lr_max,
        #                                   step_size_up=lr_steps/2, step_size_down=lr_steps/2)
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)

        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
                step,
                args.n_epochs * len(train_loader),
                1,  # lr_lambda computes multiplicative factor
                1e-6 / args.learning_rate))

        optimal_loss = 1e5
        for epoch in range(1, args.n_epochs + 1):
            loss, acc = train_epoch(classifier,
                                    train_loader,
                                    args,
                                    optimizer,
                                    scheduler=scheduler)
            lr = scheduler.get_lr()[0]
            logger.info('Epoch {}, lr:{:.4f}, loss:{:.4f}, Acc:{:.4f}'.format(
                epoch, lr, loss, acc))

            if loss < optimal_loss:
                optimal_loss = loss
                torch.save(classifier.state_dict(),
                           '{}_at.pth'.format(args.classifier_name))

    clean_loss, clean_acc = eval_epoch(classifier,
                                       test_loader,
                                       args,
                                       adversarial=False)
    adv_loss, adv_acc = eval_epoch(classifier,
                                   test_loader,
                                   args,
                                   adversarial=True)
    logger.info('Clean loss: {:.4f}, acc: {:.4f}'.format(
        clean_loss, clean_acc))
    logger.info('Adversarial loss: {:.4f}, acc: {:.4f}'.format(
        adv_loss, adv_acc))
                       map_location=lambda storage, loc: storage))
    else:
        n_encoder_layers = int(hps.encoder_name.strip('resnet'))
        model = build_resnet_32x32(n=n_encoder_layers,
                                   fc_size=hps.n_classes,
                                   image_channel=hps.image_channel).to(
                                       hps.device)

        checkpoint_path = os.path.join(
            hps.log_dir, '{}_{}.pth'.format(hps.encoder_name, hps.problem))
        model.load_state_dict(
            torch.load(checkpoint_path,
                       map_location=lambda storage, loc: storage))

    print('Model name: {}'.format(hps.encoder_name))
    print('==>  # Model parameters: {}.'.format(cal_parameters(model)))

    if not os.path.exists(hps.log_dir):
        os.mkdir(hps.log_dir)

    if not os.path.exists(hps.attack_dir):
        os.mkdir(hps.attack_dir)

    if hps.attack == 'pgdinf':
        linfPGD_attack(model, hps)
    elif hps.attack == 'jsma':
        jsma_attack(model, hps)
    elif hps.attack == 'cw':
        cw_l2_attack(model, hps)
    elif hps.attack == 'fgsm':
        fgsm_attack(model, hps)
示例#15
0
def train(hps: DictConfig) -> None:
    # This enables a ctr-C without triggering errors
    import signal

    signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))

    logger = logging.getLogger(__name__)

    cuda_available = torch.cuda.is_available()

    torch.manual_seed(hps.seed)

    device = "cuda" if cuda_available and hps.device == 'cuda' else "cpu"

    # Models
    local_channel = hps.get(hps.base_classifier).last_conv_channel
    classifier = get_model(model_name=hps.base_classifier,
                           in_size=local_channel,
                           out_size=hps.rep_size).to(hps.device)
    logger.info('Base classifier name: {}, # parameters: {}'.format(
        hps.base_classifier, cal_parameters(classifier)))

    sdim = SDIM(disc_classifier=classifier,
                mi_units=hps.mi_units,
                n_classes=hps.n_classes,
                margin=hps.margin,
                rep_size=hps.rep_size,
                local_channel=local_channel).to(hps.device)

    # logging the SDIM desc.
    for desc in sdim.desc():
        logger.info(desc)

    train_loader = Loader('train', batch_size=hps.n_batch_train, device=device)

    if cuda_available and hps.n_gpu > 1:
        sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu)))

    optimizer = Adam(filter(lambda param: param.requires_grad is True,
                            sdim.parameters()),
                     lr=hps.lr)

    torch.manual_seed(hps.seed)
    np.random.seed(hps.seed)

    # Create log dir
    logdir = os.path.abspath(hps.log_dir) + "/"
    if not os.path.exists(logdir):
        os.mkdir(logdir)

    loss_optimal = 1e5
    n_iters = 0

    losses = AverageMeter('Loss')
    MIs = AverageMeter('MI')
    nlls = AverageMeter('NLL')
    margins = AverageMeter('Margin')
    top1 = AverageMeter('Acc@1')
    top5 = AverageMeter('Acc@5')

    for x, y in train_loader:
        n_iters += 1
        if n_iters == hps.training_iters:
            break

        # backward
        optimizer.zero_grad()
        loss, mi_loss, nll_loss, ll_margin, log_lik = sdim(x, y)
        loss.mean().backward()
        optimizer.step()

        acc1, acc5 = accuracy(log_lik, y, topk=(1, 5))
        losses.update(loss.item(), x.size(0))
        top1.update(acc1, x.size(0))
        top5.update(acc5, x.size(0))

        MIs.update(mi_loss.item(), x.size(0))
        nlls.update(nll_loss.item(), x.size(0))
        margins.update(ll_margin.item(), x.size(0))

        if n_iters % hps.log_interval == hps.log_interval - 1:
            logger.info(
                'Train loss: {:.4f}, mi: {:.4f}, nll: {:.4f}, ll_margin: {:.4f}'
                .format(losses.avg, MIs.avg, nlls.avg, margins.avg))
            logger.info('Train Acc@1: {:.3f}, Acc@5: {:.3f}'.format(
                top1.avg, top5.avg))

            if losses.avg < loss_optimal:
                loss_optimal = losses.avg
                model_path = 'SDIM_{}.pth'.format(hps.base_classifier)

                if cuda_available and hps.n_gpu > 1:
                    state = sdim.module.state_dict()
                else:
                    state = sdim.state_dict()

                check_point = {
                    'model_state': state,
                    'train_acc_top1': top1.avg,
                    'train_acc_top5': top5.avg
                }

                torch.save(check_point, os.path.join(hps.log_dir, model_path))

            losses.reset()
            MIs.reset()
            nlls.reset()
            margins.reset()
            top1.reset()
            top5.reset()
示例#16
0
                                               num_workers=args.prefetch,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              num_workers=args.prefetch,
                                              pin_memory=True)

    # Init checkpoints
    if not os.path.isdir(args.save):
        os.makedirs(args.save)

    # Init model, criterion, and optimizer
    net = CifarResNeXt(args.cardinality, args.depth, n_classes,
                       args.base_width, args.widen_factor).to(args.device)
    print('# Classifier parameters: ', cal_parameters(net))

    if use_cuda and args.n_gpu > 1:
        net = torch.nn.DataParallel(net, device_ids=list(range(args.n_gpu)))

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.decay,
                                nesterov=True)

    best_train_loss = np.inf
    best_accuracy = 0.

    # train function (forward, backward, update)
    def train():
示例#17
0
def run(args: DictConfig) -> None:
    # Load datasets
    train_transform = transforms.Compose(
        [transforms.RandomHorizontalFlip(),
         transforms.RandomCrop(32, padding=4)])

    preprocess = transforms.ToTensor()
    test_transform = preprocess

    data_dir = hydra.utils.to_absolute_path(args.data_dir)
    if args.dataset == 'cifar10':
        train_data = datasets.CIFAR10(
            data_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR10(
            data_dir, train=False, transform=test_transform, download=True)
        base_c_path = os.path.join(data_dir, 'CIFAR-10-C/')
        # args.n_classes = 10
    else:
        train_data = datasets.CIFAR100(
            data_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR100(
            data_dir, train=False, transform=test_transform, download=True)

        base_c_path = os.path.join(data_dir, 'CIFAR-100-C/')
        # args.n_classes = 100

    train_data = AugMixDataset(train_data, preprocess, args, args.no_jsd)
    train_loader = DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    test_loader = DataLoader(
        test_data,
        batch_size=args.eval_batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True)

    n_classes = args.get(args.dataset).n_classes
    classifier = resnet18(n_classes=n_classes).to(args.device)
    logger.info('Model resnet18, # parameters: {}'.format(cal_parameters(classifier)))

    cudnn.benchmark = True

    if args.inference:
        classifier.load_state_dict(torch.load('resnet18_c.pth'))
        test_loss, test_acc = eval_epoch(classifier, test_loader, args)
        logger.info('Clean Test CE:{:.4f}, acc:{:.4f}'.format(test_loss, test_acc))
    else:
        optimizer = torch.optim.SGD(
            classifier.parameters(),
            args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=True)

        best_loss = 1e5
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
                step,
                args.epochs * len(train_loader),
                1,  # lr_lambda computes multiplicative factor
                1e-6 / args.learning_rate))

        for epoch in range(args.epochs):
            loss, ce_loss, js_loss, acc = train_epoch(classifier, train_loader,  args, optimizer, scheduler)

            lr = scheduler.get_lr()[0]
            logger.info('Epoch {}, lr:{:.4f}, loss:{:.4f}, CE:{:.4f}, JS:{:.4f}, Acc:{:.4f}'
                        .format(epoch + 1, lr, loss, ce_loss, js_loss, acc))

            test_loss, test_acc = eval_epoch(classifier, test_loader, args)
            logger.info('Clean test CE:{:.4f}, acc:{:.4f}'.format(test_loss, test_acc))

            if loss < best_loss:
                best_loss = loss
                logging.info('===> New optimal, save checkpoint ...')
                torch.save(classifier.state_dict(), 'resnet18_c.pth')

    test_c_acc = eval_c(classifier, base_c_path, args)
    logger.info('Mean Corruption Error:{:.4f}'.format(1 - test_c_acc))
示例#18
0
        )

    def forward(self, x):
        return self.fc(x)


class Projection(nn.Module):
    def __init__(self, in_dim=4096, hidden_size=1024):
        super(MetricNet, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(in_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size)
        )

    def forward(self, x):
        return self.fc(x)


if __name__ == "__main__":
    x = torch.randn(1, 1, 64, 64)
    # m = FeatureNet()
    m = ResFeatureNet()
    o = m(x)
    print(o.size())
    from utils import cal_parameters
    print(cal_parameters(m))