def eval_fidelity(tree: ProtoTree, test_loader: DataLoader, device, log: Log = None, progress_prefix: str = 'Fidelity') -> dict: tree = tree.to(device) # Keep an info dict about the procedure info = dict() # Make sure the model is in evaluation mode tree.eval() # Show progress on progress bar test_iter = tqdm(enumerate(test_loader), total=len(test_loader), desc=progress_prefix, ncols=0) distr_samplemax_fidelity = 0 distr_greedy_fidelity = 0 # Iterate through the test set for i, (xs, ys) in test_iter: xs, ys = xs.to(device), ys.to(device) # Use the model to classify this batch of input data, with 3 types of routing out_distr, _ = tree.forward(xs, 'distributed') ys_pred_distr = torch.argmax(out_distr, dim=1) out_samplemax, _ = tree.forward(xs, 'sample_max') ys_pred_samplemax = torch.argmax(out_samplemax, dim=1) out_greedy, _ = tree.forward(xs, 'greedy') ys_pred_greedy = torch.argmax(out_greedy, dim=1) # Calculate fidelity distr_samplemax_fidelity += torch.sum( torch.eq(ys_pred_samplemax, ys_pred_distr)).item() distr_greedy_fidelity += torch.sum( torch.eq(ys_pred_greedy, ys_pred_distr)).item() # Update the progress bar test_iter.set_postfix_str(f'Batch [{i + 1}/{len(test_iter)}]') del out_distr del out_samplemax del out_greedy distr_samplemax_fidelity = distr_samplemax_fidelity / float( len(test_loader.dataset)) distr_greedy_fidelity = distr_greedy_fidelity / float( len(test_loader.dataset)) info['distr_samplemax_fidelity'] = distr_samplemax_fidelity info['distr_greedy_fidelity'] = distr_greedy_fidelity log.log_message( "Fidelity between standard distributed routing and sample_max routing: " + str(distr_samplemax_fidelity)) log.log_message( "Fidelity between standard distributed routing and greedy routing: " + str(distr_greedy_fidelity)) return info
def eval(tree: ProtoTree, test_loader: DataLoader, epoch, device, log: Log = None, sampling_strategy: str = 'distributed', log_prefix: str = 'log_eval_epochs', progress_prefix: str = 'Eval Epoch') -> dict: tree = tree.to(device) # Keep an info dict about the procedure info = dict() if sampling_strategy != 'distributed': info['out_leaf_ix'] = [] # Build a confusion matrix cm = np.zeros((tree._num_classes, tree._num_classes), dtype=int) # Make sure the model is in evaluation mode tree.eval() # Show progress on progress bar test_iter = tqdm(enumerate(test_loader), total=len(test_loader), desc=progress_prefix + ' %s' % epoch, ncols=0) # Iterate through the test set for i, (xs, ys) in test_iter: xs, ys = xs.to(device), ys.to(device) # Use the model to classify this batch of input data out, test_info = tree.forward(xs, sampling_strategy) ys_pred = torch.argmax(out, dim=1) # Update the confusion matrix cm_batch = np.zeros((tree._num_classes, tree._num_classes), dtype=int) for y_pred, y_true in zip(ys_pred, ys): cm[y_true][y_pred] += 1 cm_batch[y_true][y_pred] += 1 acc = acc_from_cm(cm_batch) test_iter.set_postfix_str( f'Batch [{i + 1}/{len(test_iter)}], Acc: {acc:.3f}') # keep list of leaf indices where test sample ends up when deterministic routing is used. if sampling_strategy != 'distributed': info['out_leaf_ix'] += test_info['out_leaf_ix'] del out del ys_pred del test_info info['confusion_matrix'] = cm info['test_accuracy'] = acc_from_cm(cm) log.log_message("\nEpoch %s - Test accuracy with %s routing: " % (epoch, sampling_strategy) + str(info['test_accuracy'])) return info
def train_leaves_epoch(tree: ProtoTree, train_loader: DataLoader, epoch: int, device, progress_prefix: str = 'Train Leafs Epoch' ) -> dict: #Make sure the tree is in eval mode for updating leafs tree.eval() with torch.no_grad(): _old_dist_params = dict() for leaf in tree.leaves: _old_dist_params[leaf] = leaf._dist_params.detach().clone() # Optimize class distributions in leafs eye = torch.eye(tree._num_classes).to(device) # Show progress on progress bar train_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc=progress_prefix+' %s'%epoch, ncols=0) # Iterate through the data set update_sum = dict() # Create empty tensor for each leaf that will be filled with new values for leaf in tree.leaves: update_sum[leaf] = torch.zeros_like(leaf._dist_params) for i, (xs, ys) in train_iter: xs, ys = xs.to(device), ys.to(device) #Train leafs without gradient descent out, info = tree.forward(xs) target = eye[ys] #shape (batchsize, num_classes) for leaf in tree.leaves: if tree._log_probabilities: # log version update = torch.exp(torch.logsumexp(info['pa_tensor'][leaf.index] + leaf.distribution() + torch.log(target) - out, dim=0)) else: update = torch.sum((info['pa_tensor'][leaf.index] * leaf.distribution() * target)/out, dim=0) update_sum[leaf] += update for leaf in tree.leaves: leaf._dist_params -= leaf._dist_params #set current dist params to zero leaf._dist_params += update_sum[leaf] #give dist params new value
def gen_pred_vis( tree: ProtoTree, sample: torch.Tensor, sample_dir: str, folder_name: str, args: argparse.Namespace, classes: tuple, pred_kwargs: dict = None, ): pred_kwargs = pred_kwargs or dict() # TODO -- assert deterministic routing # Create dir to store visualization img_name = sample_dir.split('/')[-1].split(".")[-2] if not os.path.exists(os.path.join(args.log_dir, folder_name)): os.makedirs(os.path.join(args.log_dir, folder_name)) destination_folder = os.path.join(os.path.join(args.log_dir, folder_name), img_name) if not os.path.isdir(destination_folder): os.mkdir(destination_folder) if not os.path.isdir(destination_folder + '/node_vis'): os.mkdir(destination_folder + '/node_vis') # Get references to where source files are stored upsample_path = os.path.join( os.path.join(args.log_dir, args.dir_for_saving_images), 'pruned_and_projected') nodevis_path = os.path.join(args.log_dir, 'pruned_and_projected/node_vis') local_upsample_path = os.path.join(destination_folder, args.dir_for_saving_images) # Get the model prediction with torch.no_grad(): pred, pred_info = tree.forward(sample, sampling_strategy='greedy', **pred_kwargs) probs = pred_info['ps'] label_ix = torch.argmax(pred, dim=1)[0].item() assert 'out_leaf_ix' in pred_info.keys() # Save input image sample_path = destination_folder + '/node_vis/sample.jpg' # save_image(sample, sample_path) Image.open(sample_dir).save(sample_path) # Save an image containing the model output output_path = destination_folder + '/node_vis/output.jpg' leaf_ix = pred_info['out_leaf_ix'][0] leaf = tree.nodes_by_index[leaf_ix] decision_path = tree.path_to(leaf) upsample_local(tree, sample, sample_dir, folder_name, img_name, decision_path, args) # Prediction graph is visualized using Graphviz # Build dot string s = 'digraph T {margin=0;rankdir=LR\n' # s += "subgraph {" s += 'node [shape=plaintext, label=""];\n' s += 'edge [penwidth="0.5"];\n' # Create a node for the sample image s += f'sample[image="{sample_path}"];\n' # Create nodes for all decisions/branches # Starting from the leaf for i, node in enumerate(decision_path[:-1]): node_ix = node.index prob = probs[node_ix].item() s += f'node_{i+1}[image="{upsample_path}/{node_ix}_nearest_patch_of_image.png" group="{"g"+str(i)}"];\n' if prob > 0.5: s += f'node_{i+1}_original[image="{local_upsample_path}/{node_ix}_bounding_box_nearest_patch_of_image.png" imagescale=width group="{"g"+str(i)}"];\n' label = "Present \nSimilarity %.4f " % prob s += f'node_{i+1}->node_{i+1}_original [label="{label}" fontsize=10 fontname=Helvetica];\n' else: s += f'node_{i+1}_original[image="{sample_path}" group="{"g"+str(i)}"];\n' label = "Absent \nSimilarity %.4f " % prob s += f'node_{i+1}->node_{i+1}_original [label="{label}" fontsize=10 fontname=Helvetica];\n' # s += f'node_{i+1}_original->node_{i+1} [label="{label}" fontsize=10 fontname=Helvetica];\n' s += f'node_{i+1}->node_{i+2};\n' s += "{rank = same; " f'node_{i+1}_original' + "; " + f'node_{i+1}' + "};" # Create a node for the model output s += f'node_{len(decision_path)}[imagepos="tc" imagescale=height image="{nodevis_path}/node_{leaf_ix}_vis.jpg" label="{classes[label_ix]}" labelloc=b fontsize=10 penwidth=0 fontname=Helvetica];\n' # Connect the input image to the first decision node s += 'sample->node_1;\n' s += '}\n' with open(os.path.join(destination_folder, 'predvis.dot'), 'w') as f: f.write(s) from_p = os.path.join(destination_folder, 'predvis.dot') to_pdf = os.path.join(destination_folder, 'predvis.pdf') check_call('dot -Tpdf -Gmargin=0 %s -o %s' % (from_p, to_pdf), shell=True)
def train_epoch(tree: ProtoTree, train_loader: DataLoader, optimizer: torch.optim.Optimizer, epoch: int, disable_derivative_free_leaf_optim: bool, device, log: Log = None, log_prefix: str = 'log_train_epochs', progress_prefix: str = 'Train Epoch' ) -> dict: tree = tree.to(device) # Make sure the model is in eval mode tree.eval() # Store info about the procedure train_info = dict() total_loss = 0. total_acc = 0. # Create a log if required log_loss = f'{log_prefix}_losses' nr_batches = float(len(train_loader)) with torch.no_grad(): _old_dist_params = dict() for leaf in tree.leaves: _old_dist_params[leaf] = leaf._dist_params.detach().clone() # Optimize class distributions in leafs eye = torch.eye(tree._num_classes).to(device) # Show progress on progress bar train_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc=progress_prefix+' %s'%epoch, ncols=0) # Iterate through the data set to update leaves, prototypes and network for i, (xs, ys) in train_iter: # Make sure the model is in train mode tree.train() # Reset the gradients optimizer.zero_grad() xs, ys = xs.to(device), ys.to(device) # Perform a forward pass through the network ys_pred, info = tree.forward(xs) # Learn prototypes and network with gradient descent. # If disable_derivative_free_leaf_optim, leaves are optimized with gradient descent as well. # Compute the loss if tree._log_probabilities: loss = F.nll_loss(ys_pred, ys) else: loss = F.nll_loss(torch.log(ys_pred), ys) # Compute the gradient loss.backward() # Update model parameters optimizer.step() if not disable_derivative_free_leaf_optim: #Update leaves with derivate-free algorithm #Make sure the tree is in eval mode tree.eval() with torch.no_grad(): target = eye[ys] #shape (batchsize, num_classes) for leaf in tree.leaves: if tree._log_probabilities: # log version update = torch.exp(torch.logsumexp(info['pa_tensor'][leaf.index] + leaf.distribution() + torch.log(target) - ys_pred, dim=0)) else: update = torch.sum((info['pa_tensor'][leaf.index] * leaf.distribution() * target)/ys_pred, dim=0) leaf._dist_params -= (_old_dist_params[leaf]/nr_batches) F.relu_(leaf._dist_params) #dist_params values can get slightly negative because of floating point issues. therefore, set to zero. leaf._dist_params += update # Count the number of correct classifications ys_pred_max = torch.argmax(ys_pred, dim=1) correct = torch.sum(torch.eq(ys_pred_max, ys)) acc = correct.item() / float(len(xs)) train_iter.set_postfix_str( f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}' ) # Compute metrics over this batch total_loss+=loss.item() total_acc+=acc if log is not None: log.log_values(log_loss, epoch, i + 1, loss.item(), acc) train_info['loss'] = total_loss/float(i+1) train_info['train_accuracy'] = total_acc/float(i+1) return train_info
def train_epoch_kontschieder(tree: ProtoTree, train_loader: DataLoader, optimizer: torch.optim.Optimizer, epoch: int, disable_derivative_free_leaf_optim: bool, device, log: Log = None, log_prefix: str = 'log_train_epochs', progress_prefix: str = 'Train Epoch' ) -> dict: tree = tree.to(device) # Store info about the procedure train_info = dict() total_loss = 0. total_acc = 0. # Create a log if required log_loss = f'{log_prefix}_losses' if log is not None and epoch==1: log.create_log(log_loss, 'epoch', 'batch', 'loss', 'batch_train_acc') # Reset the gradients optimizer.zero_grad() if disable_derivative_free_leaf_optim: print("WARNING: kontschieder arguments will be ignored when training leaves with gradient descent") else: if tree._kontschieder_normalization: # Iterate over the dataset multiple times to learn leaves following Kontschieder's approach for _ in range(10): # Train leaves with derivative-free algorithm using normalization factor train_leaves_epoch(tree, train_loader, epoch, device) else: # Train leaves with Kontschieder's derivative-free algorithm, but using softmax train_leaves_epoch(tree, train_loader, epoch, device) # Train prototypes and network. # If disable_derivative_free_leaf_optim, leafs are optimized with gradient descent as well. # Show progress on progress bar train_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc=progress_prefix+' %s'%epoch, ncols=0) # Make sure the model is in train mode tree.train() for i, (xs, ys) in train_iter: xs, ys = xs.to(device), ys.to(device) # Reset the gradients optimizer.zero_grad() # Perform a forward pass through the network ys_pred, _ = tree.forward(xs) # Compute the loss if tree._log_probabilities: loss = F.nll_loss(ys_pred, ys) else: loss = F.nll_loss(torch.log(ys_pred), ys) # Compute the gradient loss.backward() # Update model parameters optimizer.step() # Count the number of correct classifications ys_pred = torch.argmax(ys_pred, dim=1) correct = torch.sum(torch.eq(ys_pred, ys)) acc = correct.item() / float(len(xs)) train_iter.set_postfix_str( f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}' ) # Compute metrics over this batch total_loss+=loss.item() total_acc+=acc if log is not None: log.log_values(log_loss, epoch, i + 1, loss.item(), acc) train_info['loss'] = total_loss/float(i+1) train_info['train_accuracy'] = total_acc/float(i+1) return train_info