Esempio n. 1
0
def get_optimizer(model):
    """Set up the optimizer."""
    args = get_args()

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    # Use Adam.
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis})

    return optimizer
Esempio n. 2
0
def get_optimizer(model, args):
    """Set up the optimizer."""

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (DDP, FP16_Module)):
        model = model.module
    param_groups = gpt2_get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    # Use FusedAdam.
    optimizer = Adam(param_groups,
                         lr=args.lr, weight_decay=args.weight_decay)

    print(f'Optimizer = {optimizer.__class__.__name__}')

    if args.deepspeed:
        return optimizer, param_groups

    # Wrap into fp16 optimizer.
    if args.fp16:
        
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis})

    return optimizer, param_groups
Esempio n. 3
0
def get_megatron_optimizer(model):
    args = get_args()

    # Base optimizer.
    param_groups = _get_params_for_weight_decay_optimization(model)
    optimizer = Adam(param_groups,
                     lr=args.lr,
                     weight_decay=args.weight_decay,
                     betas=(args.adam_beta1, args.adam_beta2),
                     eps=args.adam_eps)

    if args.fp16:
        # Constant loss scale.
        if args.loss_scale:
            grad_scaler = ConstantGradScaler(args.loss_scale)
        # Dynamic loss scale.
        else:
            grad_scaler = DynamicGradScaler(
                initial_scale=args.initial_loss_scale,
                min_scale=args.min_loss_scale,
                growth_factor=2.0,
                backoff_factor=0.5,
                growth_interval=args.loss_scale_window,
                hysteresis=args.hysteresis)
        # Megatron optimizer.
        return FP16OptimizerWithFP16Params(optimizer, grad_scaler,
                                           args.clip_grad)

    # FP32.
    return FP32Optimizer(optimizer, args.clip_grad)
Esempio n. 4
0
def get_optimizer(model, neox_args):
    """Set up the optimizer."""
    if neox_args.no_load_optim:
        return None, None
    # Build parameter groups (weight decay and non-decay).
    param_groups = get_params_for_weight_decay_optimization(model, neox_args)
    print_rank_0(
        f'Configuring Optimizer type: {neox_args.optimizer_type} with params: {neox_args.optimizer["params"]}'
    )
    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    if neox_args.optimizer_type.lower() in ["cpu_adam", "cpu_torch_adam"]:
        if neox_args.optimizer == "cpu_torch_adam":
            cpu_adam_optimizer = torch.optim.Adam
        else:
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            cpu_adam_optimizer = DeepSpeedCPUAdam
        optimizer = cpu_adam_optimizer(param_groups,
                                       weight_decay=neox_args.weight_decay,
                                       **neox_args.optimizer["params"])
    elif neox_args.optimizer_type.lower() == "onebitadam":
        assert neox_args.deepspeed
        optimizer = None
        # onebitadam needs to be instantiated within the deepspeed engine to work :|
    elif neox_args.optimizer_type.lower() == "sm3":
        from .optimizers import SM3
        optimizer = SM3(param_groups, **neox_args.optimizer["params"])
    elif neox_args.optimizer_type.lower() == "madgrad_wd":
        from .optimizers import madgrad_wd
        optimizer = madgrad_wd(param_groups,
                               weight_decay=neox_args.weight_decay,
                               **neox_args.optimizer["params"])
    elif neox_args.optimizer_type.lower() == "adam":
        # Use Adam
        try:
            # default to apex as it's slightly faster
            from apex.optimizers import FusedAdam as Adam
        except ImportError:
            # if apex isn't installed, use deepspeed's FusedAdam
            print(
                "WARNING: APEX not installed - defaulting to deepspeed's fused adam"
            )
            from deepspeed.ops.adam import FusedAdam as Adam
        optimizer = Adam(param_groups,
                         weight_decay=neox_args.weight_decay,
                         **neox_args.optimizer["params"])
    else:
        raise ValueError(
            f"Optimizer type {neox_args.optimizer_type} not recognized")

    if neox_args.deepspeed:
        # fp16 wrapper is not required for DeepSpeed.
        return optimizer, param_groups
    else:
        raise ValueError("Must be using deepspeed to run neox")
Esempio n. 5
0
def get_megatron_optimizer(model):
    args = get_args()

    # Base optimizer.
    param_groups = _get_params_for_weight_decay_optimization(model)
    if args.optimizer == 'adam':
        optimizer = Adam(param_groups,
                         lr=args.lr,
                         weight_decay=args.weight_decay,
                         betas=(args.adam_beta1, args.adam_beta2),
                         eps=args.adam_eps)
    elif args.optimizer == 'sgd':
        optimizer = SGD(param_groups,
                        lr=args.lr,
                        weight_decay=args.weight_decay,
                        momentum=args.sgd_momentum)
    else:
        raise Exception('{} optimizer is not supported.'.format(
            args.optimizer))

    # Determine whether the params have main-grad field.
    params_have_main_grad = False
    if args.DDP_impl == 'local':
        params_have_main_grad = True

    if args.fp16 or args.bf16:

        # Grad scaler:
        #    if loss-scale is provided, instantiate the constant scaler.
        #    if we are using fp16 and loss-scale is not present, use a
        #       dynamic scaler.
        #    otherwise we are running in bf16 with no loss-scale so
        #       leave it as None.
        grad_scaler = None
        # Constant loss scale.
        if args.loss_scale:
            grad_scaler = ConstantGradScaler(args.loss_scale)
        # Dynamic loss scale.
        else:
            if args.fp16:
                grad_scaler = DynamicGradScaler(
                    initial_scale=args.initial_loss_scale,
                    min_scale=args.min_loss_scale,
                    growth_factor=2.0,
                    backoff_factor=0.5,
                    growth_interval=args.loss_scale_window,
                    hysteresis=args.hysteresis)

        # Megatron optimizer.
        return Float16OptimizerWithFloat16Params(optimizer, args.clip_grad,
                                                 args.log_num_zeros_in_grad,
                                                 params_have_main_grad,
                                                 args.bf16, grad_scaler)

    # FP32.
    return FP32Optimizer(optimizer, args.clip_grad, args.log_num_zeros_in_grad,
                         params_have_main_grad)
Esempio n. 6
0
def get_optimizer(model, args):
    """Set up the optimizer."""

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (DDP, FP16_Module)):
        print ("we're dealing with a DDP/FP16_Module, extracting the module...")
        model = model.module
    print ("Getting param groups for weight decay optimization...")    
    param_groups = gpt3_get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    if args.cpu_optimizer:
        if args.cpu_torch_adam:
            cpu_adam_optimizer = torch.optim.Adam
        else:
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            cpu_adam_optimizer = DeepSpeedCPUAdam
        optimizer = cpu_adam_optimizer(param_groups,
                                       lr=args.lr, weight_decay=args.weight_decay)
    else:
        # Use FusedAdam.
        optimizer = Adam(param_groups,
                         lr=args.lr, weight_decay=args.weight_decay)

    print(f'Optimizer = {optimizer.__class__.__name__}')
    
#     print ("let's save our optimizer...")
#     with open("/notebooks/sberbank_rugpts/our_model/optimizer.pkl", "wb") as f:
#         pickle.dump(optimizer, f)
    
    if DEEPSPEED_WRAP and args.deepspeed:
        # fp16 wrapper is not required for DeepSpeed.
        print (f"we're using deepspeed, and so returning our optimizer {optimizer}")
        return optimizer

    # Wrap into fp16 optimizer.
    if args.fp16:
        print ("Wrapping into fp16")
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis})
        
    print (f" we've probably wrapped our optimizer in fp16, \nand now we're eturning our optimizer {optimizer}")
    return optimizer
Esempio n. 7
0
def get_optimizer(model):
    """Set up the optimizer."""
    args = get_args()

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    if args.cpu_optimizer:
        if args.cpu_torch_adam:
            cpu_adam_optimizer = torch.optim.Adam
        else:
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            cpu_adam_optimizer = DeepSpeedCPUAdam
        optimizer = cpu_adam_optimizer(param_groups,
                                       lr=args.lr,
                                       weight_decay=args.weight_decay)
    elif args.onebitadam:
        assert args.deepspeed
        optimizer = None
        # onebitadam needs to be instantiated within the deepspeed engine to work :|
    else:
        # Use Adam
        optimizer = Adam(param_groups,
                         lr=args.lr,
                         weight_decay=args.weight_decay,
                         betas=(args.adam_beta1, args.adam_beta2),
                         eps=args.adam_eps)

    if args.deepspeed:
        # fp16 wrapper is not required for DeepSpeed.
        return optimizer, param_groups

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis
                                   })

    return optimizer, param_groups
Esempio n. 8
0
def get_optimizer(model):
    """Set up the optimizer."""
    args = get_args()

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (torchDDP, FP16_Module)):
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model, args)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    if args.cpu_optimizer:
        if args.cpu_torch_adam:
            cpu_adam_optimizer = torch.optim.Adam
        else:
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            cpu_adam_optimizer = DeepSpeedCPUAdam
        optimizer = cpu_adam_optimizer(param_groups,
                                       lr=args.lr,
                                       weight_decay=args.weight_decay)
    elif args.onebitadam:
        assert args.deepspeed
        optimizer = None
        # onebitadam needs to be instantiated within the deepspeed engine to work :|
    elif args.sm3:
        from .optimizers import SM3
        optimizer = SM3(
            param_groups,
            lr=args.lr,
            momentum=args.momentum,
            beta=args.adam_beta1,
            eps=args.adam_eps,
        )
    else:
        # Use Adam
        optimizer = Adam(param_groups,
                         lr=args.lr,
                         weight_decay=args.weight_decay,
                         betas=(args.adam_beta1, args.adam_beta2),
                         eps=args.adam_eps,
                         adam_w_mode=not args.no_adamw)
    if args.deepspeed:
        # fp16 wrapper is not required for DeepSpeed.
        return optimizer, param_groups
    else:
        raise ValueError("Must be using deepspeed to run neox")
def get_optimizer(model, args):
    """Set up the optimizer."""

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (DDP, FP16_Module)):
        model = model.module
    param_groups = gpt2_get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    if args.cpu_optimizer:
        #Apex FusedAdam uses decoupled weight decay so use the same here
        if args.cpu_torch_adam:
            cpu_adam_optimizer = torch.optim.AdamW
        else:
            #TODO add option for decoupled weight decay in DeepCPUAdam
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            cpu_adam_optimizer = DeepSpeedCPUAdam
        optimizer = cpu_adam_optimizer(param_groups,
                                       lr=args.lr,
                                       weight_decay=args.weight_decay)
    else:
        # Use FusedAdam.
        optimizer = Adam(param_groups,
                         lr=args.lr,
                         weight_decay=args.weight_decay)

    print(f'Optimizer = {optimizer.__class__.__name__}')
    if args.deepspeed:
        # fp16 wrapper is not required for DeepSpeed.
        return optimizer

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis
                                   })

    return optimizer
Esempio n. 10
0
def get_optimizer(param_groups, args):
    """Set up the optimizer."""
    if args.cpu_optimizer:
        # Apex FusedAdam uses decoupled weight decay so use the same here
        if args.cpu_torch_adam:
            cpu_adam_optimizer = torch.optim.AdamW
        else:
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            cpu_adam_optimizer = DeepSpeedCPUAdam
        optimizer = cpu_adam_optimizer(param_groups,
                                       lr=args.lr,
                                       weight_decay=args.weight_decay)
    else:
        # Use FusedAdam.
        if args.optimizer == 'adam':
            optimizer = Adam(param_groups,
                             lr=args.lr,
                             weight_decay=args.weight_decay,
                             betas=(args.adam_beta1, args.adam_beta2),
                             eps=args.adam_eps)
        elif args.optimizer == 'adafactor':
            from transformers import Adafactor
            optimizer = Adafactor(param_groups,
                                  lr=args.lr,
                                  relative_step=False,
                                  warmup_init=False)
        else:
            raise NotImplementedError

    print(f'Optimizer = {optimizer.__class__.__name__}')
    if hasattr(args, "deepspeed") and args.deepspeed:
        raise NotImplementedError
        # fp16 wrapper is not required for DeepSpeed.
        # return optimizer

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis
                                   })

    return optimizer
Esempio n. 11
0
def get_optimizer(model, args):
    """Set up the optimizer."""

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (args.DDP_type, FP16_Module)):
        model = model.module
    layers = model.model.bert.encoder.layer
    pooler = model.model.bert.pooler
    lmheads = model.model.cls.predictions
    nspheads = model.model.cls.seq_relationship
    embeddings = model.model.bert.embeddings
    param_groups = []
    param_groups += list(get_params_for_weight_decay_optimization(layers))
    param_groups += list(get_params_for_weight_decay_optimization(pooler))
    param_groups += list(get_params_for_weight_decay_optimization(nspheads))
    param_groups += list(get_params_for_weight_decay_optimization(embeddings))
    param_groups += list(
        get_params_for_weight_decay_optimization(lmheads.transform))
    param_groups[1]['params'].append(lmheads.bias)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    # Use Adam.
    betas = (0.9, 0.999)
    optimizer = Adam(param_groups,
                     betas=betas,
                     lr=args.lr,
                     weight_decay=args.weight_decay)

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis
                                   })

    return optimizer
def get_optimizer(param_groups, args):
    """Set up the optimizer."""
    if args.cpu_optimizer:
        #Apex FusedAdam uses decoupled weight decay so use the same here
        if args.cpu_torch_adam:
            cpu_adam_optimizer = torch.optim.AdamW
        else:
            #TODO add option for decoupled weight decay in DeepCPUAdam
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            cpu_adam_optimizer = DeepSpeedCPUAdam
        optimizer = cpu_adam_optimizer(param_groups,
                                       lr=args.lr,
                                       weight_decay=args.weight_decay)
    else:
        # Use FusedAdam.
        optimizer = Adam(param_groups,
                         lr=args.lr,
                         weight_decay=args.weight_decay)

    print(f'Optimizer = {optimizer.__class__.__name__}')
    if hasattr(args, "deepspeed") and args.deepspeed:
        raise NotImplementedError
        # fp16 wrapper is not required for DeepSpeed.
        # return optimizer

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis
                                   })

    return optimizer
def main():
    opt = parse_args()
    opt.sever_name = gethostname()

    # --- CUDA setting ---
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Random Seed setting ---
    if opt.random_seed is None:
        opt.random_seed = random.randint(1, 10000)
    random.seed(opt.random_seed)
    np.random.seed(opt.random_seed)
    torch.manual_seed(opt.random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(opt.random_seed)
        cudnn.deterministic = True
        cudnn.benchmark = opt.cudnn_benchmark

    # --- PATH setting ---
    save_result_dir = Path(__file__).parent / "results" / opt.experiment_name
    opt.save_model_dir = str(save_result_dir / "trained_models")
    opt.save_log_path = str(save_result_dir / "train.log")
    mkdirs(opt.save_model_dir)

    # --- Prepare DataLoader ---
    train_loader, valid_loader = get_train_val_loader(opt)
    opt.src_vocab_size = train_loader.dataset.src_vocab_size
    opt.tgt_vocab_size = train_loader.dataset.tgt_vocab_size

    # --- Prepare Model ---
    model = MultimodalTransformer(
        src_vocab_size=opt.src_vocab_size,
        tgt_vocab_size=opt.tgt_vocab_size,
        max_position_num=opt.max_position_num,
        d_model=opt.d_model,
        head_num=opt.head_num,
        d_k=opt.d_k,
        d_v=opt.d_v,
        d_inner=opt.d_inner,
        layer_num=opt.layer_num,
        dropout=opt.dropout,
        cnn_fine_tuning=opt.cnn_fine_tuning,
        shared_embedding=opt.shared_embedding,
        share_dec_input_output_embed=opt.share_dec_input_output_embed,
        init_weight=opt.init_weight,
        fused_layer_norm=opt.use_fused,
    ).to(device)

    # --- Prepare optimizer and scaler ---
    if opt.use_fused:
        from apex.optimizers import FusedAdam as Adam
    else:
        from torch.optim import Adam
    optimizer = Adam(filter(lambda x: x.requires_grad, model.parameters()),
                     betas=(0.9, 0.98),
                     eps=1e-09,
                     weight_decay=opt.weight_decay)
    scaler = GradScaler(init_scale=65536.0, enabled=opt.use_amp)

    # --- Restart setting ---
    start_cnt = 1
    steps_cnt = 0
    if opt.adapt_prop_MNMT is not None:
        ex_name, epoch_cnt = opt.adapt_prop_MNMT.split(',')
        saved_path = f"{pardir}/results/{ex_name}/trained_models/epoch_{epoch_cnt}.pth"
        saved_dict = torch.load(saved_path,
                                map_location=lambda storage, loc: storage)
        init_dir, init_epoch = saved_dict["settings"].MNMT.split(',')
        init_path = f"{pardir}/MNMT/results/{init_dir}/trained_models/epoch_{init_epoch}.pth"
        init_data = torch.load(init_path,
                               map_location=lambda storage, loc: storage)
        check_arguments(init_data["settings"], opt)
        model.load_state_dict(saved_dict["models"]["MNMT"])
        print(f"[Info]Loading complete ({saved_path})")
    elif opt.adapt_init_MNMT is not None:
        ex_name, epoch_cnt = opt.adapt_init_MNMT.split(',')
        saved_path = f"{Path(__file__).parent}/results/{ex_name}/trained_models/epoch_{epoch_cnt}.pth"
        saved_dict = torch.load(saved_path,
                                map_location=lambda storage, loc: storage)
        check_arguments(saved_dict["settings"], opt)
        model.load_state_dict(saved_dict["model"])
        print(f"[Info]Loading complete ({saved_path})")

    if opt.restart is not None:
        start_cnt = opt.restart + 1
        if opt.restart < 500:
            model_name = f"epoch_{opt.restart}.pth"
        else:
            model_name = f"step_{opt.restart}.pth"
        saved_path = f"{opt.save_model_dir}/{model_name}"
        saved_dict = torch.load(saved_path,
                                map_location=lambda storage, loc: storage)
        check_arguments(saved_dict["settings"], opt)
        model.load_state_dict(saved_dict["model"])
        optimizer.load_state_dict(saved_dict["optimizer"])
        scaler.load_state_dict(saved_dict["scaler"])
        steps_cnt = saved_dict["steps_cnt"]
        print(f"[Info]Loading complete ({saved_path})")

    scheduler = Scheduler(
        optimizer=optimizer,
        init_lr=0.,
        end_lr=opt.end_lr,
        warmup_steps=opt.warmup_steps,
        current_steps=steps_cnt,
    )

    # --- DataParallel setting ---
    gpus = [i for i in range(len(opt.gpu_ids.split(',')))]
    if len(gpus) > 1:
        model = nn.DataParallel(model, device_ids=gpus)

    # --- Prepare trainer and validator ---
    if valid_loader is not None:
        validator = ScoreCalculator(
            model=model,
            data_loader=valid_loader,
            references=valid_loader.dataset.tgt_insts,
            bpe=opt.bpe,
            cp_avg_num=opt.check_point_average,
        )
    else:
        validator = None

    trainer = MNMTTrainer(
        model=model,
        train_loader=train_loader,
        optimizer=optimizer,
        scaler=scaler,
        scheduler=scheduler,
        opt=opt,
        validator=validator,
    )

    # -- Train --
    if opt.max_epoch is not None:
        trainer.train_by_epoch(start_cnt)
    else:
        trainer.train_by_step(start_cnt)