Example #1
0
    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                    constant_init(m, 1)

            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
        else:
            raise TypeError('pretrained must be a str or None')
Example #2
0
    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                    constant_init(m, 1)

            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottle2neck):
                        # dcn in Res2Net bottle2neck is in ModuleList
                        for n in m.convs:
                            if hasattr(n, 'conv_offset'):
                                constant_init(n.conv_offset, 0)

            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottle2neck):
                        constant_init(m.norm3, 0)
        else:
            raise TypeError('pretrained must be a str or None')
Example #3
0
    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.features.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, nn.BatchNorm2d):
                    constant_init(m, 1)
                elif isinstance(m, nn.Linear):
                    normal_init(m, std=0.01)
        else:
            raise TypeError('pretrained must be a str or None')

        for m in self.extra.modules():
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform')

        constant_init(self.l2_norm, self.l2_norm.scale)
Example #4
0
 def init_weights(self, pretrained: dict):
     if pretrained is not None:
         assert isinstance(pretrained, dict)
         logger = get_root_logger()
         print_log(f'load decoder weight from: {pretrained}', logger=logger)
         self.transformer.init_weights(
             pretrained=pretrained['transformer_pretrained'])
Example #5
0
    def init_weights(self, pretrained=None):
        """Initialize the weights in Encoder

        """
        if pretrained is not None:
            assert isinstance(pretrained, dict)
            logger = get_root_logger()
            print_log(f'load encoder weight from: {pretrained}',
                      logger=logger)
            self.backbone.init_weights(pretrained=pretrained['backbone_pretrained'])
Example #6
0
 def init_weights(self, pretrained: str):
     if pretrained is not None:
         logger = get_root_logger()
         load_checkpoint(self,
                         pretrained,
                         strict=False,
                         logger=logger,
                         map_location=lambda storage, loc: storage)
     else:
         """Initialize the weights of FPN module."""
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
                 xavier_init(m, distribution='uniform')
Example #7
0
    def init_weights(self, pretrained: dict):
        """Initialize the weights in Captioning Model.

        Args:
            pretrained (dict, optional): Path Dict to pre-trained weights.
                Defaults to None.
        """
        if pretrained is not None:
            assert isinstance(pretrained, dict)
            logger = get_root_logger()
            print_log(f'load weight from: {pretrained}', logger=logger)
            self.encoder.init_weights(pretrained['encoder_pretrained'])
            self.decoder.init_weights(pretrained['decoder_pretrained'])
Example #8
0
def train_caption_model(model,
                        dataset,
                        cfg,
                        distributed=False,
                        validate=False,
                        timestamp=None,
                        meta=None):
    logger = get_root_logger(cfg.log_level)

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # cfg.gpus will be ignored if distributed
            len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed) for ds in dataset
    ]

    # put model on gpus
    if distributed:
        #find_unused_parameters = cfg.get('find_unused_parameters', False)
        find_unused_parameters = cfg.get('find_unused_parameters', True)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)
    else:
        model = MMDataParallel(model.cuda(cfg.gpu_ids[0]),
                               device_ids=cfg.gpu_ids)
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print("N_PARAMETERS", n_parameters)
    print('--------------------------------------')
    # build runner
    # AdamW Optimizer
    # TODO -> build_optimizer 구현

    param_dicts = [
        {"names": [n for n, p in model.named_parameters() \
                if "backbone" in n and p.requires_grad],
         "params": [p for n, p in model.named_parameters() \
                if "backbone" in n and p.requires_grad],
         "lr": cfg.lr_dict.lr_backbone},
        {"names": [n for n, p in model.named_parameters() \
                if "backbone" not in n and p.requires_grad],
         "params": [p for n, p in model.named_parameters() \
                if "backbone" not in n and p.requires_grad]},
    ]
    #optimizer = build_optimizer(model, cfg.optimizer)
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=cfg.lr_dict.lr,
                                  weight_decay=cfg.weight_decay)

    # nondistubuted -> TextGenerateRunner
    # distributed -> EpochBasedRunner
    if not distributed:
        runner = TextGenerateRunner(model,
                                    optimizer=optimizer,
                                    work_dir=cfg.work_dir,
                                    logger=logger,
                                    meta=meta)
        # default 50 batch 마다 하나의 샘플에 대해서 문장 생성함
        runner.set_gen_iter(cfg.log_config.interval)
        # set tokenizer for train sample generation
        runner.set_tokenizer(dataset[0].tokenizer)
        # set decoding method for train sample generation
        runner.set_decoding_cfg(cfg.train_cfg.decoding_cfg)
    else:  # distributed
        runner = EpochBasedRunner(model,
                                  optimizer=optimizer,
                                  work_dir=cfg.work_dir,
                                  logger=logger,
                                  meta=meta)

    # an ugly workaround to make .log and .log.json filenames the same
    # TODO -> Docker 시간 설정
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config,
                                             **fp16_cfg,
                                             distributed=distributed)
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))
    if distributed:
        runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        # TODO : Support batch_size > 1 in validation
        val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
        if val_samples_per_gpu > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.val.pipeline = replace_ImageToTensor(
                cfg.data.val.pipeline)
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
        val_dataloader = build_dataloader(
            val_dataset,
            samples_per_gpu=val_samples_per_gpu,
            workers_per_gpu=cfg.data.workers_per_gpu,
            dist=distributed,
            shuffle=False)
        eval_cfg = cfg.get('evaluation', {})
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
Example #9
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # import modules from string list.
    if cfg.get('custom_imports', None):
        from mmcv.utils import import_modules_from_strings
        import_modules_from_strings(**cfg['custom_imports'])
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('checkpoints', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./checkpoints',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init the logger before other steps
    #TODO timezone 변경
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info
    meta['config'] = cfg.pretty_text

    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')

    # set random seeds
    if args.seed is not None:
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed
    meta['exp_name'] = osp.basename(args.config)

    # get model by task
    model = build_caption(cfg.model)

    # Including tokenizer config in dataset config
    cfg.data.train.tokenizer = cfg.tokenizer
    cfg.data.val.tokenizer = cfg.tokenizer
    cfg.data.test.tokenizer = cfg.tokenizer

    # build dataset
    datasets = [build_dataset(cfg.data.train)]
    # debug
    #x = datasets[0].__getitem__(0)
    #    for i in range(100):
    #        x = datasets[0].__getitem__(i)

    # not used
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))

    if cfg.checkpoint_config is not None:
        cfg.checkpoint_config.meta = dict(mmcap_version=__version__ +
                                          get_git_hash()[:7])  #,

    # train
    train_caption_model(model,
                        datasets,
                        cfg,
                        distributed=distributed,
                        validate=(not args.no_validate),
                        timestamp=timestamp,
                        meta=meta)