예제 #1
0
    def _run_batch(self, data, eval=False, ds=None):
        time_dataloading = time.time() - self.iter_starttime
        time_proc_start = time.time()
        iter_stats = {'time_dataloading': time_dataloading}

        batch = Batch(data, eval=eval, gpu=self.args.gpu)

        targets = batch.masks.float()
        images = batch.images

        self.net.zero_grad()

        with torch.set_grad_enabled(not eval):
            X_vessels = self.net(images)

        loss = F.binary_cross_entropy(X_vessels, targets)

        loss.backward()
        self.optimizer.step()

        # statistics
        iter_stats.update({
            'loss': loss.item(),
            'epoch': self.epoch,
            'timestamp': time.time(),
            'iter_time': time.time() - self.iter_starttime,
            'time_processing': time.time() - time_proc_start,
            'iter': self.iter_in_epoch,
            'total_iter': self.total_iter,
            'batch_size': len(batch)
        })
        self.iter_starttime = time.time()
        self.epoch_stats.append(iter_stats)

        # print stats every N mini-batches
        if self._is_printout_iter(eval):
            nimgs = 5
            avg_PR = eval_vessels.calculate_metrics(X_vessels, targets)['PR']
            PRs = get_image_PRs(X_vessels[:nimgs], targets[:nimgs])
            iter_stats.update({'avg_PR': avg_PR})

            self._print_iter_stats(
                self.epoch_stats[-self._print_interval(eval):])

            # Batch visualization
            if self.args.show:
                retina_vis.visualize_vessels(images,
                                             images,
                                             vessel_hm=targets,
                                             scores=PRs,
                                             pred_vessel_hm=X_vessels,
                                             ds=ds,
                                             wait=self.args.wait,
                                             f=1.0,
                                             overlay_heatmaps_recon=True,
                                             nimgs=1,
                                             horizontal=True)
예제 #2
0
    from csl_common.utils.common import init_random
    from csl_common.utils.ds_utils import build_transform
    from csl_common.vis import vis
    import config

    init_random(3)

    path = config.get_dataset_paths('wflw')[0]
    ds = WFLW(root=path,
              train=False,
              deterministic=True,
              use_cache=False,
              daug=0,
              image_size=256,
              transform=build_transform(deterministic=False, daug=0))
    ds.filter_labels({'pose': 1, 'occlusion': 0, 'make-up': 1})
    dl = td.DataLoader(ds, batch_size=10, shuffle=False, num_workers=0)
    print(ds)

    for data in dl:
        batch = Batch(data, gpu=False)
        images = vis.to_disp_images(batch.images, denorm=True)
        # lms = lmutils.convert_landmarks(to_numpy(batch.landmarks), lmutils.LM98_TO_LM68)
        lms = batch.landmarks
        images = vis.add_landmarks_to_images(images,
                                             lms,
                                             draw_wireframe=False,
                                             color=(0, 255, 0),
                                             radius=3)
        vis.vis_square(images, nCols=10, fx=1., fy=1., normalize=False)
예제 #3
0
def get_fixed_samples(ds, num):
    dl = td.DataLoader(ds, batch_size=num, shuffle=False, num_workers=0)
    data = next(iter(dl))
    return Batch(data, n=num)
예제 #4
0
    def evaluate(self):
        log.info("")
        log.info("Evaluating '{}'...".format(self.session_name))
        # log.info("")

        self.iters_per_epoch = len(self.fixed_val_data)
        self.iter_in_epoch = 0
        self.iter_starttime = time.time()
        epoch_starttime = time.time()
        epoch_stats = []

        self.net.eval()

        for data in self.fixed_val_data:
            batch = Batch(data, eval=True)
            targets = batch.masks.float()

            time_proc_start = time.time()
            time_dataloading = time.time() - self.iter_starttime
            with torch.no_grad():
                X_vessels = self.net(batch.images)
            loss = F.binary_cross_entropy(X_vessels, targets)

            iter_stats = {
                'loss': loss.item(),
                'epoch': self.epoch,
                'timestamp': time.time(),
                'time_dataloading': time_dataloading,
                'time_processing': time.time() - time_proc_start,
                'iter_time': time.time() - self.iter_starttime,
                'iter': self.iter_in_epoch,
                'total_iter': self.total_iter,
                'batch_size': len(batch)
            }
            epoch_stats.append(iter_stats)

            if self._is_printout_iter(eval=True):
                nimgs = 1
                avg_PR = eval_vessels.calculate_metrics(X_vessels,
                                                        targets)['PR']
                PRs = get_image_PRs(X_vessels[:nimgs], targets[:nimgs])
                iter_stats.update({'avg_PR': avg_PR})
                self._print_iter_stats(
                    epoch_stats[-self._print_interval(True):])

                #
                # Batch visualization
                #
                if self.args.show:
                    retina_vis.visualize_vessels(batch.images,
                                                 batch.images,
                                                 vessel_hm=targets,
                                                 scores=PRs,
                                                 pred_vessel_hm=X_vessels,
                                                 wait=self.args.wait,
                                                 f=1.0,
                                                 overlay_heatmaps_recon=True,
                                                 nimgs=nimgs,
                                                 horizontal=True)

            self.iter_starttime = time.time()
            self.iter_in_epoch += 1

        # print average loss and accuracy over epoch
        self._print_epoch_summary(epoch_stats, epoch_starttime)

        # update scheduler
        means = pd.DataFrame(epoch_stats).mean().to_dict()
        val_loss = means['loss']
        val_PR = means['avg_PR']
예제 #5
0
    def _run_batch(self, data, eval=False, ds=None):
        time_dataloading = time.time() - self.iter_starttime
        time_proc_start = time.time()
        iter_stats = {'time_dataloading': time_dataloading}

        batch = Batch(data, eval=eval)
        X_target = batch.target_images if batch.target_images is not None else batch.images

        self.saae.zero_grad()
        loss = torch.zeros(1, requires_grad=True).cuda()

        #######################
        # Encoding
        #######################
        with torch.set_grad_enabled(self.args.train_encoder):

            z_sample = self.saae.Q(batch.images)

            ###########################
            # Encoding regularization
            ###########################
            if (not eval or self._is_printout_iter(eval)
                ) and self.args.with_zgan and self.args.train_encoder:
                if WITH_LOSS_ZREG:
                    loss_zreg = torch.abs(z_sample).mean()
                    loss += loss_zreg
                    iter_stats.update({'loss_zreg': loss_zreg.item()})
                encoding = self.update_encoding(z_sample)
                iter_stats.update(encoding)

        iter_stats['z_recon_mean'] = z_sample.mean().item()
        iter_stats['z_recon_std'] = z_sample.std().item()

        #######################
        # Decoding
        #######################

        if not self.args.train_encoder:
            z_sample = z_sample.detach()

        with torch.set_grad_enabled(self.args.train_decoder):

            # reconstruct images
            X_recon = self.saae.P(z_sample)

            #######################
            # Reconstruction loss
            #######################
            loss_recon = aae_training.loss_recon(X_target, X_recon)
            loss = loss_recon * self.args.w_rec
            iter_stats['loss_recon'] = loss_recon.item()

            #######################
            # Structural loss
            #######################
            cs_error_maps = None
            if self.args.with_ssim_loss or eval:
                store_cs_maps = self._is_printout_iter(
                    eval) or eval  # get error maps for visualization
                loss_ssim, cs_error_maps = aae_training.loss_struct(
                    X_target,
                    X_recon,
                    self.ssim,
                    calc_error_maps=store_cs_maps)
                loss_ssim *= self.args.w_ssim
                loss = 0.5 * loss + 0.5 * loss_ssim
                iter_stats['ssim_torch'] = loss_ssim.item()

            #######################
            # Adversarial loss
            #######################
            if self.args.with_gan and self.args.train_decoder and self.iter_in_epoch % 1 == 0:
                gan_stats, loss_G = self.update_gan(
                    X_target,
                    X_recon,
                    z_sample,
                    train=not eval,
                    with_gen_loss=self.args.with_gen_loss)
                loss += loss_G
                iter_stats.update(gan_stats)

            iter_stats['loss'] = loss.item()

            if self.args.train_decoder:
                loss.backward()

            # Update auto-encoder
            if not eval:
                if self.args.train_encoder:
                    self.optimizer_E.step()
                if self.args.train_decoder:
                    self.optimizer_G.step()

            if eval or self._is_printout_iter(eval):
                iter_stats['ssim'] = aae_training.calc_ssim(X_target, X_recon)

        # statistics
        iter_stats.update({
            'epoch': self.epoch,
            'timestamp': time.time(),
            'iter_time': time.time() - self.iter_starttime,
            'time_processing': time.time() - time_proc_start,
            'iter': self.iter_in_epoch,
            'total_iter': self.total_iter,
            'batch_size': len(batch)
        })
        self.iter_starttime = time.time()
        self.epoch_stats.append(iter_stats)

        # print stats every N mini-batches
        if self._is_printout_iter(eval):
            self._print_iter_stats(
                self.epoch_stats[-self._print_interval(eval):])

            #
            # Batch visualization
            #
            if self.args.show:
                num_sample_images = {
                    128: 8,
                    256: 7,
                    512: 2,
                    1024: 1,
                }
                nimgs = num_sample_images[self.args.input_size]
                self.visualize_random_images(nimgs, z_real=z_sample)
                self.visualize_interpolations(z_sample, nimgs=2)
                self.visualize_batch(batch,
                                     X_recon,
                                     nimgs=nimgs,
                                     ssim_maps=cs_error_maps,
                                     ds=ds,
                                     wait=self.wait)
예제 #6
0
    def _run_batch(self, data, eval=False, ds=None):
        time_dataloading = time.time() - self.iter_starttime
        time_proc_start = time.time()
        iter_stats = {'time_dataloading': time_dataloading}

        batch = Batch(data, eval=eval)

        self.saae.zero_grad()
        self.saae.eval()

        input_images = batch.target_images if batch.target_images is not None else batch.images

        with torch.set_grad_enabled(self.args.train_encoder):
            z_sample = self.saae.Q(input_images)

        iter_stats.update({'z_recon_mean': z_sample.mean().item()})

        #######################
        # Reconstruction phase
        #######################
        with torch.set_grad_enabled(self.args.train_encoder and not eval):
            X_recon = self.saae.P(z_sample)

        # calculate reconstruction error for debugging and reporting
        with torch.no_grad():
            iter_stats['loss_recon'] = aae_training.loss_recon(
                batch.images, X_recon)

        #######################
        # Landmark predictions
        #######################
        train_lmhead = not eval
        lm_preds_max = None
        with torch.set_grad_enabled(train_lmhead):
            self.saae.LMH.train(train_lmhead)
            X_lm_hm = self.saae.LMH(self.saae.P)
            if batch.lm_heatmaps is not None:
                loss_lms = F.mse_loss(batch.lm_heatmaps, X_lm_hm) * 100 * 3
                iter_stats.update({'loss_lms': loss_lms.item()})

            if eval or self._is_printout_iter(eval):
                # expensive, so only calculate when every N iterations
                # X_lm_hm = lmutils.decode_heatmap_blob(X_lm_hm)
                X_lm_hm = lmutils.smooth_heatmaps(X_lm_hm)
                lm_preds_max = self.saae.heatmaps_to_landmarks(X_lm_hm)

            if eval or self._is_printout_iter(eval):
                lm_gt = to_numpy(batch.landmarks)
                nmes = lmutils.calc_landmark_nme(
                    lm_gt,
                    lm_preds_max,
                    ocular_norm=self.args.ocular_norm,
                    image_size=self.args.input_size)
                # nccs = lmutils.calc_landmark_ncc(batch.images, X_recon, lm_gt)
                iter_stats.update({'nmes': nmes})

        if train_lmhead:
            # if self.args.train_encoder:
            #     loss_lms = loss_lms * 80.0
            loss_lms.backward()
            self.optimizer_lm_head.step()
            if self.args.train_encoder:
                self.optimizer_E.step()
                # self.optimizer_G.step()

        # statistics
        iter_stats.update({
            'epoch': self.epoch,
            'timestamp': time.time(),
            'iter_time': time.time() - self.iter_starttime,
            'time_processing': time.time() - time_proc_start,
            'iter': self.iter_in_epoch,
            'total_iter': self.total_iter,
            'batch_size': len(batch)
        })
        self.iter_starttime = time.time()

        self.epoch_stats.append(iter_stats)

        # print stats every N mini-batches
        if self._is_printout_iter(eval):
            self._print_iter_stats(
                self.epoch_stats[-self._print_interval(eval):])

            lmvis.visualize_batch(
                batch.images,
                batch.landmarks,
                X_recon,
                X_lm_hm,
                lm_preds_max,
                lm_heatmaps=batch.lm_heatmaps,
                target_images=batch.target_images,
                ds=ds,
                ocular_norm=self.args.ocular_norm,
                clean=False,
                overlay_heatmaps_input=False,
                overlay_heatmaps_recon=False,
                landmarks_only_outline=self.landmarks_only_outline,
                landmarks_no_outline=self.landmarks_no_outline,
                f=1.0,
                wait=self.wait)