def main(): args = parse_args() args.nclasses = 10 args.theta_dim = args.nclasses model_path, log_path, results_path = generate_dir_names('mnist', args) train_loader, valid_loader, test_loader, train_tds, test_tds = load_mnist_data( batch_size=args.batch_size, num_workers=args.num_workers ) # Set h_type if args.h_type == 'input': conceptizer = input_conceptizer() args.nconcepts = 28*28 + int(not args.nobias) elif args.h_type == 'cnn': conceptizer = image_cnn_conceptizer(28*28, args.nconcepts, args.concept_dim) #, sparsity = sparsity_l) else: conceptizer = image_fcc_conceptizer(28*28, args.nconcepts, args.concept_dim) #, sparsity = sparsity_l) # Initialize model parametrizer = image_parametrizer(28*28, args.nconcepts, args.theta_dim, only_positive = args.positive_theta) aggregator = additive_scalar_aggregator(args.concept_dim, args.nclasses) model = GSENN(conceptizer, parametrizer, aggregator) # If load_model == True, load existing model if args.load_model: checkpoint = torch.load(os.path.join(model_path,'model_best.pth.tar'), map_location=lambda storage, loc: storage) checkpoint.keys() model = checkpoint['model'] # Specify theta regression type if args.theta_reg_type in ['unreg','none', None]: trainer = VanillaClassTrainer(model, args) elif args.theta_reg_type == 'grad1': trainer = GradPenaltyTrainer(model, args, typ = 1) elif args.theta_reg_type == 'grad2': trainer = GradPenaltyTrainer(model, args, typ = 2) elif args.theta_reg_type == 'grad3': trainer = GradPenaltyTrainer(model, args, typ = 3) elif args.theta_reg_type == 'crosslip': trainer = CLPenaltyTrainer(model, args) else: raise ValueError('Unrecoginzed theta_reg_type') # Train model if not args.load_model and args.train: trainer.train(train_loader, valid_loader, epochs = args.epochs, save_path = model_path) trainer.plot_losses(save_path=results_path) # If nothing is specified, load model and use VanillaClassTrainer else: checkpoint = torch.load(os.path.join(model_path,'model_best.pth.tar'), map_location=lambda storage, loc: storage) checkpoint.keys() model = checkpoint['model'] trainer = VanillaClassTrainer(model, args) print("Done training/ loading model") # Evaluation ### 1. Single point lipshiz estimate via black box optim # All methods tested with BB optim for fair comparison) features = None classes = [str(i) for i in range(10)] model.eval() expl = gsenn_wrapper(model, mode = 'classification', input_type = 'image', multiclass=True, feature_names = features, class_names = classes, train_data = train_loader, skip_bias = True, verbose = False) # Make noise stability plots, (Figure 4 paper) print("Results_path", results_path) # noise_stability_plots(model, test_tds, cuda = args.cuda, save_path = results_path) # One sample # plt.imshow(test_tds[1][0].reshape(28,28)) # plt.show() print(test_tds) distances = defaultdict(list) print(distances) scale = 0.5 print(distances) for i in tqdm(range(1000)): x = Variable(test_tds[i][0].view(1,1,28,28), volatile = True) true_class = test_tds[i][1][0].item() pred = model(x) theta = model.thetas.data.cpu().numpy().squeeze() klass = pred.data.max(1)[1] deps = theta[:,klass].squeeze() # print("prediction", klass) # print("dependencies", deps) # Add noise to sample and repeat noise = Variable(scale*torch.randn(x.size()), volatile = True) pred = model(noise) theta = model.thetas.data.cpu().numpy().squeeze() klass_noise = pred.data.max(1)[1] deps_noise = theta[:,klass].squeeze() dist = np.linalg.norm(deps - deps_noise) distances[true_class].append(dist) print(distances)
def main(): args = parse_args() np.random.seed(args.seed) torch.manual_seed(args.seed) args.nclasses = 10 args.theta_dim = args.nclasses if (args.theta_arch == 'simple') or ('vgg' in args.theta_arch): H, W = 32, 32 else: # Need to resize to have access to torchvision's models H, W = 224, 224 args.input_dim = H * W model_path, log_path, results_path = generate_dir_names('cifar', args) train_loader, valid_loader, test_loader, train_tds, test_tds = load_cifar_data( batch_size=args.batch_size, num_workers=args.num_workers, resize=(H, W)) if args.h_type == 'input': conceptizer = input_conceptizer() args.nconcepts = args.input_dim + int(not args.nobias) elif args.h_type == 'cnn': # biase. They treat it like any other concept. #args.nconcepts += int(not args.nobias) conceptizer = image_cnn_conceptizer( args.input_dim, args.nconcepts, args.concept_dim, nchannel=3) #, sparsity = sparsity_l) else: #args.nconcepts += int(not args.nobias) conceptizer = image_fcc_conceptizer( args.input_dim, args.nconcepts, args.concept_dim, nchannel=3) #, sparsity = sparsity_l) if args.theta_arch == 'simple': parametrizer = image_parametrizer(args.input_dim, args.nconcepts, args.theta_dim, nchannel=3, only_positive=args.positive_theta) elif 'vgg' in args.theta_arch: parametrizer = vgg_parametrizer( args.input_dim, args.nconcepts, args.theta_dim, arch=args.theta_arch, nchannel=3, only_positive=args.positive_theta ) #torchvision.models.alexnet(num_classes = args.nconcepts*args.theta_dim) else: parametrizer = torchvision_parametrizer( args.input_dim, args.nconcepts, args.theta_dim, arch=args.theta_arch, nchannel=3, only_positive=args.positive_theta ) #torchvision.models.alexnet(num_classes = args.nconcepts*args.theta_dim) aggregator = additive_scalar_aggregator(args.concept_dim, args.nclasses) model = GSENN(conceptizer, parametrizer, aggregator) #, learn_h = args.train_h) # if not args.train and args.load_model: # checkpoint = torch.load(os.path.join(model_path,'model_best.pth.tar'), map_location=lambda storage, loc: storage) # checkpoint.keys() # model = checkpoint['model'] # # if args.theta_reg_type in ['unreg', 'none', None]: trainer = VanillaClassTrainer(model, args) elif args.theta_reg_type == 'grad1': trainer = GradPenaltyTrainer(model, args, typ=1) elif args.theta_reg_type == 'grad2': trainer = GradPenaltyTrainer(model, args, typ=2) elif args.theta_reg_type == 'grad3': trainer = GradPenaltyTrainer(model, args, typ=3) elif args.theta_reg_type == 'crosslip': trainer = CLPenaltyTrainer(model, args) else: raise ValueError('Unrecoginzed theta_reg_type') if args.train or not args.load_model or (not os.path.isfile( os.path.join(model_path, 'model_best.pth.tar'))): trainer.train(train_loader, valid_loader, epochs=args.epochs, save_path=model_path) trainer.plot_losses(save_path=results_path) else: checkpoint = torch.load(os.path.join(model_path, 'model_best.pth.tar'), map_location=lambda storage, loc: storage) checkpoint.keys() model = checkpoint['model'] trainer = VanillaClassTrainer( model, args) # arbtrary trained, only need to compuyte val acc #trainer.validate(test_loader, fold = 'test') model.eval() All_Results = {} ### 0. Concept Grid for Visualization #concept_grid(model, test_loader, top_k = 10, cuda = args.cuda, save_path = results_path + '/concept_grid.pdf') ### 1. Single point lipshiz estimate via black box optim (for fair comparison) # with other methods in which we have to use BB optimization. features = None classes = [str(i) for i in range(10)] expl = gsenn_wrapper(model, mode='classification', input_type='image', multiclass=True, feature_names=features, class_names=classes, train_data=train_loader, skip_bias=True, verbose=False) ### Debug single input # x = next(iter(train_tds))[0] # attr = expl(x, show_plot = False) # pdb.set_trace() # #### Debug multi input # x = next(iter(test_loader))[0] # Transformed # x_raw = test_loader.dataset.test_data[:args.batch_size,:,:] # attr = expl(x, x_raw = x_raw, show_plot = True) # #pdb.set_trace() # #### Debug argmaz plot_theta_stability if args.h_type == 'input': x = next(iter(test_tds))[0].numpy() y = next(iter(test_tds))[0].numpy() x_raw = (test_tds.test_data[0].float() / 255).numpy() y_raw = revert_to_raw(x) att_x = expl(x, show_plot=False) att_y = expl(y, show_plot=False) lip = 1 lipschitz_argmax_plot(x_raw, y_raw, att_x, att_y, lip) # save_path=fpath) #pdb.set_trace() ### 2. Single example lipschitz estimate with Black Box do_bb_stability_example = True if do_bb_stability_example: print('**** Performing lipschitz estimation for a single point ****') idx = 0 print('Example index: {}'.format(idx)) #x = train_tds[idx][0].view(1,28,28).numpy() x = next(iter(test_tds))[0].numpy() #x_raw = (test_tds.test_data[0].float()/255).numpy() x_raw = (test_tds.test_data[0] / 255) #x_raw = next(iter(train_tds))[0] # args.optim = 'gp' # args.lip_eps = 0.1 # args.lip_calls = 10 Results = {} lip, argmax = expl.local_lipschitz_estimate(x, bound_type='box_std', optim=args.optim, eps=args.lip_eps, n_calls=4 * args.lip_calls, njobs=1, verbose=2) #pdb.set_trace() Results['lip_argmax'] = (x, argmax, lip) # .reshape(inputs.shape[0], inputs.shape[1], -1) att = expl(x, None, show_plot=False) #.squeeze() # .reshape(inputs.shape[0], inputs.shape[1], -1) att_argmax = expl(argmax, None, show_plot=False) #.squeeze() #pdb.set_trace() Argmax_dict = {'lip': lip, 'argmax': argmax, 'x': x} fpath = os.path.join(results_path, 'argmax_lip_gp_senn.pdf') if args.h_type == 'input': lipschitz_argmax_plot(x_raw, revert_to_raw(argmax), att, att_argmax, lip, save_path=fpath) pickle.dump(Argmax_dict, open(results_path + '/argmax_lip_gp_senn.pkl', "wb")) #noise_stability_plots(model, test_tds, cuda = args.cuda, save_path = results_path) ### 3. Local lipschitz estimate over multiple samples with Black BOx Optim do_bb_stability = True if do_bb_stability: print( '**** Performing black-box lipschitz estimation over subset of dataset ****' ) maxpoints = 20 #valid_loader 0 it's shuffled, so it's like doing random choice mini_test = next(iter(valid_loader))[0][:maxpoints].numpy() lips = expl.estimate_dataset_lipschitz(mini_test, n_jobs=-1, bound_type='box_std', eps=args.lip_eps, optim=args.optim, n_calls=args.lip_calls, verbose=2) Stability_dict = {'lips': lips} pickle.dump(Stability_dict, open(results_path + '_stability_blackbox.pkl', "wb")) All_Results['stability_blackbox'] = lips pickle.dump( All_Results, open(results_path + '_combined_metrics.pkl'.format(dataname), "wb"))
def main(): args = parse_args() args.nclasses = 10 args.theta_dim = args.nclasses model_path, log_path, results_path = generate_dir_names('mnist', args) train_loader, valid_loader, test_loader, train_tds, test_tds = load_mnist_data( batch_size=args.batch_size, num_workers=args.num_workers) if args.h_type == 'input': conceptizer = input_conceptizer() args.nconcepts = 28 * 28 + int(not args.nobias) elif args.h_type == 'cnn': #args.nconcepts += int(not args.nobias) conceptizer = image_cnn_conceptizer( 28 * 28, args.nconcepts, args.concept_dim) #, sparsity = sparsity_l) else: #args.nconcepts += int(not args.nobias) conceptizer = image_fcc_conceptizer( 28 * 28, args.nconcepts, args.concept_dim) #, sparsity = sparsity_l) parametrizer = image_parametrizer(28 * 28, args.nconcepts, args.theta_dim, only_positive=args.positive_theta) aggregator = additive_scalar_aggregator(args.concept_dim, args.nclasses) model = GSENN(conceptizer, parametrizer, aggregator) #, learn_h = args.train_h) if args.load_model: checkpoint = torch.load(os.path.join(model_path, 'model_best.pth.tar'), map_location=lambda storage, loc: storage) checkpoint.keys() model = checkpoint['model'] if args.theta_reg_type in ['unreg', 'none', None]: trainer = VanillaClassTrainer(model, args) elif args.theta_reg_type == 'grad1': trainer = GradPenaltyTrainer(model, args, typ=1) elif args.theta_reg_type == 'grad2': trainer = GradPenaltyTrainer(model, args, typ=2) elif args.theta_reg_type == 'grad3': trainer = GradPenaltyTrainer(model, args, typ=3) elif args.theta_reg_type == 'crosslip': trainer = CLPenaltyTrainer(model, args) else: raise ValueError('Unrecoginzed theta_reg_type') if not args.load_model and args.train: trainer.train(train_loader, valid_loader, epochs=args.epochs, save_path=model_path) trainer.plot_losses(save_path=results_path) else: checkpoint = torch.load(os.path.join(model_path, 'model_best.pth.tar'), map_location=lambda storage, loc: storage) checkpoint.keys() model = checkpoint['model'] trainer = VanillaClassTrainer(model, args) trainer.validate(test_loader, fold='test') All_Results = {} ### 1. Single point lipshiz estimate via black box optim # All methods tested with BB optim for fair comparison) features = None classes = [str(i) for i in range(10)] model.eval() expl = new_wrapper(model, mode='classification', input_type='image', multiclass=True, feature_names=features, class_names=classes, train_data=train_loader, skip_bias=True, verbose=False) #### Debug single input # x = next(iter(train_tds))[0] # attr = expl(x, show_plot = False) # pdb.set_trace() # #### Debug multi input # x = next(iter(test_loader))[0] # Transformed # x_raw = test_loader.dataset.test_data[:args.batch_size,:,:] # attr = expl(x, x_raw = x_raw, show_plot = True) # #pdb.set_trace() ### Consistency analysis # for i, (inputs, targets) in enumerate(test_loader): # # get the inputs # if model.cuda: # inputs, targets = inputs.cuda(), targets.cuda() # input_var = torch.autograd.Variable(inputs, volatile=True) # save_path = results_path + '/faithfulness' + str(i) + '/' # if not os.path.isdir(save_path): # os.mkdir(save_path) # corrs = expl.compute_dataset_consistency(input_var, inputs_are_concepts = False, save_path = save_path) # #### Debug argmax plot_theta_stability if args.h_type == 'input': x = next(iter(test_tds))[0].numpy() y = next(iter(test_tds))[0].numpy() x_raw = (test_tds.test_data[0].float() / 255).numpy() y_raw = revert_to_raw(x) att_x = expl(x, show_plot=False) att_y = expl(y, show_plot=False) lip = 1 lipschitz_argmax_plot(x_raw, y_raw, att_x, att_y, lip) # save_path=fpath) #pdb.set_trace() ### 2. Single example lipschitz estimate with Black Box do_bb_stability_example = False # Aangepast: was False if do_bb_stability_example: print('**** Performing lipschitz estimation for a single point ****') idx = 0 print('Example index: {}'.format(idx)) #x = train_tds[idx][0].view(1,28,28).numpy() x = next(iter(test_tds))[0].numpy() x_raw = (test_tds.test_data[0].float() / 255).numpy() #x_raw = next(iter(train_tds))[0] # args.optim = 'gp' # args.lip_eps = 0.1 # args.lip_calls = 10 Results = {} lip, argmax = expl.local_lipschitz_estimate(x, bound_type='box_std', optim=args.optim, eps=args.lip_eps, n_calls=4 * args.lip_calls, njobs=1, verbose=2) #pdb.set_trace() Results['lip_argmax'] = (x, argmax, lip) # .reshape(inputs.shape[0], inputs.shape[1], -1) att = expl(x, None, show_plot=False) #.squeeze() # .reshape(inputs.shape[0], inputs.shape[1], -1) att_argmax = expl(argmax, None, show_plot=False) #.squeeze() #pdb.set_trace() Argmax_dict = {'lip': lip, 'argmax': argmax, 'x': x} fpath = os.path.join(results_path, 'argmax_lip_gp_senn.pdf') if args.h_type == 'input': lipschitz_argmax_plot(x_raw, revert_to_raw(argmax), att, att_argmax, lip, save_path=fpath) pickle.dump(Argmax_dict, open(results_path + '/argmax_lip_gp_senn.pkl', "wb")) pdb.set_trace() # print(asd.asd) noise_stability_plots(model, test_tds, cuda=args.cuda, save_path=results_path) ### 3. Local lipschitz estimate over multiple samples with Black BOx Optim do_bb_stability = False # Aangepast, was: True if do_bb_stability: print( '**** Performing black-box lipschitz estimation over subset of dataset ****' ) maxpoints = 20 #valid_loader 0 it's shuffled, so it's like doing random choice mini_test = next(iter(valid_loader))[0][:maxpoints].numpy() lips = expl.estimate_dataset_lipschitz(mini_test, n_jobs=-1, bound_type='box_std', eps=args.lip_eps, optim=args.optim, n_calls=args.lip_calls, verbose=2) pdb.set_trace() Stability_dict = {'lips': lips} pickle.dump(Stability_dict, open(results_path + '_stability_blackbox.pkl', "wb")) All_Results['stability_blackbox'] = lips # add concept plot concept_grid(model, test_loader, top_k=10, save_path=results_path + '/concept_grid.pdf') pickle.dump(All_Results, open(results_path + '_combined_metrics.pkl', "wb")) # Aangepast: .pkl was .format(dataname)
def main(): args = parse_args() args.nclasses = 10 args.theta_dim = args.nclasses model_path, log_path, results_path = generate_dir_names('mnist', args) # print("Model path out", model_path) train_loader, valid_loader, test_loader, train_tds, test_tds = load_mnist_data( batch_size=args.batch_size, num_workers=args.num_workers) if args.h_type == 'input': conceptizer = input_conceptizer() args.nconcepts = 28 * 28 + int(not args.nobias) elif args.h_type == 'cnn': #args.nconcepts += int(not args.nobias) conceptizer = image_cnn_conceptizer( 28 * 28, args.nconcepts, args.concept_dim) #, sparsity = sparsity_l) else: #args.nconcepts += int(not args.nobias) conceptizer = image_fcc_conceptizer( 28 * 28, args.nconcepts, args.concept_dim) #, sparsity = sparsity_l) parametrizer = image_parametrizer(28 * 28, args.nconcepts, args.theta_dim, only_positive=args.positive_theta) aggregator = additive_scalar_aggregator(args.concept_dim, args.nclasses) model = GSENN(conceptizer, parametrizer, aggregator) #, learn_h = args.train_h) if args.load_model: checkpoint = torch.load(os.path.join(model_path, 'model_best.pth.tar'), map_location=lambda storage, loc: storage) checkpoint.keys() model = checkpoint['model'] if args.theta_reg_type in ['unreg', 'none', None]: trainer = VanillaClassTrainer(model, args) elif args.theta_reg_type == 'grad1': trainer = GradPenaltyTrainer(model, args, typ=1) elif args.theta_reg_type == 'grad2': trainer = GradPenaltyTrainer(model, args, typ=2) elif args.theta_reg_type == 'grad3': trainer = GradPenaltyTrainer(model, args, typ=3) elif args.theta_reg_type == 'crosslip': trainer = CLPenaltyTrainer(model, args) else: raise ValueError('Unrecoginzed theta_reg_type') if not args.load_model and args.train: trainer.train(train_loader, valid_loader, epochs=args.epochs, save_path=model_path) trainer.plot_losses(save_path=results_path) else: checkpoint = torch.load(os.path.join(model_path, 'model_best.pth.tar'), map_location=lambda storage, loc: storage) checkpoint.keys() model = checkpoint['model'] trainer = VanillaClassTrainer(model, args) All_Results = {} features = None classes = [str(i) for i in range(10)] model.eval() expl = new_wrapper(model, mode='classification', input_type='image', multiclass=True, feature_names=features, class_names=classes, train_data=train_loader, skip_bias=True, verbose=False) ## Faithfulness analysis. Generates faithfulness plots, dependency plots and correlation scores. print("Performing faithfulness analysis...") correlations = np.array([]) altcorrelations = np.array([]) for i, (inputs, targets) in enumerate(tqdm(test_loader)): # get the inputs if args.demo: if args.nconcepts == 5: if i != 0: continue else: args.demo = 0 + 1 elif args.nconcepts == 22: if i != 0: continue else: args.demo = 1 + 1 if args.cuda: inputs, targets = inputs.cuda(), targets.cuda() input_var = torch.autograd.Variable(inputs, volatile=True) target_var = torch.autograd.Variable(targets) if not args.noplot: save_path = results_path + '/faithfulness' + str(i) + '/' if not os.path.isdir(save_path): os.mkdir(save_path) else: save_path = None corrs, altcorrs = expl.compute_dataset_consistency( input_var, targets=target_var, inputs_are_concepts=False, save_path=save_path, demo_mode=args.demo) correlations = np.append(correlations, corrs) altcorrelations = np.append(altcorrelations, altcorrs) average_correlation = np.sum(correlations) / len(correlations) std_correlation = np.std(correlations) average_alt_correlation = np.sum(altcorrelations) / len(altcorrelations) std_alt_correlation = np.std(altcorrelations) print("Average correlation:", average_correlation) print("Standard deviation of correlations: ", std_correlation) print("Average alternative correlation:", average_alt_correlation) print("Standard deviation of alternative correlations: ", std_alt_correlation) print("Generating Faithfulness correlation box plot..") box_plot_values = [correlations, altcorrelations] box = plt.boxplot(box_plot_values, patch_artist=True, labels=['theta(x)', 'theta(x) h(x)']) colors = ['blue', 'purple'] for patch, color in zip(box['boxes'], colors): patch.set_facecolor(color) plt.savefig(results_path + '/faithfulness_box_plot.png', format="png", dpi=300, verbose=True) if not args.noplot and args.nconcepts == 5: print("Generating theta, h and theta*h distribution histograms...") # Make histograms plot_distribution_h(test_loader, expl, 'thetaxhx', fig=0, results_path=results_path) plot_distribution_h(test_loader, expl, 'thetax', fig=1, results_path=results_path) plot_distribution_h(test_loader, expl, 'hx', fig=2, results_path=results_path) print("Generating stability metrics...") # Compute stabilites noises = np.arange(0, 0.21, 0.02) dist_dict, dist_dict_2 = {}, {} for noise in noises: distances = eval_stability_2(test_loader, expl, noise, False) distances_2 = eval_stability_2(test_loader, expl, noise, True) dist_dict[noise] = distances dist_dict_2[noise] = distances_2 print("Generating stability plot...") # Plot stability distances, distances_2, noises = dist_dict, dist_dict_2, noises means = [np.mean(distances[noise]) for noise in noises] stds = [np.std(distances[noise]) for noise in noises] means_min = [means[i] - stds[i] for i in range(len(means))] means_max = [means[i] + stds[i] for i in range(len(means))] means_2 = [np.mean(distances_2[noise]) for noise in noises] stds_2 = [np.std(distances_2[noise]) for noise in noises] means_min_2 = [means_2[i] - stds_2[i] for i in range(len(means_2))] means_max_2 = [means_2[i] + stds_2[i] for i in range(len(means_2))] fig, ax = plt.subplots(1) ax.plot(noises, means, lw=2, label='theta(x)', color='blue') ax.plot(noises, means_2, lw=2, label='theta(x)^T h(x)', color='purple') ax.fill_between(noises, means_max, means_min, facecolor='blue', alpha=0.3) ax.fill_between(noises, means_max_2, means_min_2, facecolor='purple', alpha=0.3) ax.set_title('Stability') ax.legend(loc='upper left') ax.set_xlabel('Added noise') ax.set_ylabel('Norm of relevance coefficients') ax.grid() fig.savefig(results_path + '/stability' + '.png', format="png", dpi=300) if (not args.demo) or (args.demo and args.nconcepts == 5): concept_grid(model, test_loader, cuda=args.cuda, top_k=10, save_path=results_path + '/concept_grid.png') print("Finished")
def main(): args = parse_args() args.nclasses = 1 args.theta_dim = args.nclasses args.print_freq = 100 args.epochs = 10 train_loader, valid_loader, test_loader, train, valid, test, data, feat_names = load_compas_data( ) layer_sizes = (10, 10, 5) input_dim = 11 if args.h_type == 'input': conceptizer = input_conceptizer() args.nconcepts = 11 + int(not args.nobias) elif args.h_type == 'fcc': args.nconcepts += int(not args.nobias) conceptizer = image_fcc_conceptizer( 11, args.nconcepts, args.concept_dim) #, sparsity = sparsity_l) else: raise ValueError('Unrecognized h_type') model_path, log_path, results_path = generate_dir_names('compas', args) parametrizer = dfc_parametrizer(input_dim, *layer_sizes, args.nconcepts, args.theta_dim) aggregator = additive_scalar_aggregator(args.concept_dim, args.nclasses) model = GSENN(conceptizer, parametrizer, aggregator) #, learn_h = args.train_h) if args.theta_reg_type == 'unreg': trainer = VanillaClassTrainer(model, args) elif args.theta_reg_type == 'grad1': trainer = GradPenaltyTrainer(model, args, typ=1) elif args.theta_reg_type == 'grad2': trainer = GradPenaltyTrainer(model, args, typ=2) elif args.theta_reg_type == 'grad3': trainer = GradPenaltyTrainer(model, args, typ=3) elif args.theta_reg_type == 'crosslip': trainer = CLPenaltyTrainer(model, args) else: raise ValueError('Unrecognized theta_reg_type') trainer.train(train_loader, valid_loader, epochs=args.epochs, save_path=model_path) trainer.plot_losses(save_path=results_path) # Load Best One checkpoint = torch.load(os.path.join(model_path, 'model_best.pth.tar')) model = checkpoint['model'] results = {} train_acc = trainer.validate(train_loader, fold='train') valid_acc = trainer.validate(valid_loader, fold='valid') test_acc = trainer.validate(test_loader, fold='test') results['train_accuracy'] = train_acc results['valid_accuracy'] = valid_acc results['test_accuracy'] = test_acc print('Train accuracy: {:8.2f}'.format(train_acc)) print('Valid accuracy: {:8.2f}'.format(valid_acc)) print('Test accuracy: {:8.2f}'.format(test_acc)) #noise_stability_plots(model, test_tds, cuda = args.cuda, save_path = results_path) lips, argmaxes = sample_local_lipschitz(model, test, mode=2, top_k=10, max_distance=3) max_lip = lips.max() imax = np.unravel_index(np.argmax(lips), lips.shape)[0] jmax = argmaxes[imax][0][0] print('Max Lip value: {}, attained for pair ({},{})'.format( max_lip, imax, jmax)) x = test.tensors[0][imax] argmax = test.tensors[0][jmax] pred_x = model(Variable(x.view(1, -1), volatile=True)).data att_x = model.thetas.data.squeeze().numpy().squeeze() pred_argmax = model(Variable(argmax.view(1, -1), volatile=True)).data att_argmax = model.thetas.data.squeeze().numpy().squeeze() # pdb.set_trace() results['x_max'] = x results['x_argmax'] = argmax results['test_discrete_glip'] = lips results['test_discrete_glip_argmaxes'] = argmaxes print('Local g-Lipschitz estimate: {:8.2f}'.format(lips.mean())) fpath = os.path.join(results_path, 'discrete_lip_gsenn') ppath = os.path.join(results_path, 'relevance_argmax_gsenn') pickle.dump(results, open(fpath + '.pkl', "wb")) # FOrmerly model_metrics print(ppath) lipschitz_feature_argmax_plot(x, argmax, att_x, att_argmax, feat_names=feat_names, digits=2, figsize=(5, 6), widths=(2, 3), save_path=ppath + '.pdf')