コード例 #1
0
def run(args: DictConfig) -> None:
    cuda_available = torch.cuda.is_available()
    torch.manual_seed(args.seed)

    device = "cuda" if cuda_available and args.device == 'cuda' else "cpu"

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    classifier = load_pretrained_model(args)
    if args.dataset == 'tiny_imagenet':
        args.data_dir = 'tiny_imagenet'

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin,
                alpha=args.alpha,
                beta=args.beta,
                gamma=args.gamma).to(args.device)

    optimizer = Adam(sdim.parameters(), lr=args.learning_rate)

    if args.inference:
        save_name = 'SDIM_{}.pth'.format(args.classifier_name)
        sdim.load_state_dict(
            torch.load(save_name, map_location=lambda storage, loc: storage))

        thresholds1, thresholds2 = extract_thresholds(sdim, args)
        clean_eval(sdim, args, thresholds1, thresholds2)
    else:
        train(sdim, optimizer, args)
コード例 #2
0
def run(args: DictConfig) -> None:
    assert torch.cuda.is_available()
    torch.manual_seed(args.seed)

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    classifier = load_pretrained_model(args)

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin,
                alpha=args.alpha,
                beta=args.beta,
                gamma=args.gamma).to(args.device)

    optimizer = Adam(sdim.parameters(), lr=args.learning_rate)

    if args.inference:
        save_name = 'SDIM_resnet18{}.pth'.format(suffix_dict[args.base_type])
        sdim.load_state_dict(
            torch.load(save_name, map_location=lambda storage, loc: storage))

        thresholds1, thresholds2 = extract_thresholds(sdim, args)
        clean_eval(sdim, args, thresholds1, thresholds2)
    else:
        train(sdim, optimizer, args)
コード例 #3
0
        hps.image_channel = 3
    elif hps.problem == 'mnist':
        hps.image_channel = 1
    elif hps.problem == 'fashion':
        hps.image_channel = 1

    model = SDIM(rep_size=hps.rep_size,
                 mi_units=hps.mi_units,
                 encoder_name=hps.encoder_name,
                 image_channel=hps.image_channel,
                 margin=hps.margin,
                 alpha=hps.alpha,
                 beta=hps.beta,
                 gamma=hps.gamma
                 ).to(hps.device)
    optimizer = Adam(model.parameters(), lr=hps.lr)

    print('==>  # Model parameters: {}.'.format(cal_parameters(model)))

    if hps.noise_attack:
        noise_attack(model, hps)
    elif hps.inference:
        inference(model, hps)
    elif hps.ood_inference:
        ood_inference(model, hps)
    elif hps.rejection_inference:
        inference_rejection(model, hps)
    elif hps.noise_ood_inference:
        noise_ood_inference(model, hps)
    else:
        train(model, optimizer, hps)
コード例 #4
0
                             pin_memory=True, num_workers=16)

    # Models
    print('Classifier name: {}'.format(hps.classifier_name))
    classifier = get_model(model_name=hps.classifier_name).to(hps.device)

    sdim = SDIM(disc_classifier=classifier,
                rep_size=hps.rep_size,
                mi_units=hps.mi_units,
                n_classes=hps.n_classes,
                margin=3).to(hps.device)

    if use_cuda and hps.n_gpu > 1:
        sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu)))

    optimizer = Adam(filter(lambda param: param.requires_grad is True, sdim.parameters()), lr=hps.lr)

    torch.manual_seed(hps.seed)
    np.random.seed(hps.seed)

    # Create log dir
    logdir = os.path.abspath(hps.log_dir) + "/"
    if not os.path.exists(logdir):
        os.mkdir(logdir)


    def train_epoch():
        """
        One epoch training.
        """
        sdim.train()
    # Init model, criterion, and optimizer
    classifier = CifarResNeXt(args.cardinality, args.depth, args.n_classes, args.base_width, args.widen_factor).to(args.device)
    print('# Classifier parameters: ', cal_parameters(classifier))

    save_name = 'ResNeXt{}_{}x{}d.pth'.format(args.depth, args.cardinality, args.base_width)
    check_point = torch.load(os.path.join(args.save, save_name))
    classifier.load_state_dict(check_point['model_state'])

    train_acc = check_point['train_acc']
    test_acc = check_point['test_acc']
    print('Original Discriminative Classifier, train acc: {:.4f}, test acc: {:.4f}'.format(train_acc, test_acc))

    sdim = SDIM(disc_classifier=classifier, rep_size=args.rep_size, mi_units=args.mi_units, n_classes=args.n_classes).to(args.device)

    optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad is True, sdim.parameters()),
                                 lr=args.learning_rate)

    if use_cuda and args.n_gpu > 1:
        sdim = torch.nn.DataParallel(sdim, device_ids=list(range(args.n_gpu)))

    best_train_loss = np.inf

    # train function (forward, backward, update)
    def train():
        sdim.train()
        loss_list = []
        mi_list = []
        nll_list = []
        margin_list = []
コード例 #6
0
def train(hps: DictConfig) -> None:
    # This enables a ctr-C without triggering errors
    import signal

    signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))

    logger = logging.getLogger(__name__)

    cuda_available = torch.cuda.is_available()

    torch.manual_seed(hps.seed)

    device = "cuda" if cuda_available and hps.device == 'cuda' else "cpu"

    # Models
    local_channel = hps.get(hps.base_classifier).last_conv_channel
    classifier = get_model(model_name=hps.base_classifier,
                           in_size=local_channel,
                           out_size=hps.rep_size).to(hps.device)
    logger.info('Base classifier name: {}, # parameters: {}'.format(
        hps.base_classifier, cal_parameters(classifier)))

    sdim = SDIM(disc_classifier=classifier,
                mi_units=hps.mi_units,
                n_classes=hps.n_classes,
                margin=hps.margin,
                rep_size=hps.rep_size,
                local_channel=local_channel).to(hps.device)

    # logging the SDIM desc.
    for desc in sdim.desc():
        logger.info(desc)

    train_loader = Loader('train', batch_size=hps.n_batch_train, device=device)

    if cuda_available and hps.n_gpu > 1:
        sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu)))

    optimizer = Adam(filter(lambda param: param.requires_grad is True,
                            sdim.parameters()),
                     lr=hps.lr)

    torch.manual_seed(hps.seed)
    np.random.seed(hps.seed)

    # Create log dir
    logdir = os.path.abspath(hps.log_dir) + "/"
    if not os.path.exists(logdir):
        os.mkdir(logdir)

    loss_optimal = 1e5
    n_iters = 0

    losses = AverageMeter('Loss')
    MIs = AverageMeter('MI')
    nlls = AverageMeter('NLL')
    margins = AverageMeter('Margin')
    top1 = AverageMeter('Acc@1')
    top5 = AverageMeter('Acc@5')

    for x, y in train_loader:
        n_iters += 1
        if n_iters == hps.training_iters:
            break

        # backward
        optimizer.zero_grad()
        loss, mi_loss, nll_loss, ll_margin, log_lik = sdim(x, y)
        loss.mean().backward()
        optimizer.step()

        acc1, acc5 = accuracy(log_lik, y, topk=(1, 5))
        losses.update(loss.item(), x.size(0))
        top1.update(acc1, x.size(0))
        top5.update(acc5, x.size(0))

        MIs.update(mi_loss.item(), x.size(0))
        nlls.update(nll_loss.item(), x.size(0))
        margins.update(ll_margin.item(), x.size(0))

        if n_iters % hps.log_interval == hps.log_interval - 1:
            logger.info(
                'Train loss: {:.4f}, mi: {:.4f}, nll: {:.4f}, ll_margin: {:.4f}'
                .format(losses.avg, MIs.avg, nlls.avg, margins.avg))
            logger.info('Train Acc@1: {:.3f}, Acc@5: {:.3f}'.format(
                top1.avg, top5.avg))

            if losses.avg < loss_optimal:
                loss_optimal = losses.avg
                model_path = 'SDIM_{}.pth'.format(hps.base_classifier)

                if cuda_available and hps.n_gpu > 1:
                    state = sdim.module.state_dict()
                else:
                    state = sdim.state_dict()

                check_point = {
                    'model_state': state,
                    'train_acc_top1': top1.avg,
                    'train_acc_top5': top5.avg
                }

                torch.save(check_point, os.path.join(hps.log_dir, model_path))

            losses.reset()
            MIs.reset()
            nlls.reset()
            margins.reset()
            top1.reset()
            top5.reset()