예제 #1
0
def test_visual_hook():
    with pytest.raises(AssertionError):
        VisualizationHook('temp', [1, 2, 3])
    test_dataset = ExampleDataset()
    test_dataset.evaluate = MagicMock(return_value=dict(test='success'))

    img = torch.zeros((1, 3, 10, 10))
    img[:, :, 2:9, :] = 1.
    model = ExampleModel()
    data_loader = DataLoader(test_dataset,
                             batch_size=1,
                             sampler=None,
                             num_workers=0,
                             shuffle=False)

    with tempfile.TemporaryDirectory() as tmpdir:
        visual_hook = VisualizationHook('visual', ['img'], interval=8)
        runner = mmcv.runner.IterBasedRunner(model=model,
                                             work_dir=tmpdir,
                                             logger=get_root_logger())
        runner.register_hook(visual_hook)
        runner.run([data_loader], [('train', 10)], 10)
        img_saved = mmcv.imread(osp.join(tmpdir, 'visual', 'iter_8.png'),
                                flag='unchanged')

        np.testing.assert_almost_equal(img_saved,
                                       img[0].permute(1, 2, 0) * 127 + 128)
예제 #2
0
    def init_weights(self, pretrained=None):
        """Init weights for models.

        We just use the initialization method proposed in the original paper.

        Args:
            pretrained (str, optional): Path for pretrained weights. If given
                None, pretrained weights will not be loaded. Defaults to None.
        """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            # As Improved-DDPM, we apply zero-initialization to
            #   second conv block in ResBlock (keywords: conv_2)
            #   the output layer of the Unet (keywords: 'out' but
            #     not 'out_blocks')
            #   projection layer in Attention layer (keywords: proj)
            for n, m in self.named_modules():
                if isinstance(m, nn.Conv2d) and ('conv_2' in n or
                                                 ('out' in n
                                                  and 'out_blocks' not in n)):
                    constant_init(m, 0)
                if isinstance(m, nn.Conv1d) and 'proj' in n:
                    constant_init(m, 0)
        else:
            raise TypeError('pretrained must be a str or None but'
                            f' got {type(pretrained)} instead.')
예제 #3
0
 def init_weights(self, pretrained=None):
     if isinstance(pretrained, str):
         logger = get_root_logger()
         load_checkpoint(self, pretrained, strict=False, logger=logger)
     elif pretrained is None:
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
                 normal_init(m, 0, 0.02)
             elif isinstance(m, (_BatchNorm, nn.InstanceNorm2d)):
                 constant_init(m, 1)
     else:
         raise TypeError('pretrained must be a str or None but'
                         f' got {type(pretrained)} instead.')
    def init_weights(self, pretrained=None):
        """Initialize weights for the model.

        Args:
            pretrained (str, optional): Path for pretrained weights. If given
                None, pretrained weights will not be loaded. Default: None.
        """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            generation_init_weights(
                self, init_type=self.init_type, init_gain=self.init_gain)
        else:
            raise TypeError("'pretrained' must be a str or None. "
                            f'But received {type(pretrained)}.')
    def check_and_load_prev_weight(self, curr_scale):
        if curr_scale == 0:
            return
        prev_ch = self.blocks[curr_scale - 1].base_channels
        curr_ch = self.blocks[curr_scale].base_channels

        prev_in_ch = self.blocks[curr_scale - 1].in_channels
        curr_in_ch = self.blocks[curr_scale].in_channels
        if prev_ch == curr_ch and prev_in_ch == curr_in_ch:
            load_state_dict(self.blocks[curr_scale],
                            self.blocks[curr_scale - 1].state_dict(),
                            logger=get_root_logger())
            print_log('Successfully load pretrianed model from last scale.')
        else:
            print_log(
                'Cannot load pretrained model from last scale since'
                f' prev_ch({prev_ch}) != curr_ch({curr_ch})'
                f' or prev_in_ch({prev_in_ch}) != curr_in_ch({curr_in_ch})')
예제 #6
0
    def init_weights(self, pretrained=None, init_type='ortho'):
        """Init weights for models.

        Args:
            pretrained (str | dict, optional): Path for the pretrained model or
                dict containing information for pretained models whose
                necessary key is 'ckpt_path'. Besides, you can also provide
                'prefix' to load the generator part from the whole state dict.
                Defaults to None.
            init_type (str, optional): The name of an initialization method:
                ortho | N02 | xavier. Defaults to 'ortho'.
        """

        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif isinstance(pretrained, dict):
            ckpt_path = pretrained.get('ckpt_path', None)
            assert ckpt_path is not None
            prefix = pretrained.get('prefix', '')
            map_location = pretrained.get('map_location', 'cpu')
            strict = pretrained.get('strict', True)
            state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
                                                      map_location)
            self.load_state_dict(state_dict, strict=strict)
            mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmgen')
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
                    if init_type == 'ortho':
                        nn.init.orthogonal_(m.weight)
                    elif init_type == 'N02':
                        normal_init(m, 0.0, 0.02)
                    elif init_type == 'xavier':
                        xavier_init(m)
                    else:
                        raise NotImplementedError(
                            f'{init_type} initialization \
                            not supported now.')
        else:
            raise TypeError('pretrained must be a str or None but'
                            f' got {type(pretrained)} instead.')
예제 #7
0
    def init_weights(self, pretrained=None):
        """Init weights for models.

        We just use the initialization method proposed in the original paper.

        Args:
            pretrained (str, optional): Path for pretrained weights. If given
                None, pretrained weights will not be loaded. Defaults to None.
        """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                    normal_init(m, 0, 0.02)
                elif isinstance(m, _BatchNorm):
                    nn.init.normal_(m.weight.data)
                    nn.init.constant_(m.bias.data, 0)
        else:
            raise TypeError('pretrained must be a str or None but'
                            f' got {type(pretrained)} instead.')
예제 #8
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    rank, _ = get_dist_info()

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(
        log_file=log_path, log_level=cfg.log_level, file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        if rank == 0:
            mmcv.print_log(f'set random seed to {args.seed}', 'mmgen')
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()
    if not distributed:
        _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
        model = MMDataParallel(model, device_ids=[0])

        # build metrics
        if args.eval:
            if args.eval[0] == 'none':
                # only sample images
                metrics = []
                assert args.num_samples is not None and args.num_samples > 0
            else:
                metrics = [
                    build_metric(cfg.metrics[metric]) for metric in args.eval
                ]
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in cfg.metrics
            ]

        basic_table_info = dict(
            train_cfg=os.path.basename(cfg._filename),
            ckpt=ckpt,
            sample_model=args.sample_model)

        if len(metrics) == 0:
            basic_table_info['num_samples'] = args.num_samples
            data_loader = None
        else:
            basic_table_info['num_samples'] = -1
            # build the dataloader
            if cfg.data.get('test', None):
                dataset = build_dataset(cfg.data.test)
            elif cfg.data.get('val', None):
                dataset = build_dataset(cfg.data.val)
            else:
                dataset = build_dataset(cfg.data.train)
            data_loader = build_dataloader(
                dataset,
                samples_per_gpu=args.batch_size,
                workers_per_gpu=cfg.data.get('val_workers_per_gpu',
                                             cfg.data.workers_per_gpu),
                dist=distributed,
                shuffle=True)

        if args.sample_cfg is None:
            args.sample_cfg = dict()

        # online mode will not save samples
        if args.online and len(metrics) > 0:
            single_gpu_online_evaluation(model, data_loader, metrics, logger,
                                         basic_table_info, args.batch_size,
                                         **args.sample_cfg)
        else:
            single_gpu_evaluation(model, data_loader, metrics, logger,
                                  basic_table_info, args.batch_size,
                                  args.samples_path, **args.sample_cfg)
    else:
        raise NotImplementedError("We hasn't implemented multi gpu eval yet.")
예제 #9
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(
        log_file=log_path, log_level=cfg.log_level, file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    assert isinstance(model, BaseTranslationModel)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()

    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    model = MMDataParallel(model, device_ids=[0])

    # build metrics
    if args.eval:
        if args.eval[0] == 'none':
            # only sample images
            metrics = []
            assert args.num_samples is not None and args.num_samples > 0
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in args.eval
            ]
    else:
        metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]

    # get source domain and target domain
    target_domain = args.target_domain
    if target_domain is None:
        target_domain = model.module._default_domain
    source_domain = model.module.get_other_domains(target_domain)[0]

    basic_table_info = dict(
        train_cfg=os.path.basename(cfg._filename),
        ckpt=ckpt,
        sample_model=args.sample_model,
        source_domain=source_domain,
        target_domain=target_domain)

    # build the dataloader
    if len(metrics) == 0:
        basic_table_info['num_samples'] = args.num_samples
        data_loader = None
    else:
        basic_table_info['num_samples'] = -1
        if cfg.data.get('test', None):
            dataset = build_dataset(cfg.data.test)
        else:
            dataset = build_dataset(cfg.data.train)
        data_loader = build_dataloader(
            dataset,
            samples_per_gpu=args.batch_size,
            workers_per_gpu=cfg.data.get('val_workers_per_gpu',
                                         cfg.data.workers_per_gpu),
            dist=False,
            shuffle=True)

    if args.online:
        single_gpu_online_evaluation(model, data_loader, metrics, logger,
                                     basic_table_info, args.batch_size)
    else:
        single_gpu_evaluation(model, data_loader, metrics, logger,
                              basic_table_info, args.batch_size,
                              args.samples_path)
예제 #10
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids[0:1]
        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
                      'Because we only support single GPU mode in '
                      'non-distributed testing. Use the first GPU '
                      'in `gpu_ids` now.')
    else:
        cfg.gpu_ids = [args.gpu_id]

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
        rank = 0
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
        rank, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)
        assert args.online or world_size == 1, (
            'We only support online mode for distrbuted evaluation.')

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(log_file=log_path,
                             log_level=cfg.log_level,
                             file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        if rank == 0:
            mmcv.print_log(f'set random seed to {args.seed}', 'mmgen')
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()

    if args.eval:
        if args.eval[0] == 'none':
            # only sample images
            metrics = []
            assert args.num_samples is not None and args.num_samples > 0
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in args.eval
            ]
    else:
        metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]

    # check metrics for dist evaluation
    if distributed and metrics:
        for metric in metrics:
            assert metric.name in _distributed_metrics, (
                f'We only support {_distributed_metrics} for multi gpu '
                f'evaluation, but receive {args.eval}.')

    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')

    basic_table_info = dict(train_cfg=os.path.basename(cfg._filename),
                            ckpt=ckpt,
                            sample_model=args.sample_model)

    if len(metrics) == 0:
        basic_table_info['num_samples'] = args.num_samples
        data_loader = None
    else:
        basic_table_info['num_samples'] = -1
        # build the dataloader
        if cfg.data.get('test', None) and cfg.data.test.get('imgs_root', None):
            dataset = build_dataset(cfg.data.test)
        elif cfg.data.get('val', None) and cfg.data.val.get('imgs_root', None):
            dataset = build_dataset(cfg.data.val)
        elif cfg.data.get('train', None):
            # we assume that the train part should work well
            dataset = build_dataset(cfg.data.train)
        else:
            raise RuntimeError('There is no valid dataset config to run, '
                               'please check your dataset configs.')

        # The default loader config
        loader_cfg = dict(samples_per_gpu=args.batch_size,
                          workers_per_gpu=cfg.data.get(
                              'val_workers_per_gpu', cfg.data.workers_per_gpu),
                          num_gpus=len(cfg.gpu_ids),
                          dist=distributed,
                          shuffle=True)
        # The overall dataloader settings
        loader_cfg.update({
            k: v
            for k, v in cfg.data.items() if k not in [
                'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
                'test_dataloader'
            ]
        })

        # specific config for test loader
        test_loader_cfg = {**loader_cfg, **cfg.data.get('test_dataloader', {})}

        data_loader = build_dataloader(dataset, **test_loader_cfg)
    if args.sample_cfg is None:
        args.sample_cfg = dict()

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
    else:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)

    # online mode will not save samples
    if args.online and len(metrics) > 0:
        online_evaluation(model, data_loader, metrics, logger,
                          basic_table_info, args.batch_size, **args.sample_cfg)
    else:
        offline_evaluation(model, data_loader, metrics, logger,
                           basic_table_info, args.batch_size,
                           args.samples_path, **args.sample_cfg)
예제 #11
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    dirname = os.path.dirname(args.checkpoint)
    ckpt = os.path.basename(args.checkpoint)

    if 'http' in args.checkpoint:
        log_path = None
    else:
        log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
        log_path = os.path.join(dirname, log_name)

    logger = get_root_logger(log_file=log_path,
                             log_level=cfg.log_level,
                             file_mode='a')
    logger.info('evaluation')

    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)
    assert isinstance(model, _supported_model)
    # sanity check for models without ema
    if not model.use_ema:
        args.sample_model = 'orig'

    mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')

    model.eval()

    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
    model = MMDataParallel(model, device_ids=[0])

    # build metrics
    if args.eval:
        if args.eval[0] == 'none':
            # only sample images
            metrics = []
            assert args.num_samples is not None and args.num_samples > 0
        else:
            metrics = [
                build_metric(cfg.metrics[metric]) for metric in args.eval
            ]
    else:
        metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]

    basic_table_info = dict(train_cfg=os.path.basename(cfg._filename),
                            ckpt=ckpt,
                            sample_model=args.sample_model)

    # build the dataloader
    if len(metrics) == 0:
        basic_table_info['num_samples'] = args.num_samples
        data_loader = None
    else:
        basic_table_info['num_samples'] = -1
        if cfg.data.get('test', None):
            dataset = build_dataset(cfg.data.test)
        else:
            dataset = build_dataset(cfg.data.train)
        data_loader = build_dataloader(dataset,
                                       samples_per_gpu=args.batch_size,
                                       workers_per_gpu=cfg.data.get(
                                           'val_workers_per_gpu',
                                           cfg.data.workers_per_gpu),
                                       dist=False,
                                       shuffle=True)

    # decide samples path
    samples_path = args.samples_path
    delete_samples_path = False
    if samples_path:
        mmcv.mkdir_or_exist(samples_path)
    else:
        temp_path = './work_dirs/temp_samples'
        # if temp_path exists, add suffix
        suffix = 1
        samples_path = temp_path
        while os.path.exists(samples_path):
            samples_path = temp_path + '_' + str(suffix)
            suffix += 1
        os.makedirs(samples_path)
        delete_samples_path = True

    # sample images
    num_exist = len(
        list(
            mmcv.scandir(samples_path,
                         suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
    if basic_table_info['num_samples'] > 0:
        max_num_images = basic_table_info['num_samples']
    else:
        max_num_images = max(metric.num_images for metric in metrics)
    num_needed = max(max_num_images - num_exist, 0)

    if num_needed > 0:
        mmcv.print_log(f'Sample {num_needed} fake images for evaluation',
                       'mmgen')
        # define mmcv progress bar
        pbar = mmcv.ProgressBar(num_needed)
    # select key to fetch fake images
    fake_key = 'fake_b'
    if isinstance(model.module, CycleGAN):
        fake_key = 'fake_b' if model.module.test_direction == 'a2b' else \
            'fake_a'
    # if no images, `num_exist` should be zero
    for begin in range(num_exist, num_needed, args.batch_size):
        end = min(begin + args.batch_size, max_num_images)
        # for translation model, we feed them images from dataloader
        data_loader_iter = iter(data_loader)
        data_batch = next(data_loader_iter)
        output_dict = model(test_mode=True, **data_batch)
        fakes = output_dict[fake_key]
        pbar.update(end - begin)
        for i in range(end - begin):
            images = fakes[i:i + 1]
            images = ((images + 1) / 2)
            images = images[:, [2, 1, 0], ...]
            images = images.clamp_(0, 1)
            image_name = str(begin + i) + '.png'
            save_image(images, os.path.join(samples_path, image_name))

    if num_needed > 0:
        sys.stdout.write('\n')

    # return if only save sampled images
    if len(metrics) == 0:
        return

    # empty cache to release GPU memory
    torch.cuda.empty_cache()
    fake_dataloader = make_vanilla_dataloader(samples_path, args.batch_size)
    # select key to fetch real images
    if isinstance(model.module, CycleGAN):
        real_key = 'img_b' if model.module.test_direction == 'a2b' else 'img_a'
        if model.module.direction == 'b2a':
            real_key = 'img_a' if real_key == 'img_b' else 'img_b'

    if isinstance(model.module, Pix2Pix):
        real_key = 'img_b' if model.module.direction == 'a2b' else 'img_a'

    for metric in metrics:
        mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
        metric.prepare()
        # feed in real images
        for data in data_loader:
            reals = data[real_key]
            num_left = metric.feed(reals, 'reals')
            if num_left <= 0:
                break
        # feed in fake images
        for data in fake_dataloader:
            fakes = data['real_img']
            num_left = metric.feed(fakes, 'fakes')
            if num_left <= 0:
                break
        metric.summary()
    table_str = make_metrics_table(basic_table_info['train_cfg'],
                                   basic_table_info['ckpt'],
                                   basic_table_info['sample_model'], metrics)
    logger.info('\n' + table_str)
    if delete_samples_path:
        shutil.rmtree(samples_path)
예제 #12
0
def train_model(model,
                dataset,
                cfg,
                distributed=False,
                validate=False,
                timestamp=None,
                meta=None):
    logger = get_root_logger(cfg.log_level)

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # cfg.gpus will be ignored if distributed
            len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed) for ds in dataset
    ]

    # put model on gpus
    if distributed:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        use_ddp_wrapper = cfg.get('use_ddp_wrapper', False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        if use_ddp_wrapper:
            mmcv.print_log('Use DDP Wrapper.', 'mmgen')
            model = DistributedDataParallelWrapper(
                model.cuda(),
                device_ids=[torch.cuda.current_device()],
                broadcast_buffers=False,
                find_unused_parameters=find_unused_parameters)
        else:
            model = MMDistributedDataParallel(
                model.cuda(),
                device_ids=[torch.cuda.current_device()],
                broadcast_buffers=False,
                find_unused_parameters=find_unused_parameters)
    else:
        model = MMDataParallel(model.cuda(cfg.gpu_ids[0]),
                               device_ids=cfg.gpu_ids)

    # build runner
    if cfg.optimizer:
        optimizer = build_optimizers(model, cfg.optimizer)
    # In GANs, we allow building optimizer in GAN model.
    else:
        optimizer = None

    # allow users to define the runner
    if cfg.get('runner', None):
        runner = build_runner(
            cfg.runner,
            dict(model=model,
                 optimizer=optimizer,
                 work_dir=cfg.work_dir,
                 logger=logger,
                 meta=meta))
    else:
        runner = IterBasedRunner(model,
                                 optimizer=optimizer,
                                 work_dir=cfg.work_dir,
                                 logger=logger,
                                 meta=meta)
        # set if use dynamic ddp in training
        # is_dynamic_ddp=cfg.get('is_dynamic_ddp', False))
    # an ugly walkaround to make the .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)

    # In GANs, we can directly optimize parameter in `train_step` function.
    if cfg.get('optimizer_cfg', None) is None:
        optimizer_config = None
    elif fp16_cfg is not None:
        raise NotImplementedError('Fp16 has not been supported.')
        # optimizer_config = Fp16OptimizerHook(
        #     **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
    # default to use OptimizerHook
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # update `out_dir` in  ckpt hook
    if cfg.checkpoint_config is not None:
        cfg.checkpoint_config['out_dir'] = os.path.join(
            cfg.work_dir, cfg.checkpoint_config.get('out_dir', 'ckpt'))

    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))

    # # DistSamplerSeedHook should be used with EpochBasedRunner
    # if distributed:
    #     runner.register_hook(DistSamplerSeedHook())

    # In general, we do NOT adopt standard evaluation hook in GAN training.
    # Thus, if you want a eval hook, you need further define the key of
    # 'evaluation' in the config.
    # register eval hooks
    if validate and cfg.get('evaluation', None) is not None:
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
        # Support batch_size > 1 in validation
        val_loader_cfg = {
            'samples_per_gpu': 1,
            'shuffle': False,
            'workers_per_gpu': cfg.data.workers_per_gpu,
            **cfg.data.get('val_data_loader', {})
        }
        val_dataloader = build_dataloader(val_dataset,
                                          dist=distributed,
                                          **val_loader_cfg)
        eval_cfg = deepcopy(cfg.get('evaluation'))
        eval_cfg.update(dict(dist=distributed, dataloader=val_dataloader))
        eval_hook = build_from_cfg(eval_cfg, HOOKS)
        priority = eval_cfg.pop('priority', 'NORMAL')
        runner.register_hook(eval_hook, priority=priority)

    # user-defined hooks
    if cfg.get('custom_hooks', None):
        custom_hooks = cfg.custom_hooks
        assert isinstance(custom_hooks, list), \
            f'custom_hooks expect list type, but got {type(custom_hooks)}'
        for hook_cfg in cfg.custom_hooks:
            assert isinstance(hook_cfg, dict), \
                'Each item in custom_hooks expects dict type, but got ' \
                f'{type(hook_cfg)}'
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')
            hook = build_from_cfg(hook_cfg, HOOKS)
            runner.register_hook(hook, priority=priority)

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_iters)
예제 #13
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # import modules from string list.
    if cfg.get('custom_imports', None):
        from mmcv.utils import import_modules_from_strings
        import_modules_from_strings(**cfg['custom_imports'])
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
        # re-set gpu_ids with distributed training mode
        _, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info
    meta['config'] = cfg.pretty_text
    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')

    # set random seeds
    if args.seed is not None:
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed
    meta['exp_name'] = osp.basename(args.config)

    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)

    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmgen version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(mmgen_version=__version__ +
                                          get_git_hash()[:7])

    train_model(model,
                datasets,
                cfg,
                distributed=distributed,
                validate=(not args.no_validate),
                timestamp=timestamp,
                meta=meta)