gpus=1, auto_lr_find=True, deterministic=True, check_val_every_n_epoch=1, default_root_dir=base_dir + '_xc', weights_save_path=base_dir, callbacks=[checkpoint_callback_xc]) model_xc = Explainer(n_concepts=x.shape[1], n_classes=c.shape[1], l1=0, lr=0.01, explainer_hidden=[100, 50], temperature=5000, loss=torch.nn.BCEWithLogitsLoss()) trainer_xc.fit(model_xc, train_loader_xc, val_loader_xc) model_xc.freeze() c_train_pred = model_xc.model(x_train) c_val_pred = model_xc.model(x_val) c_test_pred = model_xc.model(x_test) # train C->Y train_data = TensorDataset(c_train_pred.squeeze(), y_train) val_data = TensorDataset(c_val_pred.squeeze(), y_val) test_data = TensorDataset(c_test_pred.squeeze(), y_test) train_loader = DataLoader(train_data, batch_size=train_size) val_loader = DataLoader(val_data, batch_size=val_size) test_loader = DataLoader(test_data, batch_size=test_size) checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1)
for seed in range(n_seeds): seed_everything(seed) print(f'Seed [{seed + 1}/{n_seeds}]') train_loader = DataLoader(train_data, batch_size=len(train_data)) val_loader = DataLoader(val_data, batch_size=len(val_data)) test_loader = DataLoader(test_data, batch_size=len(test_data)) checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', save_top_k=1) trainer = Trainer(max_epochs=500, gpus=1, auto_lr_find=True, deterministic=True, check_val_every_n_epoch=1, default_root_dir=base_dir, weights_save_path=base_dir, callbacks=[checkpoint_callback]) model = Explainer(n_concepts=n_concepts, n_classes=n_classes, l1=0, lr=0.01, explainer_hidden=[10]) trainer.fit(model, train_loader, val_loader) print(f"Concept mask: {model.model[0].concept_mask}") model.freeze() model_results = trainer.test(model, test_dataloaders=test_loader) for j in range(n_classes): n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5) print(f"Extracted concepts: {n_used_concepts}") results = {} results['model_accuracy'] = model_results[0]['test_acc'] results_list.append(results) results_df = pd.DataFrame(results_list) results_df.to_csv(os.path.join(base_dir, 'results_aware_cub.csv')) results_df = pd.DataFrame(results_list) results_df.to_csv(os.path.join(base_dir, 'results_aware_cub.csv')) results_df