def valid_epoch(self, valid_loader): self.eval() losses = {} record_embeddings = False if self.logger is not None: if self.n_epochs % self.config.logging.record_embeddings_every == 0: record_embeddings = True embedding_samples = [] embedding_metadata = [] embedding_images = [] with torch.no_grad(): for data in valid_loader: x = data['obs'].to(self.config.device) x_aug = valid_loader.dataset.get_augmented_batch( data['index'], augment=True).to(self.config.device) # forward model_outputs = self.forward(x) model_outputs_aug = self.forward(x_aug) model_outputs.update( {k + '_aug': v for k, v in model_outputs_aug.items()}) loss_inputs = { key: model_outputs[key] for key in self.loss_f.input_keys_list } batch_losses = self.loss_f(loss_inputs) # save losses for k, v in batch_losses.items(): if k not in losses: losses[k] = [v.data.item()] else: losses[k].append(v.data.item()) # record embeddings if record_embeddings: embedding_samples.append(model_outputs["z"]) embedding_metadata.append(data['label']) embedding_images.append(x) for k, v in losses.items(): losses[k] = torch.mean(torch.tensor(v)).item() if record_embeddings: embedding_samples = torch.cat(embedding_samples) embedding_metadata = torch.cat(embedding_metadata) embedding_images = torch.cat(embedding_images) if len(embedding_images.shape) == 5: embedding_images = embedding_images[:, :, self.config.network. parameters.input_size[0] // 2, :, :] # we take slice at middle depth only if (embedding_images.shape[1] != 1) or (embedding_images.shape[1] != 3): embedding_images = embedding_images[:, :3, ...] embedding_images = resize_embeddings(embedding_images) self.logger.add_embedding(embedding_samples, metadata=embedding_metadata, label_img=embedding_images, global_step=self.n_epochs) return losses
def valid_epoch(self, valid_loader): self.eval() losses = {} # Prepare logging record_valid_images = False record_embeddings = False if self.logger is not None: if self.n_epochs % self.config.logging.record_valid_images_every == 0: record_valid_images = True images = [] recon_images = [] if self.n_epochs % self.config.logging.record_embeddings_every == 0: record_embeddings = True embeddings = [] labels = [] if images is None: images = [] with torch.no_grad(): for data in valid_loader: x = data['obs'] x = x.to(self.config.device) # forward model_outputs = self.forward(x) loss_inputs = { key: model_outputs[key] for key in self.loss_f.input_keys_list } batch_losses = self.loss_f(loss_inputs) # save losses for k, v in batch_losses.items(): if k not in losses: losses[k] = v.detach().cpu().unsqueeze(-1) else: losses[k] = torch.vstack( [losses[k], v.detach().cpu().unsqueeze(-1)]) if record_valid_images: recon_x = model_outputs["recon_x"] if len(images) < self.config.logging.record_memory_max: images.append(x.cpu().detach()) if len(recon_images ) < self.config.logging.record_memory_max: recon_images.append(recon_x) if record_embeddings: if len(embeddings) < self.config.logging.record_memory_max: embeddings.append( model_outputs["z"].cpu().detach().view( x.shape[0], self.config.network.parameters.n_latents)) labels.append(data["label"]) if not record_valid_images: if len(images) < self.config.logging.record_memory_max: images.append(x.cpu().detach()) if record_valid_images: recon_images = torch.cat(recon_images) images = torch.cat(images) if record_embeddings: embeddings = torch.cat(embeddings) labels = torch.cat(labels) if not record_valid_images: images = torch.cat(images) # log results if record_valid_images: n_images = min(len(images), 40) sampled_ids = torch.randperm(len(images))[:n_images] input_images = images[sampled_ids].detach().cpu() output_images = recon_images[sampled_ids].detach().cpu() if self.config.loss.parameters.reconstruction_dist == "bernoulli": output_images = torch.sigmoid(output_images) vizu_tensor_list = [None] * (2 * n_images) vizu_tensor_list[0::2] = [input_images[n] for n in range(n_images)] vizu_tensor_list[1::2] = [ output_images[n] for n in range(n_images) ] logger_add_image_list( self.logger, vizu_tensor_list, "reconstructions", global_step=self.n_epochs, n_channels=self.config.network.parameters.n_channels, spatial_dims=len(self.config.network.parameters.input_size)) if record_embeddings: if len(images.shape) == 5: images = images[:, :, self.config.network.parameters.input_size[0] // 2, :, :] # we take slice at middle depth only if (images.shape[1] != 1) or (images.shape[1] != 3): images = images[:, :3, ...] images = resize_embeddings(images) self.logger.add_embedding(embeddings, metadata=labels, label_img=images, global_step=self.n_epochs) # average loss and return for k, v in losses.items(): losses[k] = torch.mean(torch.tensor(v)).item() return losses
def valid_epoch(self, valid_loader): self.eval() losses = {} # Prepare logging record_valid_images = False record_embeddings = False if self.logger is not None: if self.n_epochs % self.config.logging.record_valid_images_every == 0 and hasattr(self.network, "decoder"): record_valid_images = True images = [] recon_images = [] if self.n_epochs % self.config.logging.record_embeddings_every == 0: record_embeddings = True embeddings = [] labels = [] if not record_valid_images: images = [] with torch.no_grad(): for data in valid_loader: data_ref, data_a, data_b, data_c = data x_ref = data_ref["obs"].to(self.config.device) x_a = data_a["obs"].to(self.config.device) x_b = data_b["obs"].to(self.config.device) x_c = data_c["obs"].to(self.config.device) # forward model_outputs = self.forward(x_ref, x_a, x_b, x_c) loss_inputs = {key: model_outputs[key] for key in self.loss_f.input_keys_list} batch_losses = self.loss_f(loss_inputs, reduction="none") # save losses for k, v in batch_losses.items(): if k not in losses: losses[k] = np.expand_dims(v.detach().cpu().numpy(), axis=-1) else: losses[k] = np.vstack([losses[k], np.expand_dims(v.detach().cpu().numpy(), axis=-1)]) if record_valid_images: recon_x_ref = model_outputs["x_ref_outputs"]["recon_x"] recon_x_a = model_outputs["x_a_outputs"]["recon_x"] recon_x_b = model_outputs["x_b_outputs"]["recon_x"] recon_x_c = model_outputs["x_c_outputs"]["recon_x"] if len(images) < self.config.logging.record_memory_max: images += [x_ref, x_a, x_b, x_c] if len(recon_images) < self.config.logging.record_memory_max: recon_images += [recon_x_ref, recon_x_a, recon_x_b, recon_x_c] if record_embeddings: if len(embeddings) < self.config.logging.record_memory_max: embeddings += [model_outputs["x_ref_outputs"]["z"], model_outputs["x_a_outputs"]["z"], model_outputs["x_b_outputs"]["z"], model_outputs["x_c_outputs"]["z"]] labels += [data_ref["label"], data_a["label"], data_b["label"], data_c["label"]] if not record_valid_images: if len(images) < self.config.logging.record_memory_max: images += [x_ref, x_a, x_b, x_c] if record_valid_images: recon_images = torch.cat(recon_images) images = torch.cat(images) if record_embeddings: embeddings = torch.cat(embeddings) labels = torch.cat(labels) if not record_valid_images: images = torch.cat(images) # log results if record_valid_images: n_images = min(len(images), 40) sampled_ids = np.random.choice(len(images), n_images, replace=False) input_images = images[sampled_ids].detach().cpu() output_images = recon_images[sampled_ids].detach().cpu() if self.config.loss.parameters.reconstruction_dist == "bernoulli": output_images = torch.sigmoid(output_images) vizu_tensor_list = [None] * (2 * n_images) vizu_tensor_list[0::2] = [input_images[n] for n in range(n_images)] vizu_tensor_list[1::2] = [output_images[n] for n in range(n_images)] if self.config.network.parameters.n_channels == 1 or self.config.network.parameters.n_channels == 3: # grey scale or RGB if len(input_images.shape) == 4: img = make_grid(vizu_tensor_list, nrow=2, padding=0) self.logger.add_image("reconstructions", img, self.n_epochs) elif len(input_images.shape) == 5: self.logger.add_video("original", torch.stack(vizu_tensor_list[0::2]).transpose(1, 2), self.n_epochs) self.logger.add_video("reconstructions", torch.stack(vizu_tensor_list[1::2]).transpose(1, 2), self.n_epochs) else: for channel in range(self.config.network.parameters.n_channels): if len(input_images.shape) == 4: img = make_grid(torch.stack(vizu_tensor_list)[:, channel, :, :].unsqueeze(1), nrow=2, padding=0) self.logger.add_image(f"reconstructions_channel_{channel}", img, self.n_epochs) elif len(input_images.shape) == 5: self.logger.add_video(f"original_channel_{channel}", torch.stack(vizu_tensor_list[0::2])[:, channel, :, :].unsqueeze( 1).transpose(1, 2), self.n_epochs) self.logger.add_video(f"reconstructions_channel_{channel}", torch.stack(vizu_tensor_list[1::2])[:, channel, :, :].unsqueeze( 1).transpose(1, 2), self.n_epochs) if record_embeddings: if len(images.shape) == 5: images = images[:, :, self.config.network.parameters.input_size[0] // 2, :, :] # we take slice at middle depth only if (images.shape[1] != 1) or (images.shape[1] != 3): images = images[:, :3, ...] images = resize_embeddings(images) self.logger.add_embedding( embeddings, metadata=labels, label_img=images, global_step=self.n_epochs) # average loss and return for k, v in losses.items(): losses[k] = np.mean(v) return losses
def valid_epoch(self, valid_loader): self.eval() losses = {} # Prepare logging record_valid_images = False record_embeddings = False if self.logger is not None: if self.n_epochs % self.config.logging.record_valid_images_every == 0: record_valid_images = True images = [] recon_images = [] if self.n_epochs % self.config.logging.record_embeddings_every == 0: record_embeddings = True embeddings = [] labels = [] if images is None: images = [] with torch.no_grad(): for data in valid_loader: data_pos_a, data_pos_b, data_neg_a, data_neg_b = data x_pos_a = data_pos_a["obs"].to(self.config.device) x_pos_b = data_pos_b["obs"].to(self.config.device) x_neg_a = data_neg_a["obs"].to(self.config.device) x_neg_b = data_neg_b["obs"].to(self.config.device) # forward model_outputs = self.forward(x_pos_a, x_pos_b, x_neg_a, x_neg_b) loss_inputs = {key: model_outputs[key] for key in self.loss_f.input_keys_list} batch_losses = self.loss_f(loss_inputs, reduction="none") # save losses for k, v in batch_losses.items(): if k not in losses: losses[k] = np.expand_dims(v.detach().cpu().numpy(), axis=-1) else: losses[k] = np.vstack([losses[k], np.expand_dims(v.detach().cpu().numpy(), axis=-1)]) if record_valid_images: recon_x_pos_a = self.forward(x_pos_a)["recon_x"] recon_x_pos_b = self.forward(x_pos_b)["recon_x"] recon_x_neg_a = self.forward(x_neg_a)["recon_x"] recon_x_neg_b = self.forward(x_neg_b)["recon_x"] if len(images) < self.config.logging.record_memory_max: images += [x_pos_a, x_pos_b, x_neg_a, x_neg_b] if len(recon_images) < self.config.logging.record_memory_max: recon_images += [recon_x_pos_a, recon_x_pos_b, recon_x_neg_a, recon_x_neg_b] if record_embeddings: if len(embeddings) < self.config.logging.record_memory_max: embeddings += [model_outputs["z_pos_a"], model_outputs["z_pos_b"], model_outputs["z_neg_a"], model_outputs["z_neg_b"]] labels += [data_pos_a["label"], data_pos_b["label"], data_neg_a["label"], data_neg_b["label"]] if not record_valid_images: if len(images) < self.config.logging.record_memory_max: images += [x_pos_a, x_pos_b, x_neg_a, x_neg_b] if record_valid_images: recon_images = torch.cat(recon_images) images = torch.cat(images) if record_embeddings: embeddings = torch.cat(embeddings) labels = torch.cat(labels) if not record_valid_images: images = torch.cat(images) # log results if record_valid_images: n_images = min(len(images), 40) sampled_ids = np.random.choice(len(images), n_images, replace=False) input_images = images[sampled_ids].detach().cpu() output_images = recon_images[sampled_ids].detach().cpu() if self.config.loss.parameters.reconstruction_dist == "bernoulli": output_images = torch.sigmoid(output_images) vizu_tensor_list = [None] * (2 * n_images) vizu_tensor_list[0::2] = [input_images[n] for n in range(n_images)] vizu_tensor_list[1::2] = [output_images[n] for n in range(n_images)] logger_add_image_list(self.logger, vizu_tensor_list, "reconstructions", global_step=self.n_epochs, n_channels=self.config.network.parameters.n_channels, spatial_dims=len(self.config.network.parameters.input_size)) if record_embeddings: if len(images.shape) == 5: images = images[:, :, self.config.network.parameters.input_size[0] // 2, :, :] # we take slice at middle depth only if (images.shape[1] != 1) or (images.shape[1] != 3): images = images[:, :3, ...] images = resize_embeddings(images) self.logger.add_embedding( embeddings, metadata=labels, label_img=images, global_step=self.n_epochs) # average loss and return for k, v in losses.items(): losses[k] = np.mean(v) return losses