Ejemplo n.º 1
0
    def validate(self):
        self.net.eval()
        agg_loss_dict = []
        for batch_idx, batch in enumerate(self.test_dataloader):
            x, y = batch

            with torch.no_grad():
                x, y = x.to(self.device).float(), y.to(self.device).float()
                y_hat, latent = self.net.forward(x, return_latents=True)
                loss, cur_loss_dict, id_logs = self.calc_loss(
                    x, y, y_hat, latent)
            agg_loss_dict.append(cur_loss_dict)

            # Logging related
            self.parse_and_log_images(id_logs,
                                      x,
                                      y,
                                      y_hat,
                                      title='images/test',
                                      subscript='{:04d}'.format(batch_idx))

            # For first step just do sanity test on small amount of data
            if self.global_step == 0 and batch_idx >= 4:
                self.net.train()
                return None  # Do not log, inaccurate in first batch

        loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
        self.log_metrics(loss_dict, prefix='test')
        self.print_metrics(loss_dict, prefix='test')

        self.net.train()
        return loss_dict
Ejemplo n.º 2
0
    def validate(self):
        self.net.eval()
        agg_loss_dict = []
        for batch_idx, batch in enumerate(self.test_dataloader):
            cur_loss_dict = {}
            if self.is_training_discriminator():
                cur_loss_dict = self.validate_discriminator(batch)
            with torch.no_grad():
                x, y, y_hat, latent = self.forward(batch)
                loss, cur_encoder_loss_dict, id_logs = self.calc_loss(
                    x, y, y_hat, latent)
                cur_loss_dict = {**cur_loss_dict, **cur_encoder_loss_dict}
            agg_loss_dict.append(cur_loss_dict)

            # Logging related
            self.parse_and_log_images(id_logs,
                                      x,
                                      y,
                                      y_hat,
                                      title='images/test/faces',
                                      subscript='{:04d}'.format(batch_idx))

            # For first step just do sanity test on small amount of data
            if self.global_step == 0 and batch_idx >= 4:
                self.net.train()
                return None  # Do not log, inaccurate in first batch

        loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
        self.log_metrics(loss_dict, prefix='test')
        self.print_metrics(loss_dict, prefix='test')

        self.net.train()
        return loss_dict
Ejemplo n.º 3
0
    def validate(self):
        self.net.eval()
        agg_loss_dict = []
        for batch_idx, batch in enumerate(self.test_dataloader):
            target_latent = None
            labels = None
            if self.labels_path is not None:
                x, y, labels = batch
                labels = labels.to(self.device)
            elif self.opts.latent_lambda > 0:
                x, y, target_latent = batch
                target_latent = target_latent.to(self.device)
            else:
                x, y = batch

            with torch.no_grad():
                x, y = x.to(self.device).float(), y.to(self.device).float()
                if target_latent is not None:
                    target_latent = target_latent.to(self.device)
                y_hat, latent = self.net.forward(
                    x, labels=None, return_latents=True
                )
                loss, cur_loss_dict, id_logs = self.calc_loss(
                    x, y, y_hat, latent, target_latent
                )
            agg_loss_dict.append(cur_loss_dict)

            # Logging related
            self.parse_and_log_images(
                id_logs,
                x,
                y,
                y_hat,
                title="images/test/faces",
                subscript="{:04d}".format(batch_idx),
            )

            # For first step just do sanity test on small amount of data
            if self.global_step == 0 and batch_idx >= 4:
                self.net.train()
                return None  # Do not log, inaccurate in first batch

        loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
        self.log_metrics(loss_dict, prefix="test")
        self.print_metrics(loss_dict, prefix="test")

        self.net.train()
        return loss_dict
Ejemplo n.º 4
0
    def validate(self):
        self.net.eval()
        agg_loss_dict = []
        for batch_idx, batch in enumerate(self.test_dataloader):
            x, y = batch
            x, y = x.to(self.device).float(), y.to(self.device).float()

            # validate discriminator on batch
            avg_image_for_batch = self.avg_image.unsqueeze(0).repeat(
                x.shape[0], 1, 1, 1)
            x_input = torch.cat([x, avg_image_for_batch], dim=1)
            cur_disc_loss_dict = {}
            if self.is_training_discriminator():
                cur_disc_loss_dict = self.validate_discriminator(x_input)

            # validate encoder on batch
            with torch.no_grad():
                y_hats, cur_enc_loss_dict, id_logs = self.perform_val_iteration_on_batch(
                    x, y)

            cur_loss_dict = {**cur_disc_loss_dict, **cur_enc_loss_dict}
            agg_loss_dict.append(cur_loss_dict)

            # Logging related
            self.parse_and_log_images(id_logs,
                                      x,
                                      y,
                                      y_hats,
                                      title='images/test',
                                      subscript='{:04d}'.format(batch_idx))

            # For first step just do sanity test on small amount of data
            if self.global_step == 0 and batch_idx >= 4:
                self.net.train()
                return None  # Do not log, inaccurate in first batch

        loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
        self.log_metrics(loss_dict, prefix='test')
        self.print_metrics(loss_dict, prefix='test')

        self.net.train()
        return loss_dict
Ejemplo n.º 5
0
    def validate(self):
        self.net.eval()
        agg_loss_dict = []
        for batch_idx, batch in enumerate(self.test_dataloader):
            x, y = batch
            with torch.no_grad():
                x, y = x.to(self.device).float(), y.to(self.device).float()

                # perform no aging in 33% of the time
                no_aging = random.random() <= (1. / 3)
                if no_aging:
                    x_input = self.__set_target_to_source(x)
                else:
                    x_input = [
                        self.age_transformer(img.cpu()).to(self.device)
                        for img in x
                    ]
                x_input = torch.stack(x_input)
                target_ages = x_input[:, -1, 0, 0]

                # perform forward/backward pass on real images
                y_hat, latent = self.perform_forward_pass(x_input)
                _, cur_loss_dict, id_logs = self.calc_loss(
                    x,
                    y,
                    y_hat,
                    latent,
                    target_ages=target_ages,
                    no_aging=no_aging,
                    data_type="real")

                # perform cycle on generate images by reversing the aging amount
                y_hat_inverse = self.__set_target_to_source(y_hat)
                y_hat_inverse = torch.stack(y_hat_inverse)
                y_recovered, latent_cycle = self.perform_forward_pass(
                    y_hat_inverse)
                reverse_target_ages = y_hat_inverse[:, -1, 0, 0]
                loss, cycle_loss_dict, cycle_id_logs = self.calc_loss(
                    x,
                    y,
                    y_recovered,
                    latent_cycle,
                    target_ages=reverse_target_ages,
                    no_aging=no_aging,
                    data_type="cycle")

                # combine the logs of both forwards
                for idx, cycle_log in enumerate(cycle_id_logs):
                    id_logs[idx].update(cycle_log)
                cur_loss_dict.update(cycle_loss_dict)
                cur_loss_dict["loss"] = cur_loss_dict[
                    "loss_real"] + cur_loss_dict["loss_cycle"]

            agg_loss_dict.append(cur_loss_dict)

            # Logging related
            self.parse_and_log_images(id_logs,
                                      x,
                                      y,
                                      y_hat,
                                      y_recovered,
                                      title='images/test/faces',
                                      subscript='{:04d}'.format(batch_idx))

            # For first step just do sanity test on small amount of data
            if self.global_step == 0 and batch_idx >= 4:
                self.net.train()
                return None  # Do not log, inaccurate in first batch

        loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
        self.log_metrics(loss_dict, prefix='test')
        self.print_metrics(loss_dict, prefix='test')

        self.net.train()
        return loss_dict