コード例 #1
0
def get_destructive_overthinking_samples(models_path, device='cpu'):
    sdn_name = 'tinyimagenet_vgg16bn_sdn_ic_only'
    sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1)
    sdn_model.to(device)
    dataset = af.get_dataset(sdn_params['task'])
    output_path = 'only_first'
    af.create_path(output_path)
    

    layer_correct, layer_wrong, layer_predictions, _ = mf.sdn_get_detailed_results(sdn_model, loader=dataset.test_loader, device=device)
    layers = sorted(list(layer_correct.keys()))
    
    all_correct = set()

    for layer in layers[1:]:
            all_correct = all_correct | layer_correct[layer]
    
    only_first = layer_correct[layers[0]] - all_correct
    
    for instance_id in only_first:
        instance_path = dataset.testset_paths.imgs[instance_id][0]
        filename = os.path.basename(instance_path)
        print(instance_path)
        first_predict = layer_predictions[0][instance_id][0]
        last_predict = layer_predictions[layers[-1]][instance_id][0]
        first_predict = dataset.testset_paths.classes[first_predict]
        last_predict = dataset.testset_paths.classes[last_predict]

        filename = '{}_{}_{}'.format(first_predict, last_predict, filename)
        copyfile(instance_path, output_path+'/'+filename)
コード例 #2
0
def get_simple_complex(models_path, device='cpu'):
    sdn_name = 'tinyimagenet_vgg16bn_sdn_ic_only'
    sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1)
    sdn_model.to(device)
    dataset = af.get_dataset(sdn_params['task'])
    output_path = 'simple_complex_images'
    af.create_path(output_path)
    dog_path = output_path+'/'+'dog'
    cat_path = output_path+'/'+'cat'
    af.create_path(dog_path)
    af.create_path(cat_path)

    # n02099601 dog 26
    # n02123394 cat 31
    
    layer_correct, layer_wrong, _, _ = mf.sdn_get_detailed_results(sdn_model, loader=dataset.test_loader, device=device)
    layers = sorted(list(layer_correct.keys()))
    
    wrong_until = layer_wrong[layers[0]] | layer_correct[layers[0]]
    for layer in layers[:-1]:
        instances = layer_correct[layer] & wrong_until            
        wrong_until = wrong_until - layer_correct[layer]
        print('IC: {}, Num images: {}'.format(layer, len(instances)))
        for instance_id in instances:
            instance_path = dataset.testset_paths.imgs[instance_id][0]
            filename = '{}_{}'.format(layer, os.path.basename(instance_path))
            if 'n02099601' in instance_path:
                copyfile(instance_path, dog_path+'/'+filename)
            if 'n02123394' in instance_path:
                copyfile(instance_path, cat_path+'/'+filename)
コード例 #3
0
def train_model(models_path, cr_params, device, num=0):
    type, mode, pruning, ics = cr_params
    model, params = arcs.create_resnet_iterative(models_path, type, mode,
                                                 pruning, ics, False)
    dataset = af.get_dataset('cifar10')
    params['name'] = params['base_model'] + '_{}_{}'.format(type, mode)
    if model.prune:
        params['name'] += "_prune_{}".format(
            [x * 100 for x in model.keep_ratio])
        print("prune: {}".format(model.keep_ratio))
    if mode == "0":
        params['epochs'] = 250
        params['milestones'] = [120, 160, 180]
        params['gammas'] = [0.1, 0.01, 0.01]

    if mode == "1":
        params['epochs'] = 300
        params['milestones'] = [100, 150, 200]
        params['gammas'] = [0.1, 0.1, 0.1]

    if "full" in type:
        params['learning_rate'] = 0.1
    print("lr: {}".format(params['learning_rate']))

    opti_param = (params['learning_rate'], params['weight_decay'],
                  params['momentum'], -1)
    lr_schedule_params = (params['milestones'], params['gammas'])

    model.to(device)
    train_params = dict(
        epochs=params['epochs'],
        epoch_growth=[25, 50, 75],
        epoch_prune=[10, 35, 60, 85, 110, 135, 160],  #[10, 35, 60, 85],
        prune_batch_size=pruning[2],
        prune_type='2',  # 0 skip layer, 1 normal full, 2 iterative
        reinit=False,
        min_ratio=[
            0.3, 0.1, 0.05, 0.05
        ]  # not needed if skip layers, minimum for the iterative pruning
    )

    params['epoch_growth'] = train_params['epoch_growth']
    params['epoch_prune'] = train_params['epoch_prune']
    optimizer, scheduler = af.get_full_optimizer(model, opti_param,
                                                 lr_schedule_params)
    metrics, best_model = model.train_func(model, dataset, train_params,
                                           optimizer, scheduler, device)
    _link_metrics(params, metrics)

    af.print_sparsity(best_model)

    arcs.save_model(best_model, params, models_path, params['name'], epoch=-1)
    print("test acc: {}, last val: {}".format(params['test_top1_acc'],
                                              params['valid_top1_acc'][-1]))
    return best_model, params
コード例 #4
0
def destructive_overthinking_experiment(models_path, device='cpu'):
    #sdn_name = 'cifar10_vgg16bn_bd_sdn_converted'; add_trigger = True  # for the backdoored network
    
    add_trigger = False
    
    #task = 'cifar10'
    #task = 'cifar100'
    task = 'tinyimagenet'
   
    network = 'vgg16bn'
    #network = 'resnet56'
    #network = 'wideresnet32_4'
    #network = 'mobilenet'
    
    sdn_name = task + '_' + network + '_sdn_ic_only'

    sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1)
    sdn_model.to(device)
    dataset = af.get_dataset(sdn_params['task'], add_trigger=add_trigger)
    
    top1_test, top5_test = mf.sdn_test(sdn_model, dataset.test_loader, device)
    print('Top1 Test accuracy: {}'.format(top1_test))
    print('Top5 Test accuracy: {}'.format(top5_test))

    layer_correct, layer_wrong, _, layer_confidence = mf.sdn_get_detailed_results(sdn_model, loader=dataset.test_loader, device=device)
    layers = sorted(list(layer_correct.keys()))

    end_wrong = layer_wrong[layers[-1]]
    cum_correct = set()

    for layer in layers:
        cur_correct = layer_correct[layer]
        cum_correct = cum_correct | cur_correct
        cur_overthinking = cur_correct & end_wrong
        
        print('Output: {}'.format(layer))
        print('Current correct: {}'.format(len(cur_correct)))
        print('Cumulative correct: {}'.format(len(cum_correct)))
        print('Cur cat. overthinking: {}\n'.format(len(cur_overthinking)))

        total_confidence = 0.0
        for instance in cur_overthinking:
                total_confidence += layer_confidence[layer][instance]
        
        print('Average confidence on destructive overthinking instances:{}'.format(total_confidence/(0.1 + len(cur_overthinking))))
        
        total_confidence = 0.0
        for instance in cur_correct:
                total_confidence += layer_confidence[layer][instance]
        
        print('Average confidence on correctly classified :{}'.format(total_confidence/(0.1 + len(cur_correct))))
コード例 #5
0
def wasteful_overthinking_experiment(models_path, device='cpu'):
    #task = 'cifar10'
    #task = 'cifar100'
    task = 'tinyimagenet'
    
    network = 'vgg16bn'
    #network = 'resnet56'
    #network = 'wideresnet32_4'
    #network = 'mobilenet'
    
    sdn_name = task + '_' + network + '_sdn_ic_only'
    
    sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1)
    sdn_model.to(device)
    dataset = af.get_dataset(sdn_params['task'])

    top1_test, top5_test = mf.sdn_test(sdn_model, dataset.test_loader, device)
    print('Top1 Test accuracy: {}'.format(top1_test))
    print('Top5 Test accuracy: {}'.format(top5_test))

    layer_correct, _, _, _ = mf.sdn_get_detailed_results(sdn_model, loader=dataset.test_loader, device=device)
    layers = sorted(list(layer_correct.keys()))

    end_correct = layer_correct[layers[-1]]
    total = 10000

    # to quantify the computational waste
    c_i = [0.15, 0.3, 0.45, 0.6, 0.75, 0.9]
    total_comp = 0

    cum_correct = set()
    for layer in layers:
        cur_correct = layer_correct[layer]
        unique_correct = cur_correct - cum_correct
        cum_correct = cum_correct | cur_correct

        print('Output: {}'.format(layer))
        print('Current correct: {}'.format(len(cur_correct)))
        print('Cumulative correct: {}'.format(len(cum_correct)))
        print('Unique correct: {}\n'.format(len(unique_correct)))

        if layer < layers[-1]:
                total_comp += len(unique_correct) * c_i[layer]
        else:
                total_comp += total - (len(cum_correct) - len(unique_correct))
        
    print('Total Comp: {}'.format(total_comp))
コード例 #6
0
def train_model(models_path, device):
    _, sdn = arcs.create_resnet56(models_path, 'cifar10', save_type='d')
    print('snd name: {}'.format(sdn))
    # train_sdn(models_path, sdn, device)
    print("Training model...")
    trained_model, model_params = arcs.load_model(models_path, sdn, 0)
    dataset = af.get_dataset(model_params['task'])
    lr = model_params['learning_rate']
    momentum = model_params['momentum']
    weight_decay = model_params['weight_decay']
    milestones = model_params['milestones']
    gammas = model_params['gammas']
    num_epochs = model_params['epochs']

    model_params['optimizer'] = 'SGD'

    opti_param = (lr, weight_decay, momentum, -1)
    lr_schedule_params = (milestones, gammas)

    optimizer, scheduler = af.get_full_optimizer(trained_model, opti_param,
                                                 lr_schedule_params)
    trained_model_name = sdn + '_training'

    print('Training: {}...'.format(trained_model_name))
    trained_model.to(device)
    metrics = trained_model.train_func(trained_model,
                                       dataset,
                                       num_epochs,
                                       optimizer,
                                       scheduler,
                                       device=device)
    model_params['train_top1_acc'] = metrics['train_top1_acc']
    model_params['test_top1_acc'] = metrics['test_top1_acc']
    model_params['train_top3_acc'] = metrics['train_top3_acc']
    model_params['test_top3_acc'] = metrics['test_top3_acc']
    model_params['epoch_times'] = metrics['epoch_times']
    model_params['lrs'] = metrics['lrs']
    total_training_time = sum(model_params['epoch_times'])
    model_params['total_time'] = total_training_time
    print('Training took {} seconds...'.format(total_training_time))
    arcs.save_model(trained_model,
                    model_params,
                    models_path,
                    trained_model_name,
                    epoch=-1)
    return trained_model, dataset
コード例 #7
0
def check_performance(trained_models_path, device):
    add_trigger = False

    task = 'tinyimagenet'

    network = 'vgg16bn'

    sdn_name = task + '_' + network + '_sdn_sdn_training_ds'

    sdn_model, sdn_params = arcs.load_model(trained_models_path,
                                            sdn_name,
                                            epoch=-1)
    sdn_model.to(device)
    dataset = af.get_dataset(sdn_params['task'], add_trigger=add_trigger)

    policy_net_all = []
    for i in range(sum(sdn_params['add_ic'])):
        policy_net = PolicyNet('tiny_imagenet', 200).to(device)
        policy_net.load_state_dict(torch.load('./policy_%d.dump' % (i + 1)))
        policy_net_all.append(policy_net)

    predictions = list()
    stops = list()
    for i, batch in enumerate(dataset.test_loader):
        val_x, val_y = batch
        val_x = val_x.to(device)
        val_y = val_y.to(device)
        with torch.no_grad():
            xhs = sdn_model(val_x.to(device))
        predictions.append(torch.stack(xhs))

        policy_pred = list()
        for t in range(sum(sdn_params['add_ic'])):
            policy_tmp = policy_net_all[t](val_x, xhs[t])
            policy_pred.append(policy_tmp)
        policy_pred = torch.cat(policy_pred, axis=-1)
        stops.append(policy_pred)

    stops = torch.cat(stops, axis=0)
    predictions = torch.cat(predictions, axis=1)
コード例 #8
0
def early_exit_experiments(models_path, device='cpu'):
    sdn_training_type = 'ic_only' # IC-only training
    #sdn_training_type = 'sdn_training' # SDN training


    # task = 'cifar10'
    # task = 'cifar100'
    task = 'tinyimagenet'

    #sdn_names = ['vgg16bn_sdn', 'resnet56_sdn', 'wideresnet32_4_sdn', 'mobilenet_sdn']; add_trigger = False
    
    sdn_names = ['vgg16bn_sdn']; add_trigger = False
    sdn_names = [task + '_' + sdn_name + '_' + sdn_training_type for sdn_name in sdn_names]

    
    for sdn_name in sdn_names:
        cnn_name = sdn_name.replace('sdn', 'cnn')
        cnn_name = cnn_name.replace('_ic_only', '')
        cnn_name = cnn_name.replace('_sdn_training', '')

        print(sdn_name)
        print(cnn_name)

        sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1)
        sdn_model.to(device)

        dataset = af.get_dataset(sdn_params['task'])

        cnn_model, _ = arcs.load_model(models_path, cnn_name, epoch=-1)
        cnn_model.to(device)

        print('Get CNN results')
        top1_test, top5_test, total_time = mf.cnn_test_time(cnn_model, dataset.test_loader, device)
        total_ops, total_params = profile(cnn_model, cnn_model.input_size, device)
        print("#Ops: %f GOps"%(total_ops/1e9))
        print("#Parameters: %f M"%(total_params/1e6))
        print('Top1 Test accuracy: {}'.format(top1_test))
        #print('Top5 Test accuracy: {}'.format(top5_test))

        print('25 percent cost: {}'.format((total_ops/1e9)*0.25))
        print('50 percent cost: {}'.format((total_ops/1e9)*0.5))
        print('75 percent cost: {}'.format((total_ops/1e9)*0.75))


        # to test early-exits with the SDN
        one_batch_dataset = af.get_dataset(sdn_params['task'], 1)
        print('Get SDN early exit results')
        total_ops, total_params = profile_sdn(sdn_model, sdn_model.input_size, device)
        print("#Ops (GOps): {}".format(total_ops))
        print("#Params (mil): {}".format(total_params))
        top1_test, top5_test = mf.sdn_test(sdn_model, dataset.test_loader, device)
        print('Top1 Test accuracy: {}'.format(top1_test))
        #print('Top5 Test accuracy: {}'.format(top5_test))


        print('Calibrate confidence_thresholds')
        confidence_thresholds = [0.1, 0.15, 0.25, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 0.999] # search for the confidence threshold for early exits
        sdn_model.forward = sdn_model.early_exit
        
        
        for threshold in confidence_thresholds:
            print(threshold)
            sdn_model.confidence_threshold = threshold

            # change the forward func for sdn to forward with cascade
            top1_test, top5_test, early_exit_counts, non_conf_exit_counts, total_time = mf.sdn_test_early_exits(sdn_model, one_batch_dataset.test_loader, device)
            
            average_mult_ops = 0
            total_num_instances = 0

            for output_id, output_count in enumerate(early_exit_counts):
                average_mult_ops += output_count*total_ops[output_id]
                total_num_instances += output_count

            for output_count in non_conf_exit_counts:
                total_num_instances += output_count
                average_mult_ops += output_count*total_ops[output_id]
            
            average_mult_ops /= total_num_instances

            print('Early exit Counts:')
            print(early_exit_counts)

            print('Non confident exit counts:')
            print(non_conf_exit_counts)

            print('Top1 Test accuracy: {}'.format(top1_test))
            print('Top5 Test accuracy: {}'.format(top5_test))
            print('SDN cascading took {} seconds.'.format(total_time))
            print('Average Mult-Ops: {}'.format(average_mult_ops))
コード例 #9
0
ファイル: train_networks.py プロジェクト: xinshi-chen/l2stop
def train(models_path,
          untrained_models,
          sdn=False,
          ic_only_sdn=False,
          device='cpu',
          ds=False):
    print('Training models...')

    for base_model in untrained_models:
        trained_model, model_params = arcs.load_model(models_path, base_model,
                                                      0)
        dataset = af.get_dataset(model_params['task'])

        learning_rate = model_params['learning_rate']
        momentum = model_params['momentum']
        weight_decay = model_params['weight_decay']
        milestones = model_params['milestones']
        gammas = model_params['gammas']
        num_epochs = model_params['epochs']

        model_params['optimizer'] = 'SGD'

        if ic_only_sdn:  # IC-only training, freeze the original weights
            learning_rate = model_params['ic_only']['learning_rate']
            num_epochs = model_params['ic_only']['epochs']
            milestones = model_params['ic_only']['milestones']
            gammas = model_params['ic_only']['gammas']

            model_params['optimizer'] = 'Adam'

            trained_model.ic_only = True
        else:
            trained_model.ic_only = False

        if ds:
            trained_model.ds = True
        else:
            trained_model.ds = False

        optimization_params = (learning_rate, weight_decay, momentum)
        lr_schedule_params = (milestones, gammas)

        # pdb.set_trace()

        if sdn:
            if ic_only_sdn:
                optimizer, scheduler = af.get_sdn_ic_only_optimizer(
                    trained_model, optimization_params, lr_schedule_params)
                trained_model_name = base_model + '_ic_only_ic{}'.format(
                    np.sum(model_params['add_ic']))

            else:
                optimizer, scheduler = af.get_full_optimizer(
                    trained_model, optimization_params, lr_schedule_params)
                trained_model_name = base_model + '_sdn_training_ic{}'.format(
                    np.sum(model_params['add_ic']))

        else:
            optimizer, scheduler = af.get_full_optimizer(
                trained_model, optimization_params, lr_schedule_params)
            trained_model_name = base_model

        if ds:
            trained_model_name = trained_model_name + '_ds'
        # pdb.set_trace()
        print('Training: {}...'.format(trained_model_name))
        # trained_model = nn.DataParallel(trained_model)
        trained_model.to(device)
        metrics = trained_model.train_func(trained_model,
                                           dataset,
                                           num_epochs,
                                           optimizer,
                                           scheduler,
                                           device=device)
        model_params['train_top1_acc'] = metrics['train_top1_acc']
        model_params['test_top1_acc'] = metrics['test_top1_acc']
        model_params['train_top5_acc'] = metrics['train_top5_acc']
        model_params['test_top5_acc'] = metrics['test_top5_acc']
        model_params['epoch_times'] = metrics['epoch_times']
        model_params['lrs'] = metrics['lrs']
        total_training_time = sum(model_params['epoch_times'])
        model_params['total_time'] = total_training_time
        print('Training took {} seconds...'.format(total_training_time))
        arcs.save_model(trained_model,
                        model_params,
                        models_path,
                        trained_model_name,
                        epoch=-1)
コード例 #10
0
def iter_training_0(model, data, params, optimizer, scheduler, device='cpu'):
    print("iter training 0")
    augment = model.augment_training
    metrics = {
        'epoch_times': [],
        'valid_top1_acc': [],
        'valid_top3_acc': [],
        'train_top1_acc': [],
        'train_top3_acc': [],
        'test_top1_acc': [],
        'test_top3_acc': [],
        'lrs': []
    }
    epochs, epoch_growth, epoch_prune = params['epochs'], params[
        'epoch_growth'], params['epoch_prune']
    pruning_batch_size, pruning_type, reinit = params[
        'prune_batch_size'], params['prune_type'], params['reinit']

    #epoch_growth = [25, 50, 75]  # [(i + 1) * epochs / (model.num_ics + 1) for i in range(model.num_ics)]
    print("array params: num_ics {}, epochs {}".format(model.num_ics, epochs))
    print("epochs growth: {}".format(epoch_growth))
    print("epochs prune: {}".format(epoch_prune))

    max_coeffs = calc_coeff(model)
    print('max_coeffs: {}'.format(max_coeffs))
    model.to(device)
    model.to_train()

    if model.prune:
        prune_dataset = af.get_dataset('cifar10',
                                       batch_size=pruning_batch_size)
        print("pruning_batch_size: {}, prune_type: {}, reinit: {}".format(
            pruning_batch_size, pruning_type, reinit))
        print("min_ratio: {}".format(params['min_ratio']))
        print("keep_ratio: {}".format(model.keep_ratio))

    best_model, accuracies, best_epoch = None, None, 0
    masks = []
    mask1 = None
    block_to_prune = 0
    for epoch in range(1, epochs + 1):
        print('\nEpoch: {}/{}'.format(epoch, epochs))
        if epoch in epoch_growth:
            grown_layers = model.grow()
            model.to(device)
            optimizer.add_param_group({'params': grown_layers})
            print("model grow")

        if epoch in epoch_prune and model.prune:
            loader = get_loader(prune_dataset, False)
            if pruning_type == '0':
                mask1 = prune_skip_layer(model, model.keep_ratio, loader,
                                         sdn_loss, block_to_prune, mask1,
                                         device, reinit)
                block_to_prune += 1
            elif pruning_type == '1':
                prune2(model, model.keep_ratio, loader, sdn_loss, device)
            elif pruning_type == "2":
                steps = []
                _epoch_growth = [1] + epoch_growth
                for i in range(len(_epoch_growth)):
                    if epoch < _epoch_growth[i]:
                        steps.append(0)
                    else:
                        v = [
                            0 if e < _epoch_growth[i] or e > epoch else 1
                            for e in epoch_prune
                        ]
                        steps.append(sum(v) - 1)
                mask = prune_iterative(model, model.keep_ratio,
                                       params['min_ratio'], steps, loader,
                                       sdn_loss, device, reinit)
                masks.append(mask)
                mask1 = mask

        epoch_routine(model, data, optimizer, scheduler, epoch, epochs,
                      augment, metrics, device)

        if model.num_output == model.num_ics + 1:
            if model.prune and epoch >= epoch_prune[-1]:
                print("pruning for best_model")
                best_model, accuracies, best_epoch = best_model_def(
                    best_model, model, accuracies, best_epoch, metrics, epoch)
            elif not model.prune:
                best_model, accuracies, best_epoch = best_model_def(
                    best_model, model, accuracies, best_epoch, metrics, epoch)

        af.print_sparsity(model)

    metrics['test_top1_acc'], metrics['test_top3_acc'] = sdn_test(
        best_model, data.test_loader, device)
    test_top1, _ = sdn_test(model, data.test_loader, device)
    metrics['best_model_epoch'] = best_epoch
    metrics['masks'] = masks
    print("best epoch: {}".format(best_epoch))
    print("comparison best and latest: {}/{}".format(metrics['test_top1_acc'],
                                                     test_top1))
    return metrics, best_model
コード例 #11
0
def policy_training(models_path, device='cpu'):
    #sdn_name = 'cifar10_vgg16bn_bd_sdn_converted'; add_trigger = True  # for the backdoored network

    add_trigger = False

    #task = 'cifar10'
    # task = 'cifar100'
    task = 'tinyimagenet'

    network = 'vgg16bn'
    #network = 'resnet56'
    #network = 'wideresnet32_4'
    #network = 'mobilenet'

    # sdn_name = task + '_' + network + '_sdn_ic_only'
    sdn_name = task + '_' + network + '_sdn_ic_only_ic1'
    # sdn_name = task + '_' + network + '_sdn_ic_only_ic1_ds'
    # sdn_name = task + '_' + network + '_sdn_sdn_training'
    # sdn_name = task + '_' + network + '_sdn_sdn_training_ds'
    # sdn_name = task + '_' + network + '_sdn_sdn_training_ic14_ds'

    sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1)
    sdn_model.to(device)
    dataset = af.get_dataset(sdn_params['task'], add_trigger=add_trigger)

    # need to construct the policy network and train the policy net.
    # the architecture of the policy network need to be designed.

    ######################################
    # need to think about the model of policynet
    ######################################
    sdn_model.eval()
    p_true_all = list()
    xhs_all = list()
    y_all = list()
    for batch in dataset.val_loader:
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        batch_size = y.shape[0]
        with torch.no_grad():
            xhs = sdn_model(x)
            categories = xhs[-1].shape[-1]
            # pdb.set_trace()
            # internal_fm = sdn_model.internal_fm
            # sdn_model.internal_fm = [None]*len(internal_fm)
            p_true, _ = PolicyKL.true_posterior(cmd_args, xhs, y)

        xhs_all.append(xhs)
        y_all.append(y)
        p_true_all.append(p_true)

    p_true = torch.cat(p_true_all, dim=0)
    p_det = max_onehot(p_true, dim=-1, device=device)
    p_true = torch.mean(p_true, dim=0)
    # find positions with nonzero posterior
    train_post = {}
    nz_post = {}
    i = 0
    for t in range(cmd_args.num_output):
        if p_true[t] > 0.001:
            train_post[i] = t
            nz_post[i] = t
            i += 1
    del train_post[i - 1]

    p_str = 'val p true:['
    p_str += ','.join(['%0.3f' % p_true[t] for t in nz_post.values()])
    print(p_str + ']')

    p_det = torch.mean(p_det, dim=0)
    p_str = 'val p true det:['
    p_str += ','.join(['%0.3f' % p_det[t] for t in nz_post.values()])
    print(p_str + ']')
    ######################################

    ####
    #check the performance based on confidence score
    ####
    y_all = torch.cat(y_all, dim=-1)
    xhs_all = list(zip(*xhs_all))
    for i in range(len(xhs_all)):
        xhs_all[i] = torch.cat(xhs_all[i], dim=0)
        print('The {}th classifier performance:'.format(i))
        prec1, prec5 = data.accuracy(xhs_all[i], y_all, topk=(1, 5))
        print('Top1 Test accuracy: {}'.format(prec1))
        print('Top5 Test accuracy: {}'.format(prec5))

    xhs_all = list(map(lambda x: F.softmax(x, dim=-1), xhs_all))
    max_confidences = list(map(lambda x: torch.max(x, dim=-1)[0], xhs_all))
    max_confidences = torch.stack(max_confidences, dim=-1)
    xhs_all_stack = torch.stack(xhs_all, dim=1)
    predictions = list(map(lambda x: torch.argmax(x, dim=-1), xhs_all))
    predictions = torch.stack(predictions, dim=-1)

    thresholds = [
        0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 0.999,
        -1
    ]
    # thresholds = [0.8, 0.9, 0.95, 0.99, 0.999, -1]
    for threshold in thresholds:
        if threshold == -1:
            index = torch.argmax(max_confidences, dim=-1).cpu().numpy()
            # pdb.set_trace()
        else:
            mask = (max_confidences > threshold).to(int).cpu().numpy()
            mask[:, -1] = 1
            index = np.array(list(map(lambda x: list(x).index(1), list(mask))))
        results = xhs_all_stack.gather(
            1,
            torch.Tensor([index] * 200).t().view(
                -1, 1, 200).long().to(device)).squeeze()

        prec1, prec5 = data.accuracy(results, y_all, topk=(1, 5))
        print('htreshold: ', threshold)
        print('Top1 Test accuracy: {}'.format(prec1))
        print('Top5 Test accuracy: {}'.format(prec5))
    ####
    #confidence score check finish
    ####

    # pdb.set_trace()
    internal_fm = [torch.rand(2, 2) for i in range(cmd_args.num_output)]
    # initialize nets with nonzero posterior
    if cmd_args.model_type == 'sequential':
        score_net = MNIconfidence(cmd_args,
                                  x,
                                  internal_fm,
                                  train_post,
                                  category=categories,
                                  share=cmd_args.share,
                                  net_size=cmd_args.net_size)
        score_net.to(device)
        # print('Sequential model to be implemented')
    if cmd_args.model_type == 'multiclass':
        score_net = MulticlassNetImage(cmd_args,
                                       x,
                                       internal_fm,
                                       train_post,
                                       category=categories)
        score_net.to(device)
    if cmd_args.model_type == 'confidence':
        score_net = MNIconfidence(cmd_args,
                                  x,
                                  internal_fm,
                                  train_post,
                                  category=categories,
                                  share=cmd_args.share,
                                  net_size=cmd_args.net_size)
        score_net.to(device)
    if cmd_args.model_type == 'imiconfidence':
        score_net = Imiconfidence(cmd_args,
                                  x,
                                  internal_fm,
                                  train_post,
                                  category=categories,
                                  share=cmd_args.share,
                                  net_size=cmd_args.net_size)
        score_net.to(device)

    # train
    if cmd_args.phase == 'train':

        # start training
        optimizer = optim.Adam(list(score_net.parameters()),
                               lr=cmd_args.learning_rate,
                               weight_decay=cmd_args.weight_decay)
        milestones = [10, 20, 40, 60, 80]
        gammas = [0.4, 0.2, 0.2, 0.2, 0.2]
        scheduler = MultiStepMultiLR(optimizer,
                                     milestones=milestones,
                                     gammas=gammas)
        trainer = PolicyKL(args=cmd_args,
                           sdn_model=sdn_model,
                           score_net=score_net,
                           train_post=train_post,
                           nz_post=nz_post,
                           optimizer=optimizer,
                           data_loader=dataset,
                           device=device,
                           scheduler=scheduler,
                           sdn_name=sdn_name)
        trainer.train()
        #pdb.set_trace()
    # test
    dump = cmd_args.save_dir + '/{}_best_val_policy.dump'.format(sdn_name)
    print('Loading model...')
    score_net.load_state_dict(torch.load(dump))

    PolicyKL.test(args=cmd_args,
                  score_net=score_net,
                  sdn_model=sdn_model,
                  data_loader=dataset.test_loader,
                  nz_post=nz_post,
                  device=device)
    print(cmd_args.save_dir)