Exemplo n.º 1
0
    def __init__(self, hparams: Hparams, **kwargs):
        super(ErnieEmbedding, self).__init__(**kwargs)
        self.vocab_size = hparams.vocab_size
        self.hidden_size = hparams.hidden_size
        self.initializer_range = hparams.initializer_range
        self.use_task_id = hparams.use_task_id

        self.position_embeddings = tf.keras.layers.Embedding(
            hparams.max_position_embeddings,
            hparams.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="position_embeddings")

        self.token_type_embeddings = tf.keras.layers.Embedding(
            hparams.get("type_vocab_size",
                        hparams.get("sent_type_vocab_size")),
            hparams.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="token_type_embeddings")

        if self.use_task_id:
            self.task_embeddings = tf.keras.layers.Embedding(
                hparams.task_type_vocab_size,
                hparams.hidden_size,
                embeddings_initializer=get_initializer(self.initializer_range),
                name="task_type_embeddings")

        self.layer_norm = tf.keras.layers.LayerNormalization(
            epsilon=hparams.layer_norm_eps, name="LayerNorm")
        self.dropout = tf.keras.layers.Dropout(hparams.hidden_dropout_prob)
Exemplo n.º 2
0
def load_dataset(hparams: Hparams,
                 ret_train=True,
                 ret_dev=True,
                 ret_test=True,
                 ret_info=True):
    from aispace import datasets

    train_split, validation_split, test_split = get_dataset_split(hparams)
    if ret_train:
        train_datasets, dataset_info = build_dataset(hparams,
                                                     train_split,
                                                     with_info=True)
    if ret_dev:
        dev_datasets, dev_dataset_info = build_dataset(hparams,
                                                       validation_split,
                                                       with_info=True)
        if dev_dataset_info is not None:
            dataset_info = dev_dataset_info
    if ret_test:
        test_datasets, test_dataset_info = build_dataset(hparams,
                                                         test_split,
                                                         with_info=True)
        if test_dataset_info is not None:
            dataset_info = test_dataset_info

    # check the consistence of tokenizer using in building dataset and now using.
    if hparams.get("dataset", {}).get("tokenizer", {}).get("name", "") != "":
        if dataset_info.metadata is None:
            logger.warning("dataset_info has no metadata attribute.")
        elif hparams.get("dataset", {}).get("tokenizer", {}).get("name", "") \
                != dataset_info.metadata.get("tokenizer", ""):
            raise ValueError(
                f'The dataset is built using tokenizer {dataset_info.metadata.get("tokenizer", "")}, '
                f'however, now is using {hparams.get("dataset", {}).get("tokenizer", {}).get("name", "")}, '
                f'please remove/rebuild the data and restart!')
        elif hparams.get("pretrained", {}).get("config", {}).get("vocab_size", 0) \
                != dataset_info.metadata.get("vocab_size", 0):
            raise ValueError(
                f'The dataset is built using tokenizer {dataset_info.metadata.get("tokenizer", "")}, '
                f'whose vocab size is {dataset_info.metadata.get("vocab_size", "xx")},'
                f'however, now is {hparams.get("pretrained", {}).get("config", {}).get("vocab_size", 0)}, '
                f'please remove/rebuild the data and restart!')

    # data mapping
    def build_generator(fields):
        input_names = [itm.get('name') for itm in hparams.dataset.inputs]
        output_names = [itm.get('name') for itm in hparams.dataset.outputs]
        output_name2column = {
            itm.get('name'): itm.get('column')
            for itm in hparams.dataset.outputs
        }
        inputs, outputs = {}, {}
        for k, v in fields.items():
            if k in input_names:
                inputs[k] = v
            elif k in output_names:
                inputs[output_name2column.get(k, k)] = v
                outputs[k] = v
            else:
                raise ValueError(f"{k} not in inputs or outputs.")
        return inputs, outputs

    training_hparams = hparams.training
    # reset some hparams
    if ret_info:
        print(dataset_info)
        # train_data_size = dataset_info.splits.get("train").num_examples
        # validation_data_size = dataset_info.splits.get("validation").num_examples
        # test_data_size = dataset_info.splits.get("test").num_examples
        # steps_per_epoch = int(train_data_size / training_hparams.batch_size)
        # num_warmup_steps = \
        #     int(
        #         training_hparams.max_epochs * train_data_size * training_hparams.warmup_factor / training_hparams.batch_size)
        # num_warmup_steps = min(steps_per_epoch, num_warmup_steps)

        # if validation_data_size is not None:
        #     validation_steps = validation_data_size // training_hparams.batch_size
        # else:
        #     validation_steps = None
        #
        # if test_data_size is not None:
        #     test_steps = test_data_size // training_hparams.batch_size
        # else:
        #     test_steps = None

    for i in range(len(train_split)):
        # build batch
        if ret_train:
            if train_datasets is not None and train_datasets[i] is not None:
                # get train_steps and reset training hparams
                logger.info(
                    "Reset training hparams according to real training data info."
                )
                steps_per_epoch = 0
                for _ in train_datasets[i]:
                    steps_per_epoch += 1
                steps_per_epoch //= training_hparams.batch_size
                num_warmup_steps = \
                    int(training_hparams.max_epochs * steps_per_epoch * training_hparams.warmup_factor)
                if "num_warmup_steps" not in training_hparams or training_hparams.num_warmup_steps <= 0:
                    hparams.cascade_set('training.num_warmup_steps',
                                        num_warmup_steps)
                    logger.info(
                        f"Set training.num_warmup_steps to {num_warmup_steps}")
                else:
                    logger.info(
                        f"Get training.num_warmup_steps is {hparams.training.num_warmup_steps}"
                    )
                if "steps_per_epoch" not in training_hparams or training_hparams.steps_per_epoch <= 0:
                    hparams.cascade_set('training.steps_per_epoch',
                                        steps_per_epoch)
                    logger.info(
                        f"Set training.steps_per_epoch to {steps_per_epoch}")
                else:
                    logger.info(
                        f"Get training.steps_per_epoch is {hparams.training.steps_per_epoch}"
                    )

                # prepare train dataset
                train_dataset = train_datasets[i].\
                    map(build_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE). \
                    prefetch(buffer_size=tf.data.experimental.AUTOTUNE). \
                    shuffle(hparams.training.shuffle_size).\
                    repeat(). \
                    batch(hparams.training.batch_size)

                logger.info("Train dataset has loaded.")
            else:
                train_dataset = None
                logger.info("Train dateset get None.")
        if ret_dev:
            if dev_datasets is not None and dev_datasets[i] is not None:
                logger.info(
                    "Reset validation hparams according to real validation data info."
                )
                validation_steps = 0
                for _ in dev_datasets[i]:
                    validation_steps += 1
                validation_steps //= training_hparams.batch_size
                if "validation_steps" not in training_hparams or training_hparams.validation_steps <= 0:
                    hparams.cascade_set('training.validation_steps',
                                        validation_steps)
                    logger.info(
                        f"Set training.validation_steps to {validation_steps}")
                else:
                    logger.info(
                        f"Get training.validation_steps is {hparams.training.validation_steps}"
                    )

                dev_dataset = dev_datasets[i].\
                    map(build_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE). \
                    prefetch(buffer_size=tf.data.experimental.AUTOTUNE). \
                    repeat(). \
                    batch(hparams.training.batch_size)

                logger.info("Validation dataset has loaded.")
            else:
                dev_dataset = None
                logger.info("Validation dataset get None.")
        if ret_test:
            if test_datasets is not None and test_datasets[i] is not None:
                logger.info(
                    "Reset test hparams according to real test data info.")
                test_steps = 0
                for _ in test_datasets[i]:
                    test_steps += 1
                test_steps //= training_hparams.batch_size
                if "test_steps" not in training_hparams or training_hparams.test_steps <= 0:
                    hparams.cascade_set('training.test_steps', test_steps)
                    logger.info(f"Set training.test_steps to {test_steps}")
                else:
                    logger.info(
                        f"Get training.test_steps is {hparams.training.test_steps}"
                    )

                test_dataset = test_datasets[i]. \
                    map(build_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE). \
                    prefetch(buffer_size=tf.data.experimental.AUTOTUNE). \
                    batch(hparams.training.batch_size)

                logger.info("Test dataset has loaded.")
            else:
                test_dataset = None
                logger.info("Test dataset get None.")

        result = ()
        if ret_train:
            result += (train_dataset, )
        if ret_dev:
            result += (dev_dataset, )
        if ret_test:
            result += (test_dataset, )

        if ret_info:
            result += (dataset_info, )

        yield result