def calibrate(dataset,
              net_arch,
              trn_para,
              save,
              device_type,
              model_filename='model.pth',
              calibrated_filename='model_C_ts.pth',
              batch_size=256):
    """
    Applies temperature scaling to a trained model.

    Takes a pretrained DenseNet-CIFAR100 model, and a validation set
    (parameterized by indices on train set).
    Applies temperature scaling, and saves a temperature scaled version.

    NB: the "save" parameter references a DIRECTORY, not a file.
    In that directory, there should be two files:
    - model.pth (model state dict)
    - valid_indices.pth (a list of indices corresponding to the validation set).

    data (str) - path to directory where data should be loaded from/downloaded
    save (str) - directory with necessary files (see above)
    """
    # Load model state dict
    model_filename = os.path.join(save, model_filename)
    if not os.path.exists(model_filename):
        raise RuntimeError('Cannot find file %s to load' % model_filename)
    state_dict = torch.load(model_filename)

    # Load original model
    orig_model = get_model(net_arch)
    orig_model = move_to_device(orig_model, device_type)
    orig_model.load_state_dict(state_dict)

    # data loader
    valid_loader = dataset.valid_loader

    # wrap the model with a decorator that adds temperature scaling
    model = ModelWithTemperature(orig_model)

    # Tune the model temperature, and save the results
    model.opt_temperature(valid_loader)
    model_filename = os.path.join(save, calibrated_filename)
    torch.save(model.state_dict(), model_filename)
    print('Temperature scaled model sved to %s' % model_filename)
def train(dataset,
          net_arch,
          trn_para,
          save,
          device_type='cuda',
          model_filename='model.pth'):
    """
    A function to train a DenseNet-BC on CIFAR-100.

    Args:
        data (class Data) - data instance
        save (str) - path to save the model to (default /outputs)
        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        n_epochs (int) - number of epochs for training (default 300)
        lr (float) - initial learning rate
        wd (float) - weight decay
        momentum (float) - momentum
    """
    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    model = get_model(net_arch)
    if trn_para['weight_init'] == 'xavier':
        model = xavier_init_weights(model)
    elif trn_para['weight_init'] == 'kaiming':
        model = kaiming_normal_init_weights(model)
    model_wrapper = move_to_device(model, device_type, True)
    print('parameter_number: ' + str(get_para_num(model_wrapper)))

    perturbation = move_to_device(
        PerturbInput((trn_para['batch_size'], 3, 32, 32),
                     perturb_degree=trn_para['perturb_degree'],
                     sample_num=trn_para['sample_num'],
                     pt_input_ratio=trn_para['pt_input_ratio'],
                     num_classes=dataset.num_classes), device_type)

    n_epochs = trn_para['n_epochs']
    if trn_para['loss_fn'] == 'MarginLoss':
        onehot_criterion = MarginLoss()
    else:
        onehot_criterion = nn.CrossEntropyLoss()
    soft_criterion = SoftCrossEntropyLoss()

    optimizer = optim.SGD(model_wrapper.parameters(),
                          lr=trn_para['lr'],
                          momentum=trn_para['momentum'],
                          nesterov=True)
    if 'lr_scheduler' in trn_para.keys(
    ) and trn_para['lr_scheduler'] == 'ReduceLROnPlateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         patience=20)
    elif trn_para['lr_scheduler'] == 'MultiStepLR_150_225_300':
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=[150, 225, 300],
                                                   gamma=0.1)
    else:
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs], gamma=0.1)

    # Make dataloaders
    train_loader = dataset.train_loader
    valid_loader = dataset.valid_loader

    # Train model
    best_error = 1
    for epoch in range(1, n_epochs + 1):

        if trn_para['run_epoch'] == 'fast':
            run_epoch_fast(
                loader=train_loader,
                model=model_wrapper,
                criterion=onehot_criterion,
                optimizer=optimizer,
                device_type=device_type,
                epoch=epoch,
                n_epochs=n_epochs,
                train=True,
            )
        else:
            if epoch < trn_para['sample_start'] * n_epochs:
                run_epoch(
                    loader=train_loader,
                    model=model_wrapper,
                    criterion=onehot_criterion,
                    optimizer=optimizer,
                    device_type=device_type,
                    epoch=epoch,
                    n_epochs=n_epochs,
                    train=True,
                )
            else:
                run_epoch_perturb(
                    loader=train_loader,
                    model=model_wrapper,
                    perturbation=perturbation,
                    onehot_criterion=onehot_criterion,
                    soft_criterion=soft_criterion,
                    optimizer=optimizer,
                    device_type=device_type,
                    epoch=epoch,
                    n_epochs=n_epochs,
                    train=True,
                    class_num=dataset.num_classes,
                    perturb_degree=trn_para['perturb_degree'],
                    sample_num=trn_para['sample_num'],
                    alpha=trn_para['alpha'],
                    pt_input_ratio=trn_para['pt_input_ratio'],
                )
        valid_results = run_epoch(
            loader=valid_loader,
            model=model_wrapper,
            criterion=onehot_criterion,
            optimizer=optimizer,
            device_type=device_type,
            epoch=epoch,
            n_epochs=n_epochs,
            train=False,
        )

        # Determine if model is the best
        _, _, valid_error = valid_results
        if valid_error[0] < best_error:
            best_error = valid_error[0]
            print('New best error: %.4f' % best_error)
            torch.save(model.state_dict(), os.path.join(save, model_filename))
            with open(os.path.join(save, 'model_ckpt_detail.txt'),
                      'a') as ckptf:
                ckptf.write('epoch ' + str(epoch) +
                            (' reaches new best error %.4f' % best_error) +
                            '\n')

        if 'lr_scheduler' in trn_para.keys(
        ) and trn_para['lr_scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(valid_error[0])
        else:
            scheduler.step()

    torch.save(model.state_dict(), os.path.join(save,
                                                'last_' + model_filename))
    print('Train Done!')
Example #3
0
def train(dataset,
          net_arch,
          trn_para,
          save,
          device_type='cuda',
          model_filename='model.pth'):
    """
    A function to train a DenseNet-BC on CIFAR-100.

    Args:
        data (class Data) - data instance
        save (str) - path to save the model to (default /outputs)
        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        n_epochs (int) - number of epochs for training (default 300)
        lr (float) - initial learning rate
        wd (float) - weight decay
        momentum (float) - momentum
    """
    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    model = get_model(net_arch)
    if trn_para['weight_init'] == 'xavier':
        model = xavier_init_weights(model)
    elif trn_para['weight_init'] == 'kaiming':
        model = kaiming_normal_init_weights(model)
    model_wrapper = move_to_device(model, device_type, True)

    n_epochs = trn_para['n_epochs']
    if trn_para['loss_fn'] == 'MarginLoss':
        criterion = MarginLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    if 'optimizer' in trn_para.keys() and trn_para['optimizer'] == 'Adam':
        print('Using Adam')
        optimizer = optim.Adam(model_wrapper.parameters(),
                               lr=trn_para['lr'],
                               weight_decay=trn_para['wd'])
    else:  # trn_para['optimizer'] == 'SGD':
        print('Using SGD')
        optimizer = optim.SGD(model_wrapper.parameters(),
                              lr=trn_para['lr'],
                              weight_decay=trn_para['wd'],
                              momentum=trn_para['momentum'],
                              nesterov=True)
    if 'lr_scheduler' in trn_para.keys(
    ) and trn_para['lr_scheduler'] == 'ReduceLROnPlateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         patience=20)
    elif trn_para['lr_scheduler'] == 'MultiStepLR_150_225_300':
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=[150, 225, 300],
                                                   gamma=0.1)
    elif trn_para['lr_scheduler'] == 'MultiStepLR_60_120_160_200':
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[60, 120, 160, 200], gamma=0.2)
    else:
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs], gamma=0.1)

    # Make dataloaders
    train_loader = dataset.train_loader
    valid_loader = dataset.valid_loader

    # Warmup
    if trn_para['warmup'] > 0:
        warmup_optimizer = optim.SGD(model_wrapper.parameters(),
                                     lr=float(trn_para['lr']) / 10.0,
                                     momentum=trn_para['momentum'],
                                     nesterov=True)
        for warmup_epoch in range(1, trn_para['warmup'] + 1):
            run_epoch(
                loader=train_loader,
                model=model_wrapper,
                criterion=criterion,
                optimizer=warmup_optimizer,
                device_type=device_type,
                epoch=warmup_epoch,
                n_epochs=trn_para['warmup'],
                train=True,
            )

    # Train model
    best_error = 1
    for epoch in range(1, n_epochs + 1):
        if trn_para['run_epoch'] == 'fast':
            run_epoch_fast(
                loader=train_loader,
                model=model_wrapper,
                criterion=criterion,
                optimizer=optimizer,
                device_type=device_type,
                epoch=epoch,
                n_epochs=n_epochs,
                train=True,
            )
        else:
            run_epoch(
                loader=train_loader,
                model=model_wrapper,
                criterion=criterion,
                optimizer=optimizer,
                device_type=device_type,
                epoch=epoch,
                n_epochs=n_epochs,
                train=True,
            )
        valid_results = run_epoch(
            loader=valid_loader,
            model=model_wrapper,
            criterion=criterion,
            optimizer=optimizer,
            device_type=device_type,
            epoch=epoch,
            n_epochs=n_epochs,
            train=False,
        )

        # Determine if model is the best
        _, _, valid_error = valid_results
        if valid_error[0] < best_error:
            best_error = valid_error[0]
            print('New best error: %.4f' % best_error)
            torch.save(model.state_dict(), os.path.join(save, model_filename))
            with open(os.path.join(save, 'model_ckpt_detail.txt'),
                      'a') as ckptf:
                ckptf.write('epoch ' + str(epoch) +
                            (' reaches new best error %.4f' % best_error) +
                            '\n')

        if 'lr_scheduler' in trn_para.keys(
        ) and trn_para['lr_scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(valid_error[0])
        else:
            scheduler.step()

    torch.save(model.state_dict(), os.path.join(save,
                                                'last_' + model_filename))
    print('Train Done!')
Example #4
0
        f.write(' '.join(sys.argv) + '\n')
        f.write('--------------------------------\n')

    if not os.path.isdir(args.destination_dir):
        print('destination directory not exists. Exiting...')
        raise error('destination directory not exists.')

    path = os.path.join(args.destination_dir, 'cfg', 'net_arch.pickle')
    with open(path, 'rb') as handle:
        network_architecture = pickle.load(handle)

    path = os.path.join(args.destination_dir, 'cfg', 'trn_para.pickle')
    with open(path, 'rb') as handle:
        tps = pickle.load(handle)

    model = get_model(network_architecture)
    print(summary(model, torch.zeros(1, 3, 32, 32), show_input=False))

    # initialize loaders
    torch.manual_seed(tps['seed'])
    #if tps['data_name'] == 'cifar100':
    #    ds = DS_cifar100(tps['data_name'], tps['data_path'], tps['batch_size'], tps['valid_size'], 'cfg/data', tps['data_indices_path'])
    #elif tps['data_name'] == 'cifar10':
    #    ds = DS_cifar10(tps['data_name'], tps['data_path'], tps['batch_size'], tps['valid_size'], 'cfg/data', tps['data_indices_path'])
    ds = get_dataset(tps['data_name'], tps['data_path'], tps['batch_size'],
                     tps['valid_size'], 'cfg/data', tps['data_indices_path'])
    # Begin train
    train(ds, network_architecture, tps, args.destination_dir + '/outputs',
          args.device_type)
    # Begin calibrate
    calibrate(ds, network_architecture, tps, args.destination_dir + '/outputs',
Example #5
0
def test_train(dataset,
               net_arch,
               trn_para,
               save,
               device_type,
               model_filename='model.pth',
               batch_size=256):
    """
    A function to test a DenseNet-BC on data.

    Args:
        data (class Data) - data instance
        save (str) - path to save the model to (default /outputs)
        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        n_epochs (int) - number of epochs for training (default 300)
        lr (float) - initial learning rate
        wd (float) - weight decay
        momentum (float) - momentum
    """
    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    model = get_model(net_arch)
    model_uncalibrated = move_to_device(model, device_type)
    test_model = model_uncalibrated
    # Load model state dict
    model_state_filename = os.path.join(save, model_filename)
    if not os.path.exists(model_state_filename):
        raise RuntimeError('Cannot find file %s to load' %
                           model_state_filename)
    state_dict = torch.load(model_state_filename)
    test_model.load_state_dict(state_dict)
    test_model.eval()

    # Make dataloader
    test_loader = dataset.test_train_loader

    nll_criterion = move_to_device(nn.CrossEntropyLoss(), device_type)
    ece_criterion = move_to_device(ECELoss(), device_type)
    # First: collect all the logits and labels for the test set
    logits_list = []
    labels_list = []
    with torch.no_grad():
        for input, label in test_loader:
            input = move_to_device(input, device_type)
            logits = test_model(input)
            logits_list.append(logits)
            labels_list.append(label)
        logits = move_to_device(torch.cat(logits_list), device_type)
        labels = move_to_device(torch.cat(labels_list), device_type)
    # Calculate Error
    error_1 = Error_topk(logits, labels, 1).item()
    error_5 = Error_topk(logits, labels, 5).item()
    # Calculate NLL and ECE before temperature scaling
    result_nll = nll_criterion(logits, labels).item()
    result_ece = ece_criterion(logits, labels).item()
    #result_ece_2 = get_ece(logits, labels)
    torch.save(logits, os.path.join(save, 'test_train_logits.pth'))
    torch.save(labels, os.path.join(save, 'test_train_labels.pth'))
    np.save(os.path.join(save, 'test_train_logits.npy'), logits.cpu().numpy())
    np.save(os.path.join(save, 'test_train_labels.npy'), labels.cpu().numpy())
Example #6
0
def sample(dataset,
           net_arch,
           trn_para,
           save,
           device_type,
           model_filename='model.pth',
           batch_size=256):
    model = get_model(net_arch)
    sample_model = move_to_device(model, device_type)
    model_state_filename = os.path.join(save, model_filename)
    state_dict = torch.load(model_state_filename)
    sample_model.load_state_dict(state_dict)
    sample_model.eval()
    train_loader = dataset.train_loader
    sample_num = 20
    perturb_degree = 5
    inputs_list = []
    perturbations_list = []
    logits_list = []
    labels_list = []
    with torch.no_grad():
        cnt = 0
        for input, label in train_loader:
            if cnt > 5:
                break
            cnt += 1
            cur_logits_list = []
            cur_perturbations_list = []
            input = move_to_device(input, device_type)
            print('input:' + str(input.size()))
            logits = sample_model(input)
            print('logits:' + str(logits.size()))
            cur_logits_list.append(logits.unsqueeze_(1))
            for i in range(sample_num):
                perturbation = torch.randn_like(input[0])
                perturbation_flatten = torch.flatten(perturbation)
                p_n = torch.norm(perturbation_flatten, p=2)
                print('p_n:' + str(p_n.size()))
                perturbation = perturbation.div(p_n)
                perturbation = perturbation.unsqueeze(0).expand_as(
                    input) * perturb_degree
                perturbation = move_to_device(perturbation, device_type)
                print('perturbation:' + str(perturbation.size()))
                cur_perturbations_list.append(perturbation.unsqueeze(1))
                perturb_input = input + perturbation
                p_logits = sample_model(perturb_input)
                cur_logits_list.append(p_logits.unsqueeze_(1))

            cur_logits_list = torch.cat(cur_logits_list, dim=1)
            cur_perturbations_list = torch.cat(cur_perturbations_list, dim=1)
            logits_list.append(cur_logits_list)
            labels_list.append(label)
            inputs_list.append(input)
            perturbations_list.append(cur_perturbations_list)
    logits = torch.cat(logits_list)
    labels = torch.cat(labels_list)
    inputs = torch.cat(inputs_list)
    perturbations = torch.cat(perturbations_list)
    np.save(
        os.path.join(save, 'sample_input_pdg' + str(perturb_degree) + '.npy'),
        inputs.cpu().numpy())
    np.save(
        os.path.join(save, 'sample_logits_pdg' + str(perturb_degree) + '.npy'),
        logits.cpu().numpy())
    np.save(
        os.path.join(save, 'sample_perturbations_pdg' + str(perturb_degree) +
                     '.npy'),
        perturbations.cpu().numpy())
    np.save(
        os.path.join(save, 'sample_labels_pdg' + str(perturb_degree) + '.npy'),
        labels.cpu().numpy())
Example #7
0
def test(dataset, net_arch, trn_para, save, device_type, test_cal='cal', model_filename='model.pth', batch_size=256):
    """
    A function to test a DenseNet-BC on data.

    Args:
        data (class Data) - data instance
        save (str) - path to save the model to (default /outputs)
        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        n_epochs (int) - number of epochs for training (default 300)
        lr (float) - initial learning rate
        wd (float) - weight decay
        momentum (float) - momentum
    """
    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    model = get_model(net_arch)
    if test_cal == 'uncal':
        model_uncalibrated = move_to_device(model, device_type)
        test_model = model_uncalibrated
    elif test_cal == 'cal' : # test_cal:
        model_uncalibrated = move_to_device(model, device_type)
        model_calibrated = ModelWithTemperature(model_uncalibrated)
        model_calibrated = move_to_device(model_calibrated, device_type)
        test_model = model_calibrated
    elif test_cal == 'uncal_in_cal': 
        model_uncalibrated = move_to_device(model, device_type)
        model_calibrated = ModelWithTemperature(model_uncalibrated)
        model_calibrated = move_to_device(model_calibrated, device_type)
        test_model = model_calibrated
    else:
        model_uncalibrated = move_to_device(model, device_type)
        test_model = model_uncalibrated
    # Load model state dict
    model_state_filename = os.path.join(save, model_filename)
    if not os.path.exists(model_state_filename):
        raise RuntimeError('Cannot find file %s to load' % model_state_filename)
    state_dict = torch.load(model_state_filename)
    test_model.load_state_dict(state_dict)
    if test_cal in ['cal', 'uncal_in_cal']:
        print(test_cal+' '+str(test_model.temperature))
    if test_cal == 'uncal_in_cal':
        test_model = test_model.model
    test_model.eval()

    # Make dataloader
    test_loader = dataset.test_loader
    
    nll_criterion = move_to_device(nn.CrossEntropyLoss(), device_type)
    NLL_criterion = move_to_device(nn.NLLLoss(), device_type)
    logmax = move_to_device(nn.LogSoftmax(dim=1), device_type)
    ece_criterion = move_to_device(ECELoss(), device_type)
    tce_criterion = move_to_device(TCELoss(), device_type)
    ecce_criterion = move_to_device(ECCELoss(), device_type)
    ace_criterion = move_to_device(ACELoss(), device_type)
    acce_criterion = move_to_device(ACCELoss(), device_type)
    # First: collect all the logits and labels for the test set
    logits_list = []
    labels_list = []
    with torch.no_grad():
        for input, label in test_loader:
            input = move_to_device(input, device_type)
            logits = test_model(input)
            logits_list.append(logits)
            labels_list.append(label)
        logits = move_to_device(torch.cat(logits_list), device_type)
        labels = move_to_device(torch.cat(labels_list), device_type)
    # Calculate Error
    error_1 = Error_topk(logits, labels, 1).item()
    error_5 = Error_topk(logits, labels, 5).item()
    # Calculate NLL and ECE before temperature scaling
    result_nll = nll_criterion(logits, labels).item()
    result_ece = ece_criterion(logits, labels).item()
    result_tce = tce_criterion(logits, labels).item()
    ece = ece_criterion(logits, labels).item()
    ecce = ecce_criterion(logits, labels).item()
    ace = ace_criterion(logits, labels).item()
    acce = acce_criterion(logits, labels).item()
    nll = nll_criterion(logits, labels).item()
    NLL = NLL_criterion(logmax(logits), labels).item()
    #result_ece_2 = get_ece(logits, labels)
    cal_str = 'Calibrated'
    if test_cal == 'uncal': 
        cal_str = 'Uncalibrated'
    elif test_cal == 'uncal_in_cal':
        cal_str = 'Uncal in Calibrated'
    elif test_cal == 'uncal_last':
        cal_str = 'Uncal Last'
    with open(os.path.join(save, 'results.txt'), 'a') as f:
        #f.write(cal_str + ' || ERR_1: %.4f, ERR_5: %.4f, NLL: %.4f, ECE: %.4f, TCE: %.4f' % (error_1*100, error_5*100, result_nll, result_ece*100, result_tce*100)+'\n')
        write_text =  ('ACC_1:%.4f, ERR_1:%.4f, ERR_5:%.4f, nll: %.4f, ACCE(e-4):%.4f, ACE(e-4):%.4f, ECCE(e-2):%.4f, ECE(e-2):%.4f, \n %.2f & %.4f & %.2f & %.2f & %.2f & %.2f' % (1-error_1, error_1, error_5, nll, acce*1e4, ace*1e4, ecce*1e2, ece*1e2, 100*(1-error_1), nll, acce*1e4, ace*1e4, ecce*1e2, ece*1e2))
        f.write(write_text+'\n')
    if test_cal == 'cal':
        torch.save(logits, os.path.join(save, 'cal_logits.pth'))
        torch.save(labels, os.path.join(save, 'cal_labels.pth'))
        np.save(os.path.join(save, 'cal_logits.npy'), logits.cpu().numpy())
        np.save(os.path.join(save, 'cal_labels.npy'), labels.cpu().numpy())
    elif test_cal == 'uncal':
        torch.save(logits, os.path.join(save, 'uncal_logits.pth'))
        torch.save(labels, os.path.join(save, 'uncal_labels.pth'))
        np.save(os.path.join(save, 'uncal_logits.npy'), logits.cpu().numpy())
        np.save(os.path.join(save, 'uncal_labels.npy'), labels.cpu().numpy())
    elif test_cal == 'uncal_in_cal':
        torch.save(logits, os.path.join(save, 'uncal_in_cal_logits.pth'))
        torch.save(labels, os.path.join(save, 'uncal_in_cal_labels.pth'))
        np.save(os.path.join(save, 'uncal_in_cal_logits.npy'), logits.cpu().numpy())
        np.save(os.path.join(save, 'uncal_in_cal_labels.npy'), labels.cpu().numpy())