def load_context_sensitive_val(token_to_index, condition_to_index):
    processed_val_corpus_path = get_processed_corpus_path(
        CONTEXT_SENSITIVE_VAL_CORPUS_NAME)
    context_sensitive_val_dialogs = load_processed_dialogs_from_json(
        FileTextLinesIterator(processed_val_corpus_path),
        text_field_name='text',
        condition_field_name='condition')
    alternated_context_sensitive_val_dialogs = \
        get_alternated_dialogs_lines(context_sensitive_val_dialogs)
    alternated_context_sensitive_val_lines, alternated_context_sensitive_val_conditions = \
        get_dialog_lines_and_conditions(alternated_context_sensitive_val_dialogs,
                                        text_field_name='text', condition_field_name='condition')
    tokenized_alternated_context_sensitive_val_lines = ProcessedLinesIterator(
        alternated_context_sensitive_val_lines,
        processing_callbacks=[get_tokens_sequence])

    _logger.info(
        'Transform context sensitive validation lines to tensor of indexes')
    x_context_sensitive_val, y_context_sensitive_val, num_context_sensitive_val_dialogs = \
        transform_lines_to_nn_input(tokenized_alternated_context_sensitive_val_lines, token_to_index)
    condition_ids_context_sensitive_val = transform_conditions_to_nn_input(
        alternated_context_sensitive_val_conditions, condition_to_index,
        num_context_sensitive_val_dialogs)
    return Dataset(x=x_context_sensitive_val,
                   y=y_context_sensitive_val,
                   condition_ids=condition_ids_context_sensitive_val)
Exemple #2
0
def _load_train_lines(corpus_name=TRAIN_CORPUS_NAME):
    processed_corpus_path = get_processed_corpus_path(corpus_name)
    dialogs = load_processed_dialogs_from_json(
        FileTextLinesIterator(processed_corpus_path), text_field_name='text', condition_field_name='condition')
    train_lines, _ = get_dialog_lines_and_conditions(
        get_alternated_dialogs_lines(dialogs), text_field_name='text', condition_field_name='condition')
    return train_lines
def _load_train_lines(corpus_name=TRAIN_CORPUS_NAME):
    processed_corpus_path = get_processed_corpus_path(corpus_name)
    dialogs = load_processed_dialogs_from_json(
        FileTextLinesIterator(processed_corpus_path), text_field_name='text', condition_field_name='condition')
    train_lines, _ = get_dialog_lines_and_conditions(
        get_alternated_dialogs_lines(dialogs), text_field_name='text', condition_field_name='condition')
    return train_lines
def load_conditioned_dataset(corpus_name,
                             token_to_index,
                             condition_to_index,
                             subset_size=None):
    processed_corpus_path = get_processed_corpus_path(corpus_name)
    dialogs = load_processed_dialogs_from_json(
        FileTextLinesIterator(processed_corpus_path),
        text_field_name='text',
        condition_field_name='condition')
    if subset_size:
        _logger.info(
            'Slicing dataset to the first {} entries'.format(subset_size))
        dialogs = islice(dialogs, subset_size)
    train_lines, train_conditions = get_dialog_lines_and_conditions(
        get_alternated_dialogs_lines(dialogs),
        text_field_name='text',
        condition_field_name='condition')
    tokenized_alternated_train_lines = ProcessedLinesIterator(
        train_lines, processing_callbacks=[get_tokens_sequence])

    # prepare train set
    x_train, y_train, n_dialogs = transform_lines_to_nn_input(
        tokenized_alternated_train_lines, token_to_index)

    condition_ids_train = transform_conditions_to_nn_input(
        train_conditions, condition_to_index, n_dialogs)
    return Dataset(x=x_train, y=y_train, condition_ids=condition_ids_train)
Exemple #5
0
def load_conditioned_train_set(token_to_index, condition_to_index, train_subset_size=TRAIN_SUBSET_SIZE):
    processed_corpus_path = get_processed_corpus_path(TRAIN_CORPUS_NAME)
    dialogs = load_processed_dialogs_from_json(
        FileTextLinesIterator(processed_corpus_path), text_field_name='text', condition_field_name='condition')
    if train_subset_size:
        dialogs = islice(dialogs, train_subset_size)
    train_lines, train_conditions = get_dialog_lines_and_conditions(
        get_alternated_dialogs_lines(dialogs), text_field_name='text', condition_field_name='condition')
    tokenized_alternated_train_lines = ProcessedLinesIterator(train_lines, processing_callbacks=[get_tokens_sequence])

    # prepare train set
    x_train, y_train, n_dialogs = transform_lines_to_nn_input(tokenized_alternated_train_lines, token_to_index)

    condition_ids_train = transform_conditions_to_nn_input(train_conditions, condition_to_index, n_dialogs)
    return Dataset(x=x_train, y=y_train, condition_ids=condition_ids_train)
Exemple #6
0
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from cakechat.utils.text_processing import get_processed_corpus_path, load_processed_dialogs_from_json, \
    FileTextLinesIterator, get_dialog_lines_and_conditions, ProcessedLinesIterator, get_flatten_dialogs
from cakechat.utils.w2v.model import _get_w2v_model as get_w2v_model
from cakechat.config import TRAIN_CORPUS_NAME, VOCABULARY_MAX_SIZE, WORD_EMBEDDING_DIMENSION, W2V_WINDOW_SIZE, \
    USE_SKIP_GRAM

if __name__ == '__main__':
    processed_corpus_path = get_processed_corpus_path(TRAIN_CORPUS_NAME)

    dialogs = load_processed_dialogs_from_json(
        FileTextLinesIterator(processed_corpus_path),
        text_field_name='text',
        condition_field_name='condition')

    training_dialogs_lines_for_w2v, _ = get_dialog_lines_and_conditions(
        get_flatten_dialogs(dialogs),
        text_field_name='text',
        condition_field_name='condition')

    tokenized_training_lines = ProcessedLinesIterator(
        training_dialogs_lines_for_w2v, processing_callbacks=[str.split])

    get_w2v_model(tokenized_lines=tokenized_training_lines,
                  corpus_name=TRAIN_CORPUS_NAME,
                  voc_size=VOCABULARY_MAX_SIZE,
                  vec_size=WORD_EMBEDDING_DIMENSION,
                  window_size=W2V_WINDOW_SIZE,
                  skip_gram=USE_SKIP_GRAM)