def train_network(data_dict=None, n_epochs=500, batch_size=512, dropout_rate=0.2, preprocess=True, learning_rate=0.001, gpus=1, max_size=50000, early_stopping_limit=50, ): data_name = data_dict['name'] img_width = data_dict.get("width", None) img_height = data_dict.get("height", None) n_channels = data_dict.get("n_channels", None) attribute = data_dict.get('attribute', None) if data_name == "celeba": gender = data_dict.get('gender', None) data = trvae.prepare_and_load_celeba(file_path="../data/celeba/img_align_celeba.zip", attr_path="../data/celeba/list_attr_celeba.txt", landmark_path="../data/celeba/list_landmarks_align_celeba.txt", gender=gender, attribute=attribute, max_n_images=max_size, img_width=img_width, img_height=img_height, restore=True, save=True) if sparse.issparse(data.X): data.X = data.X.A data.obs.loc[(data.obs['labels'] == -1) & (data.obs['condition'] == -1), 'label'] = 0 data.obs.loc[(data.obs['labels'] == -1) & (data.obs['condition'] == 1), 'label'] = 1 data.obs.loc[(data.obs['labels'] == 1) & (data.obs['condition'] == -1), 'label'] = 2 data.obs.loc[(data.obs['labels'] == 1) & (data.obs['condition'] == 1), 'label'] = 3 if preprocess: data.X /= 255.0 else: data = sc.read(f"../data/{data_name}/{data_name}.h5ad") if preprocess: data.X /= 255.0 train_size = int(data.shape[0] * 0.85) indices = np.arange(data.shape[0]) np.random.shuffle(indices) train_idx = indices[:train_size] test_idx = indices[train_size:] train_data = data[train_idx, :] valid_data = data[test_idx, :] network = trvae.FaceNet(x_dimension=(img_width, img_height, n_channels), learning_rate=learning_rate, model_path=f"../models/", gpus=gpus, dropout_rate=dropout_rate) network.train(train_data, use_validation=True, valid_adata=valid_data, n_epochs=n_epochs, batch_size=batch_size, verbose=2, early_stop_limit=early_stopping_limit, shuffle=True, save=True) print("Model has been trained")
def data(): DATASETS = { "CelebA": { "name": 'celeba', "gender": "Male", 'attribute': "Smiling", 'source_key': -1, "target_key": 1, "width": 64, 'height': 64, "n_channels": 3 }, } data_key = "CelebA" data_dict = DATASETS[data_key] data_name = data_dict['name'] source_key = data_dict.get('source_key', None) target_key = data_dict.get('target_key', None) img_width = data_dict.get("width", None) img_height = data_dict.get("height", None) n_channels = data_dict.get("n_channels", None) attribute = data_dict.get('attribute', None) gender = data_dict.get('gender', None) data = trvae.prepare_and_load_celeba( file_path="./data/celeba/img_align_celeba.zip", attr_path="./data/celeba/list_attr_celeba.txt", landmark_path="./data/celeba/list_landmarks_align_celeba.txt", gender=gender, attribute=attribute, max_n_images=50000, img_width=img_width, img_height=img_height, restore=True, save=True) if sparse.issparse(data.X): data.X = data.X.A source_images = data.copy()[data.obs['condition'] == source_key].X target_images = data.copy()[data.obs['condition'] == target_key].X source_images = np.reshape(source_images, (-1, img_width, img_height, n_channels)) target_images = np.reshape(target_images, (-1, img_width, img_height, n_channels)) source_images /= 255.0 target_images /= 255.0 source_labels = np.zeros(shape=source_images.shape[0]) target_labels = np.ones(shape=target_images.shape[0]) train_labels = np.concatenate([source_labels, target_labels], axis=0) train_images = np.concatenate([source_images, target_images], axis=0) train_images = np.reshape(train_images, (-1, np.prod(source_images.shape[1:]))) preprocessed_data = anndata.AnnData(X=train_images) preprocessed_data.obs['condition'] = train_labels if data.obs.columns.__contains__('labels'): preprocessed_data.obs['labels'] = data.obs['condition'].values data = preprocessed_data.copy() train_size = int(data.shape[0] * 0.85) indices = np.arange(data.shape[0]) np.random.shuffle(indices) train_idx = indices[:train_size] test_idx = indices[train_size:] data_train = data[train_idx, :] data_valid = data[test_idx, :] print(data_train.shape, data_valid.shape) train_data = data_train.copy()[~( (data_train.obs['labels'] == -1) & (data_train.obs['condition'] == target_key))] valid_data = data_valid.copy()[~( (data_valid.obs['labels'] == -1) & (data_valid.obs['condition'] == target_key))] return train_data, valid_data, data_name
def visualize_trained_network_results(data_dict, z_dim=100, arch_style=1, preprocess=True, max_size=80000): plt.close("all") data_name = data_dict.get('name', None) source_key = data_dict.get('source_key', None) target_key = data_dict.get('target_key', None) img_width = data_dict.get('width', None) img_height = data_dict.get('height', None) n_channels = data_dict.get('n_channels', None) train_digits = data_dict.get('train_digits', None) test_digits = data_dict.get('test_digits', None) attribute = data_dict.get('attribute', None) path_to_save = f"../results/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/{source_key} to {target_key}/UMAPs/" os.makedirs(path_to_save, exist_ok=True) sc.settings.figdir = os.path.abspath(path_to_save) if data_name == "celeba": gender = data_dict.get('gender', None) data = trvae.prepare_and_load_celeba( file_path="../data/celeba/img_align_celeba.zip", attr_path="../data/celeba/list_attr_celeba.txt", landmark_path="../data/celeba/list_landmarks_align_celeba.txt", gender=gender, attribute=attribute, max_n_images=max_size, img_width=img_width, img_height=img_height, restore=True, save=False) if sparse.issparse(data.X): data.X = data.X.A train_images = data.X train_data = anndata.AnnData(X=data) train_data.obs['condition'] = data.obs['condition'].values train_data.obs.loc[train_data.obs['condition'] == 1, 'condition'] = f'with {attribute}' train_data.obs.loc[train_data.obs['condition'] == -1, 'condition'] = f'without {attribute}' train_data.obs['labels'] = data.obs['labels'].values train_data.obs.loc[train_data.obs['labels'] == 1, 'labels'] = f'Male' train_data.obs.loc[train_data.obs['labels'] == -1, 'labels'] = f'Female' if preprocess: train_images /= 255.0 else: train_data = sc.read(f"../data/{data_name}/{data_name}.h5ad") train_images = np.reshape(train_data.X, (-1, img_width, img_height, n_channels)) if preprocess: train_images /= 255.0 train_labels, _ = trvae.label_encoder(train_data) fake_labels = np.ones(train_labels.shape) network = trvae.DCtrVAE( x_dimension=(img_width, img_height, n_channels), z_dimension=z_dim, arch_style=arch_style, model_path= f"../models/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/", ) network.restore_model() train_data_feed = np.reshape(train_images, (-1, img_width, img_height, n_channels)) latent_with_true_labels = network.to_z_latent(train_data_feed, train_labels) latent_with_fake_labels = network.to_z_latent(train_data_feed, fake_labels) mmd_latent_with_true_labels = network.to_mmd_layer(network, train_data_feed, train_labels, feed_fake=False) mmd_latent_with_fake_labels = network.to_mmd_layer(network, train_data_feed, train_labels, feed_fake=True) latent_with_true_labels = sc.AnnData(X=latent_with_true_labels) latent_with_true_labels.obs['condition'] = pd.Categorical( train_data.obs['condition'].values) latent_with_fake_labels = sc.AnnData(X=latent_with_fake_labels) latent_with_fake_labels.obs['condition'] = pd.Categorical( train_data.obs['condition'].values) mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels) mmd_latent_with_true_labels.obs['condition'] = train_data.obs[ 'condition'].values mmd_latent_with_fake_labels = sc.AnnData(X=mmd_latent_with_fake_labels) mmd_latent_with_fake_labels.obs['condition'] = train_data.obs[ 'condition'].values if data_name.__contains__("mnist") or data_name == "celeba": latent_with_true_labels.obs['labels'] = pd.Categorical( train_data.obs['labels'].values) latent_with_fake_labels.obs['labels'] = pd.Categorical( train_data.obs['labels'].values) mmd_latent_with_true_labels.obs['labels'] = pd.Categorical( train_data.obs['labels'].values) mmd_latent_with_fake_labels.obs['labels'] = pd.Categorical( train_data.obs['labels'].values) color = ['condition', 'labels'] else: color = ['condition'] if train_digits is not None: train_data.obs.loc[(train_data.obs['condition'] == source_key) & (train_data.obs['labels'].isin(train_digits)), 'type'] = 'training' train_data.obs.loc[(train_data.obs['condition'] == source_key) & (train_data.obs['labels'].isin(test_digits)), 'type'] = 'training' train_data.obs.loc[(train_data.obs['condition'] == target_key) & (train_data.obs['labels'].isin(train_digits)), 'type'] = 'training' train_data.obs.loc[(train_data.obs['condition'] == target_key) & (train_data.obs['labels'].isin(test_digits)), 'type'] = 'heldout' sc.pp.neighbors(train_data) sc.tl.umap(train_data) sc.pl.umap(train_data, color=color, save=f'_{data_name}_train_data.png', show=False, wspace=0.5) if train_digits is not None: sc.tl.umap(train_data) sc.pl.umap(train_data, color=['type', 'labels'], save=f'_{data_name}_data_type.png', show=False) sc.pp.neighbors(latent_with_true_labels) sc.tl.umap(latent_with_true_labels) sc.pl.umap(latent_with_true_labels, color=color, save=f"_{data_name}_latent_with_true_labels.png", wspace=0.5, show=False) sc.pp.neighbors(latent_with_fake_labels) sc.tl.umap(latent_with_fake_labels) sc.pl.umap(latent_with_fake_labels, color=color, save=f"_{data_name}_latent_with_fake_labels.png", wspace=0.5, show=False) sc.pp.neighbors(mmd_latent_with_true_labels) sc.tl.umap(mmd_latent_with_true_labels) sc.pl.umap(mmd_latent_with_true_labels, color=color, save=f"_{data_name}_mmd_latent_with_true_labels.png", wspace=0.5, show=False) sc.pp.neighbors(mmd_latent_with_fake_labels) sc.tl.umap(mmd_latent_with_fake_labels) sc.pl.umap(mmd_latent_with_fake_labels, color=color, save=f"_{data_name}_mmd_latent_with_fake_labels.png", wspace=0.5, show=False) plt.close("all")
def train_network( data_dict=None, z_dim=100, mmd_dimension=256, alpha=0.001, beta=100, gamma=1.0, kernel='multi-scale-rbf', n_epochs=500, batch_size=512, dropout_rate=0.2, arch_style=1, preprocess=True, learning_rate=0.001, gpus=1, max_size=50000, early_stopping_limit=50, ): data_name = data_dict['name'] source_key = data_dict.get('source_key', None) target_key = data_dict.get('target_key', None) img_width = data_dict.get("width", None) img_height = data_dict.get("height", None) n_channels = data_dict.get("n_channels", None) train_digits = data_dict.get("train_digits", None) test_digits = data_dict.get("test_digits", None) attribute = data_dict.get('attribute', None) if data_name == "celeba": gender = data_dict.get('gender', None) data = trvae.prepare_and_load_celeba( file_path="../data/celeba/img_align_celeba.zip", attr_path="../data/celeba/list_attr_celeba.txt", landmark_path="../data/celeba/list_landmarks_align_celeba.txt", gender=gender, attribute=attribute, max_n_images=max_size, img_width=img_width, img_height=img_height, restore=True, save=True) if sparse.issparse(data.X): data.X = data.X.A source_images = data.copy()[data.obs['condition'] == source_key].X target_images = data.copy()[data.obs['condition'] == target_key].X source_images = np.reshape(source_images, (-1, img_width, img_height, n_channels)) target_images = np.reshape(target_images, (-1, img_width, img_height, n_channels)) if preprocess: source_images /= 255.0 target_images /= 255.0 else: data = sc.read(f"../data/{data_name}/{data_name}.h5ad") source_images = data.copy()[data.obs["condition"] == source_key].X target_images = data.copy()[data.obs["condition"] == target_key].X source_images = np.reshape(source_images, (-1, img_width, img_height, n_channels)) target_images = np.reshape(target_images, (-1, img_width, img_height, n_channels)) if preprocess: source_images /= 255.0 target_images /= 255.0 source_labels = np.zeros(shape=source_images.shape[0]) target_labels = np.ones(shape=target_images.shape[0]) train_labels = np.concatenate([source_labels, target_labels], axis=0) train_images = np.concatenate([source_images, target_images], axis=0) train_images = np.reshape(train_images, (-1, np.prod(source_images.shape[1:]))) if data_name.__contains__('mnist'): preprocessed_data = anndata.AnnData(X=train_images) preprocessed_data.obs["condition"] = train_labels preprocessed_data.obs['labels'] = data.obs['labels'].values data = preprocessed_data.copy() else: preprocessed_data = anndata.AnnData(X=train_images) preprocessed_data.obs['condition'] = train_labels if data.obs.columns.__contains__('labels'): preprocessed_data.obs['labels'] = data.obs['condition'].values data = preprocessed_data.copy() train_size = int(data.shape[0] * 0.85) indices = np.arange(data.shape[0]) np.random.shuffle(indices) train_idx = indices[:train_size] test_idx = indices[train_size:] data_train = data[train_idx, :] data_valid = data[test_idx, :] print(data_train.shape, data_valid.shape) if train_digits is not None: train_data = data_train.copy()[~( (data_train.obs['labels'].isin(test_digits)) & (data_train.obs['condition'] == 1))] valid_data = data_valid.copy()[~( (data_valid.obs['labels'].isin(test_digits)) & (data_valid.obs['condition'] == 1))] elif data_name == "celeba": train_data = data_train.copy()[~( (data_train.obs['labels'] == -1) & (data_train.obs['condition'] == target_key))] valid_data = data_valid.copy()[~( (data_valid.obs['labels'] == -1) & (data_valid.obs['condition'] == target_key))] else: train_data = data_train.copy() valid_data = data_valid.copy() network = trvae.archs.DCtrVAE( x_dimension=source_images.shape[1:], z_dimension=z_dim, mmd_dimension=mmd_dimension, alpha=alpha, beta=beta, gamma=gamma, kernel=kernel, arch_style=arch_style, train_with_fake_labels=False, learning_rate=learning_rate, model_path= f"../models/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/", gpus=gpus, dropout_rate=dropout_rate) print(train_data.shape, valid_data.shape) network.train(train_data, use_validation=True, valid_adata=valid_data, n_epochs=n_epochs, batch_size=batch_size, verbose=2, early_stop_limit=early_stopping_limit, shuffle=True, save=True) print("Model has been trained")
def evaluate_network(data_dict=None, z_dim=100, n_files=5, k=5, arch_style=1, preprocess=True, max_size=80000): data_name = data_dict['name'] source_key = data_dict.get('source_key', None) target_key = data_dict.get('target_key', None) img_width = data_dict.get("width", None) img_height = data_dict.get("height", None) n_channels = data_dict.get('n_channels', None) train_digits = data_dict.get('train_digits', None) test_digits = data_dict.get('test_digits', None) attribute = data_dict.get('attribute', None) if data_name == "celeba": gender = data_dict.get('gender', None) data = trvae.prepare_and_load_celeba( file_path="../data/celeba/img_align_celeba.zip", attr_path="../data/celeba/list_attr_celeba.txt", landmark_path="../data/celeba/list_landmarks_align_celeba.txt", gender=gender, attribute=attribute, max_n_images=max_size, img_width=img_width, img_height=img_height, restore=True, save=False) valid_data = data.copy()[data.obs['labels'] == -1] # get females (Male = -1) train_data = data.copy()[data.obs['labels'] == +1] # get males (Male = 1) if sparse.issparse(valid_data.X): valid_data.X = valid_data.X.A source_images_train = train_data[train_data.obs["condition"] == source_key].X source_images_valid = valid_data[valid_data.obs["condition"] == source_key].X source_images_train = np.reshape( source_images_train, (-1, img_width, img_height, n_channels)) source_images_valid = np.reshape( source_images_valid, (-1, img_width, img_height, n_channels)) if preprocess: source_images_train /= 255.0 source_images_valid /= 255.0 else: data = sc.read(f"../data/{data_name}/{data_name}.h5ad") if train_digits is not None: train_data = data[data.obs['labels'].isin(train_digits)] valid_data = data[data.obs['labels'].isin(test_digits)] else: train_data = data.copy() valid_data = data.copy() source_images_train = train_data[train_data.obs["condition"] == source_key].X target_images_train = train_data[train_data.obs["condition"] == target_key].X source_images_train = np.reshape( source_images_train, (-1, img_width, img_height, n_channels)) target_images_train = np.reshape( target_images_train, (-1, img_width, img_height, n_channels)) source_images_valid = valid_data[valid_data.obs["condition"] == source_key].X target_images_valid = valid_data[valid_data.obs["condition"] == target_key].X source_images_valid = np.reshape( source_images_valid, (-1, img_width, img_height, n_channels)) target_images_valid = np.reshape( target_images_valid, (-1, img_width, img_height, n_channels)) if preprocess: source_images_train /= 255.0 source_images_valid /= 255.0 target_images_train /= 255.0 target_images_valid /= 255.0 image_shape = (img_width, img_height, n_channels) source_images_train = np.reshape(source_images_train, (-1, np.prod(image_shape))) source_images_valid = np.reshape(source_images_valid, (-1, np.prod(image_shape))) source_data_train = anndata.AnnData(X=source_images_train) source_data_valid = anndata.AnnData(X=source_images_valid) network = trvae.DCtrVAE( x_dimension=image_shape, z_dimension=z_dim, arch_style=arch_style, model_path= f"../models/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/" ) network.restore_model() results_path_train = f"../results/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/{source_key} to {target_key}/train/" results_path_valid = f"../results/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/{source_key} to {target_key}/valid/" os.makedirs(results_path_train, exist_ok=True) os.makedirs(results_path_valid, exist_ok=True) if sparse.issparse(valid_data.X): valid_data.X = valid_data.X.A if test_digits is not None: k = len(test_digits) for j in range(n_files): if test_digits is not None: source_sample_train = [] source_sample_valid = [] target_sample_train = [] target_sample_valid = [] for digit in test_digits: source_images_digit_valid = valid_data[ (valid_data.obs['labels'] == digit) & (valid_data.obs['condition'] == source_key)] target_images_digit_valid = valid_data[ (valid_data.obs['labels'] == digit) & (valid_data.obs['condition'] == target_key)] if j == 0: source_images_digit_valid.X /= 255.0 random_samples = np.random.choice( source_images_digit_valid.shape[0], 1, replace=False) source_sample_valid.append( source_images_digit_valid.X[random_samples]) target_sample_valid.append( target_images_digit_valid.X[random_samples]) for digit in train_digits: source_images_digit_train = train_data[ (train_data.obs['labels'] == digit) & (train_data.obs['condition'] == source_key)] target_images_digit_train = train_data[ (train_data.obs['labels'] == digit) & (train_data.obs['condition'] == target_key)] if j == 0: source_images_digit_train.X /= 255.0 random_samples = np.random.choice( source_images_digit_train.shape[0], 1, replace=False) source_sample_train.append( source_images_digit_train.X[random_samples]) target_sample_train.append( target_images_digit_train.X[random_samples]) else: random_samples_train = np.random.choice(source_data_train.shape[0], k, replace=False) random_samples_valid = np.random.choice(source_data_valid.shape[0], k, replace=False) source_sample_train = source_data_train.X[random_samples_train] source_sample_valid = source_data_valid.X[random_samples_valid] source_sample_train = np.array(source_sample_train) source_sample_valid = np.array(source_sample_valid) # if data_name.__contains__("mnist"): # target_sample = np.array(target_sample) # target_sample_reshaped = np.reshape(target_sample, (-1, *image_shape)) source_sample_train = np.reshape(source_sample_train, (-1, np.prod(image_shape))) source_sample_train_reshaped = np.reshape(source_sample_train, (-1, *image_shape)) if data_name.__contains__("mnist"): target_sample_train = np.reshape(target_sample_train, (-1, np.prod(image_shape))) target_sample_train_reshaped = np.reshape(target_sample_train, (-1, *image_shape)) target_sample_valid = np.reshape(target_sample_valid, (-1, np.prod(image_shape))) target_sample_valid_reshaped = np.reshape(target_sample_valid, (-1, *image_shape)) source_sample_valid = np.reshape(source_sample_valid, (-1, np.prod(image_shape))) source_sample_valid_reshaped = np.reshape(source_sample_valid, (-1, *image_shape)) source_sample_train = anndata.AnnData(X=source_sample_train) source_sample_valid = anndata.AnnData(X=source_sample_valid) pred_sample_train = network.predict(adata=source_sample_train, encoder_labels=np.zeros((k, 1)), decoder_labels=np.ones((k, 1))) pred_sample_train = np.reshape(pred_sample_train, newshape=(-1, *image_shape)) pred_sample_valid = network.predict(adata=source_sample_valid, encoder_labels=np.zeros((k, 1)), decoder_labels=np.ones((k, 1))) pred_sample_valid = np.reshape(pred_sample_valid, newshape=(-1, *image_shape)) print(source_sample_train.shape, source_sample_train_reshaped.shape, pred_sample_train.shape) plt.close("all") if train_digits is not None: k = len(train_digits) if data_name.__contains__("mnist"): fig, ax = plt.subplots(len(train_digits), 3, figsize=(k * 1, 6)) else: fig, ax = plt.subplots(k, 2, figsize=(k * 1, 6)) for i in range(k): ax[i, 0].axis('off') if source_sample_train_reshaped.shape[-1] > 1: ax[i, 0].imshow(source_sample_train_reshaped[i]) else: ax[i, 0].imshow(source_sample_train_reshaped[i, :, :, 0], cmap='Greys') ax[i, 1].axis('off') if data_name.__contains__("mnist"): ax[i, 2].axis('off') # if i == 0: # if data_name == "celeba": # ax[i, 0].set_title(f"without {data_dict['attribute']}") # ax[i, 1].set_title(f"with {data_dict['attribute']}") # elif data_name.__contains__("mnist"): # ax[i, 0].set_title(f"Source") # ax[i, 1].set_title(f"Target (Ground Truth)") # ax[i, 2].set_title(f"Target (Predicted)") # else: # ax[i, 0].set_title(f"{source_key}") # ax[i, 1].set_title(f"{target_key}") if pred_sample_train.shape[-1] > 1: ax[i, 1].imshow(pred_sample_train[i]) else: ax[i, 1].imshow(target_sample_train_reshaped[i, :, :, 0], cmap='Greys') ax[i, 2].imshow(pred_sample_train[i, :, :, 0], cmap='Greys') # if data_name.__contains__("mnist"): # ax[i, 2].imshow(target_sample_reshaped[i, :, :, 0], cmap='Greys') plt.savefig(os.path.join(results_path_train, f"sample_images_{j}.pdf")) print(source_sample_valid.shape, source_sample_valid_reshaped.shape, pred_sample_valid.shape) plt.close("all") if test_digits is not None: k = len(test_digits) if data_name.__contains__("mnist"): fig, ax = plt.subplots(k, 3, figsize=(k * 1, 6)) else: fig, ax = plt.subplots(k, 2, figsize=(k * 1, 6)) for i in range(k): ax[i, 0].axis('off') if source_sample_valid_reshaped.shape[-1] > 1: ax[i, 0].imshow(source_sample_valid_reshaped[i]) else: ax[i, 0].imshow(source_sample_valid_reshaped[i, :, :, 0], cmap='Greys') ax[i, 1].axis('off') if data_name.__contains__("mnist"): ax[i, 2].axis('off') # if i == 0: # if data_name == "celeba": # ax[i, 0].set_title(f"without {data_dict['attribute']}") # ax[i, 1].set_title(f"with {data_dict['attribute']}") # elif data_name.__contains__("mnist"): # ax[i, 0].set_title(f"Source") # ax[i, 1].set_title(f"Target (Ground Truth)") # ax[i, 2].set_title(f"Target (Predicted)") # else: # ax[i, 0].set_title(f"{source_key}") # ax[i, 1].set_title(f"{target_key}") if pred_sample_valid.shape[-1] > 1: ax[i, 1].imshow(pred_sample_valid[i]) else: ax[i, 1].imshow(target_sample_valid_reshaped[i, :, :, 0], cmap='Greys') ax[i, 2].imshow(pred_sample_valid[i, :, :, 0], cmap='Greys') # if data_name.__contains__("mnist"): # ax[i, 2].imshow(target_sample_reshaped[i, :, :, 0], cmap='Greys') plt.savefig( os.path.join(results_path_valid, f"./sample_images_{j}.pdf"))