Ejemplo n.º 1
0
    class ModelBert(nn.Cell):
        def __init__(self, network, optimizer=None):
            super(ModelBert, self).__init__()
            self.optimizer = optimizer
            self.train_network = BertTrainOneStepWithLossScaleCell(network, self.optimizer)
            self.train_network.set_train()

        def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7):
            return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7)
Ejemplo n.º 2
0
    class ModelBert(nn.Cell):
        def __init__(self, network, optimizer=None):
            super(ModelBert, self).__init__()
            self.optimizer = optimizer
            manager = DynamicLossScaleManager()
            update_cell = LossScaleUpdateCell(manager)
            self.train_network = BertTrainOneStepWithLossScaleCell(
                network, self.optimizer, scale_update_cell=update_cell)
            self.train_network.set_train()

        def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6):
            return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6)
Ejemplo n.º 3
0
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        reserve_class_name_in_scope=False)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    ds = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    batch_size = int(os.getenv('BATCH_SIZE', '16'))
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Momentum(netwithloss.trainable_params(),
                         learning_rate=2e-5,
                         momentum=0.9)
    scale_window = 3
    scale_manager = DynamicLossScaleManager(2**16, 2, scale_window)
    netwithgrads = BertTrainOneStepWithLossScaleCell(
        netwithloss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    callback = ModelCallback()
    params = netwithloss.trainable_params()
    for param in params:
        param.init_data()
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info(
                        "***************** BERT param name is 1 {}".format(
                            name))
                    param.default_input = weight_variable(
                        value.asnumpy().shape)
                else:
                    logger.info(
                        "***************** BERT param name is 2 {}".format(
                            name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(
                        np.transpose(weight_value, [1, 0]))
            else:
                logger.info(
                    "***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=callback,
                dataset_sink_mode=False)

    # assertion occurs while the loss value, overflow state or loss_scale value is wrong
    loss_value = np.array(callback.loss_list)
    expect_loss_value = [
        12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916,
        12.17565, 12.840416, 12.40291, 12.621661
    ]
    print("loss value: {}".format(loss_value))
    assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)

    overflow = np.array(callback.overflow_list)
    expect_overflow = [
        True, True, False, False, False, True, False, False, False, True
    ]
    print("overflow: {}".format(overflow))
    assert (overflow == expect_overflow).all()

    loss_scale = np.array(callback.lossscale_list)
    expect_loss_scale = [
        32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0,
        32768.0, 16384.0
    ]
    print("loss scale: {}".format(loss_scale))
    assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001)
Ejemplo n.º 4
0
def test_bert_tdt():
    """test bert tdt"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        reserve_class_name_in_scope=False)
    context.set_context(enable_task_sink=True)
    context.set_context(enable_loop_sink=True)
    context.set_context(enable_mem_reuse=True)
    ds = me_de_train_dataset()
    version = os.getenv('VERSION', 'large')
    batch_size = int(os.getenv('BATCH_SIZE', '16'))
    config = get_config(version=version, batch_size=batch_size)
    netwithloss = BertNetworkWithLoss(config, True)
    optimizer = Momentum(netwithloss.trainable_params(),
                         learning_rate=2e-5,
                         momentum=0.9)
    scale_window = 3
    scale_manager = DynamicLossScaleManager(2**32, 2, scale_window)
    netwithgrads = BertTrainOneStepWithLossScaleCell(
        netwithloss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    callback = ModelCallback()
    params = netwithloss.trainable_params()
    for param in params:
        value = param.default_input
        name = param.name
        if isinstance(value, Tensor):
            if name.split('.')[-1] in ['weight']:
                if name.split('.')[-3] in ['cls2']:
                    logger.info(
                        "***************** BERT param name is 1 {}".format(
                            name))
                    param.default_input = weight_variable(
                        value.asnumpy().shape)
                else:
                    logger.info(
                        "***************** BERT param name is 2 {}".format(
                            name))
                    tempshape = value.asnumpy().shape
                    shape = (tempshape[1], tempshape[0])
                    weight_value = weight_variable(shape).asnumpy()
                    param.default_input = Tensor(
                        np.transpose(weight_value, [1, 0]))
            else:
                logger.info(
                    "***************** BERT param name is 3 {}".format(name))
                param.default_input = weight_variable(value.asnumpy().shape)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=callback,
                dataset_sink_mode=False)

    # assertion occurs while the loss_scale value is wrong
    count = 0
    for i in range(len(callback.overflow_list)):
        if callback.overflow_list[i] == Tensor(True, mstype.bool_) and i > 0:
            count = 0
            assert callback.lossscale_list[i] == callback.lossscale_list[
                i - 1] * Tensor(0.5, mstype.float32)
        if callback.overflow_list[i] == Tensor(False, mstype.bool_):
            count = count + 1
            if count == scale_window:
                count = 0
                assert callback.lossscale_list[i] == callback.lossscale_list[
                    i - 1] * Tensor(2.0, mstype.float32)