Exemplo n.º 1
0
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE)
    context.set_context(device_target="Ascend")
    context.set_context(enable_task_sink=True)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    parallel_callback = ModelCallback()
    ds = me_de_train_dataset(cfg.bert_config.batch_size)
    config = cfg.bert_config
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Lamb(netwithloss.trainable_params(),
                     decay_steps=cfg.decay_steps,
                     start_learning_rate=cfg.start_learning_rate,
                     end_learning_rate=cfg.end_learning_rate,
                     power=cfg.power,
                     warmup_steps=cfg.num_warmup_steps,
                     decay_filter=lambda x: False)
    netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    config_ck = CheckpointConfig(
        save_checkpoint_steps=cfg.save_checkpoint_steps,
        keep_checkpoint_max=cfg.keep_checkpoint_max)
    ckpoint_cb = ModelCheckpoint(prefix=cfg.checkpoint_prefix,
                                 config=config_ck)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=[parallel_callback, ckpoint_cb],
                dataset_sink_mode=False)
Exemplo n.º 2
0
def bert_withlossscale_manager_train_feed():
    class ModelBert(nn.Cell):
        def __init__(self, network, optimizer=None):
            super(ModelBert, self).__init__()
            self.optimizer = optimizer
            manager = DynamicLossScaleManager()
            update_cell = LossScaleUpdateCell(manager)
            self.train_network = BertTrainOneStepWithLossScaleCell(
                network, self.optimizer, scale_update_cell=update_cell)
            self.train_network.set_train()

        def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7):
            return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6,
                                      arg7)

    version = os.getenv('VERSION', 'base')
    batch_size = int(os.getenv('BATCH_SIZE', '1'))
    scaling_sens = Tensor(np.ones([1]).astype(np.float32))
    inputs = load_test_data(batch_size) + (scaling_sens, )

    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10)
    net = ModelBert(netwithloss, optimizer=optimizer)
    net.set_train()
    build_construct_graph(net, *inputs, execute=True)
Exemplo n.º 3
0
def test_bert_train():
    """
    the main function
    """
    class ModelBert(nn.Cell):
        """
        ModelBert definition
        """
        def __init__(self, network, optimizer=None):
            super(ModelBert, self).__init__()
            self.optimizer = optimizer
            self.train_network = BertTrainOneStepCell(network, self.optimizer)
            self.train_network.set_train()

        def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6):
            return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6)

    version = os.getenv('VERSION', 'large')
    batch_size = int(os.getenv('BATCH_SIZE', '1'))
    inputs = load_test_data(batch_size)

    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10)
    net = ModelBert(netwithloss, optimizer=optimizer)
    net.set_train()
    build_construct_graph(net, *inputs, execute=False)
Exemplo n.º 4
0
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        reserve_class_name_in_scope=False)
    context.set_context(enable_task_sink=True)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    parallel_callback = ModelCallback()
    ds = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    batch_size = int(os.getenv('BATCH_SIZE', '16'))
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Momentum(netwithloss.trainable_params(),
                         learning_rate=2e-5,
                         momentum=0.9)
    netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    params = netwithloss.trainable_params()
    for param in params:
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info(
                        "***************** BERT param name is 1 {}".format(
                            name))
                    param.default_input = weight_variable(
                        value.asnumpy().shape)
                else:
                    logger.info(
                        "***************** BERT param name is 2 {}".format(
                            name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(
                        np.transpose(weight_value, [1, 0]))
            else:
                logger.info(
                    "***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=parallel_callback,
                dataset_sink_mode=False)
    loss_value = np.array(parallel_callback.loss_list)
    expect_out = [
        12.19179, 11.965041, 11.969687, 11.97815, 11.969171, 12.603289,
        12.165594, 12.824818, 12.38842, 12.604046
    ]
    logger.info("expected loss value output: {}".format(expect_out))
    assert allclose(loss_value, expect_out, 0.00001, 0.00001)
Exemplo n.º 5
0
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
    context.set_context(enable_task_sink=True)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    parallel_callback = ModelCallback()
    ds = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    batch_size = int(os.getenv('BATCH_SIZE', '16'))
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Lamb(netwithloss.trainable_params(), decay_steps=10000, start_learning_rate=1e-4,
                     end_learning_rate=0.0, power=10.0, warmup_steps=0, decay_filter=lambda x: False)
    netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    params = netwithloss.trainable_params()
    for param in params:
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info("***************** BERT param name is 1 {}".format(name))
                    param.default_input = weight_variable(value.asnumpy().shape)
                else:
                    logger.info("***************** BERT param name is 2 {}".format(name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(np.transpose(weight_value, [1, 0]))
            else:
                logger.info("***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(ds.get_repeat_count(), ds, callbacks=parallel_callback, dataset_sink_mode=False)
    loss_value = np.array(parallel_callback.loss_list)
    expect_out = [12.191790, 11.739655, 11.523477, 11.320723, 11.113152, 11.203759, 10.841681, 10.826849,
                  10.616718, 10.486609]
    logger.info("expected loss value output: {}".format(expect_out))
    assert allclose(loss_value, expect_out, 0.001, 0.001)
Exemplo n.º 6
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_task_sink",
                        type=str,
                        default="true",
                        help="Enable task sink, default is true.")
    parser.add_argument("--enable_loop_sink",
                        type=str,
                        default="true",
                        help="Enable loop sink, default is true.")
    parser.add_argument("--enable_mem_reuse",
                        type=str,
                        default="true",
                        help="Enable mem reuse, default is true.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--checkpoint_path",
                        type=str,
                        default="",
                        help="Checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=args_opt.device_id)
    context.set_context(enable_task_sink=(args_opt.enable_task_sink == "true"),
                        enable_loop_sink=(args_opt.enable_loop_sink == "true"),
                        enable_mem_reuse=(args_opt.enable_mem_reuse == "true"))
    context.set_context(reserve_class_name_in_scope=False)

    if args_opt.distribute == "true":
        device_num = args_opt.device_num
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
        D.init()
        rank = args_opt.device_id % device_num
    else:
        rank = 0
        device_num = 1

    ds = create_bert_dataset(args_opt.epoch_size, device_num, rank,
                             args_opt.do_shuffle, args_opt.enable_data_sink,
                             args_opt.data_sink_steps, args_opt.data_dir,
                             args_opt.schema_dir)

    netwithloss = BertNetworkWithLoss(bert_net_cfg, True)

    if cfg.optimizer == 'Lamb':
        optimizer = Lamb(netwithloss.trainable_params(),
                         decay_steps=ds.get_dataset_size() *
                         ds.get_repeat_count(),
                         start_learning_rate=cfg.Lamb.start_learning_rate,
                         end_learning_rate=cfg.Lamb.end_learning_rate,
                         power=cfg.Lamb.power,
                         warmup_steps=cfg.Lamb.warmup_steps,
                         weight_decay=cfg.Lamb.weight_decay,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(netwithloss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(
            netwithloss.trainable_params(),
            decay_steps=ds.get_dataset_size() * ds.get_repeat_count(),
            learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
            end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
            power=cfg.AdamWeightDecayDynamicLR.power,
            weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
            eps=cfg.AdamWeightDecayDynamicLR.eps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]"
            .format(cfg.optimizer))
    callback = [LossCallBack()]
    if args_opt.enable_save_ckpt == "true":
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.checkpoint_path:
        param_dict = load_checkpoint(args_opt.checkpoint_path)
        load_param_into_net(netwithloss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        netwithgrads = BertTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)

    model = Model(netwithgrads)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"))
Exemplo n.º 7
0

def load_test_data():
    dataset = get_dataset()
    return dataset.next()


input_ids, input_mask, token_type_id, \
next_sentence_labels, masked_lm_positions, \
masked_lm_ids, masked_lm_weights = load_test_data()

test_sets = [
    ('BertNetworkWithLoss_1', {
        'block':
        BertNetworkWithLoss(BertConfig(batch_size=1),
                            False,
                            use_one_hot_embeddings=True),
        'desc_inputs': [
            input_ids, input_mask, token_type_id, next_sentence_labels,
            masked_lm_positions, masked_lm_ids, masked_lm_weights
        ],
        'desc_bprop': [[1]]
    }),
    ('BertNetworkWithLoss_2', {
        'block':
        BertNetworkWithLoss(BertConfig(batch_size=1), False, True),
        'desc_inputs': [
            input_ids, input_mask, token_type_id, next_sentence_labels,
            masked_lm_positions, masked_lm_ids, masked_lm_weights
        ],
        'desc_bprop': [[1]]
Exemplo n.º 8
0
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        reserve_class_name_in_scope=False)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    ds = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    batch_size = int(os.getenv('BATCH_SIZE', '16'))
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Momentum(netwithloss.trainable_params(),
                         learning_rate=2e-5,
                         momentum=0.9)
    scale_window = 3
    scale_manager = DynamicLossScaleManager(2**16, 2, scale_window)
    netwithgrads = BertTrainOneStepWithLossScaleCell(
        netwithloss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    callback = ModelCallback()
    params = netwithloss.trainable_params()
    for param in params:
        param.init_data()
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info(
                        "***************** BERT param name is 1 {}".format(
                            name))
                    param.default_input = weight_variable(
                        value.asnumpy().shape)
                else:
                    logger.info(
                        "***************** BERT param name is 2 {}".format(
                            name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(
                        np.transpose(weight_value, [1, 0]))
            else:
                logger.info(
                    "***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=callback,
                dataset_sink_mode=False)

    # assertion occurs while the loss value, overflow state or loss_scale value is wrong
    loss_value = np.array(callback.loss_list)
    expect_loss_value = [
        12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916,
        12.17565, 12.840416, 12.40291, 12.621661
    ]
    print("loss value: {}".format(loss_value))
    assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)

    overflow = np.array(callback.overflow_list)
    expect_overflow = [
        True, True, False, False, False, True, False, False, False, True
    ]
    print("overflow: {}".format(overflow))
    assert (overflow == expect_overflow).all()

    loss_scale = np.array(callback.lossscale_list)
    expect_loss_scale = [
        32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0,
        32768.0, 16384.0
    ]
    print("loss scale: {}".format(loss_scale))
    assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001)
Exemplo n.º 9
0
 {
     'id':
     'BertNetworkWithLoss',
     'group':
     'bert',
     'block':
     BertNetworkWithLoss(config=BertConfig(
         batch_size=1,
         seq_length=128,
         vocab_size=21128,
         hidden_size=1024,
         num_hidden_layers=2,
         num_attention_heads=16,
         intermediate_size=4096,
         hidden_act="gelu",
         hidden_dropout_prob=0.0,
         attention_probs_dropout_prob=0.0,
         max_position_embeddings=512,
         type_vocab_size=2,
         initializer_range=0.02,
         use_relative_positions=True,
         input_mask_from_dataset=True,
         token_type_ids_from_dataset=True,
         dtype=mstype.float32,
         compute_type=mstype.float32),
                         is_training=True),
     'reduce_output':
     False
 },
 {
     'id':
     'BertAttentionQueryKeyMul_CICase',
Exemplo n.º 10
0
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        reserve_class_name_in_scope=False)
    context.set_context(enable_task_sink=True)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    ds = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    batch_size = int(os.getenv('BATCH_SIZE', '16'))
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Momentum(netwithloss.trainable_params(),
                         learning_rate=2e-5,
                         momentum=0.9)
    scale_window = 3
    scale_manager = DynamicLossScaleManager(2**32, 2, scale_window)
    netwithgrads = BertTrainOneStepWithLossScaleCell(
        netwithloss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    callback = ModelCallback()
    params = netwithloss.trainable_params()
    for param in params:
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info(
                        "***************** BERT param name is 1 {}".format(
                            name))
                    param.default_input = weight_variable(
                        value.asnumpy().shape)
                else:
                    logger.info(
                        "***************** BERT param name is 2 {}".format(
                            name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(
                        np.transpose(weight_value, [1, 0]))
            else:
                logger.info(
                    "***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=callback,
                dataset_sink_mode=False)

    # assertion occurs while the loss_scale value is wrong
    count = 0
    for i in range(len(callback.overflow_list)):
        if callback.overflow_list[i] == Tensor(True, mstype.bool_) and i > 0:
            count = 0
            assert callback.lossscale_list[i] == callback.lossscale_list[
                i - 1] * Tensor(0.5, mstype.float32)
        if callback.overflow_list[i] == Tensor(False, mstype.bool_):
            count = count + 1
            if count == scale_window:
                count = 0
                assert callback.lossscale_list[i] == callback.lossscale_list[
                    i - 1] * Tensor(2.0, mstype.float32)