示例#1
0
def _get_datasets(nn_model, is_reverse_model):
    train = load_conditioned_train_set(nn_model.token_to_index,
                                       nn_model.condition_to_index)
    context_free_val = load_context_free_val(nn_model.token_to_index)

    context_sensitive_val = load_context_sensitive_val(
        nn_model.token_to_index, nn_model.condition_to_index)
    if is_reverse_model:
        service_tokens = ServiceTokensIDs(nn_model.token_to_index)
        train = reverse_nn_input(train, service_tokens)
        context_free_val = reverse_nn_input(context_free_val, service_tokens)
        context_sensitive_val = reverse_nn_input(context_sensitive_val,
                                                 service_tokens)

    # Train subset of same size as a context-free val for metrics calculation
    train_subset = generate_subset(train, VAL_SUBSET_SIZE)

    # Context-sensitive val subset of same size as a context-free val for metrics calculation
    context_sensitive_val_subset = generate_subset(context_sensitive_val,
                                                   VAL_SUBSET_SIZE)

    datasets_collection = DatasetsCollection(
        train=train,
        train_subset=train_subset,
        context_free_val=context_free_val,
        context_sensitive_val=context_sensitive_val,
        context_sensitive_val_subset=context_sensitive_val_subset)

    return datasets_collection
def get_validation_dataset_name_to_data(validation_sets_names, token_to_index,
                                        condition_to_index, is_reverse_model):
    _logger.info('Loading validations sets...')
    factory = {
        CONTEXT_FREE_VAL_CORPUS_NAME:
        lambda: load_context_free_val(token_to_index),
        CONTEXT_SENSITIVE_VAL_CORPUS_NAME:
        lambda: load_context_sensitive_val(token_to_index, condition_to_index)
    }
    dataset_name_to_data = {
        val_set_name: factory[val_set_name]()
        for val_set_name in validation_sets_names
    }
    _logger.info('Done loading validations sets')

    if is_reverse_model:
        _logger.info('Reversing validations sets...')
        service_tokens = ServiceTokensIDs(token_to_index)
        dataset_name_to_data = {
            val_set_name: reverse_nn_input(val_set, service_tokens)
            for val_set_name, val_set in dataset_name_to_data.items()
        }
        _logger.info('Done reversing validations sets')

    return dataset_name_to_data
示例#3
0
 def _compute_likelihood_of_input_given_output(self, context, candidates,
                                               condition_id):
     # Repeat to get same context for each candidate
     repeated_context = np.repeat(context, candidates.shape[0], axis=0)
     reversed_dataset = reverse_nn_input(
         Dataset(x=repeated_context, y=candidates, condition_ids=None),
         self._service_tokens_ids)
     return get_sequence_score(self._reverse_model, reversed_dataset.x,
                               reversed_dataset.y, condition_id)
def get_training_dataset(train_corpus_name,
                         token_to_index,
                         condition_to_index,
                         is_reverse_model,
                         train_subset_size=None):
    _logger.info('Loading training dataset...')
    train_dataset = load_conditioned_dataset(train_corpus_name, token_to_index,
                                             condition_to_index,
                                             train_subset_size)

    if is_reverse_model:
        _logger.info('Reversing training dataset...')
        service_tokens = ServiceTokensIDs(token_to_index)
        train_dataset = reverse_nn_input(train_dataset, service_tokens)

    return train_dataset
示例#5
0
def train_model(nn_model, is_reverse_model=False):
    """
    Main function fo training. Refactoring anticipated.
    """
    validation_prediction_mode = PREDICTION_MODES.sampling if is_reverse_model else PREDICTION_MODE_FOR_TESTS

    train = load_conditioned_train_set(nn_model.token_to_index,
                                       nn_model.condition_to_index)

    context_free_val = load_context_free_val(nn_model.token_to_index)

    context_sensitive_val = load_context_sensitive_val(
        nn_model.token_to_index, nn_model.condition_to_index)
    if is_reverse_model:
        service_tokens = ServiceTokensIDs(nn_model.token_to_index)
        train = reverse_nn_input(train, service_tokens)
        context_free_val = reverse_nn_input(context_free_val, service_tokens)
        context_sensitive_val = reverse_nn_input(context_sensitive_val,
                                                 service_tokens)

    # Train subset of same size as a context-free val for metrics calculation
    train_subset = generate_subset(train, VAL_SUBSET_SIZE)

    # Context-sensitive val subset of same size as a context-free val for metrics calculation
    context_sensitive_val_subset = generate_subset(context_sensitive_val,
                                                   VAL_SUBSET_SIZE)

    _logger.info('Finished preprocessing! Start training')

    batch_id = 0
    avg_loss = 0
    total_training_time = 0
    best_val_perplexities = (float('inf'), float('inf'))
    batches_num = (train.x.shape[0] - 1) / BATCH_SIZE + 1
    start_time = time.time()
    cur_val_metrics = None

    try:
        for epoches_counter in xrange(1, EPOCHES_NUM + 1):
            _logger.info(
                'Starting epoch #%d; time = %0.2f s(training of it = %0.2f s)'
                % (epoches_counter, time.time() - start_time,
                   total_training_time))

            for train_batch in get_training_batch(
                [train.x, train.y, train.condition_ids],
                    BATCH_SIZE,
                    random_permute=SHUFFLE_TRAINING_BATCHES):
                x_train_batch, y_train_batch, condition_ids_train_batch = train_batch

                batch_id += 1
                prev_time = time.time()
                loss = nn_model.train(x_train_batch, y_train_batch,
                                      condition_ids_train_batch)

                cur_time = time.time()
                total_training_time += cur_time - prev_time
                total_time = cur_time - start_time
                avg_loss = LOG_LOSS_DECAY * avg_loss + (
                    1 - LOG_LOSS_DECAY) * loss if batch_id > 1 else loss

                progress = 100 * float(batch_id) / batches_num
                avr_time_per_sample = total_time / batch_id
                expected_time_per_epoch = avr_time_per_sample * batches_num

                # use print here for better readability
                _logger.info('batch %s / %s (%d%%) \t'
                             'loss: %.2f \t '
                             'time: epoch %.1f h | '
                             'total %0.1f h | '
                             'train %0.1f h (%.1f%%)' %
                             (batch_id, batches_num, progress, avg_loss,
                              expected_time_per_epoch / 3600,
                              total_time / 3600, total_training_time / 3600,
                              100 * total_training_time / total_time))

                if batch_id % SCREEN_LOG_FREQUENCY_PER_BATCHES == 0:
                    _log_sample_answers(
                        context_free_val.x[:SCREEN_LOG_NUM_TEST_LINES],
                        nn_model, validation_prediction_mode, is_reverse_model)

                if batch_id % LOG_FREQUENCY_PER_BATCHES == 0:
                    _calc_and_save_train_metrics(nn_model, train_subset,
                                                 avg_loss)

                    val_metrics = _calc_and_save_val_metrics(
                        nn_model,
                        context_sensitive_val_subset,
                        context_free_val,
                        prediction_mode=validation_prediction_mode)
                    _save_val_results(
                        nn_model,
                        context_free_val.x,
                        context_sensitive_val_subset.x,
                        val_metrics,
                        train_info=(start_time, batch_id, batches_num),
                        prediction_mode=validation_prediction_mode)
                    cur_val_metrics = val_metrics

                    best_val_perplexities = \
                        _update_saved_nn_model(nn_model,
                                               (val_metrics['context_free_perplexity'],
                                                val_metrics['context_sensitive_perplexity']),
                                               best_val_perplexities,
                                               is_reverse_model=is_reverse_model)

    except (KeyboardInterrupt, SystemExit):
        _logger.info('Training cycle is stopped manually')
        _save_model(nn_model, get_model_full_path(is_reverse_model) + '_final')
        _save_val_results(nn_model,
                          context_free_val.x,
                          context_sensitive_val_subset.x,
                          cur_val_metrics,
                          train_info=(start_time, batch_id, batches_num),
                          suffix='_final',
                          prediction_mode=validation_prediction_mode)
示例#6
0
def train_model(nn_model, is_reverse_model=False):
    """
    Main function fo training. Refactoring anticipated.
    """
    validation_prediction_mode = PREDICTION_MODES.sampling if is_reverse_model else PREDICTION_MODE_FOR_TESTS

    train = load_conditioned_train_set(nn_model.token_to_index, nn_model.condition_to_index)

    context_free_val = load_context_free_val(nn_model.token_to_index)

    context_sensitive_val = load_context_sensitive_val(nn_model.token_to_index, nn_model.condition_to_index)
    if is_reverse_model:
        service_tokens = ServiceTokensIDs(nn_model.token_to_index)
        train = reverse_nn_input(train, service_tokens)
        context_free_val = reverse_nn_input(context_free_val, service_tokens)
        context_sensitive_val = reverse_nn_input(context_sensitive_val, service_tokens)

    # Train subset of same size as a context-free val for metrics calculation
    train_subset = generate_subset(train, VAL_SUBSET_SIZE)

    # Context-sensitive val subset of same size as a context-free val for metrics calculation
    context_sensitive_val_subset = generate_subset(context_sensitive_val, VAL_SUBSET_SIZE)

    _logger.info('Finished preprocessing! Start training')

    batch_id = 0
    avg_loss = 0
    total_training_time = 0
    best_val_perplexities = (float('inf'), float('inf'))
    batches_num = (train.x.shape[0] - 1) / BATCH_SIZE + 1
    start_time = time.time()
    cur_val_metrics = None

    try:
        for epoches_counter in xrange(1, EPOCHES_NUM + 1):
            _logger.info('Starting epoch #%d; time = %0.2f s(training of it = %0.2f s)' %
                         (epoches_counter, time.time() - start_time, total_training_time))

            for train_batch in get_training_batch(
                [train.x, train.y, train.condition_ids], BATCH_SIZE, random_permute=SHUFFLE_TRAINING_BATCHES):
                x_train_batch, y_train_batch, condition_ids_train_batch = train_batch

                batch_id += 1
                prev_time = time.time()
                loss = nn_model.train(x_train_batch, y_train_batch, condition_ids_train_batch)

                cur_time = time.time()
                total_training_time += cur_time - prev_time
                total_time = cur_time - start_time
                avg_loss = LOG_LOSS_DECAY * avg_loss + (1 - LOG_LOSS_DECAY) * loss if batch_id > 1 else loss

                progress = 100 * float(batch_id) / batches_num
                avr_time_per_sample = total_time / batch_id
                expected_time_per_epoch = avr_time_per_sample * batches_num

                # use print here for better readability
                _logger.info('batch %s / %s (%d%%) \t'
                             'loss: %.2f \t '
                             'time: epoch %.1f h | '
                             'total %0.1f h | '
                             'train %0.1f h (%.1f%%)' %
                             (batch_id, batches_num, progress, avg_loss, expected_time_per_epoch / 3600,
                              total_time / 3600, total_training_time / 3600, 100 * total_training_time / total_time))

                if batch_id % SCREEN_LOG_FREQUENCY_PER_BATCHES == 0:
                    _log_sample_answers(context_free_val.x[:SCREEN_LOG_NUM_TEST_LINES], nn_model,
                                        validation_prediction_mode, is_reverse_model)

                if batch_id % LOG_FREQUENCY_PER_BATCHES == 0:
                    _calc_and_save_train_metrics(nn_model, train_subset, avg_loss)

                    val_metrics = _calc_and_save_val_metrics(
                        nn_model,
                        context_sensitive_val_subset,
                        context_free_val,
                        prediction_mode=validation_prediction_mode)
                    _save_val_results(
                        nn_model,
                        context_free_val.x,
                        context_sensitive_val_subset.x,
                        val_metrics,
                        train_info=(start_time, batch_id, batches_num),
                        prediction_mode=validation_prediction_mode)
                    cur_val_metrics = val_metrics

                    best_val_perplexities = \
                        _update_saved_nn_model(nn_model,
                                               (val_metrics['context_free_perplexity'],
                                                val_metrics['context_sensitive_perplexity']),
                                               best_val_perplexities,
                                               is_reverse_model=is_reverse_model)

    except (KeyboardInterrupt, SystemExit):
        _logger.info('Training cycle is stopped manually')
        _save_model(nn_model, get_model_full_path(is_reverse_model) + '_final')
        _save_val_results(
            nn_model,
            context_free_val.x,
            context_sensitive_val_subset.x,
            cur_val_metrics,
            train_info=(start_time, batch_id, batches_num),
            suffix='_final',
            prediction_mode=validation_prediction_mode)
示例#7
0
 def _compute_likelihood_of_input_given_output(self, context, candidates, condition_id):
     # Repeat to get same context for each candidate
     repeated_context = np.repeat(context, candidates.shape[0], axis=0)
     reversed_dataset = reverse_nn_input(
         Dataset(x=repeated_context, y=candidates, condition_ids=None), self._service_tokens_ids)
     return get_sequence_score(self._reverse_model, reversed_dataset.x, reversed_dataset.y, condition_id)