def _disentanglement_metric(self, dataset, method_names, sample_size, n_epochs=6000, dataset_size=1000, hidden_dim=256, use_non_linear=False): #train models for all concerned methods and stor them in a dict methods = {} runtimes = {} for method_name in tqdm( method_names, desc= "Iterating over methods for the Higgins disentanglement metric" ): if method_name == "VAE": methods["VAE"] = self.model elif method_name == "PCA": start = time.time() print("Training PCA...") pca = decomposition.PCA(n_components=self.model.latent_dim, whiten=True, random_state=self.seed) if dataset.imgs.ndim == 4: data_imgs = dataset.imgs[:, :, :, :] print(f"Shape of data images: {data_imgs.shape}") imgs_pca = np.reshape( data_imgs, (data_imgs.shape[0], data_imgs.shape[3] * data_imgs.shape[1]**2)) else: data_imgs = dataset.imgs imgs_pca = np.reshape( dataset.imgs, (data_imgs.shape[0], data_imgs.shape[1]**2)) size = min( 3500 if (len(data_imgs.shape) > 3 and data_imgs.shape[3]) > 1 else 25000, len(imgs_pca)) idx = np.random.randint(len(imgs_pca), size=size) imgs_pca = imgs_pca[ idx, :] #not enough memory for full dataset -> repeat with random subsets pca.fit(imgs_pca) methods["PCA"] = pca self.logger.info("Done") runtimes[method_name] = time.time() - start elif method_name == "ICA": start = time.time() print("Training ICA...") ica = decomposition.FastICA(n_components=self.model.latent_dim, max_iter=400, random_state=self.seed) if dataset.imgs.ndim == 4: data_imgs = dataset.imgs[:, :, :, :] print(f"Shape of data images: {data_imgs.shape}") imgs_ica = np.reshape( data_imgs, (data_imgs.shape[0], data_imgs.shape[3] * data_imgs.shape[1]**2)) else: data_imgs = dataset.imgs imgs_ica = np.reshape( dataset.imgs, (data_imgs.shape[0], data_imgs.shape[1]**2)) size = min( 1000 if (len(data_imgs.shape) > 3 and data_imgs.shape[3]) > 1 else 2500, len(imgs_ica)) idx = np.random.randint(len(imgs_ica), size=size) imgs_ica = imgs_ica[ idx, :] #not enough memory for full dataset -> repeat with random subsets ica.fit(imgs_ica) methods["ICA"] = ica self.logger.info("Done") runtimes[method_name] = time.time() - start else: raise ValueError("Unknown method : {}".format(method_name)) if self.use_wandb: try: wandb.log(runtimes) except: pass data_train, data_test = {}, {} for method in methods: data_train[method] = [], [] data_test[method] = [], [] #latent dim = length of z_b_diff for arbitrary method = output dimension of linear classifier latent_dim = self.model.latent_dim #generate dataset_size many training data points and 20% of that test data points for i in tqdm(range(dataset_size), desc="Generating datasets for Higgins metric"): data = self._compute_z_b_diff_y(methods, sample_size, dataset) for method in methods: data_train[method][0].append(data[method][0]) data_train[method][1].append(data[method][1]) if i <= int(dataset_size * 0.5): data = self._compute_z_b_diff_y(methods, sample_size, dataset) for method in methods: data_test[method][0].append(data[method][0]) data_test[method][1].append(data[method][1]) test_acc = {"linear": {}} test_acc = {"logreg": {}, "linear": {}, "nonlinear": {}, "rf": {}} for model_class in ["linear", "nonlinear", "logreg", "rf"]: if model_class in ["linear", "nonlinear"]: model = Classifier(latent_dim, hidden_dim, len(dataset.lat_sizes), use_non_linear=True if model_class == "nonlinear" else False) model.to(self.device) model.train() #log softmax with NLL loss criterion = torch.nn.NLLLoss() optim = torch.optim.Adagrad( model.parameters(), lr=0.01 if model_class == "linear" else 0.001, weight_decay=0 if model_class == "linear" else 1e-4) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optim, 'min', patience=5000, min_lr=0.00001) for method in tqdm( methods.keys(), desc="Training classifiers for the Higgins metric"): if method == "ICA": optim = torch.optim.Adam( model.parameters(), lr=1 if model_class == "linear" else 0.001, weight_decay=0 if model_class == "linear" else 1e-4) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optim, 'min', patience=5000, min_lr=0.00001) X_train, Y_train = data_train[method] X_train, Y_train = torch.tensor( X_train, dtype=torch.float32), torch.tensor(Y_train, dtype=torch.long) X_train = X_train.to(self.device) Y_train = Y_train.to(self.device) X_test, Y_test = data_test[method] X_test, Y_test = torch.tensor( X_test, dtype=torch.float32), torch.tensor(Y_test, dtype=torch.long) X_test = X_test.to(self.device) Y_test = Y_test.to(self.device) print(f'Training the classifier for model {method}') for e in tqdm( range(n_epochs if model_class == "linear" else round(n_epochs / 2)), desc= "Iterating over epochs while training the Higgins classifier" ): model.train() optim.zero_grad() scores_train = model(X_train) loss = criterion(scores_train, Y_train) loss.backward() optim.step() scheduler.step(loss) if (e + 1) % 2000 == 0: model.eval() with torch.no_grad(): scores_test = model(X_test) test_loss = criterion(scores_test, Y_test) tqdm.write( f'In this epoch {e+1}/{n_epochs}, Training loss: {loss.item():.4f}, Test loss: {test_loss.item():.4f}' ) model.eval() scores_train = model(X_train) scores_test = model(X_test) _, prediction_train = scores_train.max(1) _, prediction_test = scores_test.max(1) train_acc = (prediction_train == Y_train ).sum().float() / len(X_train) test_acc[model_class][method] = ( prediction_test == Y_test).sum().float() / len(X_test) tqdm.write( f'Accuracy of {method} on training set: {train_acc.item():.4f}, test set: {test_acc[model_class][method].item():.4f}' ) model.train() model.eval() with torch.no_grad(): scores_train = model(X_train) scores_test = model(X_test) _, prediction_train = scores_train.max(1) _, prediction_test = scores_test.max(1) train_acc = (prediction_train == Y_train).sum().float() / len(X_train) test_acc[model_class][method] = ( prediction_test == Y_test).sum().float() / len(X_test) print( f'Accuracy of {method} on training set: {train_acc.item():.4f}, test set: {test_acc[model_class][method].item():.4f}' ) model.apply(weight_reset) elif model_class in ["logreg", "rf"]: for method in tqdm( methods.keys(), desc="Training classifiers for the Higgins metric"): if model_class == "logreg": classifier = linear_model.LogisticRegression( max_iter=500, random_state=self.seed) elif model_class == "rf": classifier = sklearn.ensemble.RandomForestClassifier( n_estimators=150) X_train, Y_train = data_train[method] X_test, Y_test = data_test[method] classifier.fit(X_train, Y_train) train_acc = np.mean(classifier.predict(X_train) == Y_train) test_acc[model_class][method] = np.mean( classifier.predict(X_test) == Y_test) print( f'Accuracy of {method} on training set: {train_acc:.4f}, test set: {test_acc[model_class][method].item():.4f}' ) return test_acc