Exemple #1
0
 def _create_weighted_cross_entropy_loss(affectnet_dataset):
     _weights = 1.0 / affectnet_dataset.get_class_sizes()
     if _weights[7] > 1.0: _weights[7] = 0
     _weights = _weights.astype(np.float32)
     _weights /= np.sum(_weights)
     log.info("AffectNet weights: {}".format(_weights))
     return torch.nn.CrossEntropyLoss(
         weight=torch.from_numpy(_weights).to(device))
    def _print_epoch_summary(self, epoch_stats, epoch_starttime):
        means = pd.DataFrame(epoch_stats).mean().to_dict()
        try:
            ssim_scores = np.concatenate([
                stats['ssim'] for stats in self.epoch_stats if 'ssim' in stats
            ])
        except:
            ssim_scores = np.array(0)
        duration = int(time.time() - epoch_starttime)

        log.info("{}".format('-' * 140))
        str_stats = [
            'Train:         '
            'l={avg_loss:.3f} '
            'l_rec={avg_loss_recon:.3f} '
            'l_ssim={avg_ssim_torch:.3f}({avg_ssim:.3f}) '
            'l_lmrec={avg_lms_recon:.3f} '
            'l_lmssim={avg_lms_ssim:.3f} '
            # 'l_lmcs={avg_lms_cs:.3f} '
            # 'l_lmncc={avg_lms_ncc:.3f} '
            'z_mu={avg_z_recon_mean:.3f} '
        ]
        str_stats[0] += [
            'l_D_z={avg_loss_D_z:.4f} '
            'l_E={avg_loss_E:.4f}  '
            'l_D={avg_loss_D:.4f} '
            'l_G={avg_loss_G:.4f} '
            '\tT: {epoch_time} ({total_time})'
        ][0]
        log.info(str_stats[0].format(
            iters_per_epoch=self.iters_per_epoch,
            avg_loss=means.get('loss', -1),
            avg_loss_recon=means.get('loss_recon', -1),
            avg_lms_recon=means.get('landmark_recon_errors', -1),
            avg_lms_ssim=means.get('landmark_ssim_scores', -1),
            avg_lms_ncc=means.get('landmark_ncc_errors', -1),
            avg_lms_cs=means.get('landmark_cs_errors', -1),
            avg_ssim=ssim_scores.mean(),
            avg_ssim_torch=means.get('ssim_torch', -1),
            avg_loss_E=means.get('loss_E', -1),
            avg_loss_D_z=means.get('loss_D_z', -1),
            avg_loss_D=means.get('loss_D', -1),
            avg_loss_G=means.get('loss_G', -1),
            avg_loss_D_real=means.get('err_real', -1),
            avg_loss_D_fake=means.get('err_fake', -1),
            avg_z_recon_mean=means.get('z_recon_mean', -1),
            t=means['iter_time'],
            t_data=means['time_dataloading'],
            t_proc=means['time_processing'],
            total_iter=self.total_iter + 1,
            total_time=str(datetime.timedelta(seconds=self._training_time())),
            totatl_time=str(
                datetime.timedelta(seconds=self.total_training_time())),
            epoch_time=str(datetime.timedelta(seconds=duration))))
    def _print_iter_stats(self, stats):
        means = pd.DataFrame(stats).mean().to_dict()
        current = stats[-1]
        ssim_scores = current['ssim'].mean()

        str_stats = [
            '[{ep}][{i}/{iters_per_epoch}] '
            'l={avg_loss:.3f}  '
            'l_rec={avg_loss_recon:.3f} '
            'l_ssim={avg_ssim_torch:.3f}({avg_ssim:.2f}) '
            'l_lmrec={avg_lms_recon:.3f} '
            'l_lmssim={avg_lms_ssim:.2f} '
            # 'l_lmcs={avg_lms_cs:.2f} '
            # 'l_lmncc={avg_lms_ncc:.2f} '
            # 'l_act={avg_loss_activations:.3f} '
            'z_mu={avg_z_recon_mean: .3f} '
        ]
        str_stats[0] += [
            'l_D_z={avg_loss_D_z:.3f} '
            'l_E={avg_loss_E:.3f} '
            'l_D={avg_loss_D:.3f}({avg_loss_D_rec:.3f}/{avg_loss_D_gen:.3f}) '
            'l_G={avg_loss_G:.3f}({avg_loss_G_rec:.3f}/{avg_loss_G_gen:.3f}) '
            '{t_data:.2f}/{t_proc:.2f}/{t:.2f}s ({total_iter:06d} {epoch_time})'
        ][0]
        log.info(str_stats[0].format(
            ep=current['epoch'] + 1,
            i=current['iter'] + 1,
            iters_per_epoch=self.iters_per_epoch,
            avg_loss=means.get('loss', -1),
            avg_loss_recon=means.get('loss_recon', -1),
            avg_lms_recon=means.get('landmark_recon_errors', -1),
            avg_lms_ssim=means.get('landmark_ssim_scores', -1),
            avg_lms_ncc=means.get('landmark_ncc_errors', -1),
            avg_lms_cs=means.get('landmark_cs_errors', -1),
            avg_ssim=ssim_scores.mean(),
            avg_ssim_torch=means.get('ssim_torch', -1),
            avg_loss_activations=means.get('loss_activations', -1),
            avg_loss_F=means.get('loss_F', -1),
            avg_loss_E=means.get('loss_E', -1),
            avg_loss_D_z=means.get('loss_D_z', -1),
            avg_loss_D=means.get('loss_D', -1),
            avg_loss_D_rec=means.get('loss_D_rec', -1),
            avg_loss_D_gen=means.get('loss_D_gen', -1),
            avg_loss_G=means.get('loss_G', -1),
            avg_loss_G_rec=means.get('loss_G_rec', -1),
            avg_loss_G_gen=means.get('loss_G_gen', -1),
            avg_loss_D_real=means.get('err_real', -1),
            avg_loss_D_fake=means.get('err_fake', -1),
            avg_z_recon_mean=means.get('z_recon_mean', -1),
            t=means['iter_time'],
            t_data=means['time_dataloading'],
            t_proc=means['time_processing'],
            total_iter=self.total_iter + 1,
            epoch_time=str(datetime.timedelta(seconds=self._training_time()))))
Exemple #4
0
def extract_features(split, st=None, nd=None):
    """ Extract facial features (landmarks, pose,...) from images """
    import glob
    assert(split in ['train', 'test'])
    person_dirs = sorted(glob.glob(os.path.join(VGGFACE2_ROOT, split, 'imgs', '*')))[st:nd]
    # print(os.path.join(cfg.VGGFACE2_ROOT, split, 'imgs', '*'))
    for cnt, img_dir in enumerate(person_dirs):
        folder_name = os.path.split(img_dir)[1]
        out_dir = os.path.join(VGGFACE2_ROOT_LOCAL, split, 'features', folder_name)
        log.info("{}/{}".format(cnt, len(person_dirs)))
        cropping.run_open_face(img_dir, out_dir, is_sequence=False)
    def eval_epoch(self):
        log.info("")
        log.info("Evaluating '{}'...".format(self.session_name))
        # log.info("")

        epoch_starttime = time.time()
        self.epoch_stats = []
        self.saae.eval()

        self._run_epoch(self.datasets[VAL], eval=True)
        # print average loss and accuracy over epoch
        self._print_epoch_summary(self.epoch_stats, epoch_starttime)
        return self.epoch_stats
Exemple #6
0
 def _print_epoch_summary(self, epoch_stats, epoch_starttime):
     means = pd.DataFrame(epoch_stats).mean().to_dict()
     duration = int(time.time() - epoch_starttime)
     log.info("{}".format('-' * 100))
     str_stats = '          loss={avg_loss:.4f} PR={avg_PR:.3f} \tT: {time_epoch}'
     log.info(
         str_stats.format(
             iters_per_epoch=self.iters_per_epoch,
             avg_loss=means.get('loss', -1),
             avg_PR=means.get('avg_PR', -1),
             t=means['iter_time'],
             t_data=means['time_dataloading'],
             t_proc=means['time_processing'],
             total_iter=self.total_iter + 1,
             total_time=str(
                 datetime.timedelta(seconds=self._training_time())),
             time_epoch=str(datetime.timedelta(seconds=duration))))
Exemple #7
0
 def _print_iter_stats(self, stats):
     means = pd.DataFrame(stats).mean().to_dict()
     current = stats[-1]
     str_stats = '[{ep}][{i}/{iters_per_epoch}] loss={avg_loss:.4f} PR={avg_PR:.3f} {t_data:.2f}/{t_proc:.3f}/{t:.2f}s ({total_iter:06d} {total_time})'
     log.info(
         str_stats.format(
             ep=current['epoch'] + 1,
             i=current['iter'] + 1,
             iters_per_epoch=self.iters_per_epoch,
             avg_loss=means.get('loss', -1),
             avg_PR=means.get('avg_PR', -1),
             t=means['iter_time'],
             t_data=means['time_dataloading'],
             t_proc=means['time_processing'],
             total_iter=self.total_iter + 1,
             total_time=str(
                 datetime.timedelta(seconds=self._training_time()))))
Exemple #8
0
    def _load_snapshot(self, snapshot_name, data_dir=None):
        if data_dir is None:
            data_dir = self.snapshot_dir

        model_snap_dir = os.path.join(data_dir, snapshot_name)
        try:
            nn.read_model(model_snap_dir, 'saae', self.net)
        except KeyError as e:
            print(e)

        meta = nn.read_meta(model_snap_dir)
        self.epoch = meta['epoch']
        self.total_iter = meta['total_iter']
        self.total_training_time_previous = meta.get('total_time', 0)
        self.total_images = meta.get('total_images', 0)
        self.best_score = meta['best_score']
        self.net.total_iter = self.total_iter
        str_training_time = str(datetime.timedelta(seconds=self.total_training_time()))
        log.info("Model {} trained for {} iterations ({}).".format(snapshot_name, self.total_iter, str_training_time))
    def _print_iter_stats(self, stats):
        means = pd.DataFrame(stats).mean().to_dict()
        current = stats[-1]
        nmes = current.get('nmes', np.zeros(0))

        str_stats = [
            '[{ep}][{i}/{iters_per_epoch}] '
            'l_rec={avg_loss_recon:.3f} '
            # 'ssim={avg_ssim:.3f} '
            # 'ssim_torch={avg_ssim_torch:.3f} '
            # 'z_mu={avg_z_recon_mean: .3f} '
            'l_lms={avg_loss_lms:.4f} '
            'err_lms={avg_err_lms:.2f}/{avg_err_lms_outline:.2f}/{avg_err_lms_all:.2f} '
            '{t_data:.2f}/{t_proc:.2f}/{t:.2f}s ({total_iter:06d} {total_time})'
        ][0]
        log.info(
            str_stats.format(
                ep=current['epoch'] + 1,
                i=current['iter'] + 1,
                iters_per_epoch=self.iters_per_epoch,
                avg_loss=means.get('loss', -1),
                avg_loss_recon=means.get('loss_recon', -1),
                avg_ssim=1.0 - means.get('ssim', -1),
                avg_ssim_torch=means.get('ssim_torch', -1),
                avg_loss_activations=means.get('loss_activations', -1),
                avg_loss_lms=means.get('loss_lms', -1),
                avg_z_l1=means.get('z_l1', -1),
                avg_z_recon_mean=means.get('z_recon_mean', -1),
                t=means['iter_time'],
                t_data=means['time_dataloading'],
                t_proc=means['time_processing'],
                avg_err_lms=nmes[:, self.landmarks_no_outline].mean(),
                avg_err_lms_outline=nmes[:,
                                         self.landmarks_only_outline].mean(),
                avg_err_lms_all=nmes[:, self.all_landmarks].mean(),
                total_iter=self.total_iter + 1,
                total_time=str(
                    datetime.timedelta(seconds=self._training_time()))))
Exemple #10
0
    def _print_epoch_summary(self, epoch_stats, epoch_starttime, eval=False):
        means = pd.DataFrame(epoch_stats).mean().to_dict()

        try:
            nmes = np.concatenate(
                [s['nmes'] for s in self.epoch_stats if 'nmes' in s])
        except KeyError:
            nmes = np.zeros((1, 100))

        duration = int(time.time() - epoch_starttime)
        log.info("{}".format('-' * 100))
        str_stats = [
            '           '
            'l_rec={avg_loss_recon:.3f} '
            # 'ssim={avg_ssim:.3f} '
            # 'ssim_torch={avg_ssim_torch:.3f} '
            # 'z_mu={avg_z_recon_mean:.3f} '
            'l_lms={avg_loss_lms:.4f} '
            'err_lms={avg_err_lms:.2f}/{avg_err_lms_outline:.2f}/{avg_err_lms_all:.2f} '
            '\tT: {time_epoch}'
        ][0]
        log.info(
            str_stats.format(
                iters_per_epoch=self.iters_per_epoch,
                avg_loss=means.get('loss', -1),
                avg_loss_recon=means.get('loss_recon', -1),
                avg_ssim=1.0 - means.get('ssim', -1),
                avg_ssim_torch=means.get('ssim_torch', -1),
                avg_loss_lms=means.get('loss_lms', -1),
                avg_loss_lms_cnn=means.get('loss_lms_cnn', -1),
                avg_err_lms=nmes[:, self.landmarks_no_outline].mean(),
                avg_err_lms_outline=nmes[:,
                                         self.landmarks_only_outline].mean(),
                avg_err_lms_all=nmes[:, self.all_landmarks].mean(),
                avg_z_recon_mean=means.get('z_recon_mean', -1),
                t=means['iter_time'],
                t_data=means['time_dataloading'],
                t_proc=means['time_processing'],
                total_iter=self.total_iter + 1,
                total_time=str(
                    datetime.timedelta(seconds=self._training_time())),
                time_epoch=str(datetime.timedelta(seconds=duration))))
        try:
            recon_errors = np.concatenate(
                [stats['l1_recon_errors'] for stats in self.epoch_stats])
            rmse = np.sqrt(np.mean(recon_errors**2))
            log.info("RMSE: {} ".format(rmse))
        except KeyError:
            # print("no l1_recon_error")
            pass

        if self.args.eval and nmes is not None:
            # benchmark_mode = hasattr(self.args, 'benchmark')
            # self.print_eval_metrics(nmes, show=benchmark_mode)
            self.print_eval_metrics(nmes, show=False)
Exemple #11
0
    def __init__(self, datasets, args, snapshot_dir=cfg.SNAPSHOT_DIR):

        self.args = args
        self.session_name = args.sessionname
        self.datasets = datasets
        self.net = self._get_network(pretrained=False)

        log.info("Learning rate: {}".format(self.args.lr))

        self.snapshot_dir = snapshot_dir
        self.total_iter = 0
        self.total_images = 0
        self.iter_in_epoch = 0
        self.epoch = 0
        self.best_score = 999
        self.epoch_stats = []

        if ENCODING_DISTRIBUTION == 'normal':
            self.enc_rand = torch.randn
            self.enc_rand_like = torch.randn_like
        elif ENCODING_DISTRIBUTION == 'uniform':
            self.enc_rand = torch.rand
            self.enc_rand_like = torch.rand_like
        else:
            raise ValueError()

        self.total_training_time_previous = 0
        self.time_start_training = time.time()

        snapshot = args.resume
        if snapshot is not None:
            log.info("Resuming session {} from snapshot {}...".format(self.session_name, snapshot))
            self._load_snapshot(snapshot)

        # self.net = self.net.cuda()

        log.info("Total model params: {:,}".format(count_parameters(self.net)))

        n_fixed_images = 10
        self.fixed_batch = {}
        for phase in datasets.keys():
            self.fixed_batch[phase] = get_fixed_samples(datasets[phase], n_fixed_images, gpu=self.args.gpu)
    def print_eval_metrics(nmes, show=False):
        def ced_curve(_nmes):
            y = []
            x = np.linspace(0, 10, 50)
            for th in x:
                recall = 1.0 - lmutils.calc_landmark_failure_rate(_nmes, th)
                recall *= 1 / len(x)
                y.append(recall)
            return x, y

        def auc(recalls):
            return np.sum(recalls)

        # for err_scale in np.linspace(0.1, 1, 10):
        for err_scale in [1.0]:
            # print('\nerr_scale', err_scale)
            # print(np.clip(lm_errs_max_all, a_min=0, a_max=10).mean())

            fr = lmutils.calc_landmark_failure_rate(nmes * err_scale)
            X, Y = ced_curve(nmes)

            log.info('NME:   {:>6.3f}'.format(nmes.mean() * err_scale))
            log.info('FR@10: {:>6.3f} ({})'.format(
                fr * 100, np.sum(nmes.mean(axis=1) > 10)))
            log.info('AUC:   {:>6.4f}'.format(auc(Y)))
            # log.info('NME:   {nme:>6.3f}, FR@10: {fr:>6.3f} ({fc}), AUC:   {auc:>6.4f}'.format(
            #     nme=nmes.mean()*err_scale,
            #     fr=fr*100,
            #     fc=np.sum(nmes.mean(axis=1) > 10),
            #     auc=auc(Y)))

            if show:
                import matplotlib.pyplot as plt
                fig, axes = plt.subplots(1, 2)
                axes[0].plot(X, Y)
                print(nmes.mean(axis=1).shape)
                print(nmes.mean(axis=1).max())
                axes[1].hist(nmes.mean(axis=1), bins=20)
                plt.show()
Exemple #13
0
                        help='how to normalize landmark errors', choices=['pupil', 'outer', 'none'])

    args = parser.parse_args()

    if args.resume is None:
        raise ValueError("Please specify the model to be evaluated: '-r MODELNAME'")

    args.dataset_train = args.dataset
    args.dataset_val = args.dataset

    args.eval = True
    args.batchsize_eval = 10
    args.wait = 0
    args.workers = 0
    args.print_freq_eval = 1
    args.epochs = 1

    if args.benchmark:
        log.info('Switching to benchmark mode...')
        args.batchsize_eval = 50
        args.wait = 10
        args.workers = 4
        args.print_freq_eval = 20
        args.epochs = 1
        args.val_count = None

    if args.sessionname is None:
        args.sessionname = args.resume

    run(args)
Exemple #14
0
    def __init__(self,
                 datasets,
                 args,
                 session_name='debug',
                 snapshot_dir=cfg.SNAPSHOT_DIR,
                 snapshot_interval=5,
                 workers=6,
                 macro_batch_size=20,
                 wait=10):

        self.args = args
        self.session_name = session_name
        self.datasets = datasets
        self.macro_batch_size = macro_batch_size
        self.workers = workers
        self.ssim = pytorch_msssim.SSIM(window_size=31)
        self.wait = wait
        self.saae = self._get_network(pretrained=False)

        print("Learning rate: {}".format(self.args.lr))

        self.snapshot_dir = snapshot_dir
        self.total_iter = 0
        self.total_images = 0
        self.iter_in_epoch = 0
        self.epoch = 0
        self.best_score = 999
        self.epoch_stats = []

        self.snapshot_interval = snapshot_interval

        if ENCODING_DISTRIBUTION == 'normal':
            self.enc_rand = torch.randn
            self.enc_rand_like = torch.randn_like
        elif ENCODING_DISTRIBUTION == 'uniform':
            self.enc_rand = torch.rand
            self.enc_rand_like = torch.rand_like
        else:
            raise ValueError()

        self.total_training_time_previous = 0
        self.time_start_training = time.time()

        snapshot = args.resume
        if snapshot is not None:
            log.info("Resuming session {} from snapshot {}...".format(
                self.session_name, snapshot))
            self._load_snapshot(snapshot)

        # reset discriminator
        if args.reset:
            self.saae.D.apply(weights_init)

        # Set optimizators
        betas = (self.args.beta1, self.args.beta2)
        Q_params = list(
            filter(lambda p: p.requires_grad, self.saae.Q.parameters()))
        self.optimizer_E = optim.Adam(Q_params, lr=args.lr, betas=betas)
        self.optimizer_G = optim.Adam(self.saae.P.parameters(),
                                      lr=args.lr,
                                      betas=betas)
        self.optimizer_D_z = optim.Adam(self.saae.D_z.parameters(),
                                        lr=args.lr,
                                        betas=betas)
        self.optimizer_D = optim.Adam(self.saae.D.parameters(),
                                      lr=args.lr * 0.5,
                                      betas=betas)

        n_fixed_images = 10
        self.fixed_batch = {}
        for phase in datasets.keys():
            self.fixed_batch[phase] = get_fixed_samples(
                datasets[phase], n_fixed_images)
    def train(self, num_epochs):

        log.info("")
        log.info("Starting training session '{}'...".format(self.session_name))
        log.info("")

        while num_epochs is None or self.epoch < num_epochs:
            log.info('')
            log.info('=' * 5 +
                     ' Epoch {}/{}'.format(self.epoch + 1, num_epochs))

            self.epoch_stats = []
            epoch_starttime = time.time()
            self.saae.train(True)

            self._run_epoch(self.datasets[TRAIN])

            # save model every few epochs
            if (self.epoch + 1) % self.snapshot_interval == 0:
                log.info("*** saving snapshot *** ")
                self._save_snapshot(is_best=False)

            # print average loss and accuracy over epoch
            self._print_epoch_summary(self.epoch_stats, epoch_starttime)

            if self._is_eval_epoch() and self.args.input_size < 512:
                self.eval_epoch()

            # save visualizations to disk
            if (self.epoch + 1) % 1 == 0:
                self.reconstruct_fixed_samples()

            self.epoch += 1

        time_elapsed = time.time() - self.time_start_training
        log.info('Training completed in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
Exemple #16
0
    def train(self, num_epochs):

        log.info("")
        log.info("Starting training session '{}'...".format(self.session_name))
        # log.info("")

        while num_epochs is None or self.epoch < num_epochs:
            log.info('')
            log.info('=' * 5 +
                     ' Epoch {}/{}'.format(self.epoch + 1, num_epochs))

            self.epoch_stats = []
            epoch_starttime = time.time()

            self.net.train()
            self._run_epoch(self.dataloaders[TRAIN])

            # save model every few epochs
            if (self.epoch + 1) % self.args.save_freq == 0:
                log.info("*** saving snapshot *** ")
                self._save_snapshot(is_best=False)

            # print average loss and accuracy over epoch
            self._print_epoch_summary(self.epoch_stats, epoch_starttime)

            if self._is_eval_epoch():
                self.evaluate()

            self.epoch += 1

        time_elapsed = time.time() - self.time_start_training
        log.info('Training completed in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
Exemple #17
0
    def evaluate(self):
        log.info("")
        log.info("Evaluating '{}'...".format(self.session_name))
        # log.info("")

        self.iters_per_epoch = len(self.fixed_val_data)
        self.iter_in_epoch = 0
        self.iter_starttime = time.time()
        epoch_starttime = time.time()
        epoch_stats = []

        self.net.eval()

        for data in self.fixed_val_data:
            batch = Batch(data, eval=True)
            targets = batch.masks.float()

            time_proc_start = time.time()
            time_dataloading = time.time() - self.iter_starttime
            with torch.no_grad():
                X_vessels = self.net(batch.images)
            loss = F.binary_cross_entropy(X_vessels, targets)

            iter_stats = {
                'loss': loss.item(),
                'epoch': self.epoch,
                'timestamp': time.time(),
                'time_dataloading': time_dataloading,
                'time_processing': time.time() - time_proc_start,
                'iter_time': time.time() - self.iter_starttime,
                'iter': self.iter_in_epoch,
                'total_iter': self.total_iter,
                'batch_size': len(batch)
            }
            epoch_stats.append(iter_stats)

            if self._is_printout_iter(eval=True):
                nimgs = 1
                avg_PR = eval_vessels.calculate_metrics(X_vessels,
                                                        targets)['PR']
                PRs = get_image_PRs(X_vessels[:nimgs], targets[:nimgs])
                iter_stats.update({'avg_PR': avg_PR})
                self._print_iter_stats(
                    epoch_stats[-self._print_interval(True):])

                #
                # Batch visualization
                #
                if self.args.show:
                    retina_vis.visualize_vessels(batch.images,
                                                 batch.images,
                                                 vessel_hm=targets,
                                                 scores=PRs,
                                                 pred_vessel_hm=X_vessels,
                                                 wait=self.args.wait,
                                                 f=1.0,
                                                 overlay_heatmaps_recon=True,
                                                 nimgs=nimgs,
                                                 horizontal=True)

            self.iter_starttime = time.time()
            self.iter_in_epoch += 1

        # print average loss and accuracy over epoch
        self._print_epoch_summary(epoch_stats, epoch_starttime)

        # update scheduler
        means = pd.DataFrame(epoch_stats).mean().to_dict()
        val_loss = means['loss']
        val_PR = means['avg_PR']