Exemplo n.º 1
0
def main(cfg) -> None:
    logging.info("\n\n************** Experiment configuration ***********")
    logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

    plugins = [NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes)]
    if cfg.trainer.precision == 16:
        plugins.append(
            NLPNativeMixedPrecisionPlugin(
                init_scale=cfg.model.get('native_amp_init_scale', 2**32),
                growth_interval=cfg.model.get('native_amp_growth_interval',
                                              1000),
            ))
    elif cfg.trainer.precision == 'bf16':
        plugins.append(NLPNativeBfloat16PrecisionPlugin())
    else:
        plugins.append(NLPPrecisionPlugin())

    if cfg.get('cluster_type', None) == 'BCP':
        plugins.append(TorchElasticEnvironment())

    trainer = Trainer(plugins=plugins, **cfg.trainer)

    exp_manager(trainer, cfg.exp_manager)

    # update resume from checkpoint found by exp_manager
    resume_from_checkpoint = trainer.resume_from_checkpoint
    if resume_from_checkpoint is not None:
        mp_rank = compute_model_parallel_rank(
            trainer.local_rank, cfg.model.tensor_model_parallel_size)
        resume_from_checkpoint = Path(resume_from_checkpoint)
        resume_from_checkpoint = resume_from_checkpoint.parent.parent.joinpath(
            f'mp_rank_{mp_rank:02d}').joinpath(resume_from_checkpoint.name)
        resume_from_checkpoint = str(resume_from_checkpoint)
        logging.info(
            f'Resuming training from checkpoint: {resume_from_checkpoint}')

    trainer.checkpoint_connector = CheckpointConnector(
        trainer, resume_from_checkpoint=resume_from_checkpoint)
    # Override timer callback to a stateless one
    for idx, callback in enumerate(trainer.callbacks):
        if isinstance(callback, Timer):
            trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, )

    model = MegatronGPTModel(cfg.model, trainer)

    trainer.fit(model)
Exemplo n.º 2
0
def main(cfg) -> None:
    logging.info("\n\n************** Experiment configuration ***********")
    logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

    megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
    fp32_grad_accum = cfg.model.get('fp32_grad_accum', False)
    plugins = [
        NLPDDPPlugin(
            num_nodes=cfg.trainer.num_nodes,
            no_ddp_communication_hook=(megatron_amp_o2 and fp32_grad_accum),
            gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
        )
    ]
    if cfg.trainer.precision in [16, 'bf16']:
        scaler = None
        if cfg.trainer.precision == 16:
            scaler = GradScaler(
                init_scale=cfg.model.get('native_amp_init_scale', 2**32),
                growth_interval=cfg.model.get('native_amp_growth_interval',
                                              1000),
                hysteresis=cfg.model.get('hysteresis', 2),
            )
        if megatron_amp_o2:
            plugins.append(
                MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision,
                                            device='cuda',
                                            scaler=scaler))
        else:
            plugins.append(
                NativeMixedPrecisionPlugin(precision=cfg.trainer.precision,
                                           device='cuda',
                                           scaler=scaler))

    if cfg.get('cluster_type', None) == 'BCP':
        plugins.append(TorchElasticEnvironment())

    trainer = Trainer(plugins=plugins, **cfg.trainer)

    exp_manager(trainer, cfg.exp_manager)

    # update resume from checkpoint found by exp_manager
    resume_from_checkpoint = trainer.checkpoint_connector.resume_from_checkpoint_fit_path
    if resume_from_checkpoint is not None:
        # inject mp_rank into resume_from_checkpoint
        if cfg.model.tensor_model_parallel_size is not None and cfg.model.tensor_model_parallel_size > 1:
            mp_rank = compute_model_parallel_rank(
                trainer.local_rank, cfg.model.tensor_model_parallel_size)
            resume_from_checkpoint = Path(resume_from_checkpoint)
            resume_from_checkpoint = resume_from_checkpoint.parent.parent.joinpath(
                f'mp_rank_{mp_rank:02d}').joinpath(resume_from_checkpoint.name)
            resume_from_checkpoint = str(resume_from_checkpoint)
        logging.info(
            f'Resuming training from checkpoint: {resume_from_checkpoint}')

    trainer.checkpoint_connector = CheckpointConnector(
        trainer, resume_from_checkpoint=resume_from_checkpoint)
    # Override timer callback to a stateless one
    for idx, callback in enumerate(trainer.callbacks):
        if isinstance(callback, Timer):
            trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, )

    # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
    with open_dict(cfg):
        cfg.model.precision = cfg.trainer.precision

    model = MegatronGPTModel(cfg.model, trainer)

    trainer.fit(model)
Exemplo n.º 3
0
def main(cfg) -> None:
    logging.info("\n\n************** Experiment configuration ***********")
    logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

    megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)

    plugins = [
        NLPDDPPlugin(
            no_ddp_communication_hook=
            True,  # we don't use DDP for async grad allreduce
            gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
            find_unused_parameters=False,
        )
    ]
    if cfg.trainer.precision in [16, 'bf16']:
        scaler = None
        if cfg.trainer.precision == 16:
            scaler = GradScaler(
                init_scale=cfg.model.get('native_amp_init_scale', 2**32),
                growth_interval=cfg.model.get('native_amp_growth_interval',
                                              1000),
                hysteresis=cfg.model.get('hysteresis', 2),
            )
        if megatron_amp_o2:
            plugins.append(
                MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision,
                                            device='cuda',
                                            scaler=scaler))
        else:
            plugins.append(
                PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision,
                                             device='cuda',
                                             scaler=scaler))

    if cfg.get('cluster_type', None) == 'BCP':
        plugins.append(TorchElasticEnvironment())

    trainer = Trainer(plugins=plugins, **cfg.trainer)

    exp_manager(trainer, cfg.exp_manager)

    # update resume from checkpoint found by exp_manager
    if cfg.model.resume_from_checkpoint is not None:
        resume_from_checkpoint = cfg.model.resume_from_checkpoint
    else:
        resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path

    logging.info(
        f'Resuming training from checkpoint: {resume_from_checkpoint}')

    trainer._checkpoint_connector = CheckpointConnector(
        trainer, resume_from_checkpoint=resume_from_checkpoint)
    # Override timer callback to a stateless one
    for idx, callback in enumerate(trainer.callbacks):
        if isinstance(callback, Timer):
            trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, )

    # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
    with open_dict(cfg):
        cfg.model.precision = cfg.trainer.precision

    model = MegatronGPTModel(cfg.model, trainer)

    trainer.fit(model)