def test_case_calculate_fid_stat_CIFAR():
    from template_lib.d2.data import build_dataset_mapper
    from template_lib.d2template.trainer.base_trainer import build_detection_test_loader
    from template_lib.v2.GAN.evaluation import build_GAN_metric
    from template_lib.d2.utils.d2_utils import D2Utils
    from template_lib.v2.config_cfgnode import global_cfg

    from detectron2.utils import logger
    logger.setup_logger('d2')

    cfg = D2Utils.create_cfg()
    cfg.update(global_cfg)
    global_cfg.merge_from_dict(cfg)

    # fmt: off
    dataset_name                 = cfg.dataset_name
    IMS_PER_BATCH                = cfg.IMS_PER_BATCH
    img_size                     = cfg.img_size
    dataset_mapper_cfg           = cfg.dataset_mapper_cfg
    GAN_metric                   = cfg.GAN_metric
    # fmt: on

    num_workers = comm.get_world_size()
    batch_size = IMS_PER_BATCH // num_workers

    dataset_mapper = build_dataset_mapper(dataset_mapper_cfg, img_size=img_size)
    data_loader = build_detection_test_loader(
      cfg, dataset_name=dataset_name, batch_size=batch_size, mapper=dataset_mapper)

    FID_IS_tf = build_GAN_metric(GAN_metric)
    FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader)

    comm.synchronize()

    pass
Exemplo n.º 2
0
    def build_test_loader(self, cfg, dataset_name, batch_size, dataset_mapper):

        if dataset_mapper is not None:
            dataset_mapper = build_dataset_mapper(dataset_mapper)

        data_loader = build_detection_test_loader(cfg,
                                                  dataset_name=dataset_name,
                                                  batch_size=batch_size,
                                                  mapper=dataset_mapper)
        return data_loader
Exemplo n.º 3
0
    def test_case_calculate_fid_stat_CIFAR10():
        from template_lib.d2.data import build_dataset_mapper
        from template_lib.d2template.trainer.base_trainer import build_detection_test_loader
        from template_lib.gans.evaluation import build_GAN_metric
        from template_lib.utils.detection2_utils import D2Utils
        from template_lib.d2.data import build_cifar10

        cfg_str = """
                  update_cfg: true
                  dataset_name: "cifar10_train"
                  IMS_PER_BATCH: 32
                  img_size: 32
                  NUM_WORKERS: 0
                  dataset_mapper_cfg:
                    name: CIFAR10DatasetMapper
                  GAN_metric:
                    tf_fid_stat: "datasets/fid_stats_tf_cifar10.npz"
              """
        config = EasyDict(yaml.safe_load(cfg_str))
        config = TFFIDISScore.update_cfg(config)

        cfg = D2Utils.create_cfg()
        cfg = D2Utils.cfg_merge_from_easydict(cfg, config)

        # fmt: off
        dataset_name = cfg.dataset_name
        IMS_PER_BATCH = cfg.IMS_PER_BATCH
        img_size = cfg.img_size
        NUM_WORKERS = cfg.NUM_WORKERS
        dataset_mapper_cfg = cfg.dataset_mapper_cfg
        GAN_metric = cfg.GAN_metric
        # fmt: on

        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS
        cfg.GAN_metric.tf_fid_stat = cfg.GAN_metric.tf_fid_stat.format(
            dataset_name=dataset_name, img_size=img_size)
        cfg.freeze()

        num_workers = comm.get_world_size()
        batch_size = IMS_PER_BATCH // num_workers

        dataset_mapper = build_dataset_mapper(dataset_mapper_cfg,
                                              img_size=img_size)
        data_loader = build_detection_test_loader(cfg,
                                                  dataset_name=dataset_name,
                                                  batch_size=batch_size,
                                                  mapper=dataset_mapper)

        FID_IS_tf = build_GAN_metric(GAN_metric)
        FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader)

        comm.synchronize()

        pass
def compute_fid_stats(cfg, args, myargs):

    dataset_name = cfg.start.dataset_name
    IMS_PER_BATCH = cfg.start.IMS_PER_BATCH
    img_size = cfg.start.img_size
    NUM_WORKERS = cfg.start.NUM_WORKERS

    cfg.defrost()
    cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS
    cfg.GAN_metric.torch_fid_stat = cfg.GAN_metric.torch_fid_stat.format(
        dataset_name=dataset_name, img_size=img_size)
    cfg.GAN_metric.tf_fid_stat = cfg.GAN_metric.tf_fid_stat.format(
        dataset_name=dataset_name, img_size=img_size)
    cfg.freeze()

    num_workers = comm.get_world_size()
    batch_size = IMS_PER_BATCH // num_workers

    dataset_mapper = build_dataset_mapper(cfg.dataset_mapper,
                                          img_size=img_size)
    data_loader = build_detection_test_loader(cfg,
                                              dataset_name=dataset_name,
                                              batch_size=batch_size,
                                              mapper=dataset_mapper)

    metric_dict = build_GAN_metric_dict(cfg)
    if "PyTorchFIDISScore" in metric_dict:
        FID_IS_torch = metric_dict['PyTorchFIDISScore']
        FID_IS_torch.calculate_fid_stat_of_dataloader(data_loader=data_loader,
                                                      stdout=myargs.stdout)
    if "TFFIDISScore" in metric_dict:
        FID_IS_tf = metric_dict['TFFIDISScore']
        FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader,
                                                   stdout=myargs.stdout)

    comm.synchronize()
    return
Exemplo n.º 5
0
def do_train(cfg, args):
    # fmt: off
    run_func = cfg.start.get('run_func', 'train_func')
    dataset_name = cfg.start.dataset_name
    IMS_PER_BATCH = cfg.start.IMS_PER_BATCH * comm.get_world_size()
    NUM_WORKERS = cfg.start.NUM_WORKERS
    dataset_mapper = cfg.start.dataset_mapper

    max_epoch = cfg.start.max_epoch
    checkpoint_period = cfg.start.checkpoint_period

    resume_cfg = get_attr_kwargs(cfg.start, 'resume_cfg', default=None)

    cfg.defrost()
    cfg.DATASETS.TRAIN = (dataset_name, )
    cfg.SOLVER.IMS_PER_BATCH = IMS_PER_BATCH
    cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS
    cfg.freeze()
    # fmt: on

    # build dataset
    mapper = build_dataset_mapper(dataset_mapper)
    data_loader = build_detection_train_loader(cfg, mapper=mapper)
    metadata = MetadataCatalog.get(dataset_name)
    num_samples = metadata.get('num_samples')
    iter_every_epoch = num_samples // IMS_PER_BATCH
    max_iter = iter_every_epoch * max_epoch

    model = build_trainer(cfg=cfg,
                          args=args,
                          iter_every_epoch=iter_every_epoch,
                          batch_size=IMS_PER_BATCH,
                          max_iter=max_iter,
                          metadata=metadata,
                          max_epoch=max_epoch,
                          data_loader=data_loader)
    model.train()

    # optimizer = build_optimizer(cfg, model)
    optims_dict = model.build_optimizer()
    scheduler = model.build_lr_scheduler()

    checkpointer = DetectionCheckpointer(model.get_saved_model(),
                                         cfg.OUTPUT_DIR, **optims_dict,
                                         **scheduler)
    if resume_cfg and resume_cfg.resume:
        resume_ckpt_dir = model._get_ckpt_path(
            ckpt_dir=resume_cfg.ckpt_dir,
            ckpt_epoch=resume_cfg.ckpt_epoch,
            iter_every_epoch=resume_cfg.iter_every_epoch)
        start_iter = (
            checkpointer.resume_or_load(resume_ckpt_dir).get("iteration", -1) +
            1)
        if get_attr_kwargs(resume_cfg, 'finetune', default=False):
            start_iter = 0
        model.after_resume()
    else:
        start_iter = 0

    if run_func != 'train_func':
        eval(f'model.{run_func}()')
        exit(0)

    checkpoint_period = eval(checkpoint_period,
                             dict(iter_every_epoch=iter_every_epoch))
    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 checkpoint_period,
                                                 max_iter=max_iter)
    logger.info("Starting training from iteration {}".format(start_iter))
    # modelarts_utils.modelarts_sync_results(args=myargs.args, myargs=myargs, join=True, end=False)
    with EventStorage(start_iter) as storage:
        pbar = zip(data_loader, range(start_iter, max_iter))
        if comm.is_main_process():
            pbar = tqdm.tqdm(
                pbar,
                desc=f'do_train, {args.tl_time_str}, '
                f'iters {iter_every_epoch} * bs {IMS_PER_BATCH} = '
                f'imgs {iter_every_epoch*IMS_PER_BATCH}',
                initial=start_iter,
                total=max_iter)

        for data, iteration in pbar:
            comm.synchronize()
            iteration = iteration + 1
            storage.step()

            model.train_func(data, iteration - 1, pbar=pbar)

            periodic_checkpointer.step(iteration)
            pass
    # modelarts_utils.modelarts_sync_results(args=myargs.args, myargs=myargs, join=True, end=True)
    comm.synchronize()
Exemplo n.º 6
0
def do_train(cfg, args, myargs):
    run_func = cfg.start.get('run_func', 'train_func')
    dataset_name = cfg.start.dataset_name
    IMS_PER_BATCH = cfg.start.IMS_PER_BATCH
    max_epoch = cfg.start.max_epoch
    ASPECT_RATIO_GROUPING = cfg.start.ASPECT_RATIO_GROUPING
    NUM_WORKERS = cfg.start.NUM_WORKERS
    checkpoint_period = cfg.start.checkpoint_period
    dataset_mapper = cfg.start.dataset_mapper
    resume_ckpt_dir = get_attr_kwargs(cfg.start,
                                      'resume_ckpt_dir',
                                      default=None)
    resume_ckpt_epoch = get_attr_kwargs(cfg.start,
                                        'resume_ckpt_epoch',
                                        default=0)
    resume_ckpt_iter_every_epoch = get_attr_kwargs(
        cfg.start, 'resume_ckpt_iter_every_epoch', default=0)

    cfg.defrost()
    cfg.DATASETS.TRAIN = (dataset_name, )
    cfg.SOLVER.IMS_PER_BATCH = IMS_PER_BATCH
    cfg.DATALOADER.ASPECT_RATIO_GROUPING = ASPECT_RATIO_GROUPING
    cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS
    cfg.freeze()

    # build dataset
    mapper = build_dataset_mapper(dataset_mapper)
    data_loader = build_detection_train_loader(cfg, mapper=mapper)
    metadata = MetadataCatalog.get(dataset_name)
    num_images = metadata.get('num_images')
    iter_every_epoch = num_images // IMS_PER_BATCH
    max_iter = iter_every_epoch * max_epoch

    model = build_trainer(cfg,
                          myargs=myargs,
                          iter_every_epoch=iter_every_epoch,
                          img_size=dataset_mapper.img_size,
                          dataset_name=dataset_name,
                          train_bs=IMS_PER_BATCH,
                          max_iter=max_iter)
    model.train()

    # optimizer = build_optimizer(cfg, model)
    optims_dict = model.build_optimizer()
    # scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = DetectionCheckpointer(model.get_saved_model(),
                                         cfg.OUTPUT_DIR, **optims_dict)
    if args.resume:
        resume_ckpt_dir = model._get_ckpt_path(
            ckpt_dir=resume_ckpt_dir,
            ckpt_epoch=resume_ckpt_epoch,
            iter_every_epoch=resume_ckpt_iter_every_epoch)
        start_iter = (
            checkpointer.resume_or_load(resume_ckpt_dir).get("iteration", -1) +
            1)
        if get_attr_kwargs(args, 'finetune', default=False):
            start_iter = 0
    else:
        start_iter = 0

    model.after_resume()

    if run_func != 'train_func':
        eval(f'model.{run_func}()')
        exit(0)

    checkpoint_period = eval(checkpoint_period,
                             dict(iter_every_epoch=iter_every_epoch))
    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 checkpoint_period,
                                                 max_iter=max_iter)
    logger.info("Starting training from iteration {}".format(start_iter))

    with EventStorage(start_iter) as storage:
        pbar = zip(data_loader, range(start_iter, max_iter))
        if comm.is_main_process():
            pbar = tqdm.tqdm(
                pbar,
                desc=f'do_train, {myargs.args.time_str_suffix}, '
                f'iters {iter_every_epoch} * bs {IMS_PER_BATCH} = imgs {iter_every_epoch*IMS_PER_BATCH}',
                file=myargs.stdout,
                initial=start_iter,
                total=max_iter)

        for data, iteration in pbar:
            comm.synchronize()
            iteration = iteration + 1
            storage.step()

            model.train_func(data, iteration - 1, pbar=pbar)

            periodic_checkpointer.step(iteration)
            pass

    comm.synchronize()
def compute_fid_stats_per_class(cfg, args, myargs):

    imagenet_root_dir = cfg.start.imagenet_root_dir
    dataset_name = cfg.start.dataset_name
    IMS_PER_BATCH = cfg.start.IMS_PER_BATCH
    img_size = cfg.start.img_size
    NUM_WORKERS = cfg.start.NUM_WORKERS
    torch_fid_stat = cfg.GAN_metric.torch_fid_stat
    tf_fid_stat = cfg.GAN_metric.tf_fid_stat

    if dataset_name.startswith('cifar10_train_per_class'):
        from template_lib.d2.data.build_cifar10_per_class import find_classes
    elif dataset_name.startswith('cifar100_train_per_class'):
        from template_lib.d2.data.build_cifar100_per_class import find_classes

    torch_fid_stat = torch_fid_stat.format(dataset_name=dataset_name,
                                           img_size=img_size)
    tf_fid_stat = tf_fid_stat.format(dataset_name=dataset_name,
                                     img_size=img_size)
    cfg.defrost()
    cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS
    cfg.GAN_metric.torch_fid_stat = torch_fid_stat
    cfg.GAN_metric.tf_fid_stat = tf_fid_stat
    cfg.freeze()

    num_workers = comm.get_world_size()
    batch_size = IMS_PER_BATCH // num_workers

    metric_dict = build_GAN_metric_dict(cfg)
    if "PyTorchFIDISScore" in metric_dict:
        os.makedirs(torch_fid_stat, exist_ok=True)
        FID_IS_torch = metric_dict['PyTorchFIDISScore']
    if "TFFIDISScore" in metric_dict:
        os.makedirs(tf_fid_stat, exist_ok=True)
        FID_IS_tf = metric_dict['TFFIDISScore']
        # FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader, stdout=myargs.stdout)

    classes, class_to_idx = find_classes(imagenet_root_dir)
    dataset_mapper = build_dataset_mapper(cfg.dataset_mapper,
                                          img_size=img_size)
    comm.synchronize()

    for class_path, idx in tqdm.tqdm(class_to_idx.items(),
                                     desc=f"compute_fid_stats_per_class"):
        registed_name = f'{dataset_name}_{class_path}'
        data_loader = build_detection_test_loader(cfg,
                                                  dataset_name=registed_name,
                                                  batch_size=batch_size,
                                                  mapper=dataset_mapper)

        if "PyTorchFIDISScore" in metric_dict:
            mu, sigma = FID_IS_torch.calculate_fid_stat_of_dataloader(
                data_loader=data_loader,
                return_fid_stat=True,
                stdout=myargs.stdout)
            if comm.is_main_process():
                np.savez(os.path.join(torch_fid_stat, f'{idx}.npz'), **{
                    'mu': mu,
                    'sigma': sigma
                })

        if "TFFIDISScore" in metric_dict:
            mu, sigma = FID_IS_tf.calculate_fid_stat_of_dataloader(
                data_loader=data_loader,
                return_fid_stat=True,
                stdout=myargs.stdout)
            if comm.is_main_process():
                np.savez(os.path.join(tf_fid_stat, f'{idx}.npz'), **{
                    'mu': mu,
                    'sigma': sigma
                })

    comm.synchronize()
    return
Exemplo n.º 8
0
def train(cfg, args, myargs):
    dataset_name = cfg.start.dataset_name
    IMS_PER_BATCH = cfg.start.IMS_PER_BATCH
    max_epoch = cfg.start.max_epoch
    ASPECT_RATIO_GROUPING = cfg.start.ASPECT_RATIO_GROUPING
    NUM_WORKERS = cfg.start.NUM_WORKERS
    checkpoint_period = cfg.start.checkpoint_period

    cfg.defrost()
    cfg.DATASETS.TRAIN = (dataset_name, )
    cfg.SOLVER.IMS_PER_BATCH = IMS_PER_BATCH
    cfg.DATALOADER.ASPECT_RATIO_GROUPING = ASPECT_RATIO_GROUPING
    cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS
    cfg.freeze()

    # build dataset
    mapper = build_dataset_mapper(cfg)
    data_loader = build_detection_train_loader(cfg, mapper=mapper)
    metadata = MetadataCatalog.get(dataset_name)
    num_images = metadata.get('num_images')
    iter_every_epoch = num_images // IMS_PER_BATCH
    max_iter = iter_every_epoch * max_epoch

    model = build_trainer(cfg,
                          myargs=myargs,
                          iter_every_epoch=iter_every_epoch)
    model.train()

    logger.info("Model:\n{}".format(model))

    # optimizer = build_optimizer(cfg, model)
    optims_dict = model.build_optimizer()
    # scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = DetectionCheckpointer(model.get_saved_model(),
                                         cfg.OUTPUT_DIR, **optims_dict)
    start_iter = (checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=args.resume).get("iteration", -1) + 1)

    checkpoint_period = eval(checkpoint_period,
                             dict(iter_every_epoch=iter_every_epoch))
    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 checkpoint_period,
                                                 max_iter=max_iter)

    logger.info("Starting training from iteration {}".format(start_iter))
    modelarts_utils.modelarts_sync_results(args=myargs.args,
                                           myargs=myargs,
                                           join=True,
                                           end=False)
    with EventStorage(start_iter) as storage:
        pbar = zip(data_loader, range(start_iter, max_iter))
        if comm.is_main_process():
            pbar = tqdm.tqdm(
                pbar,
                desc=f'train, {myargs.args.time_str_suffix}, '
                f'iters {iter_every_epoch} * bs {IMS_PER_BATCH} = imgs {iter_every_epoch*IMS_PER_BATCH}',
                file=myargs.stdout,
                initial=start_iter,
                total=max_iter)

        for data, iteration in pbar:
            comm.synchronize()
            iteration = iteration + 1
            storage.step()

            model.train_func(data, iteration - 1, pbar=pbar)

            periodic_checkpointer.step(iteration)
            pass
    modelarts_utils.modelarts_sync_results(args=myargs.args,
                                           myargs=myargs,
                                           join=True,
                                           end=True)
    comm.synchronize()