コード例 #1
0
ファイル: spade.py プロジェクト: zzf2014/imaginaire
    def _compute_fid(self):
        r"""We will compute FID for the regular model using the eval mode.
        For the moving average model, we will use the eval mode.
        """
        self.net_G.eval()
        net_G_for_evaluation = \
            functools.partial(self.net_G, random_style=True)
        regular_fid_path = self._get_save_path('regular_fid', 'npy')
        preprocess = \
            functools.partial(self._start_of_iteration, current_iteration=0)

        regular_fid_value = compute_fid(regular_fid_path,
                                        self.val_data_loader,
                                        net_G_for_evaluation,
                                        preprocess=preprocess)
        print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format(
            self.current_epoch, self.current_iteration, regular_fid_value))
        if self.cfg.trainer.model_average:
            avg_net_G_for_evaluation = \
                functools.partial(self.net_G.module.averaged_model,
                                  random_style=True)
            fid_path = self._get_save_path('average_fid', 'npy')
            fid_value = compute_fid(fid_path,
                                    self.val_data_loader,
                                    avg_net_G_for_evaluation,
                                    preprocess=preprocess)
            print('Epoch {:05}, Iteration {:09}, FID {}'.format(
                self.current_epoch, self.current_iteration, fid_value))
            self.net_G.float()
            return regular_fid_value, fid_value
        else:
            self.net_G.float()
            return regular_fid_value
コード例 #2
0
ファイル: funit.py プロジェクト: yejees/ObjectSwap
    def _compute_fid(self):
        r"""Compute FID. We will compute a FID value per test class. That is
        if you have 30 test classes, we will compute 30 different FID values.
        We will then report the mean of the FID values as the final
        performance number as described in the FUNIT paper.
        """
        self.net_G.eval()
        if self.cfg.trainer.model_average:
            net_G_for_evaluation = self.net_G.module.averaged_model
        else:
            net_G_for_evaluation = self.net_G

        all_fid_values = []
        num_test_classes = self.val_data_loader.dataset.num_style_classes
        for class_idx in range(num_test_classes):
            fid_path = self._get_save_path(os.path.join('fid', str(class_idx)),
                                           'npy')
            self.val_data_loader.dataset.set_sample_class_idx(class_idx)

            fid_value = compute_fid(fid_path, self.val_data_loader,
                                    net_G_for_evaluation, 'images_style',
                                    'images_trans')
            all_fid_values.append(fid_value)

        if is_master():
            mean_fid = np.mean(all_fid_values)
            print('Epoch {:05}, Iteration {:09}, Mean FID {}'.format(
                self.current_epoch, self.current_iteration, mean_fid))
            return mean_fid
        else:
            return None
コード例 #3
0
 def _compute_fid(self):
     r"""Compute FID for both domains.
     """
     self.net_G.eval()
     if self.cfg.trainer.model_average:
         net_G_for_evaluation = self.net_G.module.averaged_model
     else:
         net_G_for_evaluation = self.net_G
     fid_a_path = self._get_save_path('fid_a', 'npy')
     fid_b_path = self._get_save_path('fid_b', 'npy')
     fid_value_a = compute_fid(fid_a_path, self.val_data_loader,
                               net_G_for_evaluation, 'images_a', 'images_ba')
     fid_value_b = compute_fid(fid_b_path, self.val_data_loader,
                               net_G_for_evaluation, 'images_b', 'images_ab')
     print('Epoch {:05}, Iteration {:09}, FID a {}, FID b {}'.format(
         self.current_epoch, self.current_iteration,
         fid_value_a, fid_value_b))
     return fid_value_a, fid_value_b