예제 #1
0
def _get_lr(config, update_steps):
    """generate learning rate."""
    if config.lr_scheduler == "isr":
        lr = Tensor(square_root_schedule(lr=config.lr,
                                         update_num=update_steps,
                                         decay_start_step=config.decay_start_step,
                                         warmup_steps=config.warmup_steps,
                                         min_lr=config.min_lr), dtype=mstype.float32)
    elif config.lr_scheduler == "poly":
        lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
                                               min_lr=config.min_lr,
                                               decay_steps=config.decay_steps,
                                               total_update_num=update_steps,
                                               warmup_steps=config.warmup_steps,
                                               power=config.poly_lr_scheduler_power), dtype=mstype.float32)
    else:
        lr = config.lr
    return lr
예제 #2
0
def _build_training_pipeline(config: TransformerConfig,
                             pre_training_dataset=None,
                             fine_tune_dataset=None,
                             test_dataset=None):
    """
    Build training pipeline.

    Args:
        config (TransformerConfig): Config of mass model.
        pre_training_dataset (Dataset): Pre-training dataset.
        fine_tune_dataset (Dataset): Fine-tune dataset.
        test_dataset (Dataset): Test dataset.
    """
    net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
    net_with_loss.init_parameters_data()

    if config.existed_ckpt:
        if config.existed_ckpt.endswith(".npz"):
            weights = np.load(config.existed_ckpt)
        else:
            weights = load_checkpoint(config.existed_ckpt)
        for param in net_with_loss.trainable_params():
            weights_name = param.name
            if weights_name not in weights:
                raise ValueError(
                    f"Param {weights_name} is not found in ckpt file.")

            if isinstance(weights[weights_name], Parameter):
                param.default_input = weights[weights_name].default_input
            elif isinstance(weights[weights_name], Tensor):
                param.default_input = Tensor(weights[weights_name].asnumpy(),
                                             config.dtype)
            elif isinstance(weights[weights_name], np.ndarray):
                param.default_input = Tensor(weights[weights_name],
                                             config.dtype)
            else:
                param.default_input = weights[weights_name]
    else:
        for param in net_with_loss.trainable_params():
            name = param.name
            value = param.default_input
            if isinstance(value, Tensor):
                if name.endswith(".gamma"):
                    param.default_input = one_weight(value.asnumpy().shape)
                elif name.endswith(".beta") or name.endswith(".bias"):
                    param.default_input = zero_weight(value.asnumpy().shape)
                else:
                    param.default_input = weight_variable(
                        value.asnumpy().shape)

    dataset = pre_training_dataset if pre_training_dataset is not None \
        else fine_tune_dataset

    if dataset is None:
        raise ValueError(
            "pre-training dataset or fine-tuning dataset must be provided one."
        )

    update_steps = dataset.get_repeat_count() * dataset.get_dataset_size()
    if config.lr_scheduler == "isr":
        lr = Tensor(square_root_schedule(
            lr=config.lr,
            update_num=update_steps,
            decay_start_step=config.decay_start_step,
            warmup_steps=config.warmup_steps,
            min_lr=config.min_lr),
                    dtype=mstype.float32)
    elif config.lr_scheduler == "poly":
        lr = Tensor(polynomial_decay_scheduler(
            lr=config.lr,
            min_lr=config.min_lr,
            decay_steps=config.decay_steps,
            total_update_num=update_steps,
            warmup_steps=config.warmup_steps,
            power=config.poly_lr_scheduler_power),
                    dtype=mstype.float32)
    else:
        lr = config.lr

    if config.optimizer.lower() == "adam":
        optimizer = Adam(net_with_loss.trainable_params(),
                         lr,
                         beta1=0.9,
                         beta2=0.98)
    elif config.optimizer.lower() == "lamb":
        lr = BertLearningRate(decay_steps=12000,
                              learning_rate=config.lr,
                              end_learning_rate=config.min_lr,
                              power=10.0,
                              warmup_steps=config.warmup_steps)
        decay_params = list(
            filter(
                lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x
                .name.lower(), net_with_loss.trainable_params()))
        other_params = list(
            filter(
                lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.
                lower(), net_with_loss.trainable_params()))
        group_params = [{
            'params': decay_params,
            'weight_decay': 0.01
        }, {
            'params': other_params
        }]

        optimizer = Lamb(group_params, lr, eps=1e-6)
    elif config.optimizer.lower() == "momentum":
        optimizer = Momentum(net_with_loss.trainable_params(),
                             lr,
                             momentum=0.9)
    else:
        raise ValueError(f"optimizer only support `adam` and `momentum` now.")

    # Dynamic loss scale.
    scale_manager = DynamicLossScaleManager(
        init_loss_scale=config.init_loss_scale,
        scale_factor=config.loss_scale_factor,
        scale_window=config.scale_window)
    net_with_grads = TransformerTrainOneStepWithLossScaleCell(
        network=net_with_loss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    net_with_grads.set_train(True)
    model = Model(net_with_grads)
    loss_monitor = LossCallBack(config)
    ckpt_config = CheckpointConfig(
        save_checkpoint_steps=config.save_ckpt_steps,
        keep_checkpoint_max=config.keep_ckpt_max)

    rank_size = os.getenv('RANK_SIZE')
    callbacks = [loss_monitor]
    if rank_size is not None and int(
            rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
        ckpt_callback = ModelCheckpoint(
            prefix=config.ckpt_prefix,
            directory=os.path.join(config.ckpt_path,
                                   'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
            config=ckpt_config)
        callbacks.append(ckpt_callback)

    if rank_size is None or int(rank_size) == 1:
        ckpt_callback = ModelCheckpoint(
            prefix=config.ckpt_prefix,
            directory=os.path.join(config.ckpt_path,
                                   'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
            config=ckpt_config)
        callbacks.append(ckpt_callback)

    print(f" | ALL SET, PREPARE TO TRAIN.")
    _train(model=model,
           config=config,
           pre_training_dataset=pre_training_dataset,
           fine_tune_dataset=fine_tune_dataset,
           test_dataset=test_dataset,
           callbacks=callbacks)