Ejemplo n.º 1
0
def do_test_beyesian(distortion):
    num_infer = config['params']['num_infer']

    ''' Model'''
    net = net_factory.load_model(config=config, num_classes=num_classes, dropout=config['params']['dropout'])
    net = net.to(device)
    ckpt = torch.load(os.path.join(config['exp']['path'], 'best.pth'), map_location=device)
    weights = utils._load_weights(ckpt['net'])
    missing_keys = net.load_state_dict(weights, strict=True)
    print(missing_keys)

    '''print out net'''
    num_parameters = 0.
    for param in net.parameters():
        sizes = param.size()

        num_layer_param = 1.
        for size in sizes:
            num_layer_param *= size
        num_parameters += num_layer_param
    print("num. of parameters : " + str(num_parameters))

    ''' inference '''
    net.eval()
    net.apply(apply_mc_dropout)

    certainties = list()

    probs_list = list()
    targets_list = list()
    with torch.set_grad_enabled(False):
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs = _distort_image(distortion, inputs)
            inputs = inputs.to(device)

            all_probs = list()
            for iter_t in range(num_infer):
                # view_inputs(inputs)
                logits = net(inputs)
                probs = logits.softmax(dim=1)
                all_probs.append(probs.detach().cpu())
            all_probs = torch.stack(all_probs)
            all_probs = all_probs.contiguous().permute(1, 2, 0)
            var, mean = torch.var_mean(all_probs, dim=2, unbiased=True)

            probs_list.append(mean)
            targets_list.append(targets)

    probs = torch.cat(probs_list)
    targets = torch.cat(targets_list)
    ece_loss = ece_criterion(probs, targets, is_logits=False).item()

    max_probs, max_ind = probs.max(dim=1)
    all_correct = max_ind.eq(targets).float().sum().item()
    accuracy = all_correct / probs.shape[0]

    print('%-3s (accuracy) : %.5f' % (distortion, accuracy))
    print('%-3s (ece) : %.5f' % (distortion, ece_loss))

    draw_histogram(max_probs.tolist(), distortion, config['exp']['path'])
Ejemplo n.º 2
0
def do_test(distortion, is_scaling=False):
    ''' Model'''
    net = net_factory.load_model(config=config, num_classes=num_classes)
    net = net.to(device)
    if is_scaling is True:
        ckpt = torch.load(os.path.join(config['exp']['path'], 'model_with_temperature.pth'), map_location=device)
        weights = utils._load_weights(ckpt)
        net = ModelWithTemperature(net)
        net = net.to(device)
        missing_keys = net.load_state_dict(weights, strict=True)
        print(missing_keys)
    else:
        ckpt = torch.load(os.path.join(config['exp']['path'], 'best.pth'), map_location=device)
        weights = utils._load_weights(ckpt['net'])
        missing_keys = net.load_state_dict(weights, strict=True)
        print(missing_keys)

    '''print out net'''
    # print(net)
    num_parameters = 0.
    for param in net.parameters():
        sizes = param.size()

        num_layer_param = 1.
        for size in sizes:
            num_layer_param *= size
        num_parameters += num_layer_param
    print("num. of parameters : " + str(num_parameters))


    ''' inference '''
    net.eval()

    logits_list = list()
    targets_list = list()
    with torch.set_grad_enabled(False):
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs = _distort_image(distortion, inputs)
            inputs = inputs.to(device)

            # view_inputs(inputs)
            logits = net(inputs)
            logits_list.append(logits.detach().cpu())
            targets_list.append(targets)

    logits = torch.cat(logits_list)
    targets = torch.cat(targets_list)
    ece_loss = ece_criterion(logits, targets).item()

    probs = logits.softmax(dim=1)
    max_probs, max_ind = probs.max(dim=1)
    all_correct = max_ind.eq(targets).float().sum().item()
    accuracy = all_correct / probs.shape[0]

    print('%-3s (accuracy) : %.5f' % (distortion, accuracy))
    print('%-3s (ece) : %.5f' % (distortion, ece_loss))

    draw_histogram(max_probs.tolist(), distortion, config['exp']['path'])
Ejemplo n.º 3
0
def do_test_ensemble(distortion):
    num_ensemble = config['params']['num_ensembles']
    nets = list()

    for iter_idx in range(num_ensemble):
        net = net_factory.load_model(config=config, num_classes=num_classes)
        net = net.to(device)

        weight_file = 'best_' + str(iter_idx) + '.pth'
        ckpt = torch.load(os.path.join(config['exp']['path'], weight_file), map_location=device)
        weights = utils._load_weights(ckpt['net'])
        missing_keys = net.load_state_dict(weights, strict=True)
        print(missing_keys)

        net.eval()
        nets.append(net)

    probs_list = list()
    targets_list = list()
    with torch.set_grad_enabled(False):
        # with autograd.detect_anomaly():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs = _distort_image(distortion, inputs)
            inputs = inputs.to(device)

            all_probs = list()
            for net in nets:
                # view_inputs(inputs)
                logits = net(inputs)
                probs = logits.softmax(dim=1)
                all_probs.append(probs.detach().cpu())
            all_probs = torch.stack(all_probs)
            all_probs = all_probs.contiguous().permute(1, 2, 0)
            var, mean = torch.var_mean(all_probs, dim=2, unbiased=True)

            probs_list.append(mean)
            targets_list.append(targets)

    probs = torch.cat(probs_list)
    targets = torch.cat(targets_list)
    ece_loss = ece_criterion(probs, targets, is_logits=False).item()

    max_probs, max_ind = probs.max(dim=1)
    all_correct = max_ind.eq(targets).float().sum().item()
    accuracy = all_correct / probs.shape[0]

    print('%-3s (accuracy) : %.5f' % (distortion, accuracy))
    print('%-3s (ece) : %.5f' % (distortion, ece_loss))

    draw_histogram(max_probs.tolist(), distortion, config['exp']['path'])
Ejemplo n.º 4
0
def do_test_beyesian():
    num_infer = config['params']['num_infer']
    ''' Model'''
    net = net_factory.load_model(config=config,
                                 num_classes=num_classes,
                                 dropout=config['params']['dropout'])
    net = net.to(device)
    ckpt = torch.load(os.path.join(config['exp']['path'], 'best.pth'),
                      map_location=device)
    weights = utils._load_weights(ckpt['net'])
    missing_keys = net.load_state_dict(weights, strict=True)
    print(missing_keys)
    '''print out net'''
    num_parameters = 0.
    for param in net.parameters():
        sizes = param.size()

        num_layer_param = 1.
        for size in sizes:
            num_layer_param *= size
        num_parameters += num_layer_param
    print("num. of parameters : " + str(num_parameters))
    ''' inference '''
    net.eval()
    net.apply(apply_mc_dropout)
    print(net)

    certainties = list()
    with torch.set_grad_enabled(False):
        # with autograd.detect_anomaly():
        for batch_idx, (inputs, targets) in enumerate(tqdm(data_loader)):
            inputs = inputs.to(device)

            all_probs = list()
            for iter_t in range(num_infer):
                # view_inputs(inputs)
                logits = net(inputs)
                probs = logits.softmax(dim=1)
                all_probs.append(probs.detach().cpu())
            all_probs = torch.stack(all_probs)
            all_probs = all_probs.contiguous().permute(1, 2, 0)
            var, mean = torch.var_mean(all_probs, dim=2, unbiased=True)
            max_probs, max_ind = mean.max(dim=1)
            # var = var[torch.arange(0, inputs.shape[0]), max_ind]
            var = var.mean(dim=1)
            # max_probs = torch.ones([1], dtype=torch.float32) - max_probs
            certainties.extend(max_probs.tolist())

    draw_histogram(certainties, config['exp']['path'])
Ejemplo n.º 5
0
def do_test(is_scaling=False):
    ''' Model'''
    net = net_factory.load_model(config=config, num_classes=num_classes)
    net = net.to(device)
    if is_scaling is True:
        ckpt = torch.load(os.path.join(config['exp']['path'],
                                       'model_with_temperature.pth'),
                          map_location=device)
        weights = utils._load_weights(ckpt)
        net = ModelWithTemperature(net)
        net = net.to(device)
        missing_keys = net.load_state_dict(weights, strict=True)
        print(missing_keys)
    else:
        ckpt = torch.load(os.path.join(config['exp']['path'], 'best.pth'),
                          map_location=device)
        weights = utils._load_weights(ckpt['net'])
        missing_keys = net.load_state_dict(weights, strict=True)
        print(missing_keys)
    '''print out net'''
    num_parameters = 0.
    for param in net.parameters():
        sizes = param.size()

        num_layer_param = 1.
        for size in sizes:
            num_layer_param *= size
        num_parameters += num_layer_param
    print("num. of parameters : " + str(num_parameters))
    ''' inference '''
    net.eval()

    certainties = list()
    with torch.set_grad_enabled(False):
        # with autograd.detect_anomaly():
        for batch_idx, (inputs, targets) in enumerate(tqdm(data_loader)):
            inputs = inputs.to(device)

            # view_inputs(inputs)
            logits = net(inputs)
            probs = logits.softmax(dim=1)

            max_probs, max_ind = probs.detach().cpu().max(dim=1)
            # max_probs = torch.ones([1], dtype=torch.float32) - max_probs
            certainties.extend(max_probs.tolist())

    draw_histogram(certainties, config['exp']['path'])
Ejemplo n.º 6
0
def do_test_ensemble():
    num_ensemble = config['params']['num_ensembles']
    nets = list()

    for iter_idx in range(num_ensemble):
        net = net_factory.load_model(config=config, num_classes=num_classes)
        net = net.to(device)

        weight_file = 'best_' + str(iter_idx) + '.pth'
        ckpt = torch.load(os.path.join(config['exp']['path'], weight_file),
                          map_location=device)
        weights = utils._load_weights(ckpt['net'])
        missing_keys = net.load_state_dict(weights, strict=True)
        print(missing_keys)

        net.eval()
        nets.append(net)

    certainties = list()
    with torch.set_grad_enabled(False):
        # with autograd.detect_anomaly():
        for batch_idx, (inputs, targets) in enumerate(tqdm(data_loader)):
            inputs = inputs.to(device)

            all_probs = list()
            for net in nets:
                # view_inputs(inputs)
                logits = net(inputs)
                probs = logits.softmax(dim=1)
                all_probs.append(probs.detach().cpu())
            all_probs = torch.stack(all_probs)
            all_probs = all_probs.contiguous().permute(1, 2, 0)
            var, mean = torch.var_mean(all_probs, dim=2, unbiased=True)
            max_probs, max_ind = mean.max(dim=1)
            # var = var[torch.arange(0, inputs.shape[0]), max_ind]
            var = var.mean(dim=1)
            # max_probs = torch.ones([1], dtype=torch.float32) - max_probs
            certainties.extend(max_probs.tolist())

    draw_histogram(certainties, config['exp']['path'])
        num_train = num_train - num_valid

        train_dataset, valid_dataset = torch.utils.data.random_split(train_data, [num_train, num_valid])
    else:
        raise NotImplementedError('Unsupported Dataset: ' + str(config['data']['name']))

    assert valid_dataset

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=config['params']['batch_size'],
        shuffle=False, num_workers=config['params']['workers'],
        collate_fn=collate_fn_test,
        pin_memory=True)

    ''' Load Model'''
    net = net_factory.load_model(config=config, num_classes=num_classes)
    net = net.to(device)
    ckpt = torch.load(config['basenet']['path'], map_location=device)
    weights = utils._load_weights(ckpt['net'])
    missing_keys = net.load_state_dict(weights, strict=False)
    print(missing_keys)

    # Now we're going to wrap the model with a decorator that adds temperature scaling
    temp_model = ModelWithTemperature(net)
    temp_model = temp_model.to(device)

    # Tune the model temperature, and save the results
    temp_model.set_temperature(valid_loader, device)
    model_filename = os.path.join(config['exp']['path'], 'model_with_temperature.pth')
    torch.save(temp_model.state_dict(), model_filename)
    print('Temperature scaled model save to %s' % model_filename)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=config['params']['batch_size'],
    shuffle=True, num_workers=config['params']['workers'],
    collate_fn=collate_fn_train,
    pin_memory=True)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=config['params']['batch_size'],
    shuffle=False, num_workers=config['params']['workers'],
    collate_fn=collate_fn_test,
    pin_memory=True)

dataloaders = {'train': train_loader, 'valid': valid_loader}

''' Model'''
if 'dropout' in config['params']:
    net = net_factory.load_model(config=config, num_classes=num_classes, dropout=config['params']['dropout'])
else:
    net = net_factory.load_model(config=config, num_classes=num_classes, dropout=None)
net = net.to(device)

'''print out net'''
num_parameters = 0.
for param in net.parameters():
    sizes = param.size()

    num_layer_param = 1.
    for size in sizes:
        num_layer_param *= size
    num_parameters += num_layer_param
print(net)
print("num. of parameters : " + str(num_parameters))