コード例 #1
0
    async def test_gen_batch(self, trained_policy: TEDPolicy,
                             default_domain: Domain, stories_path: Path):
        training_trackers = await tests.core.test_policies.train_trackers(
            default_domain, stories_path, augmentation_factor=0)
        interpreter = RegexInterpreter()
        training_data, label_ids, entity_tags = trained_policy._featurize_for_training(
            training_trackers, default_domain, interpreter)
        label_data, all_labels = trained_policy._create_label_data(
            default_domain, interpreter)
        model_data = trained_policy._create_model_data(training_data,
                                                       label_ids, entity_tags,
                                                       all_labels)
        batch_size = 2
        data_generator = RasaBatchDataGenerator(model_data,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                batch_strategy="sequence")
        iterator = iter(data_generator)
        # model data keys were sorted, so the order is alphabetical
        (
            (
                batch_action_name_mask,
                batch_action_name_sentence_indices,
                batch_action_name_sentence_data,
                batch_action_name_sentence_shape,
                batch_dialogue_length,
                batch_entities_mask,
                batch_entities_sentence_indices,
                batch_entities_sentence_data,
                batch_entities_sentence_shape,
                batch_intent_mask,
                batch_intent_sentence_indices,
                batch_intent_sentence_data,
                batch_intent_sentence_shape,
                batch_label_ids,
                batch_slots_mask,
                batch_slots_sentence_indices,
                batch_slots_sentence_data,
                batch_slots_sentence_shape,
            ),
            _,
        ) = next(iterator)

        assert (batch_label_ids.shape[0] == batch_size
                and batch_dialogue_length.shape[0] == batch_size)
        # batch and dialogue dimensions are NOT combined for masks
        assert (batch_slots_mask.shape[0] == batch_size
                and batch_intent_mask.shape[0] == batch_size
                and batch_entities_mask.shape[0] == batch_size
                and batch_action_name_mask.shape[0] == batch_size)
        # some features might be "fake" so there sequence is `0`
        seq_len = max([
            batch_intent_sentence_shape[1],
            batch_action_name_sentence_shape[1],
            batch_entities_sentence_shape[1],
            batch_slots_sentence_shape[1],
        ])
        assert (batch_intent_sentence_shape[1] == seq_len
                or batch_intent_sentence_shape[1] == 0)
        assert (batch_action_name_sentence_shape[1] == seq_len
                or batch_action_name_sentence_shape[1] == 0)
        assert (batch_entities_sentence_shape[1] == seq_len
                or batch_entities_sentence_shape[1] == 0)
        assert (batch_slots_sentence_shape[1] == seq_len
                or batch_slots_sentence_shape[1] == 0)

        data_generator = RasaBatchDataGenerator(model_data,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                batch_strategy="balanced")
        iterator = iter(data_generator)

        (
            (
                batch_action_name_mask,
                batch_action_name_sentence_indices,
                batch_action_name_sentence_data,
                batch_action_name_sentence_shape,
                batch_dialogue_length,
                batch_entities_mask,
                batch_entities_sentence_indices,
                batch_entities_sentence_data,
                batch_entities_sentence_shape,
                batch_intent_mask,
                batch_intent_sentence_indices,
                batch_intent_sentence_data,
                batch_intent_sentence_shape,
                batch_label_ids,
                batch_slots_mask,
                batch_slots_sentence_indices,
                batch_slots_sentence_data,
                batch_slots_sentence_shape,
            ),
            _,
        ) = next(iterator)

        assert (batch_label_ids.shape[0] == batch_size
                and batch_dialogue_length.shape[0] == batch_size)
        # some features might be "fake" so there sequence is `0`
        seq_len = max([
            batch_intent_sentence_shape[1],
            batch_action_name_sentence_shape[1],
            batch_entities_sentence_shape[1],
            batch_slots_sentence_shape[1],
        ])
        assert (batch_intent_sentence_shape[1] == seq_len
                or batch_intent_sentence_shape[1] == 0)
        assert (batch_action_name_sentence_shape[1] == seq_len
                or batch_action_name_sentence_shape[1] == 0)
        assert (batch_entities_sentence_shape[1] == seq_len
                or batch_entities_sentence_shape[1] == 0)
        assert (batch_slots_sentence_shape[1] == seq_len
                or batch_slots_sentence_shape[1] == 0)
コード例 #2
0
ファイル: test_policies.py プロジェクト: ysinjab/rasa
    async def test_gen_batch(self, trained_policy: TEDPolicy, default_domain: Domain):
        training_trackers = await train_trackers(default_domain, augmentation_factor=0)
        interpreter = RegexInterpreter()
        training_data, label_ids = trained_policy.featurize_for_training(
            training_trackers, default_domain, interpreter
        )
        label_data, all_labels = trained_policy._create_label_data(
            default_domain, interpreter
        )
        model_data = trained_policy._create_model_data(
            training_data, label_ids, all_labels
        )
        batch_size = 2
        (
            batch_label_ids,
            batch_entities_mask,
            batch_entities_sentence_1,
            batch_entities_sentence_2,
            batch_entities_sentence_3,
            batch_intent_mask,
            batch_intent_sentence_1,
            batch_intent_sentence_2,
            batch_intent_sentence_3,
            batch_slots_mask,
            batch_slots_sentence_1,
            batch_slots_sentence_2,
            batch_slots_sentence_3,
            batch_action_name_mask,
            batch_action_name_sentence_1,
            batch_action_name_sentence_2,
            batch_action_name_sentence_3,
            batch_dialogue_length,
        ) = next(model_data._gen_batch(batch_size=batch_size))

        assert (
            batch_intent_mask.shape[0] == batch_size
            and batch_action_name_mask.shape[0] == batch_size
            and batch_entities_mask.shape[0] == batch_size
            and batch_slots_mask.shape[0] == batch_size
        )
        assert (
            batch_intent_sentence_3[1]
            == batch_action_name_sentence_3[1]
            == batch_entities_sentence_3[1]
            == batch_slots_sentence_3[1]
        )

        (
            batch_label_ids,
            batch_entities_mask,
            batch_entities_sentence_1,
            batch_entities_sentence_2,
            batch_entities_sentence_3,
            batch_intent_mask,
            batch_intent_sentence_1,
            batch_intent_sentence_2,
            batch_intent_sentence_3,
            batch_slots_mask,
            batch_slots_sentence_1,
            batch_slots_sentence_2,
            batch_slots_sentence_3,
            batch_action_name_mask,
            batch_action_name_sentence_1,
            batch_action_name_sentence_2,
            batch_action_name_sentence_3,
            batch_dialogue_length,
        ) = next(
            model_data._gen_batch(
                batch_size=batch_size, batch_strategy="balanced", shuffle=True
            )
        )

        assert (
            batch_intent_mask.shape[0] == batch_size
            and batch_action_name_mask.shape[0] == batch_size
            and batch_entities_mask.shape[0] == batch_size
            and batch_slots_mask.shape[0] == batch_size
        )
        assert (
            batch_intent_sentence_3[1]
            == batch_action_name_sentence_3[1]
            == batch_entities_sentence_3[1]
            == batch_slots_sentence_3[1]
        )