Ejemplo n.º 1
0
def calibrate(dataset,
              net_arch,
              trn_para,
              save,
              device_type,
              model_filename='model.pth',
              calibrated_filename='model_C_ts.pth',
              batch_size=256):
    """
    Applies temperature scaling to a trained model.

    Takes a pretrained DenseNet-CIFAR100 model, and a validation set
    (parameterized by indices on train set).
    Applies temperature scaling, and saves a temperature scaled version.

    NB: the "save" parameter references a DIRECTORY, not a file.
    In that directory, there should be two files:
    - model.pth (model state dict)
    - valid_indices.pth (a list of indices corresponding to the validation set).

    data (str) - path to directory where data should be loaded from/downloaded
    save (str) - directory with necessary files (see above)
    """
    # Load model state dict
    model_filename = os.path.join(save, model_filename)
    if not os.path.exists(model_filename):
        raise RuntimeError('Cannot find file %s to load' % model_filename)
    state_dict = torch.load(model_filename)

    # Load original model
    orig_model = get_model(net_arch)
    orig_model = move_to_device(orig_model, device_type)
    orig_model.load_state_dict(state_dict)

    # data loader
    valid_loader = dataset.valid_loader

    # wrap the model with a decorator that adds temperature scaling
    model = ModelWithTemperature(orig_model)

    # Tune the model temperature, and save the results
    model.opt_temperature(valid_loader)
    model_filename = os.path.join(save, calibrated_filename)
    torch.save(model.state_dict(), model_filename)
    print('Temperature scaled model sved to %s' % model_filename)
Ejemplo n.º 2
0
def run_epoch_fast(loader,
                   model,
                   criterion,
                   optimizer,
                   device_type='cuda',
                   epoch=0,
                   n_epochs=0,
                   train=True,
                   log_every_step=True):
    time_meter = Meter(name='Time', cum=True)
    loss_meter = Meter(name='Loss', cum=False)
    error_meter = Meter(name='Error', cum=False)

    if train:
        model.train()
        print('Training')
    else:
        model.eval()
        print('Evaluating')

    end = time.time()
    for i, (input, target) in enumerate(loader):
        if train:
            model.zero_grad()
            optimizer.zero_grad()

            # Forward pass
            input = move_to_device(input, device_type, False)
            target = move_to_device(target, device_type, False)
            output = model(input)
            loss = criterion(output, target)

            # Backward pass
            if loss.item() > 0:
                loss.backward()
                optimizer.step()
            optimizer.n_iters = optimizer.n_iters + 1 if hasattr(
                optimizer, 'n_iters') else 1

        else:
            with torch.no_grad():
                # Forward pass
                input = move_to_device(input, device_type, False)
                target = move_to_device(target, device_type, False)
                output = model(input)
                loss = criterion(output, target)

        # Accounting
        _, predictions = torch.topk(output, 1)
        error = 1 - torch.eq(torch.squeeze(predictions), target).float().mean()
        batch_time = time.time() - end
        end = time.time()

        # Log errors
        time_meter.update(batch_time)
        loss_meter.update(loss)
        error_meter.update(error)
        if log_every_step:
            for param_group in optimizer.param_groups:
                lr_value = param_group['lr']
            print('  '.join([
                '%s: (Epoch %d of %d) [%04d/%04d]' %
                ('Train' if train else 'Eval', epoch, n_epochs, i + 1,
                 len(loader)),
                str(time_meter),
                str(loss_meter),
                str(error_meter),
                '%.4f' % lr_value
            ]))

    if not log_every_step:
        print('  '.join([
            #'%s: (Epoch %d of %d) [%04d/%04d]' % ('Train' if train else 'Eval',
            #epoch, n_epochs, i + 1, len(loader)),
            '%s: (Epoch %d of %d)' %
            ('Train' if train else 'Eval', epoch, n_epochs),
            str(time_meter),
            str(loss_meter),
            str(error_meter),
        ]))

    return time_meter.value(), loss_meter.value(), error_meter.value()
Ejemplo n.º 3
0
def run_epoch_perturb(loader,
                      model,
                      perturbation,
                      onehot_criterion,
                      soft_criterion,
                      optimizer,
                      device_type='cuda',
                      epoch=0,
                      n_epochs=0,
                      train=True,
                      class_num=100,
                      perturb_degree=5,
                      sample_num=20,
                      alpha=0.5,
                      pt_input_ratio=0.25,
                      log_every_step=True):
    time_meter = Meter(name='Time', cum=True)
    loss_meter = Meter(name='Loss', cum=False)
    error_meter = Meter(name='Error', cum=False)

    if train:
        model.train()
        print('Training')

    else:
        model.eval()
        print('Evaluating')

    pt_random = torch.randn([sample_num, 3, 32, 32])
    end = time.time()
    for i, (input, target) in enumerate(loader):
        if train:
            model.zero_grad()
            optimizer.zero_grad()

            # Forward pass
            input_size = input.size()
            pt_batch = int(input_size[0] * pt_input_ratio)
            torch.nn.init.normal_(pt_random)
            pt_random.requires_grad = False
            pt_flat = torch.flatten(pt_random, start_dim=1)
            p_n = torch.norm(
                pt_flat, p=2, dim=1,
                keepdim=True).unsqueeze(1).unsqueeze(1).expand_as(pt_random)
            pt = pt_random.div_(p_n)
            pt_ = pt.unsqueeze(0).expand(pt_batch, sample_num, input_size[1],
                                         input_size[2], input_size[3])
            pt_input = input[:pt_batch].unsqueeze(1).expand_as(pt_) + pt_
            pt_input = torch.reshape(pt_input,
                                     (pt_batch * sample_num, input_size[1],
                                      input_size[2], input_size[3]))
            input = move_to_device(input, device_type, False)
            target = move_to_device(target, device_type, False)
            pt_input = move_to_device(pt_input, device_type, False)
            pt_target = target[:pt_batch].unsqueeze(1)
            p_logits = model.forward(pt_input)
            p_outputs = torch.argmax(p_logits, dim=1)
            p_outputs = torch.reshape(p_outputs, (pt_batch, sample_num))
            pt_output_sum = torch.sum(torch.eq(
                p_outputs, pt_target.expand_as(p_outputs)).float(),
                                      dim=1,
                                      keepdim=True)
            '''
            for j in range(sample_num):
                pt = torch.randn_like(input[0])
                pt_flat = torch.flatten(pt)
                p_n = torch.norm(pt_flat, p=2)
                pt = pt.div(p_n)
                pt = pt.unsqueeze(0).expand_as(input) * perturb_degree
                pt = move_to_device(pt, device_type)
                pt_input = input + pt
                #pt_input = torch.clamp(pt_input, 0, 1) # input already normalized, TODO unnormalize first
                p_logits = model.forward(pt_input) 
                p_outputs = torch.argmax(p_logits, dim=1, keepdim=True)
                pt_output_sum = pt_output_sum + torch.eq(p_outputs, target_).float()
            '''
            pt_output_mean = torch.div(pt_output_sum, sample_num)
            pt_target = smooth_label(pt_target, pt_output_mean, class_num)

            output = model(input)
            onehot_loss = onehot_criterion(output, target)
            pt_target = pt_target.detach()
            perturb_loss = soft_criterion(output[:pt_batch], pt_target)
            loss = alpha * onehot_loss + (
                1 - alpha) * pt_input_ratio * perturb_loss

            # Backward pass
            loss.backward()
            optimizer.step()
            optimizer.n_iters = optimizer.n_iters + 1 if hasattr(
                optimizer, 'n_iters') else 1

        else:
            with torch.no_grad():
                # Forward pass
                input = move_to_device(input, device_type, False)
                target = move_to_device(target, device_type, False)
                output = model(input)
                loss = onehot_criterion(output, target)

        # Accounting
        _, predictions = torch.topk(output, 1)
        error = 1 - torch.eq(torch.squeeze(predictions), target).float().mean()
        batch_time = time.time() - end
        end = time.time()

        # Log errors
        time_meter.update(batch_time)
        loss_meter.update(loss)
        error_meter.update(error)
        if log_every_step:
            for param_group in optimizer.param_groups:
                lr_value = param_group['lr']
            print('  '.join([
                '%s: (Epoch %d of %d) [%04d/%04d]' %
                ('Train' if train else 'Eval', epoch, n_epochs, i + 1,
                 len(loader)),
                str(time_meter),
                str(loss_meter),
                str(error_meter),
                '%.4f' % lr_value
            ]))

    print('pt_output_mean')
    print(pt_output_mean)
    print('onehot_loss:' + str(onehot_loss))
    print('perturb_loss:' + str(perturb_loss))
    if not log_every_step:
        print('  '.join([
            #'%s: (Epoch %d of %d) [%04d/%04d]' % ('Train' if train else 'Eval',
            #epoch, n_epochs, i + 1, len(loader)),
            '%s: (Epoch %d of %d)' %
            ('Train' if train else 'Eval', epoch, n_epochs),
            str(time_meter),
            str(loss_meter),
            str(error_meter),
        ]))

    return time_meter.value(), loss_meter.value(), error_meter.value()
Ejemplo n.º 4
0
def train(dataset,
          net_arch,
          trn_para,
          save,
          device_type='cuda',
          model_filename='model.pth'):
    """
    A function to train a DenseNet-BC on CIFAR-100.

    Args:
        data (class Data) - data instance
        save (str) - path to save the model to (default /outputs)
        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        n_epochs (int) - number of epochs for training (default 300)
        lr (float) - initial learning rate
        wd (float) - weight decay
        momentum (float) - momentum
    """
    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    model = get_model(net_arch)
    if trn_para['weight_init'] == 'xavier':
        model = xavier_init_weights(model)
    elif trn_para['weight_init'] == 'kaiming':
        model = kaiming_normal_init_weights(model)
    model_wrapper = move_to_device(model, device_type, True)
    print('parameter_number: ' + str(get_para_num(model_wrapper)))

    perturbation = move_to_device(
        PerturbInput((trn_para['batch_size'], 3, 32, 32),
                     perturb_degree=trn_para['perturb_degree'],
                     sample_num=trn_para['sample_num'],
                     pt_input_ratio=trn_para['pt_input_ratio'],
                     num_classes=dataset.num_classes), device_type)

    n_epochs = trn_para['n_epochs']
    if trn_para['loss_fn'] == 'MarginLoss':
        onehot_criterion = MarginLoss()
    else:
        onehot_criterion = nn.CrossEntropyLoss()
    soft_criterion = SoftCrossEntropyLoss()

    optimizer = optim.SGD(model_wrapper.parameters(),
                          lr=trn_para['lr'],
                          momentum=trn_para['momentum'],
                          nesterov=True)
    if 'lr_scheduler' in trn_para.keys(
    ) and trn_para['lr_scheduler'] == 'ReduceLROnPlateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         patience=20)
    elif trn_para['lr_scheduler'] == 'MultiStepLR_150_225_300':
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=[150, 225, 300],
                                                   gamma=0.1)
    else:
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs], gamma=0.1)

    # Make dataloaders
    train_loader = dataset.train_loader
    valid_loader = dataset.valid_loader

    # Train model
    best_error = 1
    for epoch in range(1, n_epochs + 1):

        if trn_para['run_epoch'] == 'fast':
            run_epoch_fast(
                loader=train_loader,
                model=model_wrapper,
                criterion=onehot_criterion,
                optimizer=optimizer,
                device_type=device_type,
                epoch=epoch,
                n_epochs=n_epochs,
                train=True,
            )
        else:
            if epoch < trn_para['sample_start'] * n_epochs:
                run_epoch(
                    loader=train_loader,
                    model=model_wrapper,
                    criterion=onehot_criterion,
                    optimizer=optimizer,
                    device_type=device_type,
                    epoch=epoch,
                    n_epochs=n_epochs,
                    train=True,
                )
            else:
                run_epoch_perturb(
                    loader=train_loader,
                    model=model_wrapper,
                    perturbation=perturbation,
                    onehot_criterion=onehot_criterion,
                    soft_criterion=soft_criterion,
                    optimizer=optimizer,
                    device_type=device_type,
                    epoch=epoch,
                    n_epochs=n_epochs,
                    train=True,
                    class_num=dataset.num_classes,
                    perturb_degree=trn_para['perturb_degree'],
                    sample_num=trn_para['sample_num'],
                    alpha=trn_para['alpha'],
                    pt_input_ratio=trn_para['pt_input_ratio'],
                )
        valid_results = run_epoch(
            loader=valid_loader,
            model=model_wrapper,
            criterion=onehot_criterion,
            optimizer=optimizer,
            device_type=device_type,
            epoch=epoch,
            n_epochs=n_epochs,
            train=False,
        )

        # Determine if model is the best
        _, _, valid_error = valid_results
        if valid_error[0] < best_error:
            best_error = valid_error[0]
            print('New best error: %.4f' % best_error)
            torch.save(model.state_dict(), os.path.join(save, model_filename))
            with open(os.path.join(save, 'model_ckpt_detail.txt'),
                      'a') as ckptf:
                ckptf.write('epoch ' + str(epoch) +
                            (' reaches new best error %.4f' % best_error) +
                            '\n')

        if 'lr_scheduler' in trn_para.keys(
        ) and trn_para['lr_scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(valid_error[0])
        else:
            scheduler.step()

    torch.save(model.state_dict(), os.path.join(save,
                                                'last_' + model_filename))
    print('Train Done!')
Ejemplo n.º 5
0
def ensbl_test(save_dirs,
               target_save_dir,
               test_cal=False,
               logit_filename='uncal_logits.pth',
               label_filename='uncal_labels.pth',
               device_type='cpu'):
    """
    A function to test ensemble of DenseNet-BC on data.

    Args:
        data (class Data) - data instance
        save (str) - path to save the model to (default /outputs)
        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        n_epochs (int) - number of epochs for training (default 300)
        lr (float) - initial learning rate
        wd (float) - weight decay
        momentum (float) - momentum
    """
    # Individual Models
    nll_criterion = move_to_device(nn.CrossEntropyLoss(), device_type)
    ece_criterion = move_to_device(ECELoss(), device_type)
    logits_list, labels_list, probs_list, ERR_1_list, ERR_5_list, NLL_list, ECE_list = [], [], [], [], [], [], []
    for sd in save_dirs:
        lg = torch.load(os.path.join(sd, logit_filename),
                        map_location=torch.device(device_type))
        lb = torch.load(os.path.join(sd, label_filename),
                        map_location=torch.device(device_type))
        #logits_list.append(torch.unsqueeze(lg, dim=1))
        labels_list.append(torch.unsqueeze(lb, dim=1))
        probs_list.append(torch.unsqueeze(torch.nn.Softmax(dim=-1)(lg), dim=1))
        ERR_1_list.append(torch.unsqueeze(Error_topk(lg, lb, 1), 0))
        ERR_5_list.append(torch.unsqueeze(Error_topk(lg, lb, 5), 0))
        NLL_list.append(torch.unsqueeze(nll_criterion(lg, lb), 0))
        ECE_list.append(torch.unsqueeze(ece_criterion(lg, lb), 0))
    ERR_1_list = torch.cat(ERR_1_list)
    ERR_5_list = torch.cat(ERR_5_list)
    print(ERR_1_list)
    print(ERR_5_list)
    NLL_list = torch.cat(NLL_list)
    ECE_list = torch.cat(ECE_list) * 100
    labels_list = torch.cat(labels_list, dim=1)
    probs_list = torch.cat(probs_list, dim=1)
    avg_prob = torch.mean(probs_list, dim=1)
    pseudo_avg_logit = torch.log(avg_prob)
    lb = labels_list[:, 0]
    avg_ERR_1 = Error_topk(pseudo_avg_logit, lb, 1)
    avg_ERR_5 = Error_topk(pseudo_avg_logit, lb, 5)
    avg_NLL = nll_criterion(pseudo_avg_logit, lb)
    avg_ECE = ece_criterion(pseudo_avg_logit, lb)
    cal_str = 'Calibrated' if test_cal else 'Uncalibrated'
    with open(os.path.join(target_save_dir, 'ensbl_results.txt'), 'a') as f:
        f.write('+++++++' + cal_str + '+++++++\n')
        f.write('==Individual Models==\n')
        f.write(
            ' ERR_1: %.4f \scriptsize{$\pm$ %.4f}\n ERR_5: %.4f \scriptsize{$\pm$ %.4f}\n NLL: %.3f \scriptsize{$\pm$ %.3f}\n ECE: %.4f \scriptsize{$\pm$ %.4f}\n'
            % (torch.mean(ERR_1_list).item(), torch.std(ERR_1_list).item(),
               torch.mean(ERR_5_list).item(), torch.std(ERR_5_list).item(),
               torch.mean(NLL_list).item(), torch.std(NLL_list).item(),
               torch.mean(ECE_list).item(), torch.std(ECE_list).item()))
        f.write('==Ensemble==\n')
        f.write(' ERR_1: %.4f \n ERR_5: %.4f \n NLL: %.3f \n ECE: %.4f \n' %
                (avg_ERR_1.item(), avg_ERR_5.item(), avg_NLL.item(),
                 avg_ECE.item()))
Ejemplo n.º 6
0
def train(dataset,
          net_arch,
          trn_para,
          save,
          device_type='cuda',
          model_filename='model.pth'):
    """
    A function to train a DenseNet-BC on CIFAR-100.

    Args:
        data (class Data) - data instance
        save (str) - path to save the model to (default /outputs)
        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        n_epochs (int) - number of epochs for training (default 300)
        lr (float) - initial learning rate
        wd (float) - weight decay
        momentum (float) - momentum
    """
    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    model = get_model(net_arch)
    if trn_para['weight_init'] == 'xavier':
        model = xavier_init_weights(model)
    elif trn_para['weight_init'] == 'kaiming':
        model = kaiming_normal_init_weights(model)
    model_wrapper = move_to_device(model, device_type, True)

    n_epochs = trn_para['n_epochs']
    if trn_para['loss_fn'] == 'MarginLoss':
        criterion = MarginLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    if 'optimizer' in trn_para.keys() and trn_para['optimizer'] == 'Adam':
        print('Using Adam')
        optimizer = optim.Adam(model_wrapper.parameters(),
                               lr=trn_para['lr'],
                               weight_decay=trn_para['wd'])
    else:  # trn_para['optimizer'] == 'SGD':
        print('Using SGD')
        optimizer = optim.SGD(model_wrapper.parameters(),
                              lr=trn_para['lr'],
                              weight_decay=trn_para['wd'],
                              momentum=trn_para['momentum'],
                              nesterov=True)
    if 'lr_scheduler' in trn_para.keys(
    ) and trn_para['lr_scheduler'] == 'ReduceLROnPlateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         patience=20)
    elif trn_para['lr_scheduler'] == 'MultiStepLR_150_225_300':
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=[150, 225, 300],
                                                   gamma=0.1)
    elif trn_para['lr_scheduler'] == 'MultiStepLR_60_120_160_200':
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[60, 120, 160, 200], gamma=0.2)
    else:
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs], gamma=0.1)

    # Make dataloaders
    train_loader = dataset.train_loader
    valid_loader = dataset.valid_loader

    # Warmup
    if trn_para['warmup'] > 0:
        warmup_optimizer = optim.SGD(model_wrapper.parameters(),
                                     lr=float(trn_para['lr']) / 10.0,
                                     momentum=trn_para['momentum'],
                                     nesterov=True)
        for warmup_epoch in range(1, trn_para['warmup'] + 1):
            run_epoch(
                loader=train_loader,
                model=model_wrapper,
                criterion=criterion,
                optimizer=warmup_optimizer,
                device_type=device_type,
                epoch=warmup_epoch,
                n_epochs=trn_para['warmup'],
                train=True,
            )

    # Train model
    best_error = 1
    for epoch in range(1, n_epochs + 1):
        if trn_para['run_epoch'] == 'fast':
            run_epoch_fast(
                loader=train_loader,
                model=model_wrapper,
                criterion=criterion,
                optimizer=optimizer,
                device_type=device_type,
                epoch=epoch,
                n_epochs=n_epochs,
                train=True,
            )
        else:
            run_epoch(
                loader=train_loader,
                model=model_wrapper,
                criterion=criterion,
                optimizer=optimizer,
                device_type=device_type,
                epoch=epoch,
                n_epochs=n_epochs,
                train=True,
            )
        valid_results = run_epoch(
            loader=valid_loader,
            model=model_wrapper,
            criterion=criterion,
            optimizer=optimizer,
            device_type=device_type,
            epoch=epoch,
            n_epochs=n_epochs,
            train=False,
        )

        # Determine if model is the best
        _, _, valid_error = valid_results
        if valid_error[0] < best_error:
            best_error = valid_error[0]
            print('New best error: %.4f' % best_error)
            torch.save(model.state_dict(), os.path.join(save, model_filename))
            with open(os.path.join(save, 'model_ckpt_detail.txt'),
                      'a') as ckptf:
                ckptf.write('epoch ' + str(epoch) +
                            (' reaches new best error %.4f' % best_error) +
                            '\n')

        if 'lr_scheduler' in trn_para.keys(
        ) and trn_para['lr_scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(valid_error[0])
        else:
            scheduler.step()

    torch.save(model.state_dict(), os.path.join(save,
                                                'last_' + model_filename))
    print('Train Done!')
Ejemplo n.º 7
0
def test_train(dataset,
               net_arch,
               trn_para,
               save,
               device_type,
               model_filename='model.pth',
               batch_size=256):
    """
    A function to test a DenseNet-BC on data.

    Args:
        data (class Data) - data instance
        save (str) - path to save the model to (default /outputs)
        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        n_epochs (int) - number of epochs for training (default 300)
        lr (float) - initial learning rate
        wd (float) - weight decay
        momentum (float) - momentum
    """
    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    model = get_model(net_arch)
    model_uncalibrated = move_to_device(model, device_type)
    test_model = model_uncalibrated
    # Load model state dict
    model_state_filename = os.path.join(save, model_filename)
    if not os.path.exists(model_state_filename):
        raise RuntimeError('Cannot find file %s to load' %
                           model_state_filename)
    state_dict = torch.load(model_state_filename)
    test_model.load_state_dict(state_dict)
    test_model.eval()

    # Make dataloader
    test_loader = dataset.test_train_loader

    nll_criterion = move_to_device(nn.CrossEntropyLoss(), device_type)
    ece_criterion = move_to_device(ECELoss(), device_type)
    # First: collect all the logits and labels for the test set
    logits_list = []
    labels_list = []
    with torch.no_grad():
        for input, label in test_loader:
            input = move_to_device(input, device_type)
            logits = test_model(input)
            logits_list.append(logits)
            labels_list.append(label)
        logits = move_to_device(torch.cat(logits_list), device_type)
        labels = move_to_device(torch.cat(labels_list), device_type)
    # Calculate Error
    error_1 = Error_topk(logits, labels, 1).item()
    error_5 = Error_topk(logits, labels, 5).item()
    # Calculate NLL and ECE before temperature scaling
    result_nll = nll_criterion(logits, labels).item()
    result_ece = ece_criterion(logits, labels).item()
    #result_ece_2 = get_ece(logits, labels)
    torch.save(logits, os.path.join(save, 'test_train_logits.pth'))
    torch.save(labels, os.path.join(save, 'test_train_labels.pth'))
    np.save(os.path.join(save, 'test_train_logits.npy'), logits.cpu().numpy())
    np.save(os.path.join(save, 'test_train_labels.npy'), labels.cpu().numpy())
Ejemplo n.º 8
0
def sample(dataset,
           net_arch,
           trn_para,
           save,
           device_type,
           model_filename='model.pth',
           batch_size=256):
    model = get_model(net_arch)
    sample_model = move_to_device(model, device_type)
    model_state_filename = os.path.join(save, model_filename)
    state_dict = torch.load(model_state_filename)
    sample_model.load_state_dict(state_dict)
    sample_model.eval()
    train_loader = dataset.train_loader
    sample_num = 20
    perturb_degree = 5
    inputs_list = []
    perturbations_list = []
    logits_list = []
    labels_list = []
    with torch.no_grad():
        cnt = 0
        for input, label in train_loader:
            if cnt > 5:
                break
            cnt += 1
            cur_logits_list = []
            cur_perturbations_list = []
            input = move_to_device(input, device_type)
            print('input:' + str(input.size()))
            logits = sample_model(input)
            print('logits:' + str(logits.size()))
            cur_logits_list.append(logits.unsqueeze_(1))
            for i in range(sample_num):
                perturbation = torch.randn_like(input[0])
                perturbation_flatten = torch.flatten(perturbation)
                p_n = torch.norm(perturbation_flatten, p=2)
                print('p_n:' + str(p_n.size()))
                perturbation = perturbation.div(p_n)
                perturbation = perturbation.unsqueeze(0).expand_as(
                    input) * perturb_degree
                perturbation = move_to_device(perturbation, device_type)
                print('perturbation:' + str(perturbation.size()))
                cur_perturbations_list.append(perturbation.unsqueeze(1))
                perturb_input = input + perturbation
                p_logits = sample_model(perturb_input)
                cur_logits_list.append(p_logits.unsqueeze_(1))

            cur_logits_list = torch.cat(cur_logits_list, dim=1)
            cur_perturbations_list = torch.cat(cur_perturbations_list, dim=1)
            logits_list.append(cur_logits_list)
            labels_list.append(label)
            inputs_list.append(input)
            perturbations_list.append(cur_perturbations_list)
    logits = torch.cat(logits_list)
    labels = torch.cat(labels_list)
    inputs = torch.cat(inputs_list)
    perturbations = torch.cat(perturbations_list)
    np.save(
        os.path.join(save, 'sample_input_pdg' + str(perturb_degree) + '.npy'),
        inputs.cpu().numpy())
    np.save(
        os.path.join(save, 'sample_logits_pdg' + str(perturb_degree) + '.npy'),
        logits.cpu().numpy())
    np.save(
        os.path.join(save, 'sample_perturbations_pdg' + str(perturb_degree) +
                     '.npy'),
        perturbations.cpu().numpy())
    np.save(
        os.path.join(save, 'sample_labels_pdg' + str(perturb_degree) + '.npy'),
        labels.cpu().numpy())
Ejemplo n.º 9
0
def test(dataset, net_arch, trn_para, save, device_type, test_cal='cal', model_filename='model.pth', batch_size=256):
    """
    A function to test a DenseNet-BC on data.

    Args:
        data (class Data) - data instance
        save (str) - path to save the model to (default /outputs)
        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        n_epochs (int) - number of epochs for training (default 300)
        lr (float) - initial learning rate
        wd (float) - weight decay
        momentum (float) - momentum
    """
    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    model = get_model(net_arch)
    if test_cal == 'uncal':
        model_uncalibrated = move_to_device(model, device_type)
        test_model = model_uncalibrated
    elif test_cal == 'cal' : # test_cal:
        model_uncalibrated = move_to_device(model, device_type)
        model_calibrated = ModelWithTemperature(model_uncalibrated)
        model_calibrated = move_to_device(model_calibrated, device_type)
        test_model = model_calibrated
    elif test_cal == 'uncal_in_cal': 
        model_uncalibrated = move_to_device(model, device_type)
        model_calibrated = ModelWithTemperature(model_uncalibrated)
        model_calibrated = move_to_device(model_calibrated, device_type)
        test_model = model_calibrated
    else:
        model_uncalibrated = move_to_device(model, device_type)
        test_model = model_uncalibrated
    # Load model state dict
    model_state_filename = os.path.join(save, model_filename)
    if not os.path.exists(model_state_filename):
        raise RuntimeError('Cannot find file %s to load' % model_state_filename)
    state_dict = torch.load(model_state_filename)
    test_model.load_state_dict(state_dict)
    if test_cal in ['cal', 'uncal_in_cal']:
        print(test_cal+' '+str(test_model.temperature))
    if test_cal == 'uncal_in_cal':
        test_model = test_model.model
    test_model.eval()

    # Make dataloader
    test_loader = dataset.test_loader
    
    nll_criterion = move_to_device(nn.CrossEntropyLoss(), device_type)
    NLL_criterion = move_to_device(nn.NLLLoss(), device_type)
    logmax = move_to_device(nn.LogSoftmax(dim=1), device_type)
    ece_criterion = move_to_device(ECELoss(), device_type)
    tce_criterion = move_to_device(TCELoss(), device_type)
    ecce_criterion = move_to_device(ECCELoss(), device_type)
    ace_criterion = move_to_device(ACELoss(), device_type)
    acce_criterion = move_to_device(ACCELoss(), device_type)
    # First: collect all the logits and labels for the test set
    logits_list = []
    labels_list = []
    with torch.no_grad():
        for input, label in test_loader:
            input = move_to_device(input, device_type)
            logits = test_model(input)
            logits_list.append(logits)
            labels_list.append(label)
        logits = move_to_device(torch.cat(logits_list), device_type)
        labels = move_to_device(torch.cat(labels_list), device_type)
    # Calculate Error
    error_1 = Error_topk(logits, labels, 1).item()
    error_5 = Error_topk(logits, labels, 5).item()
    # Calculate NLL and ECE before temperature scaling
    result_nll = nll_criterion(logits, labels).item()
    result_ece = ece_criterion(logits, labels).item()
    result_tce = tce_criterion(logits, labels).item()
    ece = ece_criterion(logits, labels).item()
    ecce = ecce_criterion(logits, labels).item()
    ace = ace_criterion(logits, labels).item()
    acce = acce_criterion(logits, labels).item()
    nll = nll_criterion(logits, labels).item()
    NLL = NLL_criterion(logmax(logits), labels).item()
    #result_ece_2 = get_ece(logits, labels)
    cal_str = 'Calibrated'
    if test_cal == 'uncal': 
        cal_str = 'Uncalibrated'
    elif test_cal == 'uncal_in_cal':
        cal_str = 'Uncal in Calibrated'
    elif test_cal == 'uncal_last':
        cal_str = 'Uncal Last'
    with open(os.path.join(save, 'results.txt'), 'a') as f:
        #f.write(cal_str + ' || ERR_1: %.4f, ERR_5: %.4f, NLL: %.4f, ECE: %.4f, TCE: %.4f' % (error_1*100, error_5*100, result_nll, result_ece*100, result_tce*100)+'\n')
        write_text =  ('ACC_1:%.4f, ERR_1:%.4f, ERR_5:%.4f, nll: %.4f, ACCE(e-4):%.4f, ACE(e-4):%.4f, ECCE(e-2):%.4f, ECE(e-2):%.4f, \n %.2f & %.4f & %.2f & %.2f & %.2f & %.2f' % (1-error_1, error_1, error_5, nll, acce*1e4, ace*1e4, ecce*1e2, ece*1e2, 100*(1-error_1), nll, acce*1e4, ace*1e4, ecce*1e2, ece*1e2))
        f.write(write_text+'\n')
    if test_cal == 'cal':
        torch.save(logits, os.path.join(save, 'cal_logits.pth'))
        torch.save(labels, os.path.join(save, 'cal_labels.pth'))
        np.save(os.path.join(save, 'cal_logits.npy'), logits.cpu().numpy())
        np.save(os.path.join(save, 'cal_labels.npy'), labels.cpu().numpy())
    elif test_cal == 'uncal':
        torch.save(logits, os.path.join(save, 'uncal_logits.pth'))
        torch.save(labels, os.path.join(save, 'uncal_labels.pth'))
        np.save(os.path.join(save, 'uncal_logits.npy'), logits.cpu().numpy())
        np.save(os.path.join(save, 'uncal_labels.npy'), labels.cpu().numpy())
    elif test_cal == 'uncal_in_cal':
        torch.save(logits, os.path.join(save, 'uncal_in_cal_logits.pth'))
        torch.save(labels, os.path.join(save, 'uncal_in_cal_labels.pth'))
        np.save(os.path.join(save, 'uncal_in_cal_logits.npy'), logits.cpu().numpy())
        np.save(os.path.join(save, 'uncal_in_cal_labels.npy'), labels.cpu().numpy())