def init_tree(tree: ProtoTree, optimizer, scheduler, device, args: argparse.Namespace): epoch = 1 mean = 0.5 std = 0.1 # load trained prototree if flag is set # NOTE: TRAINING FURTHER FROM A CHECKPOINT DOESN'T SEEM TO WORK CORRECTLY. EVALUATING A TRAINED PROTOTREE FROM A CHECKPOINT DOES WORK. if args.state_dict_dir_tree != '': if not args.disable_cuda and torch.cuda.is_available(): device = torch.device('cuda:{}'.format(torch.cuda.current_device())) else: device = torch.device('cpu') if args.disable_cuda or not torch.cuda.is_available(): # tree = load_state(args.state_dict_dir_tree, device) tree = torch.load(args.state_dict_dir_tree+'/model.pth', map_location=device) else: tree = torch.load(args.state_dict_dir_tree+'/model.pth') tree.to(device=device) try: epoch = int(args.state_dict_dir_tree.split('epoch_')[-1]) + 1 except: epoch=args.epochs+1 print("Train further from epoch: ", epoch, flush=True) optimizer.load_state_dict(torch.load(args.state_dict_dir_tree+'/optimizer_state.pth', map_location=device)) if epoch>args.freeze_epochs: for parameter in tree._net.parameters(): parameter.requires_grad = True if not args.disable_derivative_free_leaf_optim: for leaf in tree.leaves: leaf._dist_params.requires_grad = False if os.path.isfile(args.state_dict_dir_tree+'/scheduler_state.pth'): # scheduler.load_state_dict(torch.load(args.state_dict_dir_tree+'/scheduler_state.pth')) # print(scheduler.state_dict(),flush=True) scheduler.last_epoch = epoch - 1 scheduler._step_count = epoch elif args.state_dict_dir_net != '': # load pretrained conv network # initialize prototypes torch.nn.init.normal_(tree.prototype_layer.prototype_vectors, mean=mean, std=std) #strict is False so when loading pretrained model, ignore the linear classification layer tree._net.load_state_dict(torch.load(args.state_dict_dir_net+'/model_state.pth'), strict=False) tree._add_on.load_state_dict(torch.load(args.state_dict_dir_net+'/model_state.pth'), strict=False) else: with torch.no_grad(): # initialize prototypes torch.nn.init.normal_(tree.prototype_layer.prototype_vectors, mean=mean, std=std) tree._add_on.apply(init_weights_xavier) return tree, epoch
def eval_fidelity(tree: ProtoTree, test_loader: DataLoader, device, log: Log = None, progress_prefix: str = 'Fidelity') -> dict: tree = tree.to(device) # Keep an info dict about the procedure info = dict() # Make sure the model is in evaluation mode tree.eval() # Show progress on progress bar test_iter = tqdm(enumerate(test_loader), total=len(test_loader), desc=progress_prefix, ncols=0) distr_samplemax_fidelity = 0 distr_greedy_fidelity = 0 # Iterate through the test set for i, (xs, ys) in test_iter: xs, ys = xs.to(device), ys.to(device) # Use the model to classify this batch of input data, with 3 types of routing out_distr, _ = tree.forward(xs, 'distributed') ys_pred_distr = torch.argmax(out_distr, dim=1) out_samplemax, _ = tree.forward(xs, 'sample_max') ys_pred_samplemax = torch.argmax(out_samplemax, dim=1) out_greedy, _ = tree.forward(xs, 'greedy') ys_pred_greedy = torch.argmax(out_greedy, dim=1) # Calculate fidelity distr_samplemax_fidelity += torch.sum( torch.eq(ys_pred_samplemax, ys_pred_distr)).item() distr_greedy_fidelity += torch.sum( torch.eq(ys_pred_greedy, ys_pred_distr)).item() # Update the progress bar test_iter.set_postfix_str(f'Batch [{i + 1}/{len(test_iter)}]') del out_distr del out_samplemax del out_greedy distr_samplemax_fidelity = distr_samplemax_fidelity / float( len(test_loader.dataset)) distr_greedy_fidelity = distr_greedy_fidelity / float( len(test_loader.dataset)) info['distr_samplemax_fidelity'] = distr_samplemax_fidelity info['distr_greedy_fidelity'] = distr_greedy_fidelity log.log_message( "Fidelity between standard distributed routing and sample_max routing: " + str(distr_samplemax_fidelity)) log.log_message( "Fidelity between standard distributed routing and greedy routing: " + str(distr_greedy_fidelity)) return info
def eval(tree: ProtoTree, test_loader: DataLoader, epoch, device, log: Log = None, sampling_strategy: str = 'distributed', log_prefix: str = 'log_eval_epochs', progress_prefix: str = 'Eval Epoch') -> dict: tree = tree.to(device) # Keep an info dict about the procedure info = dict() if sampling_strategy != 'distributed': info['out_leaf_ix'] = [] # Build a confusion matrix cm = np.zeros((tree._num_classes, tree._num_classes), dtype=int) # Make sure the model is in evaluation mode tree.eval() # Show progress on progress bar test_iter = tqdm(enumerate(test_loader), total=len(test_loader), desc=progress_prefix + ' %s' % epoch, ncols=0) # Iterate through the test set for i, (xs, ys) in test_iter: xs, ys = xs.to(device), ys.to(device) # Use the model to classify this batch of input data out, test_info = tree.forward(xs, sampling_strategy) ys_pred = torch.argmax(out, dim=1) # Update the confusion matrix cm_batch = np.zeros((tree._num_classes, tree._num_classes), dtype=int) for y_pred, y_true in zip(ys_pred, ys): cm[y_true][y_pred] += 1 cm_batch[y_true][y_pred] += 1 acc = acc_from_cm(cm_batch) test_iter.set_postfix_str( f'Batch [{i + 1}/{len(test_iter)}], Acc: {acc:.3f}') # keep list of leaf indices where test sample ends up when deterministic routing is used. if sampling_strategy != 'distributed': info['out_leaf_ix'] += test_info['out_leaf_ix'] del out del ys_pred del test_info info['confusion_matrix'] = cm info['test_accuracy'] = acc_from_cm(cm) log.log_message("\nEpoch %s - Test accuracy with %s routing: " % (epoch, sampling_strategy) + str(info['test_accuracy'])) return info
def run_tree(args=None): args = args or get_args() # Create a logger log = Log(args.log_dir) print("Log dir: ", args.log_dir, flush=True) # Create a csv log for storing the test accuracy, mean train accuracy and mean loss for each epoch log.create_log('log_epoch_overview', 'epoch', 'test_acc', 'mean_train_acc', 'mean_train_crossentropy_loss_during_epoch') # Log the run arguments save_args(args, log.metadata_dir) if not args.disable_cuda and torch.cuda.is_available(): # device = torch.device('cuda') device = torch.device('cuda:{}'.format(torch.cuda.current_device())) else: device = torch.device('cpu') # Log which device was actually used log.log_message('Device used: '+str(device)) # Create a log for logging the loss values log_prefix = 'log_train_epochs' log_loss = log_prefix+'_losses' log.create_log(log_loss, 'epoch', 'batch', 'loss', 'batch_train_acc') # Obtain the dataset and dataloaders trainloader, projectloader, testloader, classes, num_channels = get_dataloaders(args) # Create a convolutional network based on arguments and add 1x1 conv layer features_net, add_on_layers = get_network(num_channels, args) # Create a ProtoTree tree = ProtoTree(num_classes=len(classes), feature_net = features_net, args = args, add_on_layers = add_on_layers) tree = tree.to(device=device) # Determine which optimizer should be used to update the tree parameters optimizer, params_to_freeze, params_to_train = get_optimizer(tree, args) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=args.milestones, gamma=args.gamma) tree, epoch = init_tree(tree, optimizer, scheduler, device, args) tree.save(f'{log.checkpoint_dir}/tree_init') log.log_message("Max depth %s, so %s internal nodes and %s leaves"%(args.depth, tree.num_branches, tree.num_leaves)) analyse_output_shape(tree, trainloader, log, device) leaf_labels = dict() best_train_acc = 0. best_test_acc = 0. if epoch < args.epochs + 1: ''' TRAIN AND EVALUATE TREE ''' for epoch in range(epoch, args.epochs + 1): log.log_message("\nEpoch %s"%str(epoch)) # Freeze (part of) network for some epochs if indicated in args freeze(tree, epoch, params_to_freeze, params_to_train, args, log) log_learning_rates(optimizer, args, log) # Train tree if tree._kontschieder_train: train_info = train_epoch_kontschieder(tree, trainloader, optimizer, epoch, args.disable_derivative_free_leaf_optim, device, log, log_prefix) else: train_info = train_epoch(tree, trainloader, optimizer, epoch, args.disable_derivative_free_leaf_optim, device, log, log_prefix) save_tree(tree, optimizer, scheduler, epoch, log, args) best_train_acc = save_best_train_tree(tree, optimizer, scheduler, best_train_acc, train_info['train_accuracy'], log) leaf_labels = analyse_leafs(tree, epoch, len(classes), leaf_labels, args.pruning_threshold_leaves, log) # Evaluate tree if args.epochs>100: if epoch%10==0 or epoch==args.epochs: eval_info = eval(tree, testloader, epoch, device, log) original_test_acc = eval_info['test_accuracy'] best_test_acc = save_best_test_tree(tree, optimizer, scheduler, best_test_acc, eval_info['test_accuracy'], log) log.log_values('log_epoch_overview', epoch, eval_info['test_accuracy'], train_info['train_accuracy'], train_info['loss']) else: log.log_values('log_epoch_overview', epoch, "n.a.", train_info['train_accuracy'], train_info['loss']) else: eval_info = eval(tree, testloader, epoch, device, log) original_test_acc = eval_info['test_accuracy'] best_test_acc = save_best_test_tree(tree, optimizer, scheduler, best_test_acc, eval_info['test_accuracy'], log) log.log_values('log_epoch_overview', epoch, eval_info['test_accuracy'], train_info['train_accuracy'], train_info['loss']) scheduler.step() else: #tree was loaded and not trained, so evaluate only ''' EVALUATE TREE ''' eval_info = eval(tree, testloader, epoch, device, log) original_test_acc = eval_info['test_accuracy'] best_test_acc = save_best_test_tree(tree, optimizer, scheduler, best_test_acc, eval_info['test_accuracy'], log) log.log_values('log_epoch_overview', epoch, eval_info['test_accuracy'], "n.a.", "n.a.") ''' EVALUATE AND ANALYSE TRAINED TREE ''' log.log_message("Training Finished. Best training accuracy was %s, best test accuracy was %s\n"%(str(best_train_acc), str(best_test_acc))) trained_tree = deepcopy(tree) leaf_labels = analyse_leafs(tree, epoch+1, len(classes), leaf_labels, args.pruning_threshold_leaves, log) analyse_leaf_distributions(tree, log) ''' PRUNE ''' pruned = prune(tree, args.pruning_threshold_leaves, log) name = "pruned" save_tree_description(tree, optimizer, scheduler, name, log) pruned_tree = deepcopy(tree) # Analyse and evaluate pruned tree leaf_labels = analyse_leafs(tree, epoch+2, len(classes), leaf_labels, args.pruning_threshold_leaves, log) analyse_leaf_distributions(tree, log) eval_info = eval(tree, testloader, name, device, log) pruned_test_acc = eval_info['test_accuracy'] pruned_tree = tree ''' PROJECT ''' project_info, tree = project_with_class_constraints(deepcopy(pruned_tree), projectloader, device, args, log) name = "pruned_and_projected" save_tree_description(tree, optimizer, scheduler, name, log) pruned_projected_tree = deepcopy(tree) # Analyse and evaluate pruned tree with projected prototypes average_distance_nearest_image(project_info, tree, log) leaf_labels = analyse_leafs(tree, epoch+3, len(classes), leaf_labels, args.pruning_threshold_leaves, log) analyse_leaf_distributions(tree, log) eval_info = eval(tree, testloader, name, device, log) pruned_projected_test_acc = eval_info['test_accuracy'] eval_info_samplemax = eval(tree, testloader, name, device, log, 'sample_max') get_avg_path_length(tree, eval_info_samplemax, log) eval_info_greedy = eval(tree, testloader, name, device, log, 'greedy') get_avg_path_length(tree, eval_info_greedy, log) fidelity_info = eval_fidelity(tree, testloader, device, log) # Upsample prototype for visualization project_info = upsample(tree, project_info, projectloader, name, args, log) # visualize tree gen_vis(tree, name, args, classes) return trained_tree.to('cpu'), pruned_tree.to('cpu'), pruned_projected_tree.to('cpu'), original_test_acc, pruned_test_acc, pruned_projected_test_acc, project_info, eval_info_samplemax, eval_info_greedy, fidelity_info
def train_epoch(tree: ProtoTree, train_loader: DataLoader, optimizer: torch.optim.Optimizer, epoch: int, disable_derivative_free_leaf_optim: bool, device, log: Log = None, log_prefix: str = 'log_train_epochs', progress_prefix: str = 'Train Epoch' ) -> dict: tree = tree.to(device) # Make sure the model is in eval mode tree.eval() # Store info about the procedure train_info = dict() total_loss = 0. total_acc = 0. # Create a log if required log_loss = f'{log_prefix}_losses' nr_batches = float(len(train_loader)) with torch.no_grad(): _old_dist_params = dict() for leaf in tree.leaves: _old_dist_params[leaf] = leaf._dist_params.detach().clone() # Optimize class distributions in leafs eye = torch.eye(tree._num_classes).to(device) # Show progress on progress bar train_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc=progress_prefix+' %s'%epoch, ncols=0) # Iterate through the data set to update leaves, prototypes and network for i, (xs, ys) in train_iter: # Make sure the model is in train mode tree.train() # Reset the gradients optimizer.zero_grad() xs, ys = xs.to(device), ys.to(device) # Perform a forward pass through the network ys_pred, info = tree.forward(xs) # Learn prototypes and network with gradient descent. # If disable_derivative_free_leaf_optim, leaves are optimized with gradient descent as well. # Compute the loss if tree._log_probabilities: loss = F.nll_loss(ys_pred, ys) else: loss = F.nll_loss(torch.log(ys_pred), ys) # Compute the gradient loss.backward() # Update model parameters optimizer.step() if not disable_derivative_free_leaf_optim: #Update leaves with derivate-free algorithm #Make sure the tree is in eval mode tree.eval() with torch.no_grad(): target = eye[ys] #shape (batchsize, num_classes) for leaf in tree.leaves: if tree._log_probabilities: # log version update = torch.exp(torch.logsumexp(info['pa_tensor'][leaf.index] + leaf.distribution() + torch.log(target) - ys_pred, dim=0)) else: update = torch.sum((info['pa_tensor'][leaf.index] * leaf.distribution() * target)/ys_pred, dim=0) leaf._dist_params -= (_old_dist_params[leaf]/nr_batches) F.relu_(leaf._dist_params) #dist_params values can get slightly negative because of floating point issues. therefore, set to zero. leaf._dist_params += update # Count the number of correct classifications ys_pred_max = torch.argmax(ys_pred, dim=1) correct = torch.sum(torch.eq(ys_pred_max, ys)) acc = correct.item() / float(len(xs)) train_iter.set_postfix_str( f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}' ) # Compute metrics over this batch total_loss+=loss.item() total_acc+=acc if log is not None: log.log_values(log_loss, epoch, i + 1, loss.item(), acc) train_info['loss'] = total_loss/float(i+1) train_info['train_accuracy'] = total_acc/float(i+1) return train_info
def train_epoch_kontschieder(tree: ProtoTree, train_loader: DataLoader, optimizer: torch.optim.Optimizer, epoch: int, disable_derivative_free_leaf_optim: bool, device, log: Log = None, log_prefix: str = 'log_train_epochs', progress_prefix: str = 'Train Epoch' ) -> dict: tree = tree.to(device) # Store info about the procedure train_info = dict() total_loss = 0. total_acc = 0. # Create a log if required log_loss = f'{log_prefix}_losses' if log is not None and epoch==1: log.create_log(log_loss, 'epoch', 'batch', 'loss', 'batch_train_acc') # Reset the gradients optimizer.zero_grad() if disable_derivative_free_leaf_optim: print("WARNING: kontschieder arguments will be ignored when training leaves with gradient descent") else: if tree._kontschieder_normalization: # Iterate over the dataset multiple times to learn leaves following Kontschieder's approach for _ in range(10): # Train leaves with derivative-free algorithm using normalization factor train_leaves_epoch(tree, train_loader, epoch, device) else: # Train leaves with Kontschieder's derivative-free algorithm, but using softmax train_leaves_epoch(tree, train_loader, epoch, device) # Train prototypes and network. # If disable_derivative_free_leaf_optim, leafs are optimized with gradient descent as well. # Show progress on progress bar train_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc=progress_prefix+' %s'%epoch, ncols=0) # Make sure the model is in train mode tree.train() for i, (xs, ys) in train_iter: xs, ys = xs.to(device), ys.to(device) # Reset the gradients optimizer.zero_grad() # Perform a forward pass through the network ys_pred, _ = tree.forward(xs) # Compute the loss if tree._log_probabilities: loss = F.nll_loss(ys_pred, ys) else: loss = F.nll_loss(torch.log(ys_pred), ys) # Compute the gradient loss.backward() # Update model parameters optimizer.step() # Count the number of correct classifications ys_pred = torch.argmax(ys_pred, dim=1) correct = torch.sum(torch.eq(ys_pred, ys)) acc = correct.item() / float(len(xs)) train_iter.set_postfix_str( f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}' ) # Compute metrics over this batch total_loss+=loss.item() total_acc+=acc if log is not None: log.log_values(log_loss, epoch, i + 1, loss.item(), acc) train_info['loss'] = total_loss/float(i+1) train_info['train_accuracy'] = total_acc/float(i+1) return train_info