def eval(tree: ProtoTree, test_loader: DataLoader, epoch, device, log: Log = None, sampling_strategy: str = 'distributed', log_prefix: str = 'log_eval_epochs', progress_prefix: str = 'Eval Epoch') -> dict: tree = tree.to(device) # Keep an info dict about the procedure info = dict() if sampling_strategy != 'distributed': info['out_leaf_ix'] = [] # Build a confusion matrix cm = np.zeros((tree._num_classes, tree._num_classes), dtype=int) # Make sure the model is in evaluation mode tree.eval() # Show progress on progress bar test_iter = tqdm(enumerate(test_loader), total=len(test_loader), desc=progress_prefix + ' %s' % epoch, ncols=0) # Iterate through the test set for i, (xs, ys) in test_iter: xs, ys = xs.to(device), ys.to(device) # Use the model to classify this batch of input data out, test_info = tree.forward(xs, sampling_strategy) ys_pred = torch.argmax(out, dim=1) # Update the confusion matrix cm_batch = np.zeros((tree._num_classes, tree._num_classes), dtype=int) for y_pred, y_true in zip(ys_pred, ys): cm[y_true][y_pred] += 1 cm_batch[y_true][y_pred] += 1 acc = acc_from_cm(cm_batch) test_iter.set_postfix_str( f'Batch [{i + 1}/{len(test_iter)}], Acc: {acc:.3f}') # keep list of leaf indices where test sample ends up when deterministic routing is used. if sampling_strategy != 'distributed': info['out_leaf_ix'] += test_info['out_leaf_ix'] del out del ys_pred del test_info info['confusion_matrix'] = cm info['test_accuracy'] = acc_from_cm(cm) log.log_message("\nEpoch %s - Test accuracy with %s routing: " % (epoch, sampling_strategy) + str(info['test_accuracy'])) return info
def init_tree(tree: ProtoTree, optimizer, scheduler, device, args: argparse.Namespace): epoch = 1 mean = 0.5 std = 0.1 # load trained prototree if flag is set # NOTE: TRAINING FURTHER FROM A CHECKPOINT DOESN'T SEEM TO WORK CORRECTLY. EVALUATING A TRAINED PROTOTREE FROM A CHECKPOINT DOES WORK. if args.state_dict_dir_tree != '': if not args.disable_cuda and torch.cuda.is_available(): device = torch.device('cuda:{}'.format(torch.cuda.current_device())) else: device = torch.device('cpu') if args.disable_cuda or not torch.cuda.is_available(): # tree = load_state(args.state_dict_dir_tree, device) tree = torch.load(args.state_dict_dir_tree+'/model.pth', map_location=device) else: tree = torch.load(args.state_dict_dir_tree+'/model.pth') tree.to(device=device) try: epoch = int(args.state_dict_dir_tree.split('epoch_')[-1]) + 1 except: epoch=args.epochs+1 print("Train further from epoch: ", epoch, flush=True) optimizer.load_state_dict(torch.load(args.state_dict_dir_tree+'/optimizer_state.pth', map_location=device)) if epoch>args.freeze_epochs: for parameter in tree._net.parameters(): parameter.requires_grad = True if not args.disable_derivative_free_leaf_optim: for leaf in tree.leaves: leaf._dist_params.requires_grad = False if os.path.isfile(args.state_dict_dir_tree+'/scheduler_state.pth'): # scheduler.load_state_dict(torch.load(args.state_dict_dir_tree+'/scheduler_state.pth')) # print(scheduler.state_dict(),flush=True) scheduler.last_epoch = epoch - 1 scheduler._step_count = epoch elif args.state_dict_dir_net != '': # load pretrained conv network # initialize prototypes torch.nn.init.normal_(tree.prototype_layer.prototype_vectors, mean=mean, std=std) #strict is False so when loading pretrained model, ignore the linear classification layer tree._net.load_state_dict(torch.load(args.state_dict_dir_net+'/model_state.pth'), strict=False) tree._add_on.load_state_dict(torch.load(args.state_dict_dir_net+'/model_state.pth'), strict=False) else: with torch.no_grad(): # initialize prototypes torch.nn.init.normal_(tree.prototype_layer.prototype_vectors, mean=mean, std=std) tree._add_on.apply(init_weights_xavier) return tree, epoch
def analyse_output_shape(tree: ProtoTree, trainloader: DataLoader, log: Log, device): with torch.no_grad(): # Get a batch of training data xs, ys = next(iter(trainloader)) xs, ys = xs.to(device), ys.to(device) log.log_message("Image input shape: " + str(xs[0, :, :, :].shape)) log.log_message("Features output shape (without 1x1 conv layer): " + str(tree._net(xs).shape)) log.log_message("Convolutional output shape (with 1x1 conv layer): " + str(tree._add_on(tree._net(xs)).shape)) log.log_message("Prototypes shape: " + str(tree.prototype_layer.prototype_vectors.shape))
def explain_local(args): if not args.disable_cuda and torch.cuda.is_available(): device = torch.device('cuda:{}'.format(torch.cuda.current_device())) else: device = torch.device('cpu') # Log which device was actually used print('Device used: ', str(device)) # Load trained ProtoTree tree = ProtoTree.load(args.prototree).to(device=device) # Obtain the dataset and dataloaders args.batch_size = 64 #placeholder args.augment = True #placeholder _, _, _, classes, _ = get_dataloaders(args) mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) normalize = transforms.Normalize(mean=mean, std=std) test_transform = transform_no_augment = transforms.Compose([ transforms.Resize(size=(args.image_size, args.image_size)), transforms.ToTensor(), normalize ]) sample = test_transform(Image.open( args.sample_dir)).unsqueeze(0).to(device) gen_pred_vis(tree, sample, args.sample_dir, args.results_dir, args, classes)
def train_leaves_epoch(tree: ProtoTree, train_loader: DataLoader, epoch: int, device, progress_prefix: str = 'Train Leafs Epoch' ) -> dict: #Make sure the tree is in eval mode for updating leafs tree.eval() with torch.no_grad(): _old_dist_params = dict() for leaf in tree.leaves: _old_dist_params[leaf] = leaf._dist_params.detach().clone() # Optimize class distributions in leafs eye = torch.eye(tree._num_classes).to(device) # Show progress on progress bar train_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc=progress_prefix+' %s'%epoch, ncols=0) # Iterate through the data set update_sum = dict() # Create empty tensor for each leaf that will be filled with new values for leaf in tree.leaves: update_sum[leaf] = torch.zeros_like(leaf._dist_params) for i, (xs, ys) in train_iter: xs, ys = xs.to(device), ys.to(device) #Train leafs without gradient descent out, info = tree.forward(xs) target = eye[ys] #shape (batchsize, num_classes) for leaf in tree.leaves: if tree._log_probabilities: # log version update = torch.exp(torch.logsumexp(info['pa_tensor'][leaf.index] + leaf.distribution() + torch.log(target) - out, dim=0)) else: update = torch.sum((info['pa_tensor'][leaf.index] * leaf.distribution() * target)/out, dim=0) update_sum[leaf] += update for leaf in tree.leaves: leaf._dist_params -= leaf._dist_params #set current dist params to zero leaf._dist_params += update_sum[leaf] #give dist params new value
def save_tree(tree: ProtoTree, optimizer, scheduler, epoch: int, log: Log, args: argparse.Namespace): tree.eval() # Save latest model tree.save(f'{log.checkpoint_dir}/latest') tree.save_state(f'{log.checkpoint_dir}/latest') torch.save(optimizer.state_dict(), f'{log.checkpoint_dir}/latest/optimizer_state.pth') torch.save(scheduler.state_dict(), f'{log.checkpoint_dir}/latest/scheduler_state.pth') # Save model every 10 epochs if epoch == args.epochs or epoch % 10 == 0: tree.save(f'{log.checkpoint_dir}/epoch_{epoch}') tree.save_state(f'{log.checkpoint_dir}/epoch_{epoch}') torch.save(optimizer.state_dict(), f'{log.checkpoint_dir}/epoch_{epoch}/optimizer_state.pth') torch.save(scheduler.state_dict(), f'{log.checkpoint_dir}/epoch_{epoch}/scheduler_state.pth')
def get_similarity_maps(tree: ProtoTree, project_info: dict, log: Log = None): log.log_message("\nCalculating similarity maps (after projection)...") sim_maps = dict() for j in project_info.keys(): nearest_x = project_info[j]['nearest_input'] with torch.no_grad(): _, distances_batch, _ = tree.forward_partial(nearest_x) sim_maps[j] = torch.exp(-distances_batch[0, j, :, :]).cpu().numpy() del nearest_x del project_info[j]['nearest_input'] return sim_maps, project_info
def eval_fidelity(tree: ProtoTree, test_loader: DataLoader, device, log: Log = None, progress_prefix: str = 'Fidelity') -> dict: tree = tree.to(device) # Keep an info dict about the procedure info = dict() # Make sure the model is in evaluation mode tree.eval() # Show progress on progress bar test_iter = tqdm(enumerate(test_loader), total=len(test_loader), desc=progress_prefix, ncols=0) distr_samplemax_fidelity = 0 distr_greedy_fidelity = 0 # Iterate through the test set for i, (xs, ys) in test_iter: xs, ys = xs.to(device), ys.to(device) # Use the model to classify this batch of input data, with 3 types of routing out_distr, _ = tree.forward(xs, 'distributed') ys_pred_distr = torch.argmax(out_distr, dim=1) out_samplemax, _ = tree.forward(xs, 'sample_max') ys_pred_samplemax = torch.argmax(out_samplemax, dim=1) out_greedy, _ = tree.forward(xs, 'greedy') ys_pred_greedy = torch.argmax(out_greedy, dim=1) # Calculate fidelity distr_samplemax_fidelity += torch.sum( torch.eq(ys_pred_samplemax, ys_pred_distr)).item() distr_greedy_fidelity += torch.sum( torch.eq(ys_pred_greedy, ys_pred_distr)).item() # Update the progress bar test_iter.set_postfix_str(f'Batch [{i + 1}/{len(test_iter)}]') del out_distr del out_samplemax del out_greedy distr_samplemax_fidelity = distr_samplemax_fidelity / float( len(test_loader.dataset)) distr_greedy_fidelity = distr_greedy_fidelity / float( len(test_loader.dataset)) info['distr_samplemax_fidelity'] = distr_samplemax_fidelity info['distr_greedy_fidelity'] = distr_greedy_fidelity log.log_message( "Fidelity between standard distributed routing and sample_max routing: " + str(distr_samplemax_fidelity)) log.log_message( "Fidelity between standard distributed routing and greedy routing: " + str(distr_greedy_fidelity)) return info
def save_tree_description(tree: ProtoTree, optimizer, scheduler, description: str, log: Log): tree.eval() # Save model with description tree.save(f'{log.checkpoint_dir}/' + description) tree.save_state(f'{log.checkpoint_dir}/' + description) torch.save(optimizer.state_dict(), f'{log.checkpoint_dir}/' + description + '/optimizer_state.pth') torch.save(scheduler.state_dict(), f'{log.checkpoint_dir}/' + description + '/scheduler_state.pth')
def save_best_test_tree(tree: ProtoTree, optimizer, scheduler, best_test_acc: float, test_acc: float, log: Log): tree.eval() if test_acc > best_test_acc: best_test_acc = test_acc tree.save(f'{log.checkpoint_dir}/best_test_model') tree.save_state(f'{log.checkpoint_dir}/best_test_model') torch.save( optimizer.state_dict(), f'{log.checkpoint_dir}/best_test_model/optimizer_state.pth') torch.save( scheduler.state_dict(), f'{log.checkpoint_dir}/best_test_model/scheduler_state.pth') return best_test_acc
def upsample_local(tree: ProtoTree, sample: torch.Tensor, sample_dir: str, folder_name: str, img_name: str, decision_path: list, args: argparse.Namespace): dir = os.path.join( os.path.join(os.path.join(args.log_dir, folder_name), img_name), args.dir_for_saving_images) if not os.path.exists(dir): os.makedirs(dir) with torch.no_grad(): _, distances_batch, _ = tree.forward_partial(sample) sim_map = torch.exp(-distances_batch[0, :, :, :]).cpu().numpy() for i, node in enumerate(decision_path[:-1]): decision_node_idx = node.index node_id = tree._out_map[node] img = Image.open(sample_dir) x_np = np.asarray(img) x_np = np.float32(x_np) / 255 if x_np.ndim == 2: #convert grayscale to RGB x_np = np.stack((x_np, ) * 3, axis=-1) img_size = x_np.shape[:2] similarity_map = sim_map[node_id] rescaled_sim_map = similarity_map - np.amin(similarity_map) rescaled_sim_map = rescaled_sim_map / np.amax(rescaled_sim_map) similarity_heatmap = cv2.applyColorMap( np.uint8(255 * rescaled_sim_map), cv2.COLORMAP_JET) similarity_heatmap = np.float32(similarity_heatmap) / 255 similarity_heatmap = similarity_heatmap[..., ::-1] plt.imsave(fname=os.path.join( dir, '%s_heatmap_latent_similaritymap.png' % str(decision_node_idx)), arr=similarity_heatmap, vmin=0.0, vmax=1.0) upsampled_act_pattern = cv2.resize(similarity_map, dsize=(img_size[1], img_size[0]), interpolation=cv2.INTER_CUBIC) rescaled_act_pattern = upsampled_act_pattern - np.amin( upsampled_act_pattern) rescaled_act_pattern = rescaled_act_pattern / np.amax( rescaled_act_pattern) heatmap = cv2.applyColorMap(np.uint8(255 * rescaled_act_pattern), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 heatmap = heatmap[..., ::-1] overlayed_original_img = 0.5 * x_np + 0.2 * heatmap plt.imsave(fname=os.path.join( dir, '%s_heatmap_original_image.png' % str(decision_node_idx)), arr=overlayed_original_img, vmin=0.0, vmax=1.0) # save the highly activated patch masked_similarity_map = np.ones(similarity_map.shape) masked_similarity_map[similarity_map < np.max( similarity_map )] = 0 #mask similarity map such that only the nearest patch z* is visualized upsampled_prototype_pattern = cv2.resize(masked_similarity_map, dsize=(img_size[1], img_size[0]), interpolation=cv2.INTER_CUBIC) plt.imsave(fname=os.path.join( dir, '%s_masked_upsampled_heatmap.png' % str(decision_node_idx)), arr=upsampled_prototype_pattern, vmin=0.0, vmax=1.0) high_act_patch_indices = find_high_activation_crop( upsampled_prototype_pattern, args.upsample_threshold) high_act_patch = x_np[ high_act_patch_indices[0]:high_act_patch_indices[1], high_act_patch_indices[2]:high_act_patch_indices[3], :] plt.imsave(fname=os.path.join( dir, '%s_nearest_patch_of_image.png' % str(decision_node_idx)), arr=high_act_patch, vmin=0.0, vmax=1.0) # save the original image with bounding box showing high activation patch imsave_with_bbox(fname=os.path.join( dir, '%s_bounding_box_nearest_patch_of_image.png' % str(decision_node_idx)), img_rgb=x_np, bbox_height_start=high_act_patch_indices[0], bbox_height_end=high_act_patch_indices[1], bbox_width_start=high_act_patch_indices[2], bbox_width_end=high_act_patch_indices[3], color=(0, 255, 255))
def gen_pred_vis( tree: ProtoTree, sample: torch.Tensor, sample_dir: str, folder_name: str, args: argparse.Namespace, classes: tuple, pred_kwargs: dict = None, ): pred_kwargs = pred_kwargs or dict() # TODO -- assert deterministic routing # Create dir to store visualization img_name = sample_dir.split('/')[-1].split(".")[-2] if not os.path.exists(os.path.join(args.log_dir, folder_name)): os.makedirs(os.path.join(args.log_dir, folder_name)) destination_folder = os.path.join(os.path.join(args.log_dir, folder_name), img_name) if not os.path.isdir(destination_folder): os.mkdir(destination_folder) if not os.path.isdir(destination_folder + '/node_vis'): os.mkdir(destination_folder + '/node_vis') # Get references to where source files are stored upsample_path = os.path.join( os.path.join(args.log_dir, args.dir_for_saving_images), 'pruned_and_projected') nodevis_path = os.path.join(args.log_dir, 'pruned_and_projected/node_vis') local_upsample_path = os.path.join(destination_folder, args.dir_for_saving_images) # Get the model prediction with torch.no_grad(): pred, pred_info = tree.forward(sample, sampling_strategy='greedy', **pred_kwargs) probs = pred_info['ps'] label_ix = torch.argmax(pred, dim=1)[0].item() assert 'out_leaf_ix' in pred_info.keys() # Save input image sample_path = destination_folder + '/node_vis/sample.jpg' # save_image(sample, sample_path) Image.open(sample_dir).save(sample_path) # Save an image containing the model output output_path = destination_folder + '/node_vis/output.jpg' leaf_ix = pred_info['out_leaf_ix'][0] leaf = tree.nodes_by_index[leaf_ix] decision_path = tree.path_to(leaf) upsample_local(tree, sample, sample_dir, folder_name, img_name, decision_path, args) # Prediction graph is visualized using Graphviz # Build dot string s = 'digraph T {margin=0;rankdir=LR\n' # s += "subgraph {" s += 'node [shape=plaintext, label=""];\n' s += 'edge [penwidth="0.5"];\n' # Create a node for the sample image s += f'sample[image="{sample_path}"];\n' # Create nodes for all decisions/branches # Starting from the leaf for i, node in enumerate(decision_path[:-1]): node_ix = node.index prob = probs[node_ix].item() s += f'node_{i+1}[image="{upsample_path}/{node_ix}_nearest_patch_of_image.png" group="{"g"+str(i)}"];\n' if prob > 0.5: s += f'node_{i+1}_original[image="{local_upsample_path}/{node_ix}_bounding_box_nearest_patch_of_image.png" imagescale=width group="{"g"+str(i)}"];\n' label = "Present \nSimilarity %.4f " % prob s += f'node_{i+1}->node_{i+1}_original [label="{label}" fontsize=10 fontname=Helvetica];\n' else: s += f'node_{i+1}_original[image="{sample_path}" group="{"g"+str(i)}"];\n' label = "Absent \nSimilarity %.4f " % prob s += f'node_{i+1}->node_{i+1}_original [label="{label}" fontsize=10 fontname=Helvetica];\n' # s += f'node_{i+1}_original->node_{i+1} [label="{label}" fontsize=10 fontname=Helvetica];\n' s += f'node_{i+1}->node_{i+2};\n' s += "{rank = same; " f'node_{i+1}_original' + "; " + f'node_{i+1}' + "};" # Create a node for the model output s += f'node_{len(decision_path)}[imagepos="tc" imagescale=height image="{nodevis_path}/node_{leaf_ix}_vis.jpg" label="{classes[label_ix]}" labelloc=b fontsize=10 penwidth=0 fontname=Helvetica];\n' # Connect the input image to the first decision node s += 'sample->node_1;\n' s += '}\n' with open(os.path.join(destination_folder, 'predvis.dot'), 'w') as f: f.write(s) from_p = os.path.join(destination_folder, 'predvis.dot') to_pdf = os.path.join(destination_folder, 'predvis.pdf') check_call('dot -Tpdf -Gmargin=0 %s -o %s' % (from_p, to_pdf), shell=True)
def run_tree(args=None): args = args or get_args() # Create a logger log = Log(args.log_dir) print("Log dir: ", args.log_dir, flush=True) # Create a csv log for storing the test accuracy, mean train accuracy and mean loss for each epoch log.create_log('log_epoch_overview', 'epoch', 'test_acc', 'mean_train_acc', 'mean_train_crossentropy_loss_during_epoch') # Log the run arguments save_args(args, log.metadata_dir) if not args.disable_cuda and torch.cuda.is_available(): # device = torch.device('cuda') device = torch.device('cuda:{}'.format(torch.cuda.current_device())) else: device = torch.device('cpu') # Log which device was actually used log.log_message('Device used: '+str(device)) # Create a log for logging the loss values log_prefix = 'log_train_epochs' log_loss = log_prefix+'_losses' log.create_log(log_loss, 'epoch', 'batch', 'loss', 'batch_train_acc') # Obtain the dataset and dataloaders trainloader, projectloader, testloader, classes, num_channels = get_dataloaders(args) # Create a convolutional network based on arguments and add 1x1 conv layer features_net, add_on_layers = get_network(num_channels, args) # Create a ProtoTree tree = ProtoTree(num_classes=len(classes), feature_net = features_net, args = args, add_on_layers = add_on_layers) tree = tree.to(device=device) # Determine which optimizer should be used to update the tree parameters optimizer, params_to_freeze, params_to_train = get_optimizer(tree, args) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=args.milestones, gamma=args.gamma) tree, epoch = init_tree(tree, optimizer, scheduler, device, args) tree.save(f'{log.checkpoint_dir}/tree_init') log.log_message("Max depth %s, so %s internal nodes and %s leaves"%(args.depth, tree.num_branches, tree.num_leaves)) analyse_output_shape(tree, trainloader, log, device) leaf_labels = dict() best_train_acc = 0. best_test_acc = 0. if epoch < args.epochs + 1: ''' TRAIN AND EVALUATE TREE ''' for epoch in range(epoch, args.epochs + 1): log.log_message("\nEpoch %s"%str(epoch)) # Freeze (part of) network for some epochs if indicated in args freeze(tree, epoch, params_to_freeze, params_to_train, args, log) log_learning_rates(optimizer, args, log) # Train tree if tree._kontschieder_train: train_info = train_epoch_kontschieder(tree, trainloader, optimizer, epoch, args.disable_derivative_free_leaf_optim, device, log, log_prefix) else: train_info = train_epoch(tree, trainloader, optimizer, epoch, args.disable_derivative_free_leaf_optim, device, log, log_prefix) save_tree(tree, optimizer, scheduler, epoch, log, args) best_train_acc = save_best_train_tree(tree, optimizer, scheduler, best_train_acc, train_info['train_accuracy'], log) leaf_labels = analyse_leafs(tree, epoch, len(classes), leaf_labels, args.pruning_threshold_leaves, log) # Evaluate tree if args.epochs>100: if epoch%10==0 or epoch==args.epochs: eval_info = eval(tree, testloader, epoch, device, log) original_test_acc = eval_info['test_accuracy'] best_test_acc = save_best_test_tree(tree, optimizer, scheduler, best_test_acc, eval_info['test_accuracy'], log) log.log_values('log_epoch_overview', epoch, eval_info['test_accuracy'], train_info['train_accuracy'], train_info['loss']) else: log.log_values('log_epoch_overview', epoch, "n.a.", train_info['train_accuracy'], train_info['loss']) else: eval_info = eval(tree, testloader, epoch, device, log) original_test_acc = eval_info['test_accuracy'] best_test_acc = save_best_test_tree(tree, optimizer, scheduler, best_test_acc, eval_info['test_accuracy'], log) log.log_values('log_epoch_overview', epoch, eval_info['test_accuracy'], train_info['train_accuracy'], train_info['loss']) scheduler.step() else: #tree was loaded and not trained, so evaluate only ''' EVALUATE TREE ''' eval_info = eval(tree, testloader, epoch, device, log) original_test_acc = eval_info['test_accuracy'] best_test_acc = save_best_test_tree(tree, optimizer, scheduler, best_test_acc, eval_info['test_accuracy'], log) log.log_values('log_epoch_overview', epoch, eval_info['test_accuracy'], "n.a.", "n.a.") ''' EVALUATE AND ANALYSE TRAINED TREE ''' log.log_message("Training Finished. Best training accuracy was %s, best test accuracy was %s\n"%(str(best_train_acc), str(best_test_acc))) trained_tree = deepcopy(tree) leaf_labels = analyse_leafs(tree, epoch+1, len(classes), leaf_labels, args.pruning_threshold_leaves, log) analyse_leaf_distributions(tree, log) ''' PRUNE ''' pruned = prune(tree, args.pruning_threshold_leaves, log) name = "pruned" save_tree_description(tree, optimizer, scheduler, name, log) pruned_tree = deepcopy(tree) # Analyse and evaluate pruned tree leaf_labels = analyse_leafs(tree, epoch+2, len(classes), leaf_labels, args.pruning_threshold_leaves, log) analyse_leaf_distributions(tree, log) eval_info = eval(tree, testloader, name, device, log) pruned_test_acc = eval_info['test_accuracy'] pruned_tree = tree ''' PROJECT ''' project_info, tree = project_with_class_constraints(deepcopy(pruned_tree), projectloader, device, args, log) name = "pruned_and_projected" save_tree_description(tree, optimizer, scheduler, name, log) pruned_projected_tree = deepcopy(tree) # Analyse and evaluate pruned tree with projected prototypes average_distance_nearest_image(project_info, tree, log) leaf_labels = analyse_leafs(tree, epoch+3, len(classes), leaf_labels, args.pruning_threshold_leaves, log) analyse_leaf_distributions(tree, log) eval_info = eval(tree, testloader, name, device, log) pruned_projected_test_acc = eval_info['test_accuracy'] eval_info_samplemax = eval(tree, testloader, name, device, log, 'sample_max') get_avg_path_length(tree, eval_info_samplemax, log) eval_info_greedy = eval(tree, testloader, name, device, log, 'greedy') get_avg_path_length(tree, eval_info_greedy, log) fidelity_info = eval_fidelity(tree, testloader, device, log) # Upsample prototype for visualization project_info = upsample(tree, project_info, projectloader, name, args, log) # visualize tree gen_vis(tree, name, args, classes) return trained_tree.to('cpu'), pruned_tree.to('cpu'), pruned_projected_tree.to('cpu'), original_test_acc, pruned_test_acc, pruned_projected_test_acc, project_info, eval_info_samplemax, eval_info_greedy, fidelity_info
def project_with_class_constraints( tree: ProtoTree, project_loader: DataLoader, device, args: argparse.Namespace, log: Log, log_prefix: str = 'log_projection_with_constraints', # TODO progress_prefix: str = 'Projection') -> dict: log.log_message( "\nProjecting prototypes to nearest training patch (with class restrictions)..." ) # Set the model to evaluation mode tree.eval() torch.cuda.empty_cache() # The goal is to find the latent patch that minimizes the L2 distance to each prototype # To do this we iterate through the train dataset and store for each prototype the closest latent patch seen so far # Also store info about the image that was used for projection global_min_proto_dist = {j: np.inf for j in range(tree.num_prototypes)} global_min_patches = {j: None for j in range(tree.num_prototypes)} global_min_info = {j: None for j in range(tree.num_prototypes)} # Get the shape of the prototypes W1, H1, D = tree.prototype_shape # Build a progress bar for showing the status projection_iter = tqdm(enumerate(project_loader), total=len(project_loader), desc=progress_prefix, ncols=0) with torch.no_grad(): # Get a batch of data xs, ys = next(iter(project_loader)) batch_size = xs.shape[0] # For each internal node, collect the leaf labels in the subtree with this node as root. # Only images from these classes can be used for projection. leaf_labels_subtree = dict() for branch, j in tree._out_map.items(): leaf_labels_subtree[branch.index] = set() for leaf in branch.leaves: leaf_labels_subtree[branch.index].add( torch.argmax(leaf.distribution()).item()) for i, (xs, ys) in projection_iter: xs, ys = xs.to(device), ys.to(device) # Get the features and distances # - features_batch: features tensor (shared by all prototypes) # shape: (batch_size, D, W, H) # - distances_batch: distances tensor (for all prototypes) # shape: (batch_size, num_prototypes, W, H) # - out_map: a dict mapping decision nodes to distances (indices) features_batch, distances_batch, out_map = tree.forward_partial(xs) # Get the features dimensions bs, D, W, H = features_batch.shape # Get a tensor containing the individual latent patches # Create the patches by unfolding over both the W and H dimensions # TODO -- support for strides in the prototype layer? (corresponds to step size here) patches_batch = features_batch.unfold(2, W1, 1).unfold( 3, H1, 1) # Shape: (batch_size, D, W, H, W1, H1) # Iterate over all decision nodes/prototypes for node, j in out_map.items(): leaf_labels = leaf_labels_subtree[node.index] # Iterate over all items in the batch # Select the features/distances that are relevant to this prototype # - distances: distances of the prototype to the latent patches # shape: (W, H) # - patches: latent patches # shape: (D, W, H, W1, H1) for batch_i, (distances, patches) in enumerate( zip(distances_batch[:, j, :, :], patches_batch)): #Check if label of this image is in one of the leaves of the subtree if ys[batch_i].item() in leaf_labels: # Find the index of the latent patch that is closest to the prototype min_distance = distances.min() min_distance_ix = distances.argmin() # Use the index to get the closest latent patch closest_patch = patches.view(D, W * H, W1, H1)[:, min_distance_ix, :, :] # Check if the latent patch is closest for all data samples seen so far if min_distance < global_min_proto_dist[j]: global_min_proto_dist[j] = min_distance global_min_patches[j] = closest_patch global_min_info[j] = { 'input_image_ix': i * batch_size + batch_i, 'patch_ix': min_distance_ix.item( ), # Index in a flattened array of the feature map 'W': W, 'H': H, 'W1': W1, 'H1': H1, 'distance': min_distance.item(), 'nearest_input': torch.unsqueeze(xs[batch_i], 0), 'node_ix': node.index, } # Update the progress bar if required projection_iter.set_postfix_str( f'Batch: {i + 1}/{len(project_loader)}') del features_batch del distances_batch del out_map # Copy the patches to the prototype layer weights projection = torch.cat(tuple(global_min_patches[j].unsqueeze(0) for j in range(tree.num_prototypes)), dim=0, out=tree.prototype_layer.prototype_vectors) del projection return global_min_info, tree
def train_epoch(tree: ProtoTree, train_loader: DataLoader, optimizer: torch.optim.Optimizer, epoch: int, disable_derivative_free_leaf_optim: bool, device, log: Log = None, log_prefix: str = 'log_train_epochs', progress_prefix: str = 'Train Epoch' ) -> dict: tree = tree.to(device) # Make sure the model is in eval mode tree.eval() # Store info about the procedure train_info = dict() total_loss = 0. total_acc = 0. # Create a log if required log_loss = f'{log_prefix}_losses' nr_batches = float(len(train_loader)) with torch.no_grad(): _old_dist_params = dict() for leaf in tree.leaves: _old_dist_params[leaf] = leaf._dist_params.detach().clone() # Optimize class distributions in leafs eye = torch.eye(tree._num_classes).to(device) # Show progress on progress bar train_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc=progress_prefix+' %s'%epoch, ncols=0) # Iterate through the data set to update leaves, prototypes and network for i, (xs, ys) in train_iter: # Make sure the model is in train mode tree.train() # Reset the gradients optimizer.zero_grad() xs, ys = xs.to(device), ys.to(device) # Perform a forward pass through the network ys_pred, info = tree.forward(xs) # Learn prototypes and network with gradient descent. # If disable_derivative_free_leaf_optim, leaves are optimized with gradient descent as well. # Compute the loss if tree._log_probabilities: loss = F.nll_loss(ys_pred, ys) else: loss = F.nll_loss(torch.log(ys_pred), ys) # Compute the gradient loss.backward() # Update model parameters optimizer.step() if not disable_derivative_free_leaf_optim: #Update leaves with derivate-free algorithm #Make sure the tree is in eval mode tree.eval() with torch.no_grad(): target = eye[ys] #shape (batchsize, num_classes) for leaf in tree.leaves: if tree._log_probabilities: # log version update = torch.exp(torch.logsumexp(info['pa_tensor'][leaf.index] + leaf.distribution() + torch.log(target) - ys_pred, dim=0)) else: update = torch.sum((info['pa_tensor'][leaf.index] * leaf.distribution() * target)/ys_pred, dim=0) leaf._dist_params -= (_old_dist_params[leaf]/nr_batches) F.relu_(leaf._dist_params) #dist_params values can get slightly negative because of floating point issues. therefore, set to zero. leaf._dist_params += update # Count the number of correct classifications ys_pred_max = torch.argmax(ys_pred, dim=1) correct = torch.sum(torch.eq(ys_pred_max, ys)) acc = correct.item() / float(len(xs)) train_iter.set_postfix_str( f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}' ) # Compute metrics over this batch total_loss+=loss.item() total_acc+=acc if log is not None: log.log_values(log_loss, epoch, i + 1, loss.item(), acc) train_info['loss'] = total_loss/float(i+1) train_info['train_accuracy'] = total_acc/float(i+1) return train_info
def train_epoch_kontschieder(tree: ProtoTree, train_loader: DataLoader, optimizer: torch.optim.Optimizer, epoch: int, disable_derivative_free_leaf_optim: bool, device, log: Log = None, log_prefix: str = 'log_train_epochs', progress_prefix: str = 'Train Epoch' ) -> dict: tree = tree.to(device) # Store info about the procedure train_info = dict() total_loss = 0. total_acc = 0. # Create a log if required log_loss = f'{log_prefix}_losses' if log is not None and epoch==1: log.create_log(log_loss, 'epoch', 'batch', 'loss', 'batch_train_acc') # Reset the gradients optimizer.zero_grad() if disable_derivative_free_leaf_optim: print("WARNING: kontschieder arguments will be ignored when training leaves with gradient descent") else: if tree._kontschieder_normalization: # Iterate over the dataset multiple times to learn leaves following Kontschieder's approach for _ in range(10): # Train leaves with derivative-free algorithm using normalization factor train_leaves_epoch(tree, train_loader, epoch, device) else: # Train leaves with Kontschieder's derivative-free algorithm, but using softmax train_leaves_epoch(tree, train_loader, epoch, device) # Train prototypes and network. # If disable_derivative_free_leaf_optim, leafs are optimized with gradient descent as well. # Show progress on progress bar train_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc=progress_prefix+' %s'%epoch, ncols=0) # Make sure the model is in train mode tree.train() for i, (xs, ys) in train_iter: xs, ys = xs.to(device), ys.to(device) # Reset the gradients optimizer.zero_grad() # Perform a forward pass through the network ys_pred, _ = tree.forward(xs) # Compute the loss if tree._log_probabilities: loss = F.nll_loss(ys_pred, ys) else: loss = F.nll_loss(torch.log(ys_pred), ys) # Compute the gradient loss.backward() # Update model parameters optimizer.step() # Count the number of correct classifications ys_pred = torch.argmax(ys_pred, dim=1) correct = torch.sum(torch.eq(ys_pred, ys)) acc = correct.item() / float(len(xs)) train_iter.set_postfix_str( f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}' ) # Compute metrics over this batch total_loss+=loss.item() total_acc+=acc if log is not None: log.log_values(log_loss, epoch, i + 1, loss.item(), acc) train_info['loss'] = total_loss/float(i+1) train_info['train_accuracy'] = total_acc/float(i+1) return train_info