Exemple #1
0
def test_lamb_error():
    net = Net()
    with pytest.raises(TypeError):
        Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0)

    with pytest.raises(TypeError):
        Lamb(net.get_parameters(), decay_steps=1.0)

    with pytest.raises(ValueError):
        Lamb(net.get_parameters(), decay_steps=0)
Exemple #2
0
def test_edge_case():
    """ test_edge_case """
    context.set_auto_parallel_context(enable_parallel_optimizer=True)
    net = Net()
    with pytest.raises(RuntimeError):
        context.set_auto_parallel_context(parallel_mode="stand_alone")
        Lamb(net.trainable_params(), learning_rate=0.1)
    with pytest.raises(RuntimeError):
        Adam(net.trainable_params(), learning_rate=0.1)
    with pytest.raises(RuntimeError):
        context.set_auto_parallel_context(device_num=16)
        Lamb(net.trainable_params(), learning_rate=0.1)
def _get_optimizer(args_opt, network):
    """get bert optimizer, support Lamb, Momentum, AdamWeightDecay."""
    if cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
                                       end_learning_rate=cfg.Lamb.end_learning_rate,
                                       warmup_steps=cfg.Lamb.warmup_steps,
                                       decay_steps=args_opt.train_steps,
                                       power=cfg.Lamb.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.Lamb.decay_filter, params))
        other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
        group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
                        {'params': other_params},
                        {'order_params': params}]
        optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
                                       end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
                                       warmup_steps=cfg.AdamWeightDecay.warmup_steps,
                                       decay_steps=args_opt.train_steps,
                                       power=cfg.AdamWeightDecay.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
                        {'params': other_params, 'weight_decay': 0.0},
                        {'order_params': params}]
        if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU':
            optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
        elif context.get_context("mode") == context.PYNATIVE_MODE and args_opt.device_target == 'GPU':
            optimizer = AdamWeightDecayOp(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
        else:
            optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
    elif cfg.optimizer == "Thor":
        from src.utils import get_bert_thor_lr, get_bert_thor_damping
        lr = get_bert_thor_lr(cfg.Thor.lr_max, cfg.Thor.lr_min, cfg.Thor.lr_power, cfg.Thor.lr_total_steps)
        damping = get_bert_thor_damping(cfg.Thor.damping_max, cfg.Thor.damping_min, cfg.Thor.damping_power,
                                        cfg.Thor.damping_total_steps)
        split_indices = None
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                split_indices = [29, 58, 87, 116, 145, 174, 203, 217]
            else:
                split_indices = [28, 55, 82, 109, 136, 163, 190, 205]
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                split_indices = [30, 90, 150, 210, 270, 330, 390, 421]
            else:
                split_indices = [38, 93, 148, 203, 258, 313, 368, 397]
        optimizer = THOR(network, lr, damping, cfg.Thor.momentum,
                         cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size,
                         decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
                         split_indices=split_indices)
    else:
        raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
                         format(cfg.optimizer))
    return optimizer
Exemple #4
0
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE)
    context.set_context(device_target="Ascend")
    context.set_context(enable_task_sink=True)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    parallel_callback = ModelCallback()
    ds = me_de_train_dataset(cfg.bert_config.batch_size)
    config = cfg.bert_config
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Lamb(netwithloss.trainable_params(),
                     decay_steps=cfg.decay_steps,
                     start_learning_rate=cfg.start_learning_rate,
                     end_learning_rate=cfg.end_learning_rate,
                     power=cfg.power,
                     warmup_steps=cfg.num_warmup_steps,
                     decay_filter=lambda x: False)
    netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    config_ck = CheckpointConfig(
        save_checkpoint_steps=cfg.save_checkpoint_steps,
        keep_checkpoint_max=cfg.keep_checkpoint_max)
    ckpoint_cb = ModelCheckpoint(prefix=cfg.checkpoint_prefix,
                                 config=config_ck)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=[parallel_callback, ckpoint_cb],
                dataset_sink_mode=False)
Exemple #5
0
def _get_optimizer(args_opt, network):
    """get bert optimizer, support Lamb, Momentum, AdamWeightDecay."""
    if cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.Lamb.learning_rate,
            end_learning_rate=cfg.Lamb.end_learning_rate,
            warmup_steps=cfg.Lamb.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.Lamb.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.Lamb.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.Lamb.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.Lamb.weight_decay
        }, {
            'params': other_params
        }, {
            'order_params': params
        }]
        optimizer = Lamb(group_params,
                         learning_rate=lr_schedule,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=cfg.AdamWeightDecay.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.AdamWeightDecay.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }, {
            'order_params': params
        }]

        optimizer = AdamWeightDecay(group_params,
                                    learning_rate=lr_schedule,
                                    eps=cfg.AdamWeightDecay.eps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]"
            .format(cfg.optimizer))
    return optimizer
def test_bert_tdt():
    """test bert tdt"""
    np.random.seed(0)
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
    context.set_context(enable_graph_kernel=True)
    ds = me_de_train_dataset()
    config = get_config(version='large', batch_size=16)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(),
                     start_learning_rate=5e-5, end_learning_rate=1e-9,
                     power=10.0, warmup_steps=0, weight_decay=0.01)
    scale_window = 3
    scale_manager = DynamicLossScaleManager(262144, 2, scale_window)
    netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
                                                     scale_update_cell=scale_manager.get_update_cell())
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    callback = ModelCallback()
    params = netwithloss.trainable_params()
    for param in params:
        param.init_data()
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info("***************** BERT param name is 1 {}".format(name))
                    param.default_input = weight_variable(value.asnumpy().shape)
                else:
                    logger.info("***************** BERT param name is 2 {}".format(name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(np.transpose(weight_value, [1, 0]))
            else:
                logger.info("***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(1, ds, callbacks=callback, dataset_sink_mode=False)

    # assertion occurs while the loss value, overflow state or loss_scale value is wrong
    loss_value = np.array(callback.loss_list)
    expect_loss_value = [12.559319, 12.333815, 12.339806, 12.350235, 12.343947, 12.830965, 12.375336, 12.973715,
                         12.57929, 12.7766905]
    error = loss_value - expect_loss_value
    print("loss value: {}".format(loss_value))
    print("error value: {}".format(error))
    assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)

    overflow = np.array(callback.overflow_list)
    expect_overflow = [True, True, True, True, False, False, False, True, False, False]
    print("overflow: {}".format(overflow))
    assert (overflow == expect_overflow).all()

    loss_scale = np.array(callback.lossscale_list)
    expect_loss_scale = [131072.0, 65536.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0]
    print("loss scale: {}".format(loss_scale))
    assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
Exemple #7
0
def test_lamb_compile():
    """ test_Lamb_compile """
    inputs = Tensor(np.ones([1, 64]).astype(np.float32))
    label = Tensor(np.zeros([1, 10]).astype(np.float32))
    net = Net()
    net.set_train()
    loss = nn.SoftmaxCrossEntropyWithLogits()
    optimizer = Lamb(net.trainable_params(), decay_steps=10)

    net_with_loss = WithLossCell(net, loss)
    train_network = TrainOneStepCell(net_with_loss, optimizer)
    _executor.compile(train_network, inputs, label)
Exemple #8
0
def test_compile_fp16_overflow():
    inputs = Tensor(np.ones([16, 16]).astype(np.float32))
    label = Tensor(np.zeros([16, 16]).astype(np.float32))
    scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
    net = NetFP16(16, 16)

    loss = MSELoss()
    optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5)
    net_with_loss = WithLossCell(net, loss)
    train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
    train_network.set_train()
    output = train_network(inputs, label, scaling_sens)
    print("the result is ", output)
Exemple #9
0
def test_lamb_compile_dynamic_lr():
    """ test_Lamb_compile """
    inputs = Tensor(np.ones([1, 64]).astype(np.float32))
    label = Tensor(np.zeros([1, 10]).astype(np.float32))
    net = Net()
    net.set_train()
    loss = nn.SoftmaxCrossEntropyWithLogits()
    warmup_decay_lr = LambLearningRate(0.01, 0.0001, 10, 20, 1.0)
    optimizer = Lamb(net.trainable_params(), warmup_decay_lr)

    net_with_loss = WithLossCell(net, loss)
    train_network = TrainOneStepCell(net_with_loss, optimizer)
    _executor.compile(train_network, inputs, label)
def test_lamb_compile():
    """ test_Lamb_compile """
    context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
    inputs = Tensor(np.ones([32, 128]).astype(np.float32))
    label = Tensor(np.zeros([32, 768]).astype(np.float32))
    net = Net()
    net.set_train()
    loss = nn.SoftmaxCrossEntropyWithLogits()
    optimizer = Lamb(net.trainable_params(), learning_rate=0.1)

    net_with_loss = WithLossCell(net, loss)
    train_network = TrainOneStepCell(net_with_loss, optimizer)
    _executor.compile(train_network, inputs, label)
    context.reset_auto_parallel_context()
Exemple #11
0
def _get_optimizer(config, network, lr):
    """get gnmt optimizer, support Adam, Lamb, Momentum."""
    if config.optimizer.lower() == "adam":
        optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98)
    elif config.optimizer.lower() == "lamb":
        optimizer = Lamb(network.trainable_params(),
                         learning_rate=lr,
                         eps=1e-6)
    elif config.optimizer.lower() == "momentum":
        optimizer = Momentum(network.trainable_params(), lr, momentum=0.9)
    else:
        raise ValueError(f"optimizer only support `adam` and `momentum` now.")

    return optimizer
Exemple #12
0
def _get_optimizer(config, network, lr):
    """get gnmt optimizer, support Adam, Lamb, Momentum."""
    if config.optimizer.lower() == "adam":
        optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98)
    elif config.optimizer.lower() == "lamb":
        optimizer = Lamb(network.trainable_params(), decay_steps=12000,
                         start_learning_rate=config.lr, end_learning_rate=config.min_lr,
                         power=10.0, warmup_steps=config.warmup_steps, weight_decay=0.01,
                         eps=1e-6)
    elif config.optimizer.lower() == "momentum":
        optimizer = Momentum(network.trainable_params(), lr, momentum=0.9)
    else:
        raise ValueError(f"optimizer only support `adam` and `momentum` now.")

    return optimizer
Exemple #13
0
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
    """ do train """
    if load_checkpoint_path == "":
        raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
    steps_per_epoch = dataset.get_dataset_size()
    # optimizer
    if optimizer_cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
                                       end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                       decay_steps=steps_per_epoch * epoch_num,
                                       power=optimizer_cfg.AdamWeightDecay.power)
        params = network.trainable_params()
        decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
                        {'params': other_params, 'weight_decay': 0.0}]
        optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
    elif optimizer_cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
                                       end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                       decay_steps=steps_per_epoch * epoch_num,
                                       power=optimizer_cfg.Lamb.power)
        optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
    elif optimizer_cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
                             momentum=optimizer_cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix="ner",
                                 directory=None if save_checkpoint_path == "" else save_checkpoint_path,
                                 config=ckpt_config)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(network, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
    netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
    train_begin = time.time()
    model.train(epoch_num, dataset, callbacks=callbacks)
    train_end = time.time()
    print("latency: {:.6f} s".format(train_end - train_begin))
Exemple #14
0
def test_loss_scale_fp16_overflow():
    inputs = Tensor(np.ones([16, 16]).astype(np.float32))
    label = Tensor(np.zeros([16, 16]).astype(np.float32))
    net = NetFP16(16, 16)
    net.set_train()

    loss = MSELoss()
    optimizer = Lamb(net.trainable_params(), learning_rate=0.01)
    net_with_loss = WithLossCell(net, loss)
    net_with_loss.set_grad()
    train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
                                                  scale_sense=Tensor(np.full((1), np.finfo(np.float32).max),
                                                                     dtype=mstype.float32))
    output_1 = train_network(inputs, label)
    output_2 = train_network(inputs, label)
    assert output_1[0].asnumpy() == output_2[0].asnumpy()
    assert output_1[1].asnumpy() == output_2[1].asnumpy() == True
Exemple #15
0
def _get_optimizer(config, network, lr):
    """get mass optimizer, support Adam, Lamb, Momentum."""
    if config.optimizer.lower() == "adam":
        optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98)
    elif config.optimizer.lower() == "lamb":
        lr = BertLearningRate(decay_steps=12000, learning_rate=config.lr, end_learning_rate=config.min_lr,
                              power=10.0, warmup_steps=config.warmup_steps)
        decay_params = list(filter(lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
                                   network.trainable_params()))
        other_params = list(filter(lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower(),
                                   network.trainable_params()))
        group_params = [{'params': decay_params, 'weight_decay': 0.01},
                        {'params': other_params}]

        optimizer = Lamb(group_params, lr, eps=1e-6)
    elif config.optimizer.lower() == "momentum":
        optimizer = Momentum(network.trainable_params(), lr, momentum=0.9)
    else:
        raise ValueError(f"optimizer only support `adam` and `momentum` now.")
    return optimizer
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
    context.set_context(enable_task_sink=True)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    parallel_callback = ModelCallback()
    ds = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    batch_size = int(os.getenv('BATCH_SIZE', '16'))
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Lamb(netwithloss.trainable_params(), decay_steps=10000, start_learning_rate=1e-4,
                     end_learning_rate=0.0, power=10.0, warmup_steps=0, decay_filter=lambda x: False)
    netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    params = netwithloss.trainable_params()
    for param in params:
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info("***************** BERT param name is 1 {}".format(name))
                    param.default_input = weight_variable(value.asnumpy().shape)
                else:
                    logger.info("***************** BERT param name is 2 {}".format(name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(np.transpose(weight_value, [1, 0]))
            else:
                logger.info("***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(ds.get_repeat_count(), ds, callbacks=parallel_callback, dataset_sink_mode=False)
    loss_value = np.array(parallel_callback.loss_list)
    expect_out = [12.191790, 11.739655, 11.523477, 11.320723, 11.113152, 11.203759, 10.841681, 10.826849,
                  10.616718, 10.486609]
    logger.info("expected loss value output: {}".format(expect_out))
    assert allclose(loss_value, expect_out, 0.001, 0.001)
Exemple #17
0
def test_lamb_group():
    """ test_Lamb_group_compile """
    inputs = Tensor(np.ones([1, 64]).astype(np.float32))
    label = Tensor(np.zeros([1, 10]).astype(np.float32))
    net = Net()
    net.set_train()
    loss = nn.SoftmaxCrossEntropyWithLogits()
    warmup_decay_lr = LambLearningRate(0.01, 0.0001, 10, 20, 1.0)
    all_params = net.trainable_params()
    group_params = [{
        'params': [all_params[0]],
        'lr': warmup_decay_lr,
        'weight_decay': 0.9
    }, {
        'params': [all_params[1]]
    }]
    optimizer = Lamb(group_params, 0.02)

    net_with_loss = WithLossCell(net, loss)
    train_network = TrainOneStepCell(net_with_loss, optimizer)
    _executor.compile(train_network, inputs, label)
Exemple #18
0
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""):
    """ do train """
    if load_checkpoint_path == "":
        raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
    steps_per_epoch = dataset.get_dataset_size()
    epoch_num = dataset.get_repeat_count()
    # optimizer
    if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
                                             decay_steps=steps_per_epoch * epoch_num,
                                             learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate,
                                             end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate,
                                             power=optimizer_cfg.AdamWeightDecayDynamicLR.power,
                                             warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                             weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay,
                                             eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps)
    elif optimizer_cfg.optimizer == 'Lamb':
        optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num,
                         start_learning_rate=optimizer_cfg.Lamb.start_learning_rate,
                         end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
                         power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay,
                         warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                         decay_filter=optimizer_cfg.Lamb.decay_filter)
    elif optimizer_cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
                             momentum=optimizer_cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]")

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix="classifier", directory=save_checkpoint_path, config=ckpt_config)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(network, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
    netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb]
    model.train(epoch_num, dataset, callbacks=callbacks)
Exemple #19
0
from ..pipeline.gradient.check_training import pipeline_for_check_model_loss_for_case_by_case_config
from ..utils.model_util import Linreg
from ..utils.model_util import SquaredLoss

network = Linreg(2)
num_epochs = 1000

verification_set = [('Linreg', {
    'block': {
        'model':
        network,
        'loss':
        SquaredLoss(),
        'opt':
        Lamb(network.trainable_params(),
             decay_steps=num_epochs,
             warmup_steps=10,
             weight_decay=0.01),
        'num_epochs':
        num_epochs,
        'loss_upper_bound':
        0.3,
    },
    'desc_inputs': {
        'true_params': ([2, -3.4], 4.2),
        'num_samples': 100,
        'batch_size': 20,
    }
})]


@mindspore_test(pipeline_for_check_model_loss_for_case_by_case_config)
Exemple #20
0
def test_bert_precision(enable_graph_kernel=False):
    """test bert precision"""
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
    if enable_graph_kernel:
        context.set_context(enable_graph_kernel=True)
    data_set, new_repeat_count, _ = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    config = get_config(version=version)
    netwithloss = BertNetworkWithLoss(config, True)
    lr = BertLearningRate(decay_steps=data_set.get_dataset_size() * new_repeat_count,
                          learning_rate=5e-5, end_learning_rate=1e-9,
                          power=10.0, warmup_steps=0)
    decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()
    no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower()
    decay_params = list(filter(decay_filter, netwithloss.trainable_params()))
    other_params = list(filter(no_decay_filter, netwithloss.trainable_params()))
    group_params = [{'params': decay_params, 'weight_decay': 0.01},
                    {'params': other_params},
                    {'order_params': netwithloss.trainable_params()}]
    optimizer = Lamb(group_params, lr)
    scale_window = 3
    scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window)
    netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
                                                     scale_update_cell=scale_manager.get_update_cell())
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    callback = ModelCallback()
    params = netwithloss.trainable_params()
    for param in params:
        value = param.data
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info("***************** BERT param name is 1 {}".format(name))
                    param.set_data(weight_variable(value.asnumpy().shape))
                else:
                    logger.info("***************** BERT param name is 2 {}".format(name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.set_data(Tensor(np.transpose(weight_value, [1, 0])))
            else:
                logger.info("***************** BERT param name is 3 {}".format(name))
                param.set_data(weight_variable(value.asnumpy().shape))
    model.train(new_repeat_count, data_set, callbacks=callback, dataset_sink_mode=False)

    # assertion occurs while the loss value, overflow state or loss_scale value is wrong
    loss_value = np.array(callback.loss_list)

    if enable_graph_kernel:
        expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565,
                             12.185522, 12.386192]
    else:
        assert np.allclose(loss_value[0], 12.2066, 0, 0.0005)
        expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656,
                             12.407923, 12.631133]
    print("loss value: {}".format(loss_value))
    assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)

    overflow = np.array(callback.overflow_list)
    expect_overflow = [False, False, False, True, False, False, False, True, False, False]
    print("overflow: {}".format(overflow))
    assert (overflow == expect_overflow).all()

    loss_scale = np.array(callback.lossscale_list)
    expect_loss_scale = [65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0]
    print("loss scale: {}".format(loss_scale))
    assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
def test_bert_performance():
    """test bert performance"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        reserve_class_name_in_scope=False)
    data_set, new_repeat_count, sink_size = me_de_train_dataset(sink_mode=True)
    version = os.getenv('VERSION', 'large')
    config = get_config(version=version)
    netwithloss = BertNetworkWithLoss(config, True)

    lr = BertLearningRate(decay_steps=sink_size * new_repeat_count,
                          learning_rate=5e-5,
                          end_learning_rate=1e-9,
                          power=10.0,
                          warmup_steps=0)
    decay_filter = lambda x: 'layernorm' not in x.name.lower(
    ) and 'bias' not in x.name.lower()
    no_decay_filter = lambda x: 'layernorm' in x.name.lower(
    ) or 'bias' in x.name.lower()
    decay_params = list(filter(decay_filter, netwithloss.trainable_params()))
    other_params = list(filter(no_decay_filter,
                               netwithloss.trainable_params()))
    group_params = [{
        'params': decay_params,
        'weight_decay': 0.01
    }, {
        'params': other_params
    }, {
        'order_params': netwithloss.trainable_params()
    }]
    optimizer = Lamb(group_params, lr)

    scale_window = 3
    scale_manager = DynamicLossScaleManager(2**16, 2, scale_window)
    netwithgrads = BertTrainOneStepWithLossScaleCell(
        netwithloss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    callback = ModelCallback()
    params = netwithloss.trainable_params()
    for param in params:
        value = param.data
        name = param.name
        if isinstance(value, Tensor) and not value.has_init:
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info(
                        "***************** BERT param name is 1 {}".format(
                            name))
                    param.set_data(weight_variable(value.asnumpy().shape))
                else:
                    logger.info(
                        "***************** BERT param name is 2 {}".format(
                            name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.set_data(Tensor(np.transpose(weight_value, [1, 0])))
            else:
                logger.info(
                    "***************** BERT param name is 3 {}".format(name))
                param.set_data(weight_variable(value.asnumpy().shape))
    time_monitor_callback = TimeMonitor(sink_size)
    model.train(new_repeat_count,
                data_set,
                callbacks=[time_monitor_callback, callback],
                dataset_sink_mode=True,
                sink_size=sink_size)

    # assertion occurs while the loss value, overflow state or loss_scale value is wrong
    loss_value = np.array(callback.loss_list)
    expect_loss_value = [11.3660, 11.3265, 11.3264]
    print("loss value: {}".format(loss_value))
    assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)

    overflow = np.array(callback.overflow_list)
    expect_overflow = [True, True, True]
    print("overflow: {}".format(overflow))
    assert (overflow == expect_overflow).all()

    loss_scale = np.array(callback.lossscale_list)
    expect_loss_scale = [65536.0, 65536.0, 65536.0]
    print("loss scale: {}".format(loss_scale))
    assert np.allclose(loss_scale, expect_loss_scale, 0, 0)

    epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2]
    expect_epoch_mseconds = 1400
    print("epoch mseconds: {}".format(epoch_mseconds))
    assert epoch_mseconds <= expect_epoch_mseconds + 5

    per_step_mseconds = np.array(
        time_monitor_callback.per_step_mseconds_list)[2]
    expect_per_step_mseconds = 14
    print("per step mseconds: {}".format(per_step_mseconds))
    assert per_step_mseconds <= expect_per_step_mseconds + 1
Exemple #22
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_task_sink",
                        type=str,
                        default="true",
                        help="Enable task sink, default is true.")
    parser.add_argument("--enable_loop_sink",
                        type=str,
                        default="true",
                        help="Enable loop sink, default is true.")
    parser.add_argument("--enable_mem_reuse",
                        type=str,
                        default="true",
                        help="Enable mem reuse, default is true.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--checkpoint_path",
                        type=str,
                        default="",
                        help="Checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=args_opt.device_id)
    context.set_context(enable_task_sink=(args_opt.enable_task_sink == "true"),
                        enable_loop_sink=(args_opt.enable_loop_sink == "true"),
                        enable_mem_reuse=(args_opt.enable_mem_reuse == "true"))
    context.set_context(reserve_class_name_in_scope=False)

    if args_opt.distribute == "true":
        device_num = args_opt.device_num
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
        D.init()
        rank = args_opt.device_id % device_num
    else:
        rank = 0
        device_num = 1

    ds = create_bert_dataset(args_opt.epoch_size, device_num, rank,
                             args_opt.do_shuffle, args_opt.enable_data_sink,
                             args_opt.data_sink_steps, args_opt.data_dir,
                             args_opt.schema_dir)

    netwithloss = BertNetworkWithLoss(bert_net_cfg, True)

    if cfg.optimizer == 'Lamb':
        optimizer = Lamb(netwithloss.trainable_params(),
                         decay_steps=ds.get_dataset_size() *
                         ds.get_repeat_count(),
                         start_learning_rate=cfg.Lamb.start_learning_rate,
                         end_learning_rate=cfg.Lamb.end_learning_rate,
                         power=cfg.Lamb.power,
                         warmup_steps=cfg.Lamb.warmup_steps,
                         weight_decay=cfg.Lamb.weight_decay,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(netwithloss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(
            netwithloss.trainable_params(),
            decay_steps=ds.get_dataset_size() * ds.get_repeat_count(),
            learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
            end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
            power=cfg.AdamWeightDecayDynamicLR.power,
            weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
            eps=cfg.AdamWeightDecayDynamicLR.eps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]"
            .format(cfg.optimizer))
    callback = [LossCallBack()]
    if args_opt.enable_save_ckpt == "true":
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.checkpoint_path:
        param_dict = load_checkpoint(args_opt.checkpoint_path)
        load_param_into_net(netwithloss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        netwithgrads = BertTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)

    model = Model(netwithgrads)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"))
Exemple #23
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init('hccl')
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init('nccl')
            device_num = D.get_group_size()
            rank = D.get_rank()
            ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
                rank) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
        from mindspore.parallel._auto_parallel_context import auto_parallel_context
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [29, 58, 87, 116, 145, 174, 203, 217])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [28, 55, 82, 109, 136, 163, 190, 205])
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [30, 90, 150, 210, 270, 330, 390, 421])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [38, 93, 148, 203, 258, 313, 368, 397])
    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num,
                                               rank, args_opt.do_shuffle,
                                               args_opt.enable_data_sink,
                                               args_opt.data_sink_steps,
                                               args_opt.data_dir,
                                               args_opt.schema_dir)
    data_epoch_size = new_repeat_count // args_opt.epoch_size  # Epoch nums in one dataset.
    if args_opt.train_steps > 0:
        new_repeat_count = min(
            new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
    netwithloss = BertNetworkWithLoss(bert_net_cfg, True)

    if cfg.optimizer == 'Lamb':
        optimizer = Lamb(netwithloss.trainable_params(),
                         decay_steps=ds.get_dataset_size() * new_repeat_count,
                         start_learning_rate=cfg.Lamb.start_learning_rate,
                         end_learning_rate=cfg.Lamb.end_learning_rate,
                         power=cfg.Lamb.power,
                         warmup_steps=cfg.Lamb.warmup_steps,
                         weight_decay=cfg.Lamb.weight_decay,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(netwithloss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(
            netwithloss.trainable_params(),
            decay_steps=ds.get_dataset_size() * new_repeat_count,
            learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
            end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
            power=cfg.AdamWeightDecayDynamicLR.power,
            weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
            eps=cfg.AdamWeightDecayDynamicLR.eps,
            warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]"
            .format(cfg.optimizer))
    callback = [
        TimeMonitor(ds.get_dataset_size()),
        LossCallBack(data_epoch_size)
    ]
    if args_opt.enable_save_ckpt == "true":
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     directory=ckpt_save_dir,
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(netwithloss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        netwithgrads = BertTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)

    model = Model(netwithgrads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"))
def test_bert_percision():
    """test bert percision"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        reserve_class_name_in_scope=False)
    ds, new_repeat_count, _ = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    batch_size = 16
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    lr = BertLearningRate(decay_steps=ds.get_dataset_size() * new_repeat_count,
                          learning_rate=5e-5,
                          end_learning_rate=1e-9,
                          power=10.0,
                          warmup_steps=0)
    decay_filter = lambda x: 'layernorm' not in x.name.lower(
    ) and 'bias' not in x.name.lower()
    no_decay_filter = lambda x: 'layernorm' in x.name.lower(
    ) or 'bias' in x.name.lower()
    decay_params = list(filter(decay_filter, netwithloss.trainable_params()))
    other_params = list(filter(no_decay_filter,
                               netwithloss.trainable_params()))
    group_params = [{
        'params': decay_params,
        'weight_decay': 0.01
    }, {
        'params': other_params
    }, {
        'order_params': netwithloss.trainable_params()
    }]
    optimizer = Lamb(group_params, lr)
    scale_window = 3
    scale_manager = DynamicLossScaleManager(2**16, 2, scale_window)
    netwithgrads = BertTrainOneStepWithLossScaleCell(
        netwithloss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    callback = ModelCallback()
    params = netwithloss.trainable_params()
    for param in params:
        param.init_data()
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info(
                        "***************** BERT param name is 1 {}".format(
                            name))
                    param.default_input = weight_variable(
                        value.asnumpy().shape)
                else:
                    logger.info(
                        "***************** BERT param name is 2 {}".format(
                            name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(
                        np.transpose(weight_value, [1, 0]))
            else:
                logger.info(
                    "***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=False)

    # assertion occurs while the loss value, overflow state or loss_scale value is wrong
    loss_value = np.array(callback.loss_list)
    assert np.allclose(loss_value[0], 12.206575, 0, 0.000001)

    expect_loss_value = [
        12.206575, 11.865044, 11.828129, 11.826707, 11.82108, 12.407423,
        12.005459, 12.621225, 12.222903, 12.427446
    ]
    print("loss value: {}".format(loss_value))
    assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)

    overflow = np.array(callback.overflow_list)
    expect_overflow = [
        False, False, False, True, False, False, False, True, False, False
    ]
    print("overflow: {}".format(overflow))
    assert (overflow == expect_overflow).all()

    loss_scale = np.array(callback.lossscale_list)
    expect_loss_scale = [
        65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0, 131072.0,
        65536.0, 65536.0, 65536.0
    ]
    print("loss scale: {}".format(loss_scale))
    assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
def test_bert_performance():
    """test bert performance"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        reserve_class_name_in_scope=False)
    ds, new_repeat_count = me_de_train_dataset(sink_mode=True)
    version = os.getenv('VERSION', 'large')
    batch_size = 16
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Lamb(netwithloss.trainable_params(),
                     decay_steps=ds.get_dataset_size() * new_repeat_count,
                     start_learning_rate=5e-5,
                     end_learning_rate=1e-9,
                     power=10.0,
                     warmup_steps=0,
                     weight_decay=0.01)
    scale_window = 3
    scale_manager = DynamicLossScaleManager(2**16, 2, scale_window)
    netwithgrads = BertTrainOneStepWithLossScaleCell(
        netwithloss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    callback = ModelCallback()
    params = netwithloss.trainable_params()
    for param in params:
        param.init_data()
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info(
                        "***************** BERT param name is 1 {}".format(
                            name))
                    param.default_input = weight_variable(
                        value.asnumpy().shape)
                else:
                    logger.info(
                        "***************** BERT param name is 2 {}".format(
                            name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(
                        np.transpose(weight_value, [1, 0]))
            else:
                logger.info(
                    "***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    time_monitor_callback = TimeMonitor(ds.get_dataset_size())
    model.train(new_repeat_count,
                ds,
                callbacks=[time_monitor_callback, callback],
                dataset_sink_mode=True)

    # assertion occurs while the loss value, overflow state or loss_scale value is wrong
    loss_value = np.array(callback.loss_list)
    expect_loss_value = [10.235566, 10.207392, 10.206976]
    print("loss value: {}".format(loss_value))
    assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)

    overflow = np.array(callback.overflow_list)
    expect_overflow = [True, True, True]
    print("overflow: {}".format(overflow))
    assert (overflow == expect_overflow).all()

    loss_scale = np.array(callback.lossscale_list)
    expect_loss_scale = [262144.0, 262144.0, 262144.0]
    print("loss scale: {}".format(loss_scale))
    assert np.allclose(loss_scale, expect_loss_scale, 0, 0)

    epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2]
    expect_epoch_mseconds = 1600
    print("epoch mseconds: {}".format(epoch_mseconds))
    assert epoch_mseconds <= expect_epoch_mseconds + 5

    per_step_mseconds = np.array(
        time_monitor_callback.per_step_mseconds_list)[2]
    expect_per_step_mseconds = 16
    print("per step mseconds: {}".format(per_step_mseconds))
    assert per_step_mseconds <= expect_per_step_mseconds + 1
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        reserve_class_name_in_scope=False)
    ds = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    batch_size = int(os.getenv('BATCH_SIZE', '16'))
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Lamb(netwithloss.trainable_params(),
                     decay_steps=ds.get_dataset_size() * ds.get_repeat_count(),
                     start_learning_rate=5e-5,
                     end_learning_rate=1e-9,
                     power=10.0,
                     warmup_steps=0,
                     weight_decay=0.01)
    scale_window = 3
    scale_manager = DynamicLossScaleManager(2**16, 2, scale_window)
    netwithgrads = BertTrainOneStepWithLossScaleCell(
        netwithloss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    callback = ModelCallback()
    params = netwithloss.trainable_params()
    for param in params:
        param.init_data()
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info(
                        "***************** BERT param name is 1 {}".format(
                            name))
                    param.default_input = weight_variable(
                        value.asnumpy().shape)
                else:
                    logger.info(
                        "***************** BERT param name is 2 {}".format(
                            name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(
                        np.transpose(weight_value, [1, 0]))
            else:
                logger.info(
                    "***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=callback,
                dataset_sink_mode=False)

    # assertion occurs while the loss value, overflow state or loss_scale value is wrong
    loss_value = np.array(callback.loss_list)
    expect_loss_value = [
        12.207198, 11.980881, 11.984844, 11.879381, 11.832978, 12.411333,
        12.009284, 12.621277, 12.223178, 12.427385
    ]
    print("loss value: {}".format(loss_value))
    assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)

    overflow = np.array(callback.overflow_list)
    expect_overflow = [
        True, True, False, False, False, True, False, False, False, True
    ]
    print("overflow: {}".format(overflow))
    assert (overflow == expect_overflow).all()

    loss_scale = np.array(callback.lossscale_list)
    expect_loss_scale = [
        32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0,
        32768.0, 16384.0
    ]
    print("loss scale: {}".format(loss_scale))
    assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
def run_pretrain(args_opt):
    """pre-train bert"""
    global device_id
    global device_num
    global rank_id
    global job_id
    args_opt.device_id = device_id
    args_opt.device_num = device_num
    sync_dataset(args_opt.data_url)

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init('hccl')
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init('nccl')
            device_num = D.get_group_size()
            rank = D.get_rank()
            ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
                rank) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
        from mindspore.parallel._auto_parallel_context import auto_parallel_context
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [29, 58, 87, 116, 145, 174, 203, 217])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [28, 55, 82, 109, 136, 163, 190, 205])
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [30, 90, 150, 210, 270, 330, 390, 421])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [38, 93, 148, 203, 258, 313, 368, 397])
    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num,
                                               rank, args_opt.do_shuffle,
                                               args_opt.enable_data_sink,
                                               args_opt.data_sink_steps,
                                               args_opt.data_dir,
                                               args_opt.schema_dir)
    if args_opt.train_steps > 0:
        new_repeat_count = min(
            new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
    netwithloss = BertNetworkWithLoss(bert_net_cfg, True)

    if cfg.optimizer == 'Lamb':
        optimizer = Lamb(netwithloss.trainable_params(),
                         decay_steps=ds.get_dataset_size() * new_repeat_count,
                         start_learning_rate=cfg.Lamb.start_learning_rate,
                         end_learning_rate=cfg.Lamb.end_learning_rate,
                         power=cfg.Lamb.power,
                         warmup_steps=cfg.Lamb.warmup_steps,
                         weight_decay=cfg.Lamb.weight_decay,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(netwithloss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(
            netwithloss.trainable_params(),
            decay_steps=ds.get_dataset_size() * new_repeat_count,
            learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
            end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
            power=cfg.AdamWeightDecayDynamicLR.power,
            weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
            eps=cfg.AdamWeightDecayDynamicLR.eps,
            warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]"
            .format(cfg.optimizer))
    callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
    print("Enable save checkpoint: ", args_opt.enable_save_ckpt)
    print("Rank ID: ", rank_id)
    if args_opt.enable_save_ckpt == "true" and rank_id % device_num == 0:
        print("Enable save checkpoint")
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     directory=ckpt_save_dir,
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(netwithloss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        netwithgrads = BertTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)

    model = Model(netwithgrads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"))
Exemple #28
0
def _build_training_pipeline(config: TransformerConfig,
                             pre_training_dataset=None,
                             fine_tune_dataset=None,
                             test_dataset=None):
    """
    Build training pipeline.

    Args:
        config (TransformerConfig): Config of mass model.
        pre_training_dataset (Dataset): Pre-training dataset.
        fine_tune_dataset (Dataset): Fine-tune dataset.
        test_dataset (Dataset): Test dataset.
    """
    net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
    net_with_loss.init_parameters_data()

    if config.existed_ckpt:
        if config.existed_ckpt.endswith(".npz"):
            weights = np.load(config.existed_ckpt)
        else:
            weights = load_checkpoint(config.existed_ckpt)
        for param in net_with_loss.trainable_params():
            weights_name = param.name
            if weights_name not in weights:
                raise ValueError(
                    f"Param {weights_name} is not found in ckpt file.")

            if isinstance(weights[weights_name], Parameter):
                param.default_input = weights[weights_name].default_input
            elif isinstance(weights[weights_name], Tensor):
                param.default_input = Tensor(weights[weights_name].asnumpy(),
                                             config.dtype)
            elif isinstance(weights[weights_name], np.ndarray):
                param.default_input = Tensor(weights[weights_name],
                                             config.dtype)
            else:
                param.default_input = weights[weights_name]
    else:
        for param in net_with_loss.trainable_params():
            name = param.name
            value = param.default_input
            if isinstance(value, Tensor):
                if name.endswith(".gamma"):
                    param.default_input = one_weight(value.asnumpy().shape)
                elif name.endswith(".beta") or name.endswith(".bias"):
                    param.default_input = zero_weight(value.asnumpy().shape)
                else:
                    param.default_input = weight_variable(
                        value.asnumpy().shape)

    dataset = pre_training_dataset if pre_training_dataset is not None \
        else fine_tune_dataset

    if dataset is None:
        raise ValueError(
            "pre-training dataset or fine-tuning dataset must be provided one."
        )

    update_steps = dataset.get_repeat_count() * dataset.get_dataset_size()
    if config.lr_scheduler == "isr":
        lr = Tensor(square_root_schedule(
            lr=config.lr,
            update_num=update_steps,
            decay_start_step=config.decay_start_step,
            warmup_steps=config.warmup_steps,
            min_lr=config.min_lr),
                    dtype=mstype.float32)
    elif config.lr_scheduler == "poly":
        lr = Tensor(polynomial_decay_scheduler(
            lr=config.lr,
            min_lr=config.min_lr,
            decay_steps=config.decay_steps,
            total_update_num=update_steps,
            warmup_steps=config.warmup_steps,
            power=config.poly_lr_scheduler_power),
                    dtype=mstype.float32)
    else:
        lr = config.lr

    if config.optimizer.lower() == "adam":
        optimizer = Adam(net_with_loss.trainable_params(),
                         lr,
                         beta1=0.9,
                         beta2=0.98)
    elif config.optimizer.lower() == "lamb":
        lr = BertLearningRate(decay_steps=12000,
                              learning_rate=config.lr,
                              end_learning_rate=config.min_lr,
                              power=10.0,
                              warmup_steps=config.warmup_steps)
        decay_params = list(
            filter(
                lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x
                .name.lower(), net_with_loss.trainable_params()))
        other_params = list(
            filter(
                lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.
                lower(), net_with_loss.trainable_params()))
        group_params = [{
            'params': decay_params,
            'weight_decay': 0.01
        }, {
            'params': other_params
        }]

        optimizer = Lamb(group_params, lr, eps=1e-6)
    elif config.optimizer.lower() == "momentum":
        optimizer = Momentum(net_with_loss.trainable_params(),
                             lr,
                             momentum=0.9)
    else:
        raise ValueError(f"optimizer only support `adam` and `momentum` now.")

    # Dynamic loss scale.
    scale_manager = DynamicLossScaleManager(
        init_loss_scale=config.init_loss_scale,
        scale_factor=config.loss_scale_factor,
        scale_window=config.scale_window)
    net_with_grads = TransformerTrainOneStepWithLossScaleCell(
        network=net_with_loss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    net_with_grads.set_train(True)
    model = Model(net_with_grads)
    loss_monitor = LossCallBack(config)
    ckpt_config = CheckpointConfig(
        save_checkpoint_steps=config.save_ckpt_steps,
        keep_checkpoint_max=config.keep_ckpt_max)

    rank_size = os.getenv('RANK_SIZE')
    callbacks = [loss_monitor]
    if rank_size is not None and int(
            rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
        ckpt_callback = ModelCheckpoint(
            prefix=config.ckpt_prefix,
            directory=os.path.join(config.ckpt_path,
                                   'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
            config=ckpt_config)
        callbacks.append(ckpt_callback)

    if rank_size is None or int(rank_size) == 1:
        ckpt_callback = ModelCheckpoint(
            prefix=config.ckpt_prefix,
            directory=os.path.join(config.ckpt_path,
                                   'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
            config=ckpt_config)
        callbacks.append(ckpt_callback)

    print(f" | ALL SET, PREPARE TO TRAIN.")
    _train(model=model,
           config=config,
           pre_training_dataset=pre_training_dataset,
           fine_tune_dataset=fine_tune_dataset,
           test_dataset=test_dataset,
           callbacks=callbacks)
Exemple #29
0
def test_train():
    '''
    finetune function
    '''
    target = args_opt.device_target
    if target == "Ascend":
        devid = int(os.getenv('DEVICE_ID'))
        context.set_context(mode=context.GRAPH_MODE,
                            device_target="Ascend",
                            device_id=devid)
    elif target == "GPU":
        context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
    else:
        raise Exception("Target error, GPU or Ascend is supported.")
    #BertCLSTrain for classification
    #BertNERTrain for sequence labeling
    if cfg.task == 'NER':
        if cfg.use_crf:
            netwithloss = BertNER(bert_net_cfg,
                                  True,
                                  num_labels=len(tag_to_index),
                                  use_crf=True,
                                  tag_to_index=tag_to_index,
                                  dropout_prob=0.1)
        else:
            netwithloss = BertNER(bert_net_cfg,
                                  True,
                                  num_labels=cfg.num_labels,
                                  dropout_prob=0.1)
    elif cfg.task == 'SQUAD':
        netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
    else:
        netwithloss = BertCLS(bert_net_cfg,
                              True,
                              num_labels=cfg.num_labels,
                              dropout_prob=0.1)
    if cfg.task == 'SQUAD':
        dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
    else:
        dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
    # optimizer
    steps_per_epoch = dataset.get_dataset_size()
    if cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(
            netwithloss.trainable_params(),
            decay_steps=steps_per_epoch * cfg.epoch_num,
            learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
            end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
            power=cfg.AdamWeightDecayDynamicLR.power,
            warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
            weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
            eps=cfg.AdamWeightDecayDynamicLR.eps)
    elif cfg.optimizer == 'Lamb':
        optimizer = Lamb(netwithloss.trainable_params(),
                         decay_steps=steps_per_epoch * cfg.epoch_num,
                         start_learning_rate=cfg.Lamb.start_learning_rate,
                         end_learning_rate=cfg.Lamb.end_learning_rate,
                         power=cfg.Lamb.power,
                         weight_decay=cfg.Lamb.weight_decay,
                         warmup_steps=int(steps_per_epoch * cfg.epoch_num *
                                          0.1),
                         decay_filter=cfg.Lamb.decay_filter)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(netwithloss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported.")
    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix,
                                 directory=cfg.ckpt_dir,
                                 config=ckpt_config)
    param_dict = load_checkpoint(cfg.pre_training_ckpt)
    load_param_into_net(netwithloss, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32,
                                             scale_factor=2,
                                             scale_window=1000)
    if cfg.task == 'SQUAD':
        netwithgrads = BertSquadCell(netwithloss,
                                     optimizer=optimizer,
                                     scale_update_cell=update_cell)
    else:
        netwithgrads = BertFinetuneCell(netwithloss,
                                        optimizer=optimizer,
                                        scale_update_cell=update_cell)
    model = Model(netwithgrads)
    model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb])
Exemple #30
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument(
        "--accumulation_steps",
        type=int,
        default="1",
        help=
        "Accumulating gradients N times before weight update, default is 1.")
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
                rank) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    29, 58, 87, 116, 145, 174, 203, 217
                ])
            else:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    28, 55, 82, 109, 136, 163, 190, 205
                ])
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    30, 90, 150, 210, 270, 330, 390, 421
                ])
            else:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    38, 93, 148, 203, 258, 313, 368, 397
                ])
    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    if args_opt.accumulation_steps > 1:
        logger.info("accumulation steps: {}".format(
            args_opt.accumulation_steps))
        logger.info("global batch size: {}".format(
            bert_net_cfg.batch_size * args_opt.accumulation_steps))
        if args_opt.enable_data_sink == "true":
            args_opt.data_sink_steps *= args_opt.accumulation_steps
            logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
        if args_opt.enable_save_ckpt == "true":
            args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
            logger.info("save checkpoint steps: {}".format(
                args_opt.save_checkpoint_steps))

    ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle,
                             args_opt.data_dir, args_opt.schema_dir)
    net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)

    new_repeat_count = args_opt.epoch_size * ds.get_dataset_size(
    ) // args_opt.data_sink_steps
    if args_opt.train_steps > 0:
        train_steps = args_opt.train_steps * args_opt.accumulation_steps
        new_repeat_count = min(new_repeat_count,
                               train_steps // args_opt.data_sink_steps)
    else:
        args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size(
        ) // args_opt.accumulation_steps
        logger.info("train steps: {}".format(args_opt.train_steps))

    if cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.Lamb.learning_rate,
            end_learning_rate=cfg.Lamb.end_learning_rate,
            warmup_steps=cfg.Lamb.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.Lamb.power)
        params = net_with_loss.trainable_params()
        decay_params = list(filter(cfg.Lamb.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.Lamb.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.Lamb.weight_decay
        }, {
            'params': other_params
        }, {
            'order_params': params
        }]
        optimizer = Lamb(group_params,
                         learning_rate=lr_schedule,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(net_with_loss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=cfg.AdamWeightDecay.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.AdamWeightDecay.power)
        params = net_with_loss.trainable_params()
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }, {
            'order_params': params
        }]

        optimizer = AdamWeightDecay(group_params,
                                    learning_rate=lr_schedule,
                                    eps=cfg.AdamWeightDecay.eps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]"
            .format(cfg.optimizer))
    callback = [
        TimeMonitor(args_opt.data_sink_steps),
        LossCallBack(ds.get_dataset_size())
    ]
    if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(
            8, device_num) == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(
            prefix='checkpoint_bert',
            directory=None if ckpt_save_dir == "" else ckpt_save_dir,
            config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)

        if args_opt.accumulation_steps <= 1:
            net_with_grads = BertTrainOneStepWithLossScaleCell(
                net_with_loss,
                optimizer=optimizer,
                scale_update_cell=update_cell)
        else:
            accumulation_steps = args_opt.accumulation_steps
            net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(
                net_with_loss,
                optimizer=optimizer,
                scale_update_cell=update_cell,
                accumulation_steps=accumulation_steps,
                enable_global_norm=cfg.enable_global_norm)
    else:
        net_with_grads = BertTrainOneStepCell(net_with_loss,
                                              optimizer=optimizer)

    model = Model(net_with_grads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)