예제 #1
0
def build_hooks(cfg, trainer):
    for _hook in cfg.trainer.hooks:
        if _hook.type == "CheckpointHook":
            hook = build_from_cfg(_hook,
                                  HOOK_REGISTRY,
                                  output_dir=cfg.output_dir)
        else:
            hook = build_from_cfg(_hook, HOOK_REGISTRY)
        trainer.register_hook(_hook.type, hook)
예제 #2
0
def get_dataloader(cfg, imsize, get_transition_value, is_train):
    cfg_data = cfg.data_val
    if is_train:
        cfg_data = cfg.data_train
    if cfg_data.dataset.type == "MNISTDataset":
        assert cfg.models.pose_size == 0
    transform = build_transforms(cfg_data.transforms, imsize=imsize)
    dataset = build_from_cfg(cfg_data.dataset,
                             DATASET_REGISTRY,
                             imsize=imsize,
                             transform=transform)
    batch_size = cfg.trainer.batch_size_schedule[imsize]
    dataloader = torch.utils.data.DataLoader(dataset,
                                             pin_memory=False,
                                             collate_fn=fast_collate,
                                             batch_size=batch_size,
                                             **cfg_data.loader)
    dataloader = DataPrefetcher(dataloader, infinite_loader=is_train)
    # If progressive growing, perform GPU image interpolation
    if not cfg.trainer.progressive.enabled:
        return dataloader
    if get_transition_value is not None:
        assert cfg.trainer.progressive.enabled
    dataloader.next = progressive_decorator(dataloader.next,
                                            get_transition_value)
    return dataloader
예제 #3
0
 def build_from_cfg(cfg, discriminator, generator):
     lazy_regularization = cfg.trainer.optimizer.lazy_regularization
     criterions_D = [
         build_from_cfg(criterion,
                        CRITERION_REGISTRY,
                        discriminator=discriminator,
                        lazy_regularization=lazy_regularization)
         for criterion in cfg.discriminator_criterions.values()
         if criterion is not None
     ]
     criterions_G = [
         build_from_cfg(criterion,
                        CRITERION_REGISTRY,
                        discriminator=discriminator)
         for criterion in cfg.generator_criterions.values()
         if criterion is not None
     ]
     return LossOptimizer(discriminator, generator, criterions_D,
                          criterions_G, **cfg.trainer.optimizer)
예제 #4
0
def build_discriminator(cfg, data_parallel):
    discriminator = build_from_cfg(cfg.models.discriminator,
                                   DISCRIMINATOR_REGISTRY,
                                   cfg=cfg,
                                   max_imsize=cfg.models.max_imsize,
                                   pose_size=cfg.models.pose_size,
                                   image_channels=cfg.models.image_channels,
                                   conv_size=cfg.models.conv_size)
    if data_parallel:
        discriminator = NetworkWrapper(discriminator)
    discriminator = extend_model(cfg, discriminator)
    return discriminator
예제 #5
0
def build_generator(cfg, data_parallel):
    generator = build_from_cfg(cfg.models.generator,
                               GENERATOR_REGISTRY,
                               cfg=cfg,
                               max_imsize=cfg.models.max_imsize,
                               conv_size=cfg.models.conv_size,
                               image_channels=cfg.models.image_channels,
                               pose_size=cfg.models.pose_size)
    if data_parallel:
        generator = NetworkWrapper(generator)
    generator = extend_model(cfg, generator)
    return generator
예제 #6
0
def build_detector(cfg, *args, **kwargs):
    print(cfg)
    return build_from_cfg(cfg, DETECTOR_REGISTRY, *args, **kwargs)
예제 #7
0
def build_transforms(transforms, imsize):
    transforms = [
        build_from_cfg(t, TRANSFORM_REGISTRY, imsize=imsize)
        for t in transforms
    ]
    return torchvision.transforms.Compose(transforms)