Ejemplo n.º 1
0
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE)
    context.set_context(device_target="Ascend")
    context.set_context(enable_task_sink=True)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    parallel_callback = ModelCallback()
    ds = me_de_train_dataset(cfg.bert_config.batch_size)
    config = cfg.bert_config
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Lamb(netwithloss.trainable_params(),
                     decay_steps=cfg.decay_steps,
                     start_learning_rate=cfg.start_learning_rate,
                     end_learning_rate=cfg.end_learning_rate,
                     power=cfg.power,
                     warmup_steps=cfg.num_warmup_steps,
                     decay_filter=lambda x: False)
    netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    config_ck = CheckpointConfig(
        save_checkpoint_steps=cfg.save_checkpoint_steps,
        keep_checkpoint_max=cfg.keep_checkpoint_max)
    ckpoint_cb = ModelCheckpoint(prefix=cfg.checkpoint_prefix,
                                 config=config_ck)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=[parallel_callback, ckpoint_cb],
                dataset_sink_mode=False)
Ejemplo n.º 2
0
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        reserve_class_name_in_scope=False)
    context.set_context(enable_task_sink=True)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    parallel_callback = ModelCallback()
    ds = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    batch_size = int(os.getenv('BATCH_SIZE', '16'))
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Momentum(netwithloss.trainable_params(),
                         learning_rate=2e-5,
                         momentum=0.9)
    netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    params = netwithloss.trainable_params()
    for param in params:
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info(
                        "***************** BERT param name is 1 {}".format(
                            name))
                    param.default_input = weight_variable(
                        value.asnumpy().shape)
                else:
                    logger.info(
                        "***************** BERT param name is 2 {}".format(
                            name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(
                        np.transpose(weight_value, [1, 0]))
            else:
                logger.info(
                    "***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=parallel_callback,
                dataset_sink_mode=False)
    loss_value = np.array(parallel_callback.loss_list)
    expect_out = [
        12.19179, 11.965041, 11.969687, 11.97815, 11.969171, 12.603289,
        12.165594, 12.824818, 12.38842, 12.604046
    ]
    logger.info("expected loss value output: {}".format(expect_out))
    assert allclose(loss_value, expect_out, 0.00001, 0.00001)
Ejemplo n.º 3
0
    class ModelBert(nn.Cell):
        """
        ModelBert definition
        """
        def __init__(self, network, optimizer=None):
            super(ModelBert, self).__init__()
            self.optimizer = optimizer
            self.train_network = BertTrainOneStepCell(network, self.optimizer)
            self.train_network.set_train()

        def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6):
            return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6)
Ejemplo n.º 4
0
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
    context.set_context(enable_task_sink=True)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    parallel_callback = ModelCallback()
    ds = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    batch_size = int(os.getenv('BATCH_SIZE', '16'))
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Lamb(netwithloss.trainable_params(), decay_steps=10000, start_learning_rate=1e-4,
                     end_learning_rate=0.0, power=10.0, warmup_steps=0, decay_filter=lambda x: False)
    netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    params = netwithloss.trainable_params()
    for param in params:
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info("***************** BERT param name is 1 {}".format(name))
                    param.default_input = weight_variable(value.asnumpy().shape)
                else:
                    logger.info("***************** BERT param name is 2 {}".format(name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(np.transpose(weight_value, [1, 0]))
            else:
                logger.info("***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(ds.get_repeat_count(), ds, callbacks=parallel_callback, dataset_sink_mode=False)
    loss_value = np.array(parallel_callback.loss_list)
    expect_out = [12.191790, 11.739655, 11.523477, 11.320723, 11.113152, 11.203759, 10.841681, 10.826849,
                  10.616718, 10.486609]
    logger.info("expected loss value output: {}".format(expect_out))
    assert allclose(loss_value, expect_out, 0.001, 0.001)
Ejemplo n.º 5
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_task_sink",
                        type=str,
                        default="true",
                        help="Enable task sink, default is true.")
    parser.add_argument("--enable_loop_sink",
                        type=str,
                        default="true",
                        help="Enable loop sink, default is true.")
    parser.add_argument("--enable_mem_reuse",
                        type=str,
                        default="true",
                        help="Enable mem reuse, default is true.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--checkpoint_path",
                        type=str,
                        default="",
                        help="Checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=args_opt.device_id)
    context.set_context(enable_task_sink=(args_opt.enable_task_sink == "true"),
                        enable_loop_sink=(args_opt.enable_loop_sink == "true"),
                        enable_mem_reuse=(args_opt.enable_mem_reuse == "true"))
    context.set_context(reserve_class_name_in_scope=False)

    if args_opt.distribute == "true":
        device_num = args_opt.device_num
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
        D.init()
        rank = args_opt.device_id % device_num
    else:
        rank = 0
        device_num = 1

    ds = create_bert_dataset(args_opt.epoch_size, device_num, rank,
                             args_opt.do_shuffle, args_opt.enable_data_sink,
                             args_opt.data_sink_steps, args_opt.data_dir,
                             args_opt.schema_dir)

    netwithloss = BertNetworkWithLoss(bert_net_cfg, True)

    if cfg.optimizer == 'Lamb':
        optimizer = Lamb(netwithloss.trainable_params(),
                         decay_steps=ds.get_dataset_size() *
                         ds.get_repeat_count(),
                         start_learning_rate=cfg.Lamb.start_learning_rate,
                         end_learning_rate=cfg.Lamb.end_learning_rate,
                         power=cfg.Lamb.power,
                         warmup_steps=cfg.Lamb.warmup_steps,
                         weight_decay=cfg.Lamb.weight_decay,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(netwithloss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(
            netwithloss.trainable_params(),
            decay_steps=ds.get_dataset_size() * ds.get_repeat_count(),
            learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
            end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
            power=cfg.AdamWeightDecayDynamicLR.power,
            weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
            eps=cfg.AdamWeightDecayDynamicLR.eps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]"
            .format(cfg.optimizer))
    callback = [LossCallBack()]
    if args_opt.enable_save_ckpt == "true":
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.checkpoint_path:
        param_dict = load_checkpoint(args_opt.checkpoint_path)
        load_param_into_net(netwithloss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        netwithgrads = BertTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)

    model = Model(netwithgrads)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"))
Ejemplo n.º 6
0
 def __init__(self, network, optimizer=None):
     super(ModelBert, self).__init__()
     self.optimizer = optimizer
     self.train_network = BertTrainOneStepCell(network, self.optimizer)
     self.train_network.set_train()