Example #1
0
def train(init_path=None, is_reverse_model=False):
    processed_train_corpus_path = get_processed_corpus_path(TRAIN_CORPUS_NAME)
    processed_val_corpus_path = get_processed_corpus_path(
        CONTEXT_SENSITIVE_VAL_CORPUS_NAME)
    index_to_token_path = get_index_to_token_path(BASE_CORPUS_NAME)
    index_to_condition_path = get_index_to_condition_path(BASE_CORPUS_NAME)

    # check the existence of all necessary files before compiling the model
    _look_for_saved_files(files_paths=[
        processed_train_corpus_path, processed_val_corpus_path,
        index_to_token_path
    ])

    index_to_token = load_index_to_item(index_to_token_path)
    index_to_condition = load_index_to_item(index_to_condition_path)

    w2v_matrix = _get_w2v_embedding_matrix_by_corpus_path(
        processed_train_corpus_path, index_to_token)

    # get nn_model and train it
    nn_model_resolver_factory = S3FileResolver.init_resolver(
        bucket_name=S3_MODELS_BUCKET_NAME, remote_dir=S3_NN_MODEL_REMOTE_DIR)

    nn_model, _ = get_nn_model(index_to_token,
                               index_to_condition,
                               model_init_path=init_path,
                               w2v_matrix=w2v_matrix,
                               resolver_factory=nn_model_resolver_factory,
                               is_reverse_model=is_reverse_model)

    train_model(nn_model)
Example #2
0
def get_w2v_model(fetch_from_s3=False,
                  corpus_name=TRAIN_CORPUS_NAME,
                  voc_size=VOCABULARY_MAX_SIZE,
                  vec_size=WORD_EMBEDDING_DIMENSION,
                  window_size=W2V_WINDOW_SIZE,
                  skip_gram=USE_SKIP_GRAM):
    try:
        model_resolver_factory = \
            S3FileResolver.init_resolver(bucket_name=S3_MODELS_BUCKET_NAME, remote_dir=S3_W2V_REMOTE_DIR) \
            if fetch_from_s3 else None

        w2v_model = _get_w2v_model(
            corpus_name=corpus_name,
            voc_size=voc_size,
            model_resolver_factory=model_resolver_factory,
            vec_size=vec_size,
            window_size=window_size,
            skip_gram=skip_gram)

    except ModelLoaderException:
        raise ModelLoaderException('Word2vec model is absent. Please run `tools/train_w2v.py` to get the model.'
                                   ' WARNING: If you compare different dialog models, be sure that they'
                                   ' use the same w2v model (since each run of the w2v-trainer even with the same'
                                   ' parameters leads to different w2v models)')

    return w2v_model
Example #3
0
def get_w2v_embedding_matrix(tokenized_dialog_lines, index_to_token, add_start_end=False):
    if add_start_end:
        tokenized_dialog_lines = (
            [SPECIAL_TOKENS.START_TOKEN] + line + [SPECIAL_TOKENS.EOS_TOKEN] for line in tokenized_dialog_lines)

    w2v_resolver_factory = S3FileResolver.init_resolver(bucket_name=S3_MODELS_BUCKET_NAME, remote_dir=S3_W2V_REMOTE_DIR)

    w2v_model = get_w2v_model(
        TRAIN_CORPUS_NAME,
        len(index_to_token),
        model_resolver_factory=w2v_resolver_factory,
        tokenized_lines=tokenized_dialog_lines,
        vec_size=WORD_EMBEDDING_DIMENSION,
        window_size=W2V_WINDOW_SIZE,
        skip_gram=USE_SKIP_GRAM)
    w2v_matrix = transform_w2v_model_to_matrix(w2v_model, index_to_token)
    return w2v_matrix
Example #4
0
def _get_index_to_token(fetch_from_s3):
    index_to_token_path = get_index_to_token_path(BASE_CORPUS_NAME)
    if fetch_from_s3:
        tokens_idx_resolver = S3FileResolver(index_to_token_path,
                                             S3_MODELS_BUCKET_NAME,
                                             S3_TOKENS_IDX_REMOTE_DIR)
        if not tokens_idx_resolver.resolve():
            raise Exception(
                'Can\'t get index_to_token because file does not exist at S3')
    else:
        if not os.path.exists(index_to_token_path):
            raise Exception(
                'Can\'t get index_to_token because file does not exist. '
                'Run tools/download_model.py first to get all required files or construct it by yourself.'
            )

    return load_index_to_item(index_to_token_path)
Example #5
0
def _get_index_to_condition(fetch_from_s3):
    index_to_condition_path = get_index_to_condition_path(BASE_CORPUS_NAME)
    if fetch_from_s3:
        index_to_condition_resolver = S3FileResolver(
            index_to_condition_path, S3_MODELS_BUCKET_NAME,
            S3_CONDITIONS_IDX_REMOTE_DIR)
        if not index_to_condition_resolver.resolve():
            raise FileNotFoundException(
                'Can\'t get index_to_condition because file does not exist on S3'
            )
    else:
        if not os.path.exists(index_to_condition_path):
            raise FileNotFoundException(
                'Can\'t get index_to_condition because file does not exist. '
                'Run tools/fetch.py first to get all required files or construct '
                'it yourself.')

    return load_index_to_item(index_to_condition_path)
Example #6
0
def _get_index_to_token(fetch_from_s3):
    index_to_token_path = get_index_to_token_path(BASE_CORPUS_NAME)
    file_name = os.path.basename(index_to_token_path)
    if fetch_from_s3:
        tokens_idx_resolver = S3FileResolver(index_to_token_path,
                                             S3_MODELS_BUCKET_NAME,
                                             S3_TOKENS_IDX_REMOTE_DIR)
        if not tokens_idx_resolver.resolve():
            raise FileNotFoundException(
                'No such file on S3: {}'.format(file_name))
    else:
        if not os.path.exists(index_to_token_path):
            raise FileNotFoundException(
                'No such file: {}'.format(file_name) +
                'Run "python tools/fetch.py" first to get all necessary files.'
            )

    return load_index_to_item(index_to_token_path)
Example #7
0
def get_trained_model(reverse=False, fetch_from_s3=True):
    if fetch_from_s3:
        resolver_factory = S3FileResolver.init_resolver(
            bucket_name=S3_MODELS_BUCKET_NAME, remote_dir=S3_NN_MODEL_REMOTE_DIR)
    else:
        resolver_factory = None

    nn_model, model_exists = get_nn_model(
        _get_index_to_token(fetch_from_s3),
        _get_index_to_condition(fetch_from_s3),
        resolver_factory=resolver_factory,
        is_reverse_model=reverse)

    if not model_exists:
        raise Exception('Can\'t get the model. '
                        'Run tools/download_model.py first to get all required files or train it by yourself.')

    return nn_model
Example #8
0
def get_trained_model(reverse=False, fetch_from_s3=True):
    if fetch_from_s3:
        resolver_factory = S3FileResolver.init_resolver(
            bucket_name=S3_MODELS_BUCKET_NAME,
            remote_dir=S3_NN_MODEL_REMOTE_DIR)
    else:
        resolver_factory = None

    nn_model, model_exists = get_nn_model(
        index_to_token=_get_index_to_token(fetch_from_s3),
        index_to_condition=_get_index_to_condition(fetch_from_s3),
        resolver_factory=resolver_factory,
        is_reverse_model=reverse)
    if not model_exists:
        raise FileNotFoundException(
            'Can\'t get the pre-trained model. Run tools/download_model.py first '
            'to get all required files or train it by yourself.')
    return nn_model
Example #9
0
def train(model_init_path=None,
          is_reverse_model=False,
          train_subset_size=None,
          use_pretrained_w2v=USE_PRETRAINED_W2V_EMBEDDINGS_LAYER,
          train_corpus_name=TRAIN_CORPUS_NAME,
          context_sensitive_val_corpus_name=CONTEXT_SENSITIVE_VAL_CORPUS_NAME,
          base_corpus_name=BASE_CORPUS_NAME,
          s3_models_bucket_name=S3_MODELS_BUCKET_NAME,
          s3_nn_model_remote_dir=S3_NN_MODEL_REMOTE_DIR,
          prediction_mode_for_tests=PREDICTION_MODE_FOR_TESTS):
    processed_train_corpus_path = get_processed_corpus_path(train_corpus_name)
    processed_val_corpus_path = get_processed_corpus_path(
        context_sensitive_val_corpus_name)
    index_to_token_path = get_index_to_token_path(base_corpus_name)
    index_to_condition_path = get_index_to_condition_path(base_corpus_name)

    # check the existence of all necessary files before compiling the model
    _look_for_saved_files(files_paths=[
        processed_train_corpus_path, processed_val_corpus_path,
        index_to_token_path
    ])

    # load essentials for building model and training
    index_to_token = load_index_to_item(index_to_token_path)
    index_to_condition = load_index_to_item(index_to_condition_path)
    token_to_index = {v: k for k, v in index_to_token.items()}
    condition_to_index = {v: k for k, v in index_to_condition.items()}

    training_data_param = ModelParam(value=get_training_dataset(
        train_corpus_name, token_to_index, condition_to_index,
        is_reverse_model, train_subset_size),
                                     id=train_corpus_name)

    val_sets_names = get_validation_sets_names()
    validation_data_param = ModelParam(
        value=get_validation_dataset_name_to_data(val_sets_names,
                                                  token_to_index,
                                                  condition_to_index,
                                                  is_reverse_model),
        id=get_validation_data_id(val_sets_names))

    w2v_model_param = ModelParam(value=get_w2v_model(), id=get_w2v_model_id()) if use_pretrained_w2v \
        else ModelParam(value=None, id=None)

    model_resolver_factory = S3FileResolver.init_resolver(
        bucket_name=s3_models_bucket_name, remote_dir=s3_nn_model_remote_dir)

    reverse_model = get_reverse_model(
        prediction_mode_for_tests) if not is_reverse_model else None

    # build CakeChatModel
    cakechat_model = CakeChatModel(index_to_token,
                                   index_to_condition,
                                   training_data_param=training_data_param,
                                   validation_data_param=validation_data_param,
                                   w2v_model_param=w2v_model_param,
                                   model_init_path=model_init_path,
                                   model_resolver=model_resolver_factory,
                                   is_reverse_model=is_reverse_model,
                                   reverse_model=reverse_model,
                                   horovod=hvd)

    # train model
    cakechat_model.train_model()