Ejemplo n.º 1
0
    def valid_epoch(self, valid_loader):
        self.eval()
        losses = {}

        record_embeddings = False
        if self.logger is not None:
            if self.n_epochs % self.config.logging.record_embeddings_every == 0:
                record_embeddings = True
                embedding_samples = []
                embedding_metadata = []
                embedding_images = []

        with torch.no_grad():
            for data in valid_loader:
                x = data['obs'].to(self.config.device)
                x_aug = valid_loader.dataset.get_augmented_batch(
                    data['index'], augment=True).to(self.config.device)
                # forward
                model_outputs = self.forward(x)
                model_outputs_aug = self.forward(x_aug)
                model_outputs.update(
                    {k + '_aug': v
                     for k, v in model_outputs_aug.items()})
                loss_inputs = {
                    key: model_outputs[key]
                    for key in self.loss_f.input_keys_list
                }
                batch_losses = self.loss_f(loss_inputs)
                # save losses
                for k, v in batch_losses.items():
                    if k not in losses:
                        losses[k] = [v.data.item()]
                    else:
                        losses[k].append(v.data.item())
                # record embeddings
                if record_embeddings:
                    embedding_samples.append(model_outputs["z"])
                    embedding_metadata.append(data['label'])
                    embedding_images.append(x)

        for k, v in losses.items():
            losses[k] = torch.mean(torch.tensor(v)).item()

        if record_embeddings:
            embedding_samples = torch.cat(embedding_samples)
            embedding_metadata = torch.cat(embedding_metadata)
            embedding_images = torch.cat(embedding_images)
            if len(embedding_images.shape) == 5:
                embedding_images = embedding_images[:, :, self.config.network.
                                                    parameters.input_size[0] //
                                                    2, :, :]  # we take slice at middle depth only
            if (embedding_images.shape[1] != 1) or (embedding_images.shape[1]
                                                    != 3):
                embedding_images = embedding_images[:, :3, ...]
            embedding_images = resize_embeddings(embedding_images)
            self.logger.add_embedding(embedding_samples,
                                      metadata=embedding_metadata,
                                      label_img=embedding_images,
                                      global_step=self.n_epochs)

        return losses
Ejemplo n.º 2
0
    def valid_epoch(self, valid_loader):
        self.eval()
        losses = {}

        # Prepare logging
        record_valid_images = False
        record_embeddings = False
        if self.logger is not None:
            if self.n_epochs % self.config.logging.record_valid_images_every == 0:
                record_valid_images = True
                images = []
                recon_images = []
            if self.n_epochs % self.config.logging.record_embeddings_every == 0:
                record_embeddings = True
                embeddings = []
                labels = []
                if images is None:
                    images = []
        with torch.no_grad():
            for data in valid_loader:
                x = data['obs']
                x = x.to(self.config.device)
                # forward
                model_outputs = self.forward(x)
                loss_inputs = {
                    key: model_outputs[key]
                    for key in self.loss_f.input_keys_list
                }
                batch_losses = self.loss_f(loss_inputs)
                # save losses
                for k, v in batch_losses.items():
                    if k not in losses:
                        losses[k] = v.detach().cpu().unsqueeze(-1)
                    else:
                        losses[k] = torch.vstack(
                            [losses[k],
                             v.detach().cpu().unsqueeze(-1)])

                if record_valid_images:
                    recon_x = model_outputs["recon_x"]
                    if len(images) < self.config.logging.record_memory_max:
                        images.append(x.cpu().detach())
                    if len(recon_images
                           ) < self.config.logging.record_memory_max:
                        recon_images.append(recon_x)

                if record_embeddings:
                    if len(embeddings) < self.config.logging.record_memory_max:
                        embeddings.append(
                            model_outputs["z"].cpu().detach().view(
                                x.shape[0],
                                self.config.network.parameters.n_latents))
                        labels.append(data["label"])
                    if not record_valid_images:
                        if len(images) < self.config.logging.record_memory_max:
                            images.append(x.cpu().detach())

        if record_valid_images:
            recon_images = torch.cat(recon_images)
            images = torch.cat(images)
        if record_embeddings:
            embeddings = torch.cat(embeddings)
            labels = torch.cat(labels)
            if not record_valid_images:
                images = torch.cat(images)

        # log results
        if record_valid_images:
            n_images = min(len(images), 40)
            sampled_ids = torch.randperm(len(images))[:n_images]
            input_images = images[sampled_ids].detach().cpu()
            output_images = recon_images[sampled_ids].detach().cpu()
            if self.config.loss.parameters.reconstruction_dist == "bernoulli":
                output_images = torch.sigmoid(output_images)
            vizu_tensor_list = [None] * (2 * n_images)
            vizu_tensor_list[0::2] = [input_images[n] for n in range(n_images)]
            vizu_tensor_list[1::2] = [
                output_images[n] for n in range(n_images)
            ]
            logger_add_image_list(
                self.logger,
                vizu_tensor_list,
                "reconstructions",
                global_step=self.n_epochs,
                n_channels=self.config.network.parameters.n_channels,
                spatial_dims=len(self.config.network.parameters.input_size))

        if record_embeddings:
            if len(images.shape) == 5:
                images = images[:, :,
                                self.config.network.parameters.input_size[0] //
                                2, :, :]  # we take slice at middle depth only
            if (images.shape[1] != 1) or (images.shape[1] != 3):
                images = images[:, :3, ...]
            images = resize_embeddings(images)
            self.logger.add_embedding(embeddings,
                                      metadata=labels,
                                      label_img=images,
                                      global_step=self.n_epochs)

        # average loss and return
        for k, v in losses.items():
            losses[k] = torch.mean(torch.tensor(v)).item()

        return losses
Ejemplo n.º 3
0
    def valid_epoch(self, valid_loader):
        self.eval()
        losses = {}

        # Prepare logging
        record_valid_images = False
        record_embeddings = False
        if self.logger is not None:
            if self.n_epochs % self.config.logging.record_valid_images_every == 0 and hasattr(self.network, "decoder"):
                record_valid_images = True
                images = []
                recon_images = []
            if self.n_epochs % self.config.logging.record_embeddings_every == 0:
                record_embeddings = True
                embeddings = []
                labels = []
                if not record_valid_images:
                    images = []

        with torch.no_grad():
            for data in valid_loader:
                data_ref, data_a, data_b, data_c = data
                x_ref = data_ref["obs"].to(self.config.device)
                x_a = data_a["obs"].to(self.config.device)
                x_b = data_b["obs"].to(self.config.device)
                x_c = data_c["obs"].to(self.config.device)
                # forward
                model_outputs = self.forward(x_ref, x_a, x_b, x_c)
                loss_inputs = {key: model_outputs[key] for key in self.loss_f.input_keys_list}
                batch_losses = self.loss_f(loss_inputs, reduction="none")
                # save losses
                for k, v in batch_losses.items():
                    if k not in losses:
                        losses[k] = np.expand_dims(v.detach().cpu().numpy(), axis=-1)
                    else:
                        losses[k] = np.vstack([losses[k], np.expand_dims(v.detach().cpu().numpy(), axis=-1)])

                if record_valid_images:
                    recon_x_ref = model_outputs["x_ref_outputs"]["recon_x"]
                    recon_x_a = model_outputs["x_a_outputs"]["recon_x"]
                    recon_x_b = model_outputs["x_b_outputs"]["recon_x"]
                    recon_x_c = model_outputs["x_c_outputs"]["recon_x"]
                    if len(images) < self.config.logging.record_memory_max:
                        images += [x_ref, x_a, x_b, x_c]
                    if len(recon_images) < self.config.logging.record_memory_max:
                        recon_images += [recon_x_ref, recon_x_a, recon_x_b, recon_x_c]

                if record_embeddings:
                    if len(embeddings) < self.config.logging.record_memory_max:
                        embeddings += [model_outputs["x_ref_outputs"]["z"], model_outputs["x_a_outputs"]["z"],
                                   model_outputs["x_b_outputs"]["z"], model_outputs["x_c_outputs"]["z"]]
                        labels += [data_ref["label"], data_a["label"], data_b["label"], data_c["label"]]
                    if not record_valid_images:
                        if len(images) < self.config.logging.record_memory_max:
                            images += [x_ref, x_a, x_b, x_c]

        if record_valid_images:
            recon_images = torch.cat(recon_images)
            images = torch.cat(images)
        if record_embeddings:
            embeddings = torch.cat(embeddings)
            labels = torch.cat(labels)
            if not record_valid_images:
                images = torch.cat(images)

        # log results
        if record_valid_images:
            n_images = min(len(images), 40)
            sampled_ids = np.random.choice(len(images), n_images, replace=False)
            input_images = images[sampled_ids].detach().cpu()
            output_images = recon_images[sampled_ids].detach().cpu()
            if self.config.loss.parameters.reconstruction_dist == "bernoulli":
                output_images = torch.sigmoid(output_images)
            vizu_tensor_list = [None] * (2 * n_images)
            vizu_tensor_list[0::2] = [input_images[n] for n in range(n_images)]
            vizu_tensor_list[1::2] = [output_images[n] for n in range(n_images)]
            if self.config.network.parameters.n_channels == 1 or self.config.network.parameters.n_channels == 3:  # grey scale or RGB
                if len(input_images.shape) == 4:
                    img = make_grid(vizu_tensor_list, nrow=2, padding=0)
                    self.logger.add_image("reconstructions", img, self.n_epochs)
                elif len(input_images.shape) == 5:
                    self.logger.add_video("original", torch.stack(vizu_tensor_list[0::2]).transpose(1, 2),
                                          self.n_epochs)
                    self.logger.add_video("reconstructions", torch.stack(vizu_tensor_list[1::2]).transpose(1, 2),
                                          self.n_epochs)
            else:
                for channel in range(self.config.network.parameters.n_channels):
                    if len(input_images.shape) == 4:
                        img = make_grid(torch.stack(vizu_tensor_list)[:, channel, :, :].unsqueeze(1), nrow=2, padding=0)
                        self.logger.add_image(f"reconstructions_channel_{channel}", img, self.n_epochs)
                    elif len(input_images.shape) == 5:
                        self.logger.add_video(f"original_channel_{channel}",
                                              torch.stack(vizu_tensor_list[0::2])[:, channel, :, :].unsqueeze(
                                                  1).transpose(1, 2),
                                              self.n_epochs)
                        self.logger.add_video(f"reconstructions_channel_{channel}",
                                              torch.stack(vizu_tensor_list[1::2])[:, channel, :, :].unsqueeze(
                                                  1).transpose(1, 2),
                                              self.n_epochs)
        if record_embeddings:
            if len(images.shape) == 5:
                images = images[:, :, self.config.network.parameters.input_size[0] // 2, :,
                         :]  # we take slice at middle depth only
            if (images.shape[1] != 1) or (images.shape[1] != 3):
                images = images[:, :3, ...]
            images = resize_embeddings(images)
            self.logger.add_embedding(
                embeddings,
                metadata=labels,
                label_img=images,
                global_step=self.n_epochs)

        # average loss and return
        for k, v in losses.items():
            losses[k] = np.mean(v)

        return losses
Ejemplo n.º 4
0
    def valid_epoch(self, valid_loader):
        self.eval()
        losses = {}

        # Prepare logging
        record_valid_images = False
        record_embeddings = False
        if self.logger is not None:
            if self.n_epochs % self.config.logging.record_valid_images_every == 0:
                record_valid_images = True
                images = []
                recon_images = []
            if self.n_epochs % self.config.logging.record_embeddings_every == 0:
                record_embeddings = True
                embeddings = []
                labels = []
                if images is None:
                    images = []

        with torch.no_grad():
            for data in valid_loader:
                data_pos_a, data_pos_b, data_neg_a, data_neg_b = data
                x_pos_a = data_pos_a["obs"].to(self.config.device)
                x_pos_b = data_pos_b["obs"].to(self.config.device)
                x_neg_a = data_neg_a["obs"].to(self.config.device)
                x_neg_b = data_neg_b["obs"].to(self.config.device)
                # forward
                model_outputs = self.forward(x_pos_a, x_pos_b, x_neg_a, x_neg_b)
                loss_inputs = {key: model_outputs[key] for key in self.loss_f.input_keys_list}
                batch_losses = self.loss_f(loss_inputs, reduction="none")
                # save losses
                for k, v in batch_losses.items():
                    if k not in losses:
                        losses[k] = np.expand_dims(v.detach().cpu().numpy(), axis=-1)
                    else:
                        losses[k] = np.vstack([losses[k], np.expand_dims(v.detach().cpu().numpy(), axis=-1)])

                if record_valid_images:
                    recon_x_pos_a = self.forward(x_pos_a)["recon_x"]
                    recon_x_pos_b = self.forward(x_pos_b)["recon_x"]
                    recon_x_neg_a = self.forward(x_neg_a)["recon_x"]
                    recon_x_neg_b = self.forward(x_neg_b)["recon_x"]
                    if len(images) < self.config.logging.record_memory_max:
                        images += [x_pos_a, x_pos_b, x_neg_a, x_neg_b]
                    if len(recon_images) < self.config.logging.record_memory_max:
                        recon_images += [recon_x_pos_a, recon_x_pos_b, recon_x_neg_a, recon_x_neg_b]

                if record_embeddings:
                    if len(embeddings) < self.config.logging.record_memory_max:
                        embeddings += [model_outputs["z_pos_a"], model_outputs["z_pos_b"], model_outputs["z_neg_a"],
                                   model_outputs["z_neg_b"]]
                        labels += [data_pos_a["label"], data_pos_b["label"], data_neg_a["label"], data_neg_b["label"]]
                    if not record_valid_images:
                        if len(images) < self.config.logging.record_memory_max:
                            images += [x_pos_a, x_pos_b, x_neg_a, x_neg_b]

        if record_valid_images:
            recon_images = torch.cat(recon_images)
            images = torch.cat(images)
        if record_embeddings:
            embeddings = torch.cat(embeddings)
            labels = torch.cat(labels)
            if not record_valid_images:
                images = torch.cat(images)

        # log results
        if record_valid_images:
            n_images = min(len(images), 40)
            sampled_ids = np.random.choice(len(images), n_images, replace=False)
            input_images = images[sampled_ids].detach().cpu()
            output_images = recon_images[sampled_ids].detach().cpu()
            if self.config.loss.parameters.reconstruction_dist == "bernoulli":
                output_images = torch.sigmoid(output_images)
            vizu_tensor_list = [None] * (2 * n_images)
            vizu_tensor_list[0::2] = [input_images[n] for n in range(n_images)]
            vizu_tensor_list[1::2] = [output_images[n] for n in range(n_images)]
            logger_add_image_list(self.logger, vizu_tensor_list, "reconstructions",
                                  global_step=self.n_epochs, n_channels=self.config.network.parameters.n_channels,
                                  spatial_dims=len(self.config.network.parameters.input_size))

        if record_embeddings:
            if len(images.shape) == 5:
                images = images[:, :, self.config.network.parameters.input_size[0] // 2, :,
                         :]  # we take slice at middle depth only
            if (images.shape[1] != 1) or (images.shape[1] != 3):
                images = images[:, :3, ...]
            images = resize_embeddings(images)
            self.logger.add_embedding(
                embeddings,
                metadata=labels,
                label_img=images,
                global_step=self.n_epochs)

        # average loss and return
        for k, v in losses.items():
            losses[k] = np.mean(v)

        return losses