Exemplo n.º 1
0
def eval(tree: ProtoTree,
         test_loader: DataLoader,
         epoch,
         device,
         log: Log = None,
         sampling_strategy: str = 'distributed',
         log_prefix: str = 'log_eval_epochs',
         progress_prefix: str = 'Eval Epoch') -> dict:
    tree = tree.to(device)

    # Keep an info dict about the procedure
    info = dict()
    if sampling_strategy != 'distributed':
        info['out_leaf_ix'] = []
    # Build a confusion matrix
    cm = np.zeros((tree._num_classes, tree._num_classes), dtype=int)

    # Make sure the model is in evaluation mode
    tree.eval()

    # Show progress on progress bar
    test_iter = tqdm(enumerate(test_loader),
                     total=len(test_loader),
                     desc=progress_prefix + ' %s' % epoch,
                     ncols=0)

    # Iterate through the test set
    for i, (xs, ys) in test_iter:
        xs, ys = xs.to(device), ys.to(device)

        # Use the model to classify this batch of input data
        out, test_info = tree.forward(xs, sampling_strategy)
        ys_pred = torch.argmax(out, dim=1)

        # Update the confusion matrix
        cm_batch = np.zeros((tree._num_classes, tree._num_classes), dtype=int)
        for y_pred, y_true in zip(ys_pred, ys):
            cm[y_true][y_pred] += 1
            cm_batch[y_true][y_pred] += 1
        acc = acc_from_cm(cm_batch)
        test_iter.set_postfix_str(
            f'Batch [{i + 1}/{len(test_iter)}], Acc: {acc:.3f}')

        # keep list of leaf indices where test sample ends up when deterministic routing is used.
        if sampling_strategy != 'distributed':
            info['out_leaf_ix'] += test_info['out_leaf_ix']
        del out
        del ys_pred
        del test_info

    info['confusion_matrix'] = cm
    info['test_accuracy'] = acc_from_cm(cm)
    log.log_message("\nEpoch %s - Test accuracy with %s routing: " %
                    (epoch, sampling_strategy) + str(info['test_accuracy']))
    return info
Exemplo n.º 2
0
def init_tree(tree: ProtoTree, optimizer, scheduler, device, args: argparse.Namespace):
    epoch = 1
    mean = 0.5
    std = 0.1
    # load trained prototree if flag is set

    # NOTE: TRAINING FURTHER FROM A CHECKPOINT DOESN'T SEEM TO WORK CORRECTLY. EVALUATING A TRAINED PROTOTREE FROM A CHECKPOINT DOES WORK. 
    if args.state_dict_dir_tree != '':
        if not args.disable_cuda and torch.cuda.is_available():
            device = torch.device('cuda:{}'.format(torch.cuda.current_device()))
        else:
            device = torch.device('cpu')

        
        if args.disable_cuda or not torch.cuda.is_available():
        # tree = load_state(args.state_dict_dir_tree, device)
            tree = torch.load(args.state_dict_dir_tree+'/model.pth', map_location=device)
        else:
            tree = torch.load(args.state_dict_dir_tree+'/model.pth')
        tree.to(device=device)
        try:
            epoch = int(args.state_dict_dir_tree.split('epoch_')[-1]) + 1
        except:
            epoch=args.epochs+1
        print("Train further from epoch: ", epoch, flush=True)
        optimizer.load_state_dict(torch.load(args.state_dict_dir_tree+'/optimizer_state.pth', map_location=device))

        if epoch>args.freeze_epochs:
            for parameter in tree._net.parameters():
                parameter.requires_grad = True
        if not args.disable_derivative_free_leaf_optim:
            for leaf in tree.leaves:
                leaf._dist_params.requires_grad = False
        
        if os.path.isfile(args.state_dict_dir_tree+'/scheduler_state.pth'):
            # scheduler.load_state_dict(torch.load(args.state_dict_dir_tree+'/scheduler_state.pth'))
            # print(scheduler.state_dict(),flush=True)
            scheduler.last_epoch = epoch - 1
            scheduler._step_count = epoch
        

    elif args.state_dict_dir_net != '': # load pretrained conv network
        # initialize prototypes
        torch.nn.init.normal_(tree.prototype_layer.prototype_vectors, mean=mean, std=std)
        #strict is False so when loading pretrained model, ignore the linear classification layer
        tree._net.load_state_dict(torch.load(args.state_dict_dir_net+'/model_state.pth'), strict=False)
        tree._add_on.load_state_dict(torch.load(args.state_dict_dir_net+'/model_state.pth'), strict=False) 
    else:
        with torch.no_grad():
            # initialize prototypes
            torch.nn.init.normal_(tree.prototype_layer.prototype_vectors, mean=mean, std=std)
            tree._add_on.apply(init_weights_xavier)
    return tree, epoch
Exemplo n.º 3
0
def analyse_output_shape(tree: ProtoTree, trainloader: DataLoader, log: Log,
                         device):
    with torch.no_grad():
        # Get a batch of training data
        xs, ys = next(iter(trainloader))
        xs, ys = xs.to(device), ys.to(device)
        log.log_message("Image input shape: " + str(xs[0, :, :, :].shape))
        log.log_message("Features output shape (without 1x1 conv layer): " +
                        str(tree._net(xs).shape))
        log.log_message("Convolutional output shape (with 1x1 conv layer): " +
                        str(tree._add_on(tree._net(xs)).shape))
        log.log_message("Prototypes shape: " +
                        str(tree.prototype_layer.prototype_vectors.shape))
Exemplo n.º 4
0
def explain_local(args):
    if not args.disable_cuda and torch.cuda.is_available():
        device = torch.device('cuda:{}'.format(torch.cuda.current_device()))
    else:
        device = torch.device('cpu')

    # Log which device was actually used
    print('Device used: ', str(device))

    # Load trained ProtoTree
    tree = ProtoTree.load(args.prototree).to(device=device)
    # Obtain the dataset and dataloaders
    args.batch_size = 64  #placeholder
    args.augment = True  #placeholder
    _, _, _, classes, _ = get_dataloaders(args)
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    normalize = transforms.Normalize(mean=mean, std=std)
    test_transform = transform_no_augment = transforms.Compose([
        transforms.Resize(size=(args.image_size, args.image_size)),
        transforms.ToTensor(), normalize
    ])

    sample = test_transform(Image.open(
        args.sample_dir)).unsqueeze(0).to(device)

    gen_pred_vis(tree, sample, args.sample_dir, args.results_dir, args,
                 classes)
Exemplo n.º 5
0
def train_leaves_epoch(tree: ProtoTree,
                        train_loader: DataLoader,
                        epoch: int,
                        device,
                        progress_prefix: str = 'Train Leafs Epoch'
                        ) -> dict:

    #Make sure the tree is in eval mode for updating leafs
    tree.eval()

    with torch.no_grad():
        _old_dist_params = dict()
        for leaf in tree.leaves:
            _old_dist_params[leaf] = leaf._dist_params.detach().clone()
        # Optimize class distributions in leafs
        eye = torch.eye(tree._num_classes).to(device)

        # Show progress on progress bar
        train_iter = tqdm(enumerate(train_loader),
                        total=len(train_loader),
                        desc=progress_prefix+' %s'%epoch,
                        ncols=0)
        
        
        # Iterate through the data set
        update_sum = dict()

        # Create empty tensor for each leaf that will be filled with new values
        for leaf in tree.leaves:
            update_sum[leaf] = torch.zeros_like(leaf._dist_params)
        
        for i, (xs, ys) in train_iter:
            xs, ys = xs.to(device), ys.to(device)
            #Train leafs without gradient descent
            out, info = tree.forward(xs)
            target = eye[ys] #shape (batchsize, num_classes) 
            for leaf in tree.leaves:  
                if tree._log_probabilities:
                    # log version
                    update = torch.exp(torch.logsumexp(info['pa_tensor'][leaf.index] + leaf.distribution() + torch.log(target) - out, dim=0))
                else:
                    update = torch.sum((info['pa_tensor'][leaf.index] * leaf.distribution() * target)/out, dim=0)
                update_sum[leaf] += update

        for leaf in tree.leaves:
            leaf._dist_params -= leaf._dist_params #set current dist params to zero
            leaf._dist_params += update_sum[leaf] #give dist params new value
Exemplo n.º 6
0
def save_tree(tree: ProtoTree, optimizer, scheduler, epoch: int, log: Log,
              args: argparse.Namespace):
    tree.eval()
    # Save latest model
    tree.save(f'{log.checkpoint_dir}/latest')
    tree.save_state(f'{log.checkpoint_dir}/latest')
    torch.save(optimizer.state_dict(),
               f'{log.checkpoint_dir}/latest/optimizer_state.pth')
    torch.save(scheduler.state_dict(),
               f'{log.checkpoint_dir}/latest/scheduler_state.pth')

    # Save model every 10 epochs
    if epoch == args.epochs or epoch % 10 == 0:
        tree.save(f'{log.checkpoint_dir}/epoch_{epoch}')
        tree.save_state(f'{log.checkpoint_dir}/epoch_{epoch}')
        torch.save(optimizer.state_dict(),
                   f'{log.checkpoint_dir}/epoch_{epoch}/optimizer_state.pth')
        torch.save(scheduler.state_dict(),
                   f'{log.checkpoint_dir}/epoch_{epoch}/scheduler_state.pth')
Exemplo n.º 7
0
def get_similarity_maps(tree: ProtoTree, project_info: dict, log: Log = None):
    log.log_message("\nCalculating similarity maps (after projection)...")

    sim_maps = dict()
    for j in project_info.keys():
        nearest_x = project_info[j]['nearest_input']
        with torch.no_grad():
            _, distances_batch, _ = tree.forward_partial(nearest_x)
            sim_maps[j] = torch.exp(-distances_batch[0, j, :, :]).cpu().numpy()
        del nearest_x
        del project_info[j]['nearest_input']
    return sim_maps, project_info
Exemplo n.º 8
0
def eval_fidelity(tree: ProtoTree,
                  test_loader: DataLoader,
                  device,
                  log: Log = None,
                  progress_prefix: str = 'Fidelity') -> dict:
    tree = tree.to(device)

    # Keep an info dict about the procedure
    info = dict()

    # Make sure the model is in evaluation mode
    tree.eval()
    # Show progress on progress bar
    test_iter = tqdm(enumerate(test_loader),
                     total=len(test_loader),
                     desc=progress_prefix,
                     ncols=0)

    distr_samplemax_fidelity = 0
    distr_greedy_fidelity = 0
    # Iterate through the test set
    for i, (xs, ys) in test_iter:
        xs, ys = xs.to(device), ys.to(device)

        # Use the model to classify this batch of input data, with 3 types of routing
        out_distr, _ = tree.forward(xs, 'distributed')
        ys_pred_distr = torch.argmax(out_distr, dim=1)

        out_samplemax, _ = tree.forward(xs, 'sample_max')
        ys_pred_samplemax = torch.argmax(out_samplemax, dim=1)

        out_greedy, _ = tree.forward(xs, 'greedy')
        ys_pred_greedy = torch.argmax(out_greedy, dim=1)

        # Calculate fidelity
        distr_samplemax_fidelity += torch.sum(
            torch.eq(ys_pred_samplemax, ys_pred_distr)).item()
        distr_greedy_fidelity += torch.sum(
            torch.eq(ys_pred_greedy, ys_pred_distr)).item()
        # Update the progress bar
        test_iter.set_postfix_str(f'Batch [{i + 1}/{len(test_iter)}]')
        del out_distr
        del out_samplemax
        del out_greedy

    distr_samplemax_fidelity = distr_samplemax_fidelity / float(
        len(test_loader.dataset))
    distr_greedy_fidelity = distr_greedy_fidelity / float(
        len(test_loader.dataset))
    info['distr_samplemax_fidelity'] = distr_samplemax_fidelity
    info['distr_greedy_fidelity'] = distr_greedy_fidelity
    log.log_message(
        "Fidelity between standard distributed routing and sample_max routing: "
        + str(distr_samplemax_fidelity))
    log.log_message(
        "Fidelity between standard distributed routing and greedy routing: " +
        str(distr_greedy_fidelity))
    return info
Exemplo n.º 9
0
def save_tree_description(tree: ProtoTree, optimizer, scheduler,
                          description: str, log: Log):
    tree.eval()
    # Save model with description
    tree.save(f'{log.checkpoint_dir}/' + description)
    tree.save_state(f'{log.checkpoint_dir}/' + description)
    torch.save(optimizer.state_dict(),
               f'{log.checkpoint_dir}/' + description + '/optimizer_state.pth')
    torch.save(scheduler.state_dict(),
               f'{log.checkpoint_dir}/' + description + '/scheduler_state.pth')
Exemplo n.º 10
0
def save_best_test_tree(tree: ProtoTree, optimizer, scheduler,
                        best_test_acc: float, test_acc: float, log: Log):
    tree.eval()
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        tree.save(f'{log.checkpoint_dir}/best_test_model')
        tree.save_state(f'{log.checkpoint_dir}/best_test_model')
        torch.save(
            optimizer.state_dict(),
            f'{log.checkpoint_dir}/best_test_model/optimizer_state.pth')
        torch.save(
            scheduler.state_dict(),
            f'{log.checkpoint_dir}/best_test_model/scheduler_state.pth')
    return best_test_acc
Exemplo n.º 11
0
def upsample_local(tree: ProtoTree, sample: torch.Tensor, sample_dir: str,
                   folder_name: str, img_name: str, decision_path: list,
                   args: argparse.Namespace):

    dir = os.path.join(
        os.path.join(os.path.join(args.log_dir, folder_name), img_name),
        args.dir_for_saving_images)
    if not os.path.exists(dir):
        os.makedirs(dir)
    with torch.no_grad():
        _, distances_batch, _ = tree.forward_partial(sample)
        sim_map = torch.exp(-distances_batch[0, :, :, :]).cpu().numpy()
    for i, node in enumerate(decision_path[:-1]):
        decision_node_idx = node.index
        node_id = tree._out_map[node]
        img = Image.open(sample_dir)
        x_np = np.asarray(img)
        x_np = np.float32(x_np) / 255
        if x_np.ndim == 2:  #convert grayscale to RGB
            x_np = np.stack((x_np, ) * 3, axis=-1)

        img_size = x_np.shape[:2]
        similarity_map = sim_map[node_id]

        rescaled_sim_map = similarity_map - np.amin(similarity_map)
        rescaled_sim_map = rescaled_sim_map / np.amax(rescaled_sim_map)
        similarity_heatmap = cv2.applyColorMap(
            np.uint8(255 * rescaled_sim_map), cv2.COLORMAP_JET)
        similarity_heatmap = np.float32(similarity_heatmap) / 255
        similarity_heatmap = similarity_heatmap[..., ::-1]
        plt.imsave(fname=os.path.join(
            dir,
            '%s_heatmap_latent_similaritymap.png' % str(decision_node_idx)),
                   arr=similarity_heatmap,
                   vmin=0.0,
                   vmax=1.0)

        upsampled_act_pattern = cv2.resize(similarity_map,
                                           dsize=(img_size[1], img_size[0]),
                                           interpolation=cv2.INTER_CUBIC)
        rescaled_act_pattern = upsampled_act_pattern - np.amin(
            upsampled_act_pattern)
        rescaled_act_pattern = rescaled_act_pattern / np.amax(
            rescaled_act_pattern)
        heatmap = cv2.applyColorMap(np.uint8(255 * rescaled_act_pattern),
                                    cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        heatmap = heatmap[..., ::-1]
        overlayed_original_img = 0.5 * x_np + 0.2 * heatmap
        plt.imsave(fname=os.path.join(
            dir, '%s_heatmap_original_image.png' % str(decision_node_idx)),
                   arr=overlayed_original_img,
                   vmin=0.0,
                   vmax=1.0)

        # save the highly activated patch
        masked_similarity_map = np.ones(similarity_map.shape)
        masked_similarity_map[similarity_map < np.max(
            similarity_map
        )] = 0  #mask similarity map such that only the nearest patch z* is visualized

        upsampled_prototype_pattern = cv2.resize(masked_similarity_map,
                                                 dsize=(img_size[1],
                                                        img_size[0]),
                                                 interpolation=cv2.INTER_CUBIC)
        plt.imsave(fname=os.path.join(
            dir, '%s_masked_upsampled_heatmap.png' % str(decision_node_idx)),
                   arr=upsampled_prototype_pattern,
                   vmin=0.0,
                   vmax=1.0)

        high_act_patch_indices = find_high_activation_crop(
            upsampled_prototype_pattern, args.upsample_threshold)
        high_act_patch = x_np[
            high_act_patch_indices[0]:high_act_patch_indices[1],
            high_act_patch_indices[2]:high_act_patch_indices[3], :]
        plt.imsave(fname=os.path.join(
            dir, '%s_nearest_patch_of_image.png' % str(decision_node_idx)),
                   arr=high_act_patch,
                   vmin=0.0,
                   vmax=1.0)

        # save the original image with bounding box showing high activation patch
        imsave_with_bbox(fname=os.path.join(
            dir, '%s_bounding_box_nearest_patch_of_image.png' %
            str(decision_node_idx)),
                         img_rgb=x_np,
                         bbox_height_start=high_act_patch_indices[0],
                         bbox_height_end=high_act_patch_indices[1],
                         bbox_width_start=high_act_patch_indices[2],
                         bbox_width_end=high_act_patch_indices[3],
                         color=(0, 255, 255))
Exemplo n.º 12
0
def gen_pred_vis(
    tree: ProtoTree,
    sample: torch.Tensor,
    sample_dir: str,
    folder_name: str,
    args: argparse.Namespace,
    classes: tuple,
    pred_kwargs: dict = None,
):
    pred_kwargs = pred_kwargs or dict()  # TODO -- assert deterministic routing

    # Create dir to store visualization
    img_name = sample_dir.split('/')[-1].split(".")[-2]

    if not os.path.exists(os.path.join(args.log_dir, folder_name)):
        os.makedirs(os.path.join(args.log_dir, folder_name))
    destination_folder = os.path.join(os.path.join(args.log_dir, folder_name),
                                      img_name)

    if not os.path.isdir(destination_folder):
        os.mkdir(destination_folder)
    if not os.path.isdir(destination_folder + '/node_vis'):
        os.mkdir(destination_folder + '/node_vis')

    # Get references to where source files are stored
    upsample_path = os.path.join(
        os.path.join(args.log_dir, args.dir_for_saving_images),
        'pruned_and_projected')
    nodevis_path = os.path.join(args.log_dir, 'pruned_and_projected/node_vis')
    local_upsample_path = os.path.join(destination_folder,
                                       args.dir_for_saving_images)

    # Get the model prediction
    with torch.no_grad():
        pred, pred_info = tree.forward(sample,
                                       sampling_strategy='greedy',
                                       **pred_kwargs)
        probs = pred_info['ps']
        label_ix = torch.argmax(pred, dim=1)[0].item()
        assert 'out_leaf_ix' in pred_info.keys()

    # Save input image
    sample_path = destination_folder + '/node_vis/sample.jpg'
    # save_image(sample, sample_path)
    Image.open(sample_dir).save(sample_path)

    # Save an image containing the model output
    output_path = destination_folder + '/node_vis/output.jpg'
    leaf_ix = pred_info['out_leaf_ix'][0]
    leaf = tree.nodes_by_index[leaf_ix]
    decision_path = tree.path_to(leaf)

    upsample_local(tree, sample, sample_dir, folder_name, img_name,
                   decision_path, args)

    # Prediction graph is visualized using Graphviz
    # Build dot string
    s = 'digraph T {margin=0;rankdir=LR\n'
    # s += "subgraph {"
    s += 'node [shape=plaintext, label=""];\n'
    s += 'edge [penwidth="0.5"];\n'

    # Create a node for the sample image
    s += f'sample[image="{sample_path}"];\n'

    # Create nodes for all decisions/branches
    # Starting from the leaf
    for i, node in enumerate(decision_path[:-1]):
        node_ix = node.index
        prob = probs[node_ix].item()

        s += f'node_{i+1}[image="{upsample_path}/{node_ix}_nearest_patch_of_image.png" group="{"g"+str(i)}"];\n'
        if prob > 0.5:
            s += f'node_{i+1}_original[image="{local_upsample_path}/{node_ix}_bounding_box_nearest_patch_of_image.png" imagescale=width group="{"g"+str(i)}"];\n'
            label = "Present      \nSimilarity %.4f                   " % prob
            s += f'node_{i+1}->node_{i+1}_original [label="{label}" fontsize=10 fontname=Helvetica];\n'
        else:
            s += f'node_{i+1}_original[image="{sample_path}" group="{"g"+str(i)}"];\n'
            label = "Absent      \nSimilarity %.4f                   " % prob
            s += f'node_{i+1}->node_{i+1}_original [label="{label}" fontsize=10 fontname=Helvetica];\n'
        # s += f'node_{i+1}_original->node_{i+1} [label="{label}" fontsize=10 fontname=Helvetica];\n'

        s += f'node_{i+1}->node_{i+2};\n'
        s += "{rank = same; " f'node_{i+1}_original' + "; " + f'node_{i+1}' + "};"

    # Create a node for the model output
    s += f'node_{len(decision_path)}[imagepos="tc" imagescale=height image="{nodevis_path}/node_{leaf_ix}_vis.jpg" label="{classes[label_ix]}" labelloc=b fontsize=10 penwidth=0 fontname=Helvetica];\n'

    # Connect the input image to the first decision node
    s += 'sample->node_1;\n'

    s += '}\n'

    with open(os.path.join(destination_folder, 'predvis.dot'), 'w') as f:
        f.write(s)

    from_p = os.path.join(destination_folder, 'predvis.dot')
    to_pdf = os.path.join(destination_folder, 'predvis.pdf')
    check_call('dot -Tpdf -Gmargin=0 %s -o %s' % (from_p, to_pdf), shell=True)
Exemplo n.º 13
0
def run_tree(args=None):
    args = args or get_args()
    # Create a logger
    log = Log(args.log_dir)
    print("Log dir: ", args.log_dir, flush=True)
    # Create a csv log for storing the test accuracy, mean train accuracy and mean loss for each epoch
    log.create_log('log_epoch_overview', 'epoch', 'test_acc', 'mean_train_acc', 'mean_train_crossentropy_loss_during_epoch')
    # Log the run arguments
    save_args(args, log.metadata_dir)
    if not args.disable_cuda and torch.cuda.is_available():
        # device = torch.device('cuda')
        device = torch.device('cuda:{}'.format(torch.cuda.current_device()))
    else:
        device = torch.device('cpu')
        
    # Log which device was actually used
    log.log_message('Device used: '+str(device))

    # Create a log for logging the loss values
    log_prefix = 'log_train_epochs'
    log_loss = log_prefix+'_losses'
    log.create_log(log_loss, 'epoch', 'batch', 'loss', 'batch_train_acc')

    # Obtain the dataset and dataloaders
    trainloader, projectloader, testloader, classes, num_channels = get_dataloaders(args)
    # Create a convolutional network based on arguments and add 1x1 conv layer
    features_net, add_on_layers = get_network(num_channels, args)
    # Create a ProtoTree
    tree = ProtoTree(num_classes=len(classes),
                    feature_net = features_net,
                    args = args,
                    add_on_layers = add_on_layers)
    tree = tree.to(device=device)
    # Determine which optimizer should be used to update the tree parameters
    optimizer, params_to_freeze, params_to_train = get_optimizer(tree, args)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=args.milestones, gamma=args.gamma)
    tree, epoch = init_tree(tree, optimizer, scheduler, device, args)
    
    tree.save(f'{log.checkpoint_dir}/tree_init')
    log.log_message("Max depth %s, so %s internal nodes and %s leaves"%(args.depth, tree.num_branches, tree.num_leaves))
    analyse_output_shape(tree, trainloader, log, device)

    leaf_labels = dict()
    best_train_acc = 0.
    best_test_acc = 0.

    if epoch < args.epochs + 1:
        '''
            TRAIN AND EVALUATE TREE
        '''
        for epoch in range(epoch, args.epochs + 1):
            log.log_message("\nEpoch %s"%str(epoch))
            # Freeze (part of) network for some epochs if indicated in args
            freeze(tree, epoch, params_to_freeze, params_to_train, args, log)
            log_learning_rates(optimizer, args, log)
            
            # Train tree
            if tree._kontschieder_train:
                train_info = train_epoch_kontschieder(tree, trainloader, optimizer, epoch, args.disable_derivative_free_leaf_optim, device, log, log_prefix)
            else:
                train_info = train_epoch(tree, trainloader, optimizer, epoch, args.disable_derivative_free_leaf_optim, device, log, log_prefix)
            save_tree(tree, optimizer, scheduler, epoch, log, args)
            best_train_acc = save_best_train_tree(tree, optimizer, scheduler, best_train_acc, train_info['train_accuracy'], log)
            leaf_labels = analyse_leafs(tree, epoch, len(classes), leaf_labels, args.pruning_threshold_leaves, log)
            
            # Evaluate tree
            if args.epochs>100:
                if epoch%10==0 or epoch==args.epochs:
                    eval_info = eval(tree, testloader, epoch, device, log)
                    original_test_acc = eval_info['test_accuracy']
                    best_test_acc = save_best_test_tree(tree, optimizer, scheduler, best_test_acc, eval_info['test_accuracy'], log)
                    log.log_values('log_epoch_overview', epoch, eval_info['test_accuracy'], train_info['train_accuracy'], train_info['loss'])
                else:
                    log.log_values('log_epoch_overview', epoch, "n.a.", train_info['train_accuracy'], train_info['loss'])
            else:
                eval_info = eval(tree, testloader, epoch, device, log)
                original_test_acc = eval_info['test_accuracy']
                best_test_acc = save_best_test_tree(tree, optimizer, scheduler, best_test_acc, eval_info['test_accuracy'], log)
                log.log_values('log_epoch_overview', epoch, eval_info['test_accuracy'], train_info['train_accuracy'], train_info['loss'])
            
            scheduler.step()
 
    else: #tree was loaded and not trained, so evaluate only
        '''
            EVALUATE TREE
        ''' 
        eval_info = eval(tree, testloader, epoch, device, log)
        original_test_acc = eval_info['test_accuracy']
        best_test_acc = save_best_test_tree(tree, optimizer, scheduler, best_test_acc, eval_info['test_accuracy'], log)
        log.log_values('log_epoch_overview', epoch, eval_info['test_accuracy'], "n.a.", "n.a.")

    '''
        EVALUATE AND ANALYSE TRAINED TREE
    '''
    log.log_message("Training Finished. Best training accuracy was %s, best test accuracy was %s\n"%(str(best_train_acc), str(best_test_acc)))
    trained_tree = deepcopy(tree)
    leaf_labels = analyse_leafs(tree, epoch+1, len(classes), leaf_labels, args.pruning_threshold_leaves, log)
    analyse_leaf_distributions(tree, log)
    
    '''
        PRUNE
    '''
    pruned = prune(tree, args.pruning_threshold_leaves, log)
    name = "pruned"
    save_tree_description(tree, optimizer, scheduler, name, log)
    pruned_tree = deepcopy(tree)
    # Analyse and evaluate pruned tree
    leaf_labels = analyse_leafs(tree, epoch+2, len(classes), leaf_labels, args.pruning_threshold_leaves, log)
    analyse_leaf_distributions(tree, log)
    eval_info = eval(tree, testloader, name, device, log)
    pruned_test_acc = eval_info['test_accuracy']

    pruned_tree = tree

    '''
        PROJECT
    '''
    project_info, tree = project_with_class_constraints(deepcopy(pruned_tree), projectloader, device, args, log)
    name = "pruned_and_projected"
    save_tree_description(tree, optimizer, scheduler, name, log)
    pruned_projected_tree = deepcopy(tree)
    # Analyse and evaluate pruned tree with projected prototypes
    average_distance_nearest_image(project_info, tree, log)
    leaf_labels = analyse_leafs(tree, epoch+3, len(classes), leaf_labels, args.pruning_threshold_leaves, log)
    analyse_leaf_distributions(tree, log)
    eval_info = eval(tree, testloader, name, device, log)
    pruned_projected_test_acc = eval_info['test_accuracy']
    eval_info_samplemax = eval(tree, testloader, name, device, log, 'sample_max')
    get_avg_path_length(tree, eval_info_samplemax, log)
    eval_info_greedy = eval(tree, testloader, name, device, log, 'greedy')
    get_avg_path_length(tree, eval_info_greedy, log)
    fidelity_info = eval_fidelity(tree, testloader, device, log)

    # Upsample prototype for visualization
    project_info = upsample(tree, project_info, projectloader, name, args, log)
    # visualize tree
    gen_vis(tree, name, args, classes)

    
    return trained_tree.to('cpu'), pruned_tree.to('cpu'), pruned_projected_tree.to('cpu'), original_test_acc, pruned_test_acc, pruned_projected_test_acc, project_info, eval_info_samplemax, eval_info_greedy, fidelity_info
Exemplo n.º 14
0
def project_with_class_constraints(
        tree: ProtoTree,
        project_loader: DataLoader,
        device,
        args: argparse.Namespace,
        log: Log,
        log_prefix: str = 'log_projection_with_constraints',  # TODO
        progress_prefix: str = 'Projection') -> dict:

    log.log_message(
        "\nProjecting prototypes to nearest training patch (with class restrictions)..."
    )
    # Set the model to evaluation mode
    tree.eval()
    torch.cuda.empty_cache()
    # The goal is to find the latent patch that minimizes the L2 distance to each prototype
    # To do this we iterate through the train dataset and store for each prototype the closest latent patch seen so far
    # Also store info about the image that was used for projection
    global_min_proto_dist = {j: np.inf for j in range(tree.num_prototypes)}
    global_min_patches = {j: None for j in range(tree.num_prototypes)}
    global_min_info = {j: None for j in range(tree.num_prototypes)}

    # Get the shape of the prototypes
    W1, H1, D = tree.prototype_shape

    # Build a progress bar for showing the status
    projection_iter = tqdm(enumerate(project_loader),
                           total=len(project_loader),
                           desc=progress_prefix,
                           ncols=0)

    with torch.no_grad():
        # Get a batch of data
        xs, ys = next(iter(project_loader))
        batch_size = xs.shape[0]
        # For each internal node, collect the leaf labels in the subtree with this node as root.
        # Only images from these classes can be used for projection.
        leaf_labels_subtree = dict()

        for branch, j in tree._out_map.items():
            leaf_labels_subtree[branch.index] = set()
            for leaf in branch.leaves:
                leaf_labels_subtree[branch.index].add(
                    torch.argmax(leaf.distribution()).item())

        for i, (xs, ys) in projection_iter:
            xs, ys = xs.to(device), ys.to(device)
            # Get the features and distances
            # - features_batch: features tensor (shared by all prototypes)
            #   shape: (batch_size, D, W, H)
            # - distances_batch: distances tensor (for all prototypes)
            #   shape: (batch_size, num_prototypes, W, H)
            # - out_map: a dict mapping decision nodes to distances (indices)
            features_batch, distances_batch, out_map = tree.forward_partial(xs)

            # Get the features dimensions
            bs, D, W, H = features_batch.shape

            # Get a tensor containing the individual latent patches
            # Create the patches by unfolding over both the W and H dimensions
            # TODO -- support for strides in the prototype layer? (corresponds to step size here)
            patches_batch = features_batch.unfold(2, W1, 1).unfold(
                3, H1, 1)  # Shape: (batch_size, D, W, H, W1, H1)

            # Iterate over all decision nodes/prototypes
            for node, j in out_map.items():
                leaf_labels = leaf_labels_subtree[node.index]
                # Iterate over all items in the batch
                # Select the features/distances that are relevant to this prototype
                # - distances: distances of the prototype to the latent patches
                #   shape: (W, H)
                # - patches: latent patches
                #   shape: (D, W, H, W1, H1)
                for batch_i, (distances, patches) in enumerate(
                        zip(distances_batch[:, j, :, :], patches_batch)):
                    #Check if label of this image is in one of the leaves of the subtree
                    if ys[batch_i].item() in leaf_labels:
                        # Find the index of the latent patch that is closest to the prototype
                        min_distance = distances.min()
                        min_distance_ix = distances.argmin()
                        # Use the index to get the closest latent patch
                        closest_patch = patches.view(D, W * H, W1,
                                                     H1)[:,
                                                         min_distance_ix, :, :]

                        # Check if the latent patch is closest for all data samples seen so far
                        if min_distance < global_min_proto_dist[j]:
                            global_min_proto_dist[j] = min_distance
                            global_min_patches[j] = closest_patch
                            global_min_info[j] = {
                                'input_image_ix': i * batch_size + batch_i,
                                'patch_ix': min_distance_ix.item(
                                ),  # Index in a flattened array of the feature map
                                'W': W,
                                'H': H,
                                'W1': W1,
                                'H1': H1,
                                'distance': min_distance.item(),
                                'nearest_input':
                                torch.unsqueeze(xs[batch_i], 0),
                                'node_ix': node.index,
                            }

            # Update the progress bar if required
            projection_iter.set_postfix_str(
                f'Batch: {i + 1}/{len(project_loader)}')

            del features_batch
            del distances_batch
            del out_map

        # Copy the patches to the prototype layer weights
        projection = torch.cat(tuple(global_min_patches[j].unsqueeze(0)
                                     for j in range(tree.num_prototypes)),
                               dim=0,
                               out=tree.prototype_layer.prototype_vectors)
        del projection

    return global_min_info, tree
Exemplo n.º 15
0
def train_epoch(tree: ProtoTree,
                train_loader: DataLoader,
                optimizer: torch.optim.Optimizer,
                epoch: int,
                disable_derivative_free_leaf_optim: bool,
                device,
                log: Log = None,
                log_prefix: str = 'log_train_epochs',
                progress_prefix: str = 'Train Epoch'
                ) -> dict:
    
    tree = tree.to(device)
    # Make sure the model is in eval mode
    tree.eval()
    # Store info about the procedure
    train_info = dict()
    total_loss = 0.
    total_acc = 0.
    # Create a log if required
    log_loss = f'{log_prefix}_losses'

    nr_batches = float(len(train_loader))
    with torch.no_grad():
        _old_dist_params = dict()
        for leaf in tree.leaves:
            _old_dist_params[leaf] = leaf._dist_params.detach().clone()
        # Optimize class distributions in leafs
        eye = torch.eye(tree._num_classes).to(device)

    # Show progress on progress bar
    train_iter = tqdm(enumerate(train_loader),
                    total=len(train_loader),
                    desc=progress_prefix+' %s'%epoch,
                    ncols=0)
    # Iterate through the data set to update leaves, prototypes and network
    for i, (xs, ys) in train_iter:
        # Make sure the model is in train mode
        tree.train()
        # Reset the gradients
        optimizer.zero_grad()

        xs, ys = xs.to(device), ys.to(device)

        # Perform a forward pass through the network
        ys_pred, info = tree.forward(xs)

        # Learn prototypes and network with gradient descent. 
        # If disable_derivative_free_leaf_optim, leaves are optimized with gradient descent as well.
        # Compute the loss
        if tree._log_probabilities:
            loss = F.nll_loss(ys_pred, ys)
        else:
            loss = F.nll_loss(torch.log(ys_pred), ys)
        
        # Compute the gradient
        loss.backward()
        # Update model parameters
        optimizer.step()
        
        if not disable_derivative_free_leaf_optim:
            #Update leaves with derivate-free algorithm
            #Make sure the tree is in eval mode
            tree.eval()
            with torch.no_grad():
                target = eye[ys] #shape (batchsize, num_classes) 
                for leaf in tree.leaves:  
                    if tree._log_probabilities:
                        # log version
                        update = torch.exp(torch.logsumexp(info['pa_tensor'][leaf.index] + leaf.distribution() + torch.log(target) - ys_pred, dim=0))
                    else:
                        update = torch.sum((info['pa_tensor'][leaf.index] * leaf.distribution() * target)/ys_pred, dim=0)  
                    leaf._dist_params -= (_old_dist_params[leaf]/nr_batches)
                    F.relu_(leaf._dist_params) #dist_params values can get slightly negative because of floating point issues. therefore, set to zero.
                    leaf._dist_params += update

        # Count the number of correct classifications
        ys_pred_max = torch.argmax(ys_pred, dim=1)
        
        correct = torch.sum(torch.eq(ys_pred_max, ys))
        acc = correct.item() / float(len(xs))

        train_iter.set_postfix_str(
            f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}'
        )
        # Compute metrics over this batch
        total_loss+=loss.item()
        total_acc+=acc

        if log is not None:
            log.log_values(log_loss, epoch, i + 1, loss.item(), acc)

    train_info['loss'] = total_loss/float(i+1)
    train_info['train_accuracy'] = total_acc/float(i+1)
    return train_info 
Exemplo n.º 16
0
def train_epoch_kontschieder(tree: ProtoTree,
                train_loader: DataLoader,
                optimizer: torch.optim.Optimizer,
                epoch: int,
                disable_derivative_free_leaf_optim: bool,
                device,
                log: Log = None,
                log_prefix: str = 'log_train_epochs',
                progress_prefix: str = 'Train Epoch'
                ) -> dict:

    tree = tree.to(device)

    # Store info about the procedure
    train_info = dict()
    total_loss = 0.
    total_acc = 0.

    # Create a log if required
    log_loss = f'{log_prefix}_losses'
    if log is not None and epoch==1:
        log.create_log(log_loss, 'epoch', 'batch', 'loss', 'batch_train_acc')
    
    # Reset the gradients
    optimizer.zero_grad()

    if disable_derivative_free_leaf_optim:
        print("WARNING: kontschieder arguments will be ignored when training leaves with gradient descent")
    else:
        if tree._kontschieder_normalization:
            # Iterate over the dataset multiple times to learn leaves following Kontschieder's approach
            for _ in range(10):
                # Train leaves with derivative-free algorithm using normalization factor
                train_leaves_epoch(tree, train_loader, epoch, device)
        else:
            # Train leaves with Kontschieder's derivative-free algorithm, but using softmax
            train_leaves_epoch(tree, train_loader, epoch, device)
    # Train prototypes and network. 
    # If disable_derivative_free_leaf_optim, leafs are optimized with gradient descent as well.
    # Show progress on progress bar
    train_iter = tqdm(enumerate(train_loader),
                        total=len(train_loader),
                        desc=progress_prefix+' %s'%epoch,
                        ncols=0)
    # Make sure the model is in train mode
    tree.train()
    for i, (xs, ys) in train_iter:
        xs, ys = xs.to(device), ys.to(device)

        # Reset the gradients
        optimizer.zero_grad()
        # Perform a forward pass through the network
        ys_pred, _ = tree.forward(xs)
        # Compute the loss
        if tree._log_probabilities:
            loss = F.nll_loss(ys_pred, ys)
        else:
            loss = F.nll_loss(torch.log(ys_pred), ys)
        # Compute the gradient
        loss.backward()
        # Update model parameters
        optimizer.step()

        # Count the number of correct classifications
        ys_pred = torch.argmax(ys_pred, dim=1)
        
        correct = torch.sum(torch.eq(ys_pred, ys))
        acc = correct.item() / float(len(xs))

        train_iter.set_postfix_str(
            f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}'
        )
        # Compute metrics over this batch
        total_loss+=loss.item()
        total_acc+=acc

        if log is not None:
            log.log_values(log_loss, epoch, i + 1, loss.item(), acc)
        
    train_info['loss'] = total_loss/float(i+1)
    train_info['train_accuracy'] = total_acc/float(i+1)
    return train_info