def test_case_calculate_fid_stat_CIFAR():
    from template_lib.d2.data import build_dataset_mapper
    from template_lib.d2template.trainer.base_trainer import build_detection_test_loader
    from template_lib.v2.GAN.evaluation import build_GAN_metric
    from template_lib.d2.utils.d2_utils import D2Utils
    from template_lib.v2.config_cfgnode import global_cfg

    from detectron2.utils import logger
    logger.setup_logger('d2')

    cfg = D2Utils.create_cfg()
    cfg.update(global_cfg)
    global_cfg.merge_from_dict(cfg)

    # fmt: off
    dataset_name                 = cfg.dataset_name
    IMS_PER_BATCH                = cfg.IMS_PER_BATCH
    img_size                     = cfg.img_size
    dataset_mapper_cfg           = cfg.dataset_mapper_cfg
    GAN_metric                   = cfg.GAN_metric
    # fmt: on

    num_workers = comm.get_world_size()
    batch_size = IMS_PER_BATCH // num_workers

    dataset_mapper = build_dataset_mapper(dataset_mapper_cfg, img_size=img_size)
    data_loader = build_detection_test_loader(
      cfg, dataset_name=dataset_name, batch_size=batch_size, mapper=dataset_mapper)

    FID_IS_tf = build_GAN_metric(GAN_metric)
    FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader)

    comm.synchronize()

    pass
  def __init__(self, cfg, **kwargs):

    # fmt: off
    self.tf_inception_model_dir       = cfg.tf_inception_model_dir
    self.tf_fid_stat                  = cfg.tf_fid_stat
    self.num_inception_images         = getattr(cfg, 'num_inception_images', 50000)
    self.IS_splits                    = getattr(cfg, 'IS_splits', 10)
    # fmt: on

    self.logger = logging.getLogger('tl')
    ws = comm.get_world_size()
    self.num_inception_images = self.num_inception_images // ws
    self.tf_graph_name = 'FID_IS_Inception_Net'
    if os.path.isfile(self.tf_fid_stat):
      self.logger.info(f'Loading tf_fid_stat: {self.tf_fid_stat}')
      f = np.load(self.tf_fid_stat)
      self.mu_data, self.sigma_data = f['mu'][:], f['sigma'][:]
      f.close()
    else:
      self.logger.warning(f"tf_fid_stat does not exist: {self.tf_fid_stat}")

    self.tf_inception_model_dir = os.path.expanduser(self.tf_inception_model_dir)
    inception_path = self._check_or_download_inception(self.tf_inception_model_dir)
    self.logger.info('Load tf inception model in %s', inception_path)
    self._create_inception_graph(inception_path, name=self.tf_graph_name)
    self._create_inception_net()
    comm.synchronize()
예제 #3
0
    def calculate_fid_stat_of_dataloader(self,
                                         data_loader,
                                         sample_func=None,
                                         return_fid_stat=False,
                                         num_images=float('inf'),
                                         save_fid_stat=True):
        if sample_func is None:

            class SampleClass(object):
                def __init__(self, data_loader):
                    self.data_iter = iter(data_loader)

                def __call__(self, *args, **kwargs):
                    """
          :return: images: [-1, 1]
          """
                    inputs = next(self.data_iter)
                    # images = [x["image"].to('cuda') for x in inputs]
                    # images = torch.stack(images)
                    images, labels = inputs
                    images = images.to('cuda')
                    return images

            sample_func = SampleClass(data_loader)

        data, label = next(iter(data_loader))
        num_inception_images = len(data) * len(data_loader)
        num_inception_images = min(num_images, num_inception_images)
        pool, logits = self._accumulate_inception_activations(
            sample_func,
            net=self.inception_net,
            num_inception_images=num_inception_images,
            as_numpy=True)

        pool = self._gather_data(pool[:num_inception_images], is_numpy=True)
        logits = self._gather_data(logits[:num_inception_images],
                                   is_numpy=True)

        if return_fid_stat:
            if comm.is_main_process():
                self.logger.info(f"Num of images: {len(pool)}")
                mu, sigma = self._get_FID_stat(pool=pool)
            else:
                mu, sigma = 0, 0
            return mu, sigma

        if comm.is_main_process():
            self.logger.info(f"Num of images: {len(pool)}")
            IS_mean, IS_std = calculate_inception_score(logits, self.IS_splits)
            self.logger.info(f'dataset IS_mean: {IS_mean:.3f} +- {IS_std}')

            if save_fid_stat:
                mu, sigma = self._get_FID_stat(pool=pool)
                self.logger.info(
                    f'Saving torch_fid_stat to {self.torch_fid_stat}')
                os.makedirs(os.path.dirname(self.torch_fid_stat),
                            exist_ok=True)
                np.savez(self.torch_fid_stat, **{'mu': mu, 'sigma': sigma})
        comm.synchronize()
예제 #4
0
    def __call__(self,
                 sample_func,
                 return_fid_stat=False,
                 num_inception_images=None,
                 return_fid_logit=False,
                 stdout=sys.stdout):
        import torch

        class SampleClass(object):
            def __init__(self, sample_func):
                self.sample_func = sample_func

            def __call__(self, *args, **kwargs):
                """
        :return: images: [0, 255]
        """
                images = self.sample_func()
                images = images.mul_(127.5).add_(127.5).clamp_(
                    0.0, 255.0).permute(0, 2, 3, 1).type(torch.uint8)
                images = images.cpu().numpy()
                return images

        sample_func = SampleClass(sample_func)

        if num_inception_images is None:
            num_inception_images = self.num_inception_images
        pred_FIDs, pred_ISs = self._get_activations_with_sample_func(
            sample_func=sample_func,
            num_inception_images=num_inception_images,
            stdout=stdout)

        if return_fid_stat:
            if comm.is_main_process():
                self.logger.info(f"Num of images: {len(pred_FIDs)}")
                mu, sigma = self._calculate_fid_stat(pred_FIDs=pred_FIDs)
            else:
                mu, sigma = 0, 0
            if return_fid_logit:
                return mu, sigma, pred_FIDs
            else:
                return mu, sigma

        if comm.is_main_process():
            self.logger.info(f"Num of images: {len(pred_FIDs)}")
            IS_mean_tf, IS_std_tf = self._calculate_IS(
                pred_ISs=pred_ISs, IS_splits=self.IS_splits)

            # calculate FID stat
            mu = np.mean(pred_FIDs, axis=0)
            sigma = np.cov(pred_FIDs, rowvar=False)
            FID_tf = calculate_frechet_distance(mu, sigma, self.mu_data,
                                                self.sigma_data)

        else:
            FID_tf = IS_mean_tf = IS_std_tf = 0

        del pred_FIDs, pred_ISs
        comm.synchronize()
        return FID_tf, IS_mean_tf, IS_std_tf
예제 #5
0
    def test_case_calculate_fid_stat_CIFAR10():
        from template_lib.d2.data import build_dataset_mapper
        from template_lib.d2template.trainer.base_trainer import build_detection_test_loader
        from template_lib.gans.evaluation import build_GAN_metric
        from template_lib.utils.detection2_utils import D2Utils
        from template_lib.d2.data import build_cifar10

        cfg_str = """
                  update_cfg: true
                  dataset_name: "cifar10_train"
                  IMS_PER_BATCH: 32
                  img_size: 32
                  NUM_WORKERS: 0
                  dataset_mapper_cfg:
                    name: CIFAR10DatasetMapper
                  GAN_metric:
                    tf_fid_stat: "datasets/fid_stats_tf_cifar10.npz"
              """
        config = EasyDict(yaml.safe_load(cfg_str))
        config = TFFIDISScore.update_cfg(config)

        cfg = D2Utils.create_cfg()
        cfg = D2Utils.cfg_merge_from_easydict(cfg, config)

        # fmt: off
        dataset_name = cfg.dataset_name
        IMS_PER_BATCH = cfg.IMS_PER_BATCH
        img_size = cfg.img_size
        NUM_WORKERS = cfg.NUM_WORKERS
        dataset_mapper_cfg = cfg.dataset_mapper_cfg
        GAN_metric = cfg.GAN_metric
        # fmt: on

        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS
        cfg.GAN_metric.tf_fid_stat = cfg.GAN_metric.tf_fid_stat.format(
            dataset_name=dataset_name, img_size=img_size)
        cfg.freeze()

        num_workers = comm.get_world_size()
        batch_size = IMS_PER_BATCH // num_workers

        dataset_mapper = build_dataset_mapper(dataset_mapper_cfg,
                                              img_size=img_size)
        data_loader = build_detection_test_loader(cfg,
                                                  dataset_name=dataset_name,
                                                  batch_size=batch_size,
                                                  mapper=dataset_mapper)

        FID_IS_tf = build_GAN_metric(GAN_metric)
        FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader)

        comm.synchronize()

        pass
예제 #6
0
    def calculate_fid_stat_of_dataloader(self,
                                         data_loader,
                                         sample_func=None,
                                         return_fid_stat=False,
                                         num_images=float('inf'),
                                         stdout=sys.stdout):
        import torch

        if sample_func is None:

            class SampleClass(object):
                def __init__(self, data_loader):
                    self.data_iter = iter(data_loader)

                def __call__(self, *args, **kwargs):
                    inputs = next(self.data_iter)
                    images = [x["image"].to('cuda') for x in inputs]
                    images = torch.stack(images)
                    images = images.mul_(127.5).add_(127.5).clamp_(
                        0.0, 255.0).permute(0, 2, 3, 1).type(torch.uint8)
                    images = images.cpu().numpy()
                    return images

            sample_func = SampleClass(data_loader)

        num_inception_images = len(next(iter(data_loader))) * len(data_loader)
        num_inception_images = min(num_images, num_inception_images)
        pred_FIDs, pred_ISs = self._get_activations_with_sample_func(
            sample_func=sample_func,
            num_inception_images=num_inception_images,
            stdout=stdout)

        if return_fid_stat:
            if comm.is_main_process():
                self.logger.warning(f"Num of images: {len(pred_FIDs)}")
                mu, sigma = self._calculate_fid_stat(pred_FIDs=pred_FIDs)
            else:
                mu, sigma = 0, 0
            return mu, sigma

        if comm.is_main_process():
            self.logger.info(f"Num of images: {len(pred_FIDs)}")
            IS_mean, IS_std = self._calculate_IS(pred_ISs=pred_ISs,
                                                 IS_splits=self.IS_splits)
            self.logger.info(f'dataset IS_mean: {IS_mean:.3f} +- {IS_std}')

            # calculate FID stat
            mu, sigma = self._calculate_fid_stat(pred_FIDs=pred_FIDs)
            self.logger.info(f'Saving tf_fid_stat to {self.tf_fid_stat}')
            os.makedirs(os.path.dirname(self.tf_fid_stat), exist_ok=True)
            np.savez(self.tf_fid_stat, **{'mu': mu, 'sigma': sigma})
        comm.synchronize()
예제 #7
0
    def _get_activations_with_sample_func(self,
                                          sample_func,
                                          num_inception_images,
                                          stdout=sys.stdout,
                                          verbose=True):

        pred_FIDs = []
        pred_ISs = []
        count = 0

        while (count) < num_inception_images:
            if verbose and comm.is_main_process():
                print(
                    '\r',
                    end=
                    f'TF FID IS Score forwarding: [{count}/{num_inception_images}]',
                    file=stdout,
                    flush=True)
            try:
                batch = sample_func()
                # batch_list = comm.gather(data=batch)
                # if len(batch_list) > 0:
                #   batch = np.concatenate(batch_list, axis=0)
            except StopIteration:
                break
            try:
                pred_FID, pred_IS = self.sess.run(
                    [self.FID_pool3, self.IS_softmax],
                    {f'{self.tf_graph_name}/ExpandDims:0': batch})
            except KeyboardInterrupt:
                exit(-1)
            except:
                print(traceback.format_exc())
                continue
            count += len(batch)
            pred_FIDs.append(pred_FID)
            pred_ISs.append(pred_IS)
        if verbose: print(f'rank: {comm.get_rank()}', file=stdout)

        pred_FIDs = np.concatenate(pred_FIDs, 0).squeeze()
        pred_ISs = np.concatenate(pred_ISs, 0)
        # sess.close()

        pred_FIDs = self._gather_numpy_array(pred_FIDs[:num_inception_images])
        pred_ISs = self._gather_numpy_array(pred_ISs[:num_inception_images])
        comm.synchronize()
        return pred_FIDs, pred_ISs
예제 #8
0
    def __call__(self,
                 sample_func,
                 return_fid_stat=False,
                 num_inception_images=None,
                 stdout=sys.stdout):
        start_time = time.time()

        if num_inception_images is None:
            num_inception_images = self.num_inception_images
        pool, logits = self._accumulate_inception_activations(
            sample_func,
            net=self.inception_net,
            num_inception_images=num_inception_images,
            as_numpy=True,
            stdout=stdout)

        pool = self._gather_data(pool[:num_inception_images], is_numpy=True)
        logits = self._gather_data(logits[:num_inception_images],
                                   is_numpy=True)

        if return_fid_stat:
            if comm.is_main_process():
                self.logger.info(f"Num of images: {len(pool)}")
                mu, sigma = self._get_FID_stat(pool=pool)
            else:
                mu, sigma = 0, 0
            return mu, sigma

        if comm.is_main_process():
            self.logger.info(f"Num of images: {len(pool)}")
            IS_mean_torch, IS_std_torch = calculate_inception_score(
                logits, num_splits=self.IS_splits)

            FID_torch = self._calculate_FID(
                pool=pool,
                no_fid=self.no_FID,
                use_torch=self.calculate_FID_use_torch)
        else:
            IS_mean_torch = IS_std_torch = FID_torch = 0

        elapsed_time = time.time() - start_time
        time_str = time.strftime('%H:%M:%S', time.gmtime(elapsed_time))
        self.logger.info('Elapsed time: %s' % (time_str))
        del pool, logits
        comm.synchronize()
        return FID_torch, IS_mean_torch, IS_std_torch
예제 #9
0
    def __call__(self, images, labels, z, iteration, **kwargs):
        """

    :param images:
    :param labels:
    :param z: z.sample()
    :param iteration:
    :param kwargs:
    :return:
    """

        if self.dummy:
            return

        summary_d = collections.defaultdict(dict)

        real = images
        dy = labels
        gy = dy

        self.G.train()
        self.D.train()
        self.D.zero_grad()

        d_real = self.D(real, dy)
        d_real_mean = d_real.mean()
        summary_d['d_logit_mean']['d_real_mean'] = d_real_mean.item()

        z_sample = z.sample()
        z_sample = z_sample.to(self.device)
        fake = self.G(z_sample, y=gy, **kwargs)
        d_fake = self.D(fake.detach(), gy, **kwargs)
        d_fake_mean = d_fake.mean()
        summary_d['d_logit_mean']['d_fake_mean'] = d_fake_mean.item()

        gp = self.wgan_gp_gradient_penalty_cond(x=real,
                                                G_z=fake,
                                                gy=gy,
                                                f=self.D,
                                                backward=False,
                                                gp_lambda=self.gp_lambda)
        summary_d['gp']['gp'] = gp.item()

        wd = d_real_mean - d_fake_mean
        summary_d['wd']['wd'] = wd.item()
        d_loss = -wd + gp

        d_loss.backward()
        self.D_optim.step()
        summary_d['d_loss']['d_loss'] = d_loss.item()

        ############################
        # (2) Update G network
        ###########################
        if iteration % self.n_critic == 0:
            self.G.zero_grad()
            z_sample = z.sample()
            z_sample = z_sample.to(self.device)
            gy = dy

            fake = self.G(z_sample, y=gy, **kwargs)
            d_fake_g = self.D(fake, gy, **kwargs)
            d_fake_g_mean = d_fake_g.mean()

            g_loss = -d_fake_g_mean
            g_loss.backward()
            summary_d['d_logit_mean']['d_fake_g_mean'] = d_fake_g_mean.item()
            summary_d['g_loss']['g_loss'] = g_loss.item()

            if self.child_grad_bound > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.G.parameters(), self.child_grad_bound)
                summary_d['grad_norm']['grad_norm'] = grad_norm

            self.G_optim.step()

        if iteration % self.log_every == 0:
            Trainer.summary_defaultdict2txtfig(
                default_dict=summary_d,
                prefix='WGANGPCond',
                step=iteration,
                textlogger=self.myargs.textlogger)

        comm.synchronize()
        return
예제 #10
0
    def train_controller(self, searched_cnn, valid_dataset_iter,
                         preprocess_image_func, controller, controller_optim,
                         iteration, pbar):
        """

    :param controller: for ddp training
    :return:
    """
        if comm.is_main_process() and iteration % 1000 == 0:
            pbar.set_postfix_str("ClsControllerRLAlphaFair")

        meter_dict = {}

        controller.train()
        controller.zero_grad()

        sampled_arcs = controller()
        sample_entropy = get_ddp_attr(controller, 'sample_entropy')
        sample_log_prob = get_ddp_attr(controller, 'sample_log_prob')

        val_data = next(valid_dataset_iter)
        bs = len(val_data)
        batched_arcs = sampled_arcs.repeat(bs, 1)

        top1 = AverageMeter()
        for i in range(self.num_aggregate):
            val_data = next(valid_dataset_iter)
            val_X, val_y = preprocess_image_func(val_data, device=self.device)
            val_X = val_X.tensor
            with torch.set_grad_enabled(False):
                logits = searched_cnn(val_X, batched_arcs=batched_arcs)
                prec1, = top_accuracy(output=logits, target=val_y, topk=(1, ))
                top1.update(prec1.item(), bs)

        reward_g = top1.avg
        meter_dict['reward_g'] = reward_g

        # detach to make sure that gradients aren't backpropped through the reward
        reward = torch.tensor(reward_g).cuda()
        sample_entropy_mean = sample_entropy.mean()
        meter_dict['sample_entropy'] = sample_entropy_mean.item()
        reward += self.entropy_weight * sample_entropy_mean

        if self.baseline is None:
            baseline = torch.tensor(reward_g)
        else:
            baseline = self.baseline - (1 - self.bl_dec) * (self.baseline -
                                                            reward)
            # detach to make sure that gradients are not backpropped through the baseline
            baseline = baseline.detach()

        sample_log_prob_mean = sample_log_prob.mean()
        meter_dict['sample_log_prob'] = sample_log_prob_mean.item()
        loss = -1 * sample_log_prob_mean * (reward - baseline)

        meter_dict['reward'] = reward.item()
        meter_dict['baseline'] = baseline.item()
        meter_dict['loss'] = loss.item()

        loss.backward(retain_graph=False)

        grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(),
                                                   self.child_grad_bound)
        meter_dict['grad_norm'] = grad_norm

        controller_optim.step()

        baseline_list = comm.all_gather(baseline)
        baseline_mean = sum(map(lambda v: v.item(),
                                baseline_list)) / len(baseline_list)
        baseline.fill_(baseline_mean)
        self.baseline = baseline

        if iteration % self.log_every_iter == 0:
            self.print_distribution(iteration=iteration,
                                    log_prob=False,
                                    print_interval=10)
            default_dicts = collections.defaultdict(dict)
            for meter_k, meter in meter_dict.items():
                if meter_k in ['reward', 'baseline']:
                    default_dicts['reward_baseline'][meter_k] = meter
                else:
                    default_dicts[meter_k][meter_k] = meter
            summary_defaultdict2txtfig(default_dict=default_dicts,
                                       prefix='train_controller',
                                       step=iteration,
                                       textlogger=self.myargs.textlogger)
        comm.synchronize()
        return
예제 #11
0
    def __call__(self, images, labels, z, iteration, ema=None, **kwargs):
        """

    :param images:
    :param labels:
    :param z: z.sample()
    :param iteration:
    :param kwargs:
    :return:
    """

        if self.dummy:
            return

        summary_d = collections.defaultdict(dict)

        real = images
        dy = labels
        gy = dy

        self.G.train()
        self.D.train()
        self.D.zero_grad()

        d_real = self.D(real, y=dy, **kwargs)

        z_sample = z.sample()
        z_sample = z_sample.to(self.device)
        fake = self.G(z_sample, y=gy, **kwargs)
        d_fake = self.D(fake.detach(), y=gy, **kwargs)

        r_logit_mean, f_logit_mean, d_loss = self.hinge_loss_discriminator(
            r_logit=d_real, f_logit=d_fake)
        summary_d['d_logit_mean']['r_logit_mean'] = r_logit_mean.item()
        summary_d['d_logit_mean']['f_logit_mean'] = f_logit_mean.item()

        d_loss.backward()
        self.D_optim.step()
        summary_d['d_loss']['d_loss'] = d_loss.item()

        ############################
        # (2) Update G network
        ###########################
        if iteration % self.n_critic == 0:
            self.G.zero_grad()
            z_sample = z.sample()
            z_sample = z_sample.to(self.device)
            gy = dy

            fake = self.G(z_sample, y=gy, **kwargs)
            d_fake_g = self.D(fake, y=gy, **kwargs)

            G_f_logit_mean, g_loss = self.hinge_loss_generator(
                f_logit=d_fake_g)
            summary_d['d_logit_mean']['G_f_logit_mean'] = G_f_logit_mean.item()
            summary_d['g_loss']['g_loss'] = g_loss.item()

            g_loss.backward()
            self.G_optim.step()

            if ema is not None:
                ema.update(iteration)

        if iteration % self.log_every == 0:
            Trainer.summary_defaultdict2txtfig(default_dict=summary_d,
                                               prefix='HingeLossCond',
                                               step=iteration,
                                               textlogger=global_textlogger)

        comm.synchronize()
        return
예제 #12
0
    def train_controller(self, G, z, y, controller, controller_optim,
                         iteration, pbar):
        """

    :param controller: for ddp training
    :return:
    """
        if comm.is_main_process() and iteration % 1000 == 0:
            pbar.set_postfix_str("ControllerRLAlpha")

        meter_dict = {}

        G.eval()
        controller.train()

        controller.zero_grad()

        sampled_arcs = controller(iteration)

        sample_entropy = get_ddp_attr(controller, 'sample_entropy')
        sample_log_prob = get_ddp_attr(controller, 'sample_log_prob')

        pool_list, logits_list = [], []
        for i in range(self.num_aggregate):
            z_samples = z.sample().to(self.device)
            y_samples = y.sample().to(self.device)
            with torch.set_grad_enabled(False):
                batched_arcs = sampled_arcs[y_samples]
                x = G(z=z_samples, y=y_samples, batched_arcs=batched_arcs)

            pool, logits = self.FID_IS.get_pool_and_logits(x)

            # pool_list.append(pool)
            logits_list.append(logits)

        # pool = np.concatenate(pool_list, 0)
        logits = np.concatenate(logits_list, 0)

        reward_g, _ = self.FID_IS.calculate_IS(logits)
        meter_dict['reward_g'] = reward_g

        # detach to make sure that gradients aren't backpropped through the reward
        reward = torch.tensor(reward_g).cuda()
        sample_entropy_mean = sample_entropy.mean()
        meter_dict['sample_entropy'] = sample_entropy_mean.item()
        reward += self.entropy_weight * sample_entropy_mean

        if self.baseline is None:
            baseline = torch.tensor(reward_g)
        else:
            baseline = self.baseline - (1 - self.bl_dec) * (self.baseline -
                                                            reward)
            # detach to make sure that gradients are not backpropped through the baseline
            baseline = baseline.detach()

        sample_log_prob_mean = sample_log_prob.mean()
        meter_dict['sample_log_prob'] = sample_log_prob_mean.item()
        loss = -1 * sample_log_prob_mean * (reward - baseline)

        meter_dict['reward'] = reward.item()
        meter_dict['baseline'] = baseline.item()
        meter_dict['loss'] = loss.item()

        loss.backward(retain_graph=False)

        grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(),
                                                   self.child_grad_bound)
        meter_dict['grad_norm'] = grad_norm

        controller_optim.step()

        baseline_list = comm.all_gather(baseline)
        baseline_mean = sum(map(lambda v: v.item(),
                                baseline_list)) / len(baseline_list)
        baseline.fill_(baseline_mean)
        self.baseline = baseline

        if iteration % self.log_every_iter == 0:
            self.print_distribution(iteration=iteration, print_interval=10)
            if len(sampled_arcs) <= 10:
                self.logger.info('\nsampled arcs: \n%s' %
                                 sampled_arcs.cpu().numpy())
            self.myargs.textlogger.logstr(
                iteration,
                sampled_arcs='\n' +
                np.array2string(sampled_arcs.cpu().numpy(), threshold=np.inf))
            default_dicts = collections.defaultdict(dict)
            for meter_k, meter in meter_dict.items():
                if meter_k in ['reward', 'baseline']:
                    default_dicts['reward_baseline'][meter_k] = meter
                else:
                    default_dicts[meter_k][meter_k] = meter
            summary_defaultdict2txtfig(default_dict=default_dicts,
                                       prefix='train_controller',
                                       step=iteration,
                                       textlogger=self.myargs.textlogger)
        comm.synchronize()
        return