def _compute_fid(self): r"""We will compute FID for the regular model using the eval mode. For the moving average model, we will use the eval mode. """ self.net_G.eval() net_G_for_evaluation = \ functools.partial(self.net_G, random_style=True) regular_fid_path = self._get_save_path('regular_fid', 'npy') preprocess = \ functools.partial(self._start_of_iteration, current_iteration=0) regular_fid_value = compute_fid(regular_fid_path, self.val_data_loader, net_G_for_evaluation, preprocess=preprocess) print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format( self.current_epoch, self.current_iteration, regular_fid_value)) if self.cfg.trainer.model_average: avg_net_G_for_evaluation = \ functools.partial(self.net_G.module.averaged_model, random_style=True) fid_path = self._get_save_path('average_fid', 'npy') fid_value = compute_fid(fid_path, self.val_data_loader, avg_net_G_for_evaluation, preprocess=preprocess) print('Epoch {:05}, Iteration {:09}, FID {}'.format( self.current_epoch, self.current_iteration, fid_value)) self.net_G.float() return regular_fid_value, fid_value else: self.net_G.float() return regular_fid_value
def _compute_fid(self): r"""Compute FID. We will compute a FID value per test class. That is if you have 30 test classes, we will compute 30 different FID values. We will then report the mean of the FID values as the final performance number as described in the FUNIT paper. """ self.net_G.eval() if self.cfg.trainer.model_average: net_G_for_evaluation = self.net_G.module.averaged_model else: net_G_for_evaluation = self.net_G all_fid_values = [] num_test_classes = self.val_data_loader.dataset.num_style_classes for class_idx in range(num_test_classes): fid_path = self._get_save_path(os.path.join('fid', str(class_idx)), 'npy') self.val_data_loader.dataset.set_sample_class_idx(class_idx) fid_value = compute_fid(fid_path, self.val_data_loader, net_G_for_evaluation, 'images_style', 'images_trans') all_fid_values.append(fid_value) if is_master(): mean_fid = np.mean(all_fid_values) print('Epoch {:05}, Iteration {:09}, Mean FID {}'.format( self.current_epoch, self.current_iteration, mean_fid)) return mean_fid else: return None
def _compute_fid(self): r"""Compute FID for both domains. """ self.net_G.eval() if self.cfg.trainer.model_average: net_G_for_evaluation = self.net_G.module.averaged_model else: net_G_for_evaluation = self.net_G fid_a_path = self._get_save_path('fid_a', 'npy') fid_b_path = self._get_save_path('fid_b', 'npy') fid_value_a = compute_fid(fid_a_path, self.val_data_loader, net_G_for_evaluation, 'images_a', 'images_ba') fid_value_b = compute_fid(fid_b_path, self.val_data_loader, net_G_for_evaluation, 'images_b', 'images_ab') print('Epoch {:05}, Iteration {:09}, FID a {}, FID b {}'.format( self.current_epoch, self.current_iteration, fid_value_a, fid_value_b)) return fid_value_a, fid_value_b