Пример #1
0
def initialize_for_running(output_dir, tf_manager, variable_files) -> None:
    """Restore either default variables of from configuration.

    Arguments:
       output_dir: Training output directory.
       tf_manager: TensorFlow manager.
       variable_files: Files with variables to be restored or None if the
           default variables should be used.
    """
    # pylint: disable=no-member
    log_print("")

    if variable_files is None:
        default_varfile = default_variable_file(output_dir)

        log("Default variable file '{}' will be used for loading variables.".
            format(default_varfile))

        variable_files = [default_varfile]

    for vfile in variable_files:
        if not os.path.exists("{}.index".format(vfile)):
            log("Index file for var prefix {} does not exist".format(vfile),
                color="red")
            exit(1)

    tf_manager.restore(variable_files)

    log_print("")
Пример #2
0
def print_final_evaluation(name: str, eval_result: Evaluation) -> None:
    """Print final evaluation from a test dataset."""
    line_len = 22
    log("Evaluating model on \"{}\"".format(name))

    for name, value in eval_result.items():
        space = "".join([" " for _ in range(line_len - len(name))])
        log("... {}:{} {:.4f}".format(name, space, value))

    log_print("")
Пример #3
0
def print_final_evaluation(eval_result: Evaluation, name: str = None) -> None:
    """Print final evaluation from a test dataset."""
    line_len = 22

    if name is not None:
        log("Model evaluated on '{}'".format(name))

    for eval_name, value in eval_result.items():
        space = "".join([" " for _ in range(line_len - len(eval_name))])
        log("... {}:{} {:.4g}".format(eval_name, space, value))

    log_print("")
Пример #4
0
def print_dataset_evaluation(name, evaluation):
    line_len = 22
    log("Evaluating model on \"{}\"".format(name))

    log("... optimization loss:      {:.4f}".format(evaluation['opt_loss']))
    log("... runtime loss:           {:.4f}".format(evaluation['opt_loss']))

    for func in evaluation:
        if hasattr(func, '__call__'):
            name = func.__name__
            space = "".join([" " for _ in range(line_len - len(name))])
            log("... {}:{} {:.4f}".format(name, space, evaluation[func]))

    log_print("")
Пример #5
0
def print_dataset_evaluation(name, evaluation):
    line_len = 22
    log("Evaluating model on \"{}\"".format(name))

    log("... optimization loss:      {:.4f}".format(evaluation['opt_loss']))
    log("... runtime loss:           {:.4f}".format(evaluation['opt_loss']))

    for func in evaluation:
        if hasattr(func, '__call__'):
            name = func.__name__
            space = "".join([" " for _ in range(line_len - len(name))])
            log("... {}:{} {:.4f}".format(name, space, evaluation[func]))

    log_print("")
Пример #6
0
def _print_examples(dataset: Dataset,
                    outputs: Dict[str, List[Any]],
                    num_examples=15) -> None:
    """Print examples of the model output."""
    log_print(colored("Examples:", attrs=['bold']))

    # for further indexing we need to make sure, all relevant
    # dataset series are lists
    target_series = {
        series_id: list(dataset.get_series(series_id))
        for series_id in outputs.keys() if dataset.has_series(series_id)
    }
    source_series = {
        series_id: list(dataset.get_series(series_id))
        for series_id in dataset.series_ids if series_id not in outputs
    }

    for i in range(min(len(dataset), num_examples)):
        log_print(
            colored("  [{}]".format(i + 1), color='magenta', attrs=['bold']))

        def print_line(prefix, color, content):
            colored_prefix = colored(prefix, color=color)
            formated = _data_item_to_str(content)
            log_print("  {}: {}".format(colored_prefix, formated))

        for series_id, data in sorted(source_series.items(),
                                      key=lambda x: x[0]):
            print_line(series_id, 'yellow', data[i])

        for series_id, data in sorted(outputs.items(), key=lambda x: x[0]):
            model_output = data[i]
            print_line(series_id, 'magenta', model_output)

            if series_id in target_series:
                desired_output = target_series[series_id][i]
                print_line(series_id + " (ref)", "red", desired_output)
        log_print("")
Пример #7
0
def training_loop(sess,
                  saver,
                  epochs,
                  trainer,
                  all_coders,
                  decoder,
                  batch_size,
                  train_dataset,
                  val_dataset,
                  log_directory,
                  evaluators,
                  runner,
                  test_datasets=[],
                  save_n_best_vars=1,
                  link_best_vars="/tmp/variables.data.best",
                  vars_prefix="/tmp/variables.data",
                  initial_variables=None,
                  logging_period=20,
                  validation_period=500,
                  postprocess=None,
                  minimize_metric=False):
    """
    Performs the training loop for given graph and data.

    Args:

        sess: TF Session.

        saver: TF saver object.

        epochs: Number of epochs for which the algoritm will learn.

        trainer: The trainer object containg the TensorFlow code for computing
            the loss and optimization operation.

        decoder: The decoder object.

        train_dataset:

        val_dataset:

        postprocess: Function that takes the output sentence as produced by the
            decoder and transforms into tokenized sentence.

        log_directory: Directory where the TensordBoard log will be generated.
            If None, nothing will be done.

        evaluators: List of evaluators. The last evaluator
            is used as the main. Each function accepts list of decoded sequences
            and list of reference sequences and returns a float.

        use_copynet: Flag whether the copying mechanism is used.

        use_beamsearch:

        initial_variables: Either None or file where the variables are stored.
            Training then starts from the point the loaded values.

    """

    if not postprocess:
        postprocess = lambda x: x

    evaluation_labels = [f.name for f in evaluators]
    step = 0
    seen_instances = 0

    saver = tf.train.Saver()

    if initial_variables:
        saver.restore(sess, initial_variables)

    if save_n_best_vars < 1:
        raise Exception('save_n_best_vars must be greater than zero')

    if save_n_best_vars == 1:
        variables_files = [vars_prefix]
    elif save_n_best_vars > 1:
        variables_files = [
            '{}.{}'.format(vars_prefix, i) for i in range(save_n_best_vars)
        ]

    if minimize_metric:
        saved_scores = [np.inf for _ in range(save_n_best_vars)]
        best_score = np.inf
    else:
        saved_scores = [-np.inf for _ in range(save_n_best_vars)]
        best_score = -np.inf

    saver.save(sess, variables_files[0])

    if os.path.islink(link_best_vars):
        # if overwriting output dir
        os.unlink(link_best_vars)
    os.symlink(os.path.basename(variables_files[0]), link_best_vars)

    if log_directory:
        log("Initializing TensorBoard summary writer.")
        tb_writer = tf.train.SummaryWriter(log_directory, sess.graph)
        log("TesorBoard writer initialized.")

    best_score_epoch = 0
    best_score_batch_no = 0

    val_raw_tgt_sentences = val_dataset.get_series(decoder.data_id)
    val_tgt_sentences = postprocess(val_raw_tgt_sentences)

    log("Starting training")
    try:
        for i in range(epochs):
            log_print("")
            log("Epoch {} starts".format(i + 1), color='red')

            train_dataset.shuffle()
            train_batched_datasets = train_dataset.batch_dataset(batch_size)

            for batch_n, batch_dataset in enumerate(train_batched_datasets):

                batch_feed_dict = feed_dicts(batch_dataset,
                                             all_coders,
                                             train=True)
                step += 1
                batch_sentences = batch_dataset.get_series(decoder.data_id)
                seen_instances += len(batch_sentences)
                if step % logging_period == logging_period - 1:
                    summary_str = trainer.run(sess,
                                              batch_feed_dict,
                                              summary=True)
                    _, _, train_evaluation = \
                            run_on_dataset(sess, runner, all_coders, decoder, batch_dataset,
                                           evaluators, postprocess, write_out=False)

                    process_evaluation(evaluators,
                                       tb_writer,
                                       train_evaluation,
                                       seen_instances,
                                       summary_str,
                                       None,
                                       train=True)
                else:
                    trainer.run(sess, batch_feed_dict, summary=False)

                if step % validation_period == validation_period - 1:
                    decoded_val_sentences, decoded_raw_val_sentences, \
                        val_evaluation = run_on_dataset(
                            sess, runner, all_coders, decoder, val_dataset,
                            evaluators, postprocess, write_out=False)

                    this_score = val_evaluation[evaluators[-1].name]

                    def is_better(score1, score2, minimize):
                        if minimize:
                            return score1 < score2
                        else:
                            return score1 > score2

                    def argworst(scores, minimize):
                        if minimize:
                            return np.argmax(scores)
                        else:
                            return np.argmin(scores)

                    if is_better(this_score, best_score, minimize_metric):
                        best_score = this_score
                        best_score_epoch = i + 1
                        best_score_batch_no = batch_n

                    worst_index = argworst(saved_scores, minimize_metric)
                    worst_score = saved_scores[worst_index]

                    if is_better(this_score, worst_score, minimize_metric):
                        # we need to save this score instead the worst score
                        worst_var_file = variables_files[worst_index]
                        saver.save(sess, worst_var_file)
                        saved_scores[worst_index] = this_score
                        log("Variable file saved in {}".format(worst_var_file))

                        # update symlink
                        if best_score == this_score:
                            os.unlink(link_best_vars)
                            os.symlink(os.path.basename(worst_var_file),
                                       link_best_vars)

                        log("Best scores saved so far: {}".format(
                            saved_scores))

                    log("Validation (epoch {}, batch number {}):".format(
                        i + 1, batch_n),
                        color='blue')

                    process_evaluation(evaluators,
                                       tb_writer,
                                       val_evaluation,
                                       seen_instances,
                                       summary_str,
                                       None,
                                       train=False)

                    if this_score == best_score:
                        best_score_str = colored("{:.2f}".format(best_score),
                                                 attrs=['bold'])
                    else:
                        best_score_str = "{:.2f}".format(best_score)

                    log("best {} on validation: {} (in epoch {}, "
                        "after batch number {})".format(
                            evaluation_labels[-1], best_score_str,
                            best_score_epoch, best_score_batch_no),
                        color='blue')

                    log_print("")
                    log_print("Examples:")
                    for sent, sent_raw, ref_sent, ref_sent_raw in zip(
                            decoded_val_sentences[:15],
                            decoded_raw_val_sentences, val_tgt_sentences,
                            val_raw_tgt_sentences):

                        if isinstance(sent, list):
                            log_print("      raw: {}".format(
                                " ".join(sent_raw)))
                            log_print("      out: {}".format(" ".join(sent)))
                        else:
                            # TODO does this code ever execute?
                            log_print(sent_raw)
                            log_print(sent)

                        log_print(
                            colored(" raw ref.: {}".format(
                                " ".join(ref_sent_raw)),
                                    color="magenta"))
                        log_print(
                            colored("     ref.: {}".format(" ".join(ref_sent)),
                                    color="magenta"))

                    log_print("")

    except KeyboardInterrupt:
        log("Training interrupted by user.")

    if os.path.islink(link_best_vars):
        saver.restore(sess, link_best_vars)

    log("Training finished. Maximum {} on validation data: {:.2f}, epoch {}".
        format(evaluation_labels[-1], best_score, best_score_epoch))

    for dataset in test_datasets:
        _, _, evaluation = run_on_dataset(sess,
                                          runner,
                                          all_coders,
                                          decoder,
                                          dataset,
                                          evaluators,
                                          postprocess,
                                          write_out=True)
        if evaluation:
            print_dataset_evaluation(dataset.name, evaluation)

    log("Finished.")
Пример #8
0
 def print_line(prefix, color, content):
     colored_prefix = colored(prefix, color=color)
     formated = _data_item_to_str(content)
     log_print("  {}: {}".format(colored_prefix, formated))
Пример #9
0
def training_loop(tf_manager: TensorFlowManager,
                  epochs: int,
                  trainer: BaseRunner, # TODO better annotate
                  batch_size: int,
                  train_dataset: Dataset,
                  val_dataset: Dataset,
                  log_directory: str,
                  evaluators: EvalConfiguration,
                  runners: List[BaseRunner],
                  test_datasets: Optional[List[Dataset]]=None,
                  save_n_best_vars: int=1,
                  link_best_vars="/tmp/variables.data.best",
                  vars_prefix="/tmp/variables.data",
                  logging_period: int=20,
                  validation_period: int=500,
                  runners_batch_size: Optional[int]=None,
                  postprocess: Callable=None,
                  minimize_metric: bool=False):

    # TODO finish the list
    """
    Performs the training loop for given graph and data.

    Args:
        tf_manager: TensorFlowManager with initialized sessions.
        epochs: Number of epochs for which the algoritm will learn.
        trainer: The trainer object containg the TensorFlow code for computing
            the loss and optimization operation.
        train_dataset:
        val_dataset:
        postprocess: Function that takes the output sentence as produced by the
            decoder and transforms into tokenized sentence.
        log_directory: Directory where the TensordBoard log will be generated.
            If None, nothing will be done.
        evaluators: List of evaluators. The last evaluator is used as the main.
            An evaluator is a tuple of the name of the generated series, the
            name of the dataset series the generated one is evaluated with and
            the evaluation function. If only one series names is provided, it
            means the generated and dataset series have the same name.
    """

    evaluators = [(e[0], e[0], e[1]) if len(e) == 2 else e
                  for e in evaluators]

    main_metric = "{}/{}".format(evaluators[-1][0], evaluators[-1][-1].name)
    step = 0
    seen_instances = 0

    if save_n_best_vars < 1:
        raise Exception('save_n_best_vars must be greater than zero')

    if save_n_best_vars == 1:
        variables_files = [vars_prefix]
    elif save_n_best_vars > 1:
        variables_files = ['{}.{}'.format(vars_prefix, i)
                           for i in range(save_n_best_vars)]

    if minimize_metric:
        saved_scores = [np.inf for _ in range(save_n_best_vars)]
        best_score = np.inf
    else:
        saved_scores = [-np.inf for _ in range(save_n_best_vars)]
        best_score = -np.inf

    tf_manager.save(variables_files[0])

    if os.path.islink(link_best_vars):
        # if overwriting output dir
        os.unlink(link_best_vars)
    os.symlink(os.path.basename(variables_files[0]), link_best_vars)

    if log_directory:
        log("Initializing TensorBoard summary writer.")
        tb_writer = tf.train.SummaryWriter(log_directory,
                                           tf_manager.sessions[0].graph)
        log("TesorBoard writer initialized.")

    best_score_epoch = 0
    best_score_batch_no = 0

    log("Starting training")
    try:
        for i in range(epochs):
            log_print("")
            log("Epoch {} starts".format(i + 1), color='red')

            train_dataset.shuffle()
            train_batched_datasets = train_dataset.batch_dataset(batch_size)

            for batch_n, batch_dataset in enumerate(train_batched_datasets):

                step += 1
                seen_instances += len(batch_dataset)
                if step % logging_period == logging_period - 1:
                    trainer_result = tf_manager.execute(
                        batch_dataset, [trainer], train=True,
                        summaries=True)
                    train_results, train_outputs = run_on_dataset(
                        tf_manager, runners, batch_dataset,
                        postprocess, write_out=False)
                    train_evaluation = evaluation(
                        evaluators, batch_dataset, runners,
                        train_results, train_outputs)

                    _log_continuons_evaluation(tb_writer, main_metric,
                                               train_evaluation,
                                               seen_instances,
                                               trainer_result,
                                               train=True)
                else:
                    tf_manager.execute(batch_dataset, [trainer],
                                       train=True, summaries=False)

                if step % validation_period == validation_period - 1:
                    val_results, val_outputs = run_on_dataset(
                        tf_manager, runners, val_dataset,
                        postprocess, write_out=False,
                        batch_size=runners_batch_size)
                    val_evaluation = evaluation(
                        evaluators, val_dataset, runners, val_results,
                        val_outputs)

                    this_score = val_evaluation[main_metric]

                    def is_better(score1, score2, minimize):
                        if minimize:
                            return score1 < score2
                        else:
                            return score1 > score2

                    def argworst(scores, minimize):
                        if minimize:
                            return np.argmax(scores)
                        else:
                            return np.argmin(scores)

                    if is_better(this_score, best_score, minimize_metric):
                        best_score = this_score
                        best_score_epoch = i + 1
                        best_score_batch_no = batch_n

                    worst_index = argworst(saved_scores, minimize_metric)
                    worst_score = saved_scores[worst_index]

                    if is_better(this_score, worst_score, minimize_metric):
                        # we need to save this score instead the worst score
                        worst_var_file = variables_files[worst_index]
                        tf_manager.save(worst_var_file)
                        saved_scores[worst_index] = this_score
                        log("Variable file saved in {}".format(worst_var_file))

                        # update symlink
                        if best_score == this_score:
                            os.unlink(link_best_vars)
                            os.symlink(os.path.basename(worst_var_file),
                                       link_best_vars)

                        log("Best scores saved so far: {}".format(saved_scores))

                    log("Validation (epoch {}, batch number {}):"
                        .format(i + 1, batch_n), color='blue')

                    _log_continuons_evaluation(tb_writer, main_metric,
                                               val_evaluation, seen_instances,
                                               val_results, train=False)

                    if this_score == best_score:
                        best_score_str = colored("{:.2f}".format(best_score),
                                                 attrs=['bold'])
                    else:
                        best_score_str = "{:.2f}".format(best_score)

                    log("best {} on validation: {} (in epoch {}, "
                        "after batch number {})"
                        .format(main_metric, best_score_str,
                                best_score_epoch, best_score_batch_no),
                        color='blue')

                    log_print("")
                    _print_examples(val_dataset, val_outputs)

    except KeyboardInterrupt:
        log("Training interrupted by user.")


    log("Training finished. Maximum {} on validation data: {:.2f}, epoch {}"
        .format(main_metric, best_score, best_score_epoch))

    if test_datasets and os.path.islink(link_best_vars):
        tf_manager.restore(link_best_vars)

    for dataset in test_datasets:
        test_results, test_outputs = run_on_dataset(
            tf_manager, runners, dataset, postprocess,
            write_out=True, batch_size=runners_batch_size)
        eval_result = evaluation(evaluators, dataset, runners,
                                 test_results, test_outputs)
        print_final_evaluation(dataset.name, eval_result)

    log("Finished.")
Пример #10
0
def _print_examples(dataset: Dataset,
                    outputs: Dict[str, List[Any]],
                    val_preview_input_series: Optional[List[str]] = None,
                    val_preview_output_series: Optional[List[str]] = None,
                    num_examples=15) -> None:
    """Print examples of the model output.

    Arguments:
        dataset: The dataset from which to take examples
        outputs: A mapping from the output series ID to the list of its
            contents
        val_preview_input_series: An optional list of input series to include
            in the preview. An input series is a data series that is present in
            the dataset. It can be either a target series (one that is also
            present in the outputs, i.e. reference), or a source series (one
            that is not among the outputs). In the validation preview, source
            input series and preprocessed target series are yellow and target
            (reference) series are red. If None, all series are written.
        val_preview_output_series: An optional list of output series to include
            in the preview. An output series is a data series that is present
            among the outputs. In the preview, magenta is used as the font
            color for output series
    """
    log_print(colored("Examples:", attrs=["bold"]))

    source_series_names = [s for s in dataset.series_ids if s not in outputs]
    target_series_names = [s for s in dataset.series_ids if s in outputs]
    output_series_names = list(outputs.keys())

    assert outputs

    if val_preview_input_series is not None:
        target_series_names = [
            s for s in target_series_names if s in val_preview_input_series
        ]
        source_series_names = [
            s for s in source_series_names if s in val_preview_input_series
        ]

    if val_preview_output_series is not None:
        output_series_names = [
            s for s in output_series_names if s in val_preview_output_series
        ]

    # for further indexing we need to make sure, all relevant
    # dataset series are lists
    target_series = {
        series_id: list(dataset.get_series(series_id))
        for series_id in target_series_names
    }
    source_series = {
        series_id: list(dataset.get_series(series_id))
        for series_id in source_series_names
    }

    if not isinstance(dataset, LazyDataset):
        num_examples = min(len(dataset), num_examples)

    for i in range(num_examples):
        log_print(
            colored("  [{}]".format(i + 1), color="magenta", attrs=["bold"]))

        def print_line(prefix, color, content):
            colored_prefix = colored(prefix, color=color)
            formatted = _data_item_to_str(content)
            log_print("  {}: {}".format(colored_prefix, formatted))

        # Input source series = yellow
        for series_id, data in sorted(source_series.items(),
                                      key=lambda x: x[0]):
            print_line(series_id, "yellow", data[i])

        # Output series = magenta
        for series_id in sorted(output_series_names):
            data = list(outputs[series_id])
            model_output = data[i]
            print_line(series_id, "magenta", model_output)

        # Input target series (a.k.a. references) = red
        for series_id in sorted(target_series_names):
            data = outputs[series_id]
            desired_output = target_series[series_id][i]
            print_line(series_id + " (ref)", "red", desired_output)

        log_print("")
Пример #11
0
def training_loop(
        tf_manager: TensorFlowManager,
        epochs: int,
        trainer: GenericTrainer,  # TODO better annotate
        batch_size: int,
        log_directory: str,
        evaluators: EvalConfiguration,
        runners: List[BaseRunner],
        train_dataset: Dataset,
        val_dataset: Union[Dataset, List[Dataset]],
        test_datasets: Optional[List[Dataset]] = None,
        logging_period: Union[str, int] = 20,
        validation_period: Union[str, int] = 500,
        val_preview_input_series: Optional[List[str]] = None,
        val_preview_output_series: Optional[List[str]] = None,
        val_preview_num_examples: int = 15,
        train_start_offset: int = 0,
        runners_batch_size: Optional[int] = None,
        initial_variables: Optional[Union[str, List[str]]] = None,
        postprocess: Postprocess = None) -> None:
    """Execute the training loop for given graph and data.

    Args:
        tf_manager: TensorFlowManager with initialized sessions.
        epochs: Number of epochs for which the algoritm will learn.
        trainer: The trainer object containg the TensorFlow code for computing
            the loss and optimization operation.
        batch_size: number of examples in one mini-batch
        log_directory: Directory where the TensordBoard log will be generated.
            If None, nothing will be done.
        evaluators: List of evaluators. The last evaluator is used as the main.
            An evaluator is a tuple of the name of the generated
            series, the name of the dataset series the generated one is
            evaluated with and the evaluation function. If only one
            series names is provided, it means the generated and
            dataset series have the same name.
        runners: List of runners for logging and evaluation runs
        train_dataset: Dataset used for training
        val_dataset: used for validation. Can be Dataset or a list of datasets.
            The last dataset is used as the main one for storing best results.
            When using multiple datasets. It is recommended to name them for
            better Tensorboard visualization.
        test_datasets: List of datasets used for testing
        logging_period: after how many batches should the logging happen. It
            can also be defined as a time period in format like: 3s; 4m; 6h;
            1d; 3m15s; 3seconds; 4minutes; 6hours; 1days
        validation_period: after how many batches should the validation happen.
            It can also be defined as a time period in same format as logging
        val_preview_input_series: which input series to preview in validation
        val_preview_output_series: which output series to preview in validation
        val_preview_num_examples: how many examples should be printed during
            validation
        train_start_offset: how many lines from the training dataset should be
            skipped. The training starts from the next batch.
        runners_batch_size: batch size of runners. It is the same as batch_size
            if not specified
        initial_variables: variables used for initialization, for example for
            continuation of training. Provide it with a path to your model
            directory and its checkpoint file group common prefix, e.g.
            "variables.data", or "variables.data.3" in case of multiple
            checkpoints per experiment.
        postprocess: A function which takes the dataset with its output series
            and generates additional series from them.
    """
    check_argument_types()

    if isinstance(val_dataset, Dataset):
        val_datasets = [val_dataset]
    else:
        val_datasets = val_dataset

    log_period_batch, log_period_time = _resolve_period(logging_period)
    val_period_batch, val_period_time = _resolve_period(validation_period)

    _check_series_collisions(runners, postprocess)

    _log_model_variables(var_list=trainer.var_list)

    if runners_batch_size is None:
        runners_batch_size = batch_size

    evaluators = [(e[0], e[0], e[1]) if len(e) == 2 else e for e in evaluators]

    if evaluators:
        main_metric = "{}/{}".format(evaluators[-1][0],
                                     evaluators[-1][-1].name)
    else:
        main_metric = "{}/{}".format(runners[-1].decoder_data_id,
                                     runners[-1].loss_names[0])

        if not tf_manager.minimize_metric:
            raise ValueError("minimize_metric must be set to True in "
                             "TensorFlowManager when using loss as "
                             "the main metric")

    step = 0
    seen_instances = 0
    last_seen_instances = 0

    if initial_variables is None:
        # Assume we don't look at coder checkpoints when global
        # initial variables are supplied
        tf_manager.initialize_model_parts(runners + [trainer],
                                          save=True)  # type: ignore
    else:
        try:
            tf_manager.restore(initial_variables)
        except tf.errors.NotFoundError:
            warn("Some variables were not found in checkpoint.)")

    if log_directory:
        log("Initializing TensorBoard summary writer.")
        tb_writer = tf.summary.FileWriter(log_directory,
                                          tf_manager.sessions[0].graph)
        log("TensorBoard writer initialized.")

    log("Starting training")
    last_log_time = time.process_time()
    last_val_time = time.process_time()
    interrupt = None
    try:
        for epoch_n in range(1, epochs + 1):
            log_print("")
            log("Epoch {} begins".format(epoch_n), color="red")

            train_dataset.shuffle()
            train_batched_datasets = train_dataset.batch_dataset(batch_size)

            if epoch_n == 1 and train_start_offset:
                if not isinstance(train_dataset, LazyDataset):
                    warn("Not skipping training instances with "
                         "shuffled in-memory dataset")
                else:
                    _skip_lines(train_start_offset, train_batched_datasets)

            for batch_n, batch_dataset in enumerate(train_batched_datasets):
                step += 1
                seen_instances += len(batch_dataset)
                if _is_logging_time(step, log_period_batch, last_log_time,
                                    log_period_time):
                    trainer_result = tf_manager.execute(batch_dataset,
                                                        [trainer],
                                                        train=True,
                                                        summaries=True)
                    train_results, train_outputs = run_on_dataset(
                        tf_manager,
                        runners,
                        batch_dataset,
                        postprocess,
                        write_out=False,
                        batch_size=runners_batch_size)
                    # ensure train outputs are iterable more than once
                    train_outputs = {
                        k: list(v)
                        for k, v in train_outputs.items()
                    }
                    train_evaluation = evaluation(evaluators, batch_dataset,
                                                  runners, train_results,
                                                  train_outputs)

                    _log_continuous_evaluation(tb_writer,
                                               main_metric,
                                               train_evaluation,
                                               seen_instances,
                                               epoch_n,
                                               epochs,
                                               trainer_result,
                                               train=True)
                    last_log_time = time.process_time()
                else:
                    tf_manager.execute(batch_dataset, [trainer],
                                       train=True,
                                       summaries=False)

                if _is_logging_time(step, val_period_batch, last_val_time,
                                    val_period_time):
                    log_print("")
                    val_duration_start = time.process_time()
                    val_examples = 0
                    for val_id, valset in enumerate(val_datasets):
                        val_examples += len(valset)

                        val_results, val_outputs = run_on_dataset(
                            tf_manager,
                            runners,
                            valset,
                            postprocess,
                            write_out=False,
                            batch_size=runners_batch_size)
                        # ensure val outputs are iterable more than once
                        val_outputs = {
                            k: list(v)
                            for k, v in val_outputs.items()
                        }
                        val_evaluation = evaluation(evaluators, valset,
                                                    runners, val_results,
                                                    val_outputs)

                        valheader = (
                            "Validation (epoch {}, batch number {}):".format(
                                epoch_n, batch_n))
                        log(valheader, color="blue")
                        _print_examples(valset, val_outputs,
                                        val_preview_input_series,
                                        val_preview_output_series,
                                        val_preview_num_examples)
                        log_print("")
                        log(valheader, color="blue")

                        # The last validation set is selected to be the main
                        if val_id == len(val_datasets) - 1:
                            this_score = val_evaluation[main_metric]
                            tf_manager.validation_hook(this_score, epoch_n,
                                                       batch_n)

                            if this_score == tf_manager.best_score:
                                best_score_str = colored("{:.4g}".format(
                                    tf_manager.best_score),
                                                         attrs=["bold"])

                                # store also graph parts
                                all_coders = set.union(*[
                                    rnr.all_coders
                                    for rnr in runners + [trainer]
                                ])  # type: ignore
                                for coder in all_coders:
                                    for session in tf_manager.sessions:
                                        coder.save(session)
                            else:
                                best_score_str = "{:.4g}".format(
                                    tf_manager.best_score)

                            log("best {} on validation: {} (in epoch {}, "
                                "after batch number {})".format(
                                    main_metric, best_score_str,
                                    tf_manager.best_score_epoch,
                                    tf_manager.best_score_batch),
                                color="blue")

                        v_name = valset.name if len(val_datasets) > 1 else None
                        _log_continuous_evaluation(tb_writer,
                                                   main_metric,
                                                   val_evaluation,
                                                   seen_instances,
                                                   epoch_n,
                                                   epochs,
                                                   val_results,
                                                   train=False,
                                                   dataset_name=v_name)

                    # how long was the training between validations
                    training_duration = val_duration_start - last_val_time
                    val_duration = time.process_time() - val_duration_start

                    # the training should take at least twice the time of val.
                    steptime = (training_duration /
                                (seen_instances - last_seen_instances))
                    valtime = val_duration / val_examples
                    last_seen_instances = seen_instances
                    log("Validation time: {:.2f}s, inter-validation: {:.2f}s, "
                        "per-instance (train): {:.2f}s, per-instance (val): "
                        "{:.2f}s".format(val_duration, training_duration,
                                         steptime, valtime),
                        color="blue")
                    if training_duration < 2 * val_duration:
                        notice("Validation period setting is inefficient.")

                    log_print("")
                    last_val_time = time.process_time()

    except KeyboardInterrupt as ex:
        interrupt = ex

    log("Training finished. Maximum {} on validation data: {:.4g}, epoch {}".
        format(main_metric, tf_manager.best_score,
               tf_manager.best_score_epoch))

    if test_datasets:
        tf_manager.restore_best_vars()

        for dataset in test_datasets:
            test_results, test_outputs = run_on_dataset(
                tf_manager,
                runners,
                dataset,
                postprocess,
                write_out=True,
                batch_size=runners_batch_size)
            # ensure test outputs are iterable more than once
            test_outputs = {k: list(v) for k, v in test_outputs.items()}
            eval_result = evaluation(evaluators, dataset, runners,
                                     test_results, test_outputs)
            print_final_evaluation(dataset.name, eval_result)

    log("Finished.")

    if interrupt is not None:
        raise interrupt  # pylint: disable=raising-bad-type
Пример #12
0
def training_loop(
        tf_manager: TensorFlowManager,
        epochs: int,
        trainer: GenericTrainer,  # TODO better annotate
        batch_size: int,
        train_dataset: Dataset,
        val_dataset: Dataset,
        log_directory: str,
        evaluators: EvalConfiguration,
        runners: List[BaseRunner],
        test_datasets: Optional[List[Dataset]] = None,
        logging_period: int = 20,
        validation_period: int = 500,
        val_preview_input_series: Optional[List[str]] = None,
        val_preview_output_series: Optional[List[str]] = None,
        val_preview_num_examples: int = 15,
        train_start_offset: int = 0,
        runners_batch_size: Optional[int] = None,
        initial_variables: Optional[Union[str, List[str]]] = None,
        postprocess: Postprocess = None) -> None:

    # TODO finish the list
    """
    Performs the training loop for given graph and data.

    Args:
        tf_manager: TensorFlowManager with initialized sessions.
        epochs: Number of epochs for which the algoritm will learn.
        trainer: The trainer object containg the TensorFlow code for computing
            the loss and optimization operation.
        train_dataset:
        val_dataset:
        postprocess: Function that takes the output sentence as produced by the
            decoder and transforms into tokenized sentence.
        log_directory: Directory where the TensordBoard log will be generated.
            If None, nothing will be done.
        evaluators: List of evaluators. The last evaluator is used as the main.
            An evaluator is a tuple of the name of the generated series, the
            name of the dataset series the generated one is evaluated with and
            the evaluation function. If only one series names is provided, it
            means the generated and dataset series have the same name.
    """
    if validation_period < logging_period:
        raise AssertionError(
            "Validation period can't be smaller than logging period.")
    _check_series_collisions(runners, postprocess)

    _log_model_variables()

    if tf_manager.report_gpu_memory_consumption:
        log("GPU memory usage: {}".format(gpu_memusage()))

    # TODO DOCUMENT_THIS
    if runners_batch_size is None:
        runners_batch_size = batch_size

    evaluators = [(e[0], e[0], e[1]) if len(e) == 2 else e for e in evaluators]

    if evaluators:
        main_metric = "{}/{}".format(evaluators[-1][0],
                                     evaluators[-1][-1].name)
    else:
        main_metric = "{}/{}".format(runners[-1].decoder_data_id,
                                     runners[-1].loss_names[0])

        if not tf_manager.minimize_metric:
            raise ValueError("minimize_metric must be set to True in "
                             "TensorFlowManager when using loss as "
                             "the main metric")

    step = 0
    seen_instances = 0

    if initial_variables is None:
        # Assume we don't look at coder checkpoints when global
        # initial variables are supplied
        tf_manager.initialize_model_parts(runners + [trainer],
                                          save=True)  # type: ignore
    else:
        tf_manager.restore(initial_variables)

    if log_directory:
        log("Initializing TensorBoard summary writer.")
        tb_writer = tf.summary.FileWriter(log_directory,
                                          tf_manager.sessions[0].graph)
        log("TensorBoard writer initialized.")

    log("Starting training")
    try:
        for epoch_n in range(1, epochs + 1):
            log_print("")
            log("Epoch {} starts".format(epoch_n), color='red')

            train_dataset.shuffle()
            train_batched_datasets = train_dataset.batch_dataset(batch_size)

            if epoch_n == 1 and train_start_offset:
                if not isinstance(train_dataset, LazyDataset):
                    warn("Not skipping training instances with "
                         "shuffled in-memory dataset")
                else:
                    _skip_lines(train_start_offset, train_batched_datasets)

            for batch_n, batch_dataset in enumerate(train_batched_datasets):
                step += 1
                seen_instances += len(batch_dataset)
                if step % logging_period == logging_period - 1:
                    trainer_result = tf_manager.execute(batch_dataset,
                                                        [trainer],
                                                        train=True,
                                                        summaries=True)
                    train_results, train_outputs = run_on_dataset(
                        tf_manager,
                        runners,
                        batch_dataset,
                        postprocess,
                        write_out=False)
                    # ensure train outputs are iterable more than once
                    train_outputs = {
                        k: list(v)
                        for k, v in train_outputs.items()
                    }
                    train_evaluation = evaluation(evaluators, batch_dataset,
                                                  runners, train_results,
                                                  train_outputs)

                    _log_continuous_evaluation(tb_writer,
                                               tf_manager,
                                               main_metric,
                                               train_evaluation,
                                               seen_instances,
                                               epoch_n,
                                               epochs,
                                               trainer_result,
                                               train=True)
                else:
                    tf_manager.execute(batch_dataset, [trainer],
                                       train=True,
                                       summaries=False)

                if step % validation_period == validation_period - 1:
                    val_results, val_outputs = run_on_dataset(
                        tf_manager,
                        runners,
                        val_dataset,
                        postprocess,
                        write_out=False,
                        batch_size=runners_batch_size)
                    # ensure val outputs are iterable more than once
                    val_outputs = {k: list(v) for k, v in val_outputs.items()}
                    val_evaluation = evaluation(evaluators, val_dataset,
                                                runners, val_results,
                                                val_outputs)

                    this_score = val_evaluation[main_metric]
                    tf_manager.validation_hook(this_score, epoch_n, batch_n)

                    log("Validation (epoch {}, batch number {}):".format(
                        epoch_n, batch_n),
                        color='blue')

                    _log_continuous_evaluation(tb_writer,
                                               tf_manager,
                                               main_metric,
                                               val_evaluation,
                                               seen_instances,
                                               epoch_n,
                                               epochs,
                                               val_results,
                                               train=False)

                    if this_score == tf_manager.best_score:
                        best_score_str = colored("{:.4g}".format(
                            tf_manager.best_score),
                                                 attrs=['bold'])
                    else:
                        best_score_str = "{:.4g}".format(tf_manager.best_score)

                    log("best {} on validation: {} (in epoch {}, "
                        "after batch number {})".format(
                            main_metric, best_score_str,
                            tf_manager.best_score_epoch,
                            tf_manager.best_score_batch),
                        color='blue')

                    log_print("")
                    _print_examples(val_dataset, val_outputs,
                                    val_preview_input_series,
                                    val_preview_output_series,
                                    val_preview_num_examples)

    except KeyboardInterrupt:
        log("Training interrupted by user.")

    log("Training finished. Maximum {} on validation data: {:.4g}, epoch {}".
        format(main_metric, tf_manager.best_score,
               tf_manager.best_score_epoch))

    if test_datasets:
        tf_manager.restore_best_vars()

    for dataset in test_datasets:
        test_results, test_outputs = run_on_dataset(
            tf_manager,
            runners,
            dataset,
            postprocess,
            write_out=True,
            batch_size=runners_batch_size)
        # ensure test outputs are iterable more than once
        test_outputs = {k: list(v) for k, v in test_outputs.items()}
        eval_result = evaluation(evaluators, dataset, runners, test_results,
                                 test_outputs)
        print_final_evaluation(dataset.name, eval_result)

    log("Finished.")
Пример #13
0
def training_loop(cfg: Namespace) -> None:
    """Execute the training loop for given graph and data.

    Arguments:
        cfg: Experiment configuration namespace.
    """
    _check_series_collisions(cfg.runners, cfg.postprocess)

    log_model_variables(cfg.trainers)

    initialize_model(cfg.tf_manager, cfg.initial_variables,
                     cfg.runners + cfg.trainers)

    log("Initializing TensorBoard summary writer.")
    tb_writer = tf.summary.FileWriter(cfg.output,
                                      cfg.tf_manager.sessions[0].graph)
    log("TensorBoard writer initialized.")

    feedables = set.union(*[ex.feedables for ex in cfg.runners + cfg.trainers])

    log("Starting training")
    profiler = TrainingProfiler()
    profiler.training_start()

    step = 0
    seen_instances = 0
    last_seen_instances = 0
    interrupt = None

    try:
        for epoch_n in range(1, cfg.epochs + 1):
            train_batches = cfg.train_dataset.batches(cfg.batching_scheme)

            if epoch_n == 1 and cfg.train_start_offset:
                if cfg.train_dataset.shuffled and not cfg.train_dataset.lazy:
                    warn("Not skipping training instances with shuffled "
                         "non-lazy dataset")
                else:
                    _skip_lines(cfg.train_start_offset, train_batches)

            log_print("")
            log("Epoch {} begins".format(epoch_n), color="red")
            profiler.epoch_start()

            for batch_n, batch in enumerate(train_batches):
                step += 1
                seen_instances += len(batch)

                if cfg.log_timer(step, profiler.last_log_time):
                    trainer_result = cfg.tf_manager.execute(batch,
                                                            feedables,
                                                            cfg.trainers,
                                                            train=True,
                                                            summaries=True)
                    train_results, train_outputs, f_batch = run_on_dataset(
                        cfg.tf_manager,
                        cfg.runners,
                        cfg.dataset_runner,
                        batch,
                        cfg.postprocess,
                        write_out=False,
                        batching_scheme=cfg.runners_batching_scheme)
                    # ensure train outputs are iterable more than once
                    train_outputs = {
                        k: list(v)
                        for k, v in train_outputs.items()
                    }
                    train_evaluation = evaluation(cfg.evaluation, f_batch,
                                                  cfg.runners, train_results,
                                                  train_outputs)

                    _log_continuous_evaluation(tb_writer,
                                               cfg.main_metric,
                                               train_evaluation,
                                               seen_instances,
                                               epoch_n,
                                               cfg.epochs,
                                               trainer_result,
                                               train=True)

                    profiler.log_done()

                else:
                    cfg.tf_manager.execute(batch,
                                           feedables,
                                           cfg.trainers,
                                           train=True,
                                           summaries=False)

                if cfg.val_timer(step, profiler.last_val_time):

                    log_print("")
                    profiler.validation_start()

                    val_examples = 0
                    for val_id, valset in enumerate(cfg.val_datasets):
                        val_examples += len(valset)

                        val_results, val_outputs, f_valset = run_on_dataset(
                            cfg.tf_manager,
                            cfg.runners,
                            cfg.dataset_runner,
                            valset,
                            cfg.postprocess,
                            write_out=False,
                            batching_scheme=cfg.runners_batching_scheme)
                        # ensure val outputs are iterable more than once
                        val_outputs = {
                            k: list(v)
                            for k, v in val_outputs.items()
                        }
                        val_evaluation = evaluation(cfg.evaluation, f_valset,
                                                    cfg.runners, val_results,
                                                    val_outputs)

                        valheader = (
                            "Validation (epoch {}, batch number {}):".format(
                                epoch_n, batch_n))
                        log(valheader, color="blue")
                        _print_examples(f_valset, val_outputs,
                                        cfg.val_preview_input_series,
                                        cfg.val_preview_output_series,
                                        cfg.val_preview_num_examples)
                        log_print("")
                        log(valheader, color="blue")

                        # The last validation set is selected to be the main
                        if val_id == len(cfg.val_datasets) - 1:
                            this_score = val_evaluation[cfg.main_metric]
                            cfg.tf_manager.validation_hook(
                                this_score, epoch_n, batch_n)

                            if this_score == cfg.tf_manager.best_score:
                                best_score_str = colored("{:.4g}".format(
                                    cfg.tf_manager.best_score),
                                                         attrs=["bold"])

                                # store also graph parts
                                rnrs = cfg.runners + cfg.trainers
                                # TODO: refactor trainers/runners so that they
                                # have the same API predecessor
                                parameterizeds = set.union(
                                    *[rnr.parameterizeds for rnr in rnrs])
                                for coder in parameterizeds:
                                    for session in cfg.tf_manager.sessions:
                                        coder.save(session)
                            else:
                                best_score_str = "{:.4g}".format(
                                    cfg.tf_manager.best_score)

                            log("best {} on validation: {} (in epoch {}, "
                                "after batch number {})".format(
                                    cfg.main_metric, best_score_str,
                                    cfg.tf_manager.best_score_epoch,
                                    cfg.tf_manager.best_score_batch),
                                color="blue")

                        v_name = "val_{}".format(val_id) if len(
                            cfg.val_datasets) > 1 else None
                        _log_continuous_evaluation(tb_writer,
                                                   cfg.main_metric,
                                                   val_evaluation,
                                                   seen_instances,
                                                   epoch_n,
                                                   cfg.epochs,
                                                   val_results,
                                                   train=False,
                                                   dataset_name=v_name)

                    profiler.validation_done()
                    profiler.log_after_validation(
                        val_examples, seen_instances - last_seen_instances)
                    last_seen_instances = seen_instances

                    log_print("")

    except KeyboardInterrupt as ex:
        interrupt = ex

    log("Training finished. Maximum {} on validation data: {:.4g}, epoch {}".
        format(cfg.main_metric, cfg.tf_manager.best_score,
               cfg.tf_manager.best_score_epoch))

    if cfg.test_datasets:
        cfg.tf_manager.restore_best_vars()

        for test_id, dataset in enumerate(cfg.test_datasets):
            test_results, test_outputs, f_testset = run_on_dataset(
                cfg.tf_manager,
                cfg.runners,
                cfg.dataset_runner,
                dataset,
                cfg.postprocess,
                write_out=True,
                batching_scheme=cfg.runners_batching_scheme)
            # ensure test outputs are iterable more than once
            test_outputs = {k: list(v) for k, v in test_outputs.items()}
            eval_result = evaluation(cfg.evaluation, f_testset, cfg.runners,
                                     test_results, test_outputs)
            print_final_evaluation(eval_result, "test_{}".format(test_id))

    log("Finished.")

    if interrupt is not None:
        raise interrupt  # pylint: disable=raising-bad-type
Пример #14
0
def training_loop(sess, saver,
                  epochs, trainer, all_coders, decoder, batch_size,
                  train_dataset, val_dataset,
                  log_directory,
                  evaluators,
                  runner,
                  test_datasets=[],
                  save_n_best_vars=1,
                  link_best_vars="/tmp/variables.data.best",
                  vars_prefix="/tmp/variables.data",
                  initial_variables=None,
                  logging_period=20,
                  validation_period=500,
                  postprocess=None,
                  minimize_metric=False):

    """
    Performs the training loop for given graph and data.

    Args:

        sess: TF Session.

        saver: TF saver object.

        epochs: Number of epochs for which the algoritm will learn.

        trainer: The trainer object containg the TensorFlow code for computing
            the loss and optimization operation.

        decoder: The decoder object.

        train_dataset:

        val_dataset:

        postprocess: Function that takes the output sentence as produced by the
            decoder and transforms into tokenized sentence.

        log_directory: Directory where the TensordBoard log will be generated.
            If None, nothing will be done.

        evaluators: List of evaluators. The last evaluator
            is used as the main. Each function accepts list of decoded sequences
            and list of reference sequences and returns a float.

        use_copynet: Flag whether the copying mechanism is used.

        use_beamsearch:

        initial_variables: Either None or file where the variables are stored.
            Training then starts from the point the loaded values.

    """

    if not postprocess:
        postprocess = lambda x: x

    evaluation_labels = [f.name for f in evaluators]
    step = 0
    seen_instances = 0

    saver = tf.train.Saver()

    if initial_variables:
        saver.restore(sess, initial_variables)

    if save_n_best_vars < 1:
        raise Exception('save_n_best_vars must be greater than zero')

    if save_n_best_vars == 1:
        variables_files = [vars_prefix]
    elif save_n_best_vars > 1:
        variables_files = ['{}.{}'.format(vars_prefix, i)
                           for i in range(save_n_best_vars)]

    if minimize_metric:
        saved_scores = [np.inf for _ in range(save_n_best_vars)]
        best_score = np.inf
    else:
        saved_scores = [-np.inf for _ in range(save_n_best_vars)]
        best_score = -np.inf

    saver.save(sess, variables_files[0])

    if os.path.islink(link_best_vars):
        # if overwriting output dir
        os.unlink(link_best_vars)
    os.symlink(os.path.basename(variables_files[0]), link_best_vars)

    if log_directory:
        log("Initializing TensorBoard summary writer.")
        tb_writer = tf.train.SummaryWriter(log_directory, sess.graph)
        log("TesorBoard writer initialized.")

    best_score_epoch = 0
    best_score_batch_no = 0

    val_raw_tgt_sentences = val_dataset.get_series(decoder.data_id)
    val_tgt_sentences = postprocess(val_raw_tgt_sentences)

    log("Starting training")
    try:
        for i in range(epochs):
            log_print("")
            log("Epoch {} starts".format(i + 1), color='red')

            train_dataset.shuffle()
            train_batched_datasets = train_dataset.batch_dataset(batch_size)

            for batch_n, batch_dataset in enumerate(train_batched_datasets):

                batch_feed_dict = feed_dicts(batch_dataset, all_coders, train=True)
                step += 1
                batch_sentences = batch_dataset.get_series(decoder.data_id)
                seen_instances += len(batch_sentences)
                if step % logging_period == logging_period - 1:
                    summary_str = trainer.run(sess, batch_feed_dict, summary=True)
                    _, _, train_evaluation = \
                            run_on_dataset(sess, runner, all_coders, decoder, batch_dataset,
                                           evaluators, postprocess, write_out=False)

                    process_evaluation(evaluators, tb_writer, train_evaluation,
                                       seen_instances, summary_str, None, train=True)
                else:
                    trainer.run(sess, batch_feed_dict, summary=False)

                if step % validation_period == validation_period - 1:
                    decoded_val_sentences, decoded_raw_val_sentences, \
                        val_evaluation = run_on_dataset(
                            sess, runner, all_coders, decoder, val_dataset,
                            evaluators, postprocess, write_out=False)

                    this_score = val_evaluation[evaluators[-1].name]

                    def is_better(score1, score2, minimize):
                        if minimize:
                            return score1 < score2
                        else:
                            return score1 > score2

                    def argworst(scores, minimize):
                        if minimize:
                            return np.argmax(scores)
                        else:
                            return np.argmin(scores)

                    if is_better(this_score, best_score, minimize_metric):
                        best_score = this_score
                        best_score_epoch = i + 1
                        best_score_batch_no = batch_n

                    worst_index = argworst(saved_scores, minimize_metric)
                    worst_score = saved_scores[worst_index]

                    if is_better(this_score, worst_score, minimize_metric):
                        # we need to save this score instead the worst score
                        worst_var_file = variables_files[worst_index]
                        saver.save(sess, worst_var_file)
                        saved_scores[worst_index] = this_score
                        log("Variable file saved in {}".format(worst_var_file))

                        # update symlink
                        if best_score == this_score:
                            os.unlink(link_best_vars)
                            os.symlink(os.path.basename(worst_var_file), link_best_vars)

                        log("Best scores saved so far: {}".format(saved_scores))

                    log("Validation (epoch {}, batch number {}):"
                        .format(i + 1, batch_n), color='blue')

                    process_evaluation(evaluators, tb_writer,
                                       val_evaluation, seen_instances,
                                       summary_str, None, train=False)

                    if this_score == best_score:
                        best_score_str = colored("{:.2f}".format(best_score),
                                                 attrs=['bold'])
                    else:
                        best_score_str = "{:.2f}".format(best_score)

                    log("best {} on validation: {} (in epoch {}, "
                        "after batch number {})"
                        .format(evaluation_labels[-1], best_score_str,
                                best_score_epoch, best_score_batch_no),
                        color='blue')


                    log_print("")
                    log_print("Examples:")
                    for sent, sent_raw, ref_sent, ref_sent_raw in zip(
                            decoded_val_sentences[:15],
                            decoded_raw_val_sentences,
                            val_tgt_sentences,
                            val_raw_tgt_sentences):

                        if isinstance(sent, list):
                            log_print("      raw: {}"
                                      .format(" ".join(sent_raw)))
                            log_print("      out: {}".format(" ".join(sent)))
                        else:
                            # TODO does this code ever execute?
                            log_print(sent_raw)
                            log_print(sent)

                        log_print(colored(
                            " raw ref.: {}".format(" ".join(ref_sent_raw)),
                            color="magenta"))
                        log_print(colored(
                            "     ref.: {}".format(" ".join(ref_sent)),
                            color="magenta"))

                    log_print("")

    except KeyboardInterrupt:
        log("Training interrupted by user.")

    if os.path.islink(link_best_vars):
        saver.restore(sess, link_best_vars)

    log("Training finished. Maximum {} on validation data: {:.2f}, epoch {}"
        .format(evaluation_labels[-1], best_score, best_score_epoch))

    for dataset in test_datasets:
        _, _, evaluation = run_on_dataset(sess, runner, all_coders, decoder,
                                          dataset, evaluators,
                                          postprocess, write_out=True)
        if evaluation:
            print_dataset_evaluation(dataset.name, evaluation)

    log("Finished.")
Пример #15
0
def training_loop(
        tf_manager: TensorFlowManager,
        epochs: int,
        trainer: GenericTrainer,  # TODO better annotate
        batch_size: int,
        train_dataset: Dataset,
        val_dataset: Dataset,
        log_directory: str,
        evaluators: EvalConfiguration,
        runners: List[BaseRunner],
        test_datasets: Optional[List[Dataset]] = None,
        link_best_vars="/tmp/variables.data.best",
        vars_prefix="/tmp/variables.data",
        logging_period: int = 20,
        validation_period: int = 500,
        val_preview_input_series: Optional[List[str]] = None,
        val_preview_output_series: Optional[List[str]] = None,
        val_preview_num_examples: int = 15,
        train_start_offset: int = 0,
        runners_batch_size: Optional[int] = None,
        initial_variables: Optional[Union[str, List[str]]] = None,
        postprocess: Postprocess = None,
        minimize_metric: bool = False):

    # TODO finish the list
    """
    Performs the training loop for given graph and data.

    Args:
        tf_manager: TensorFlowManager with initialized sessions.
        epochs: Number of epochs for which the algoritm will learn.
        trainer: The trainer object containg the TensorFlow code for computing
            the loss and optimization operation.
        train_dataset:
        val_dataset:
        postprocess: Function that takes the output sentence as produced by the
            decoder and transforms into tokenized sentence.
        log_directory: Directory where the TensordBoard log will be generated.
            If None, nothing will be done.
        evaluators: List of evaluators. The last evaluator is used as the main.
            An evaluator is a tuple of the name of the generated series, the
            name of the dataset series the generated one is evaluated with and
            the evaluation function. If only one series names is provided, it
            means the generated and dataset series have the same name.
    """
    if validation_period < logging_period:
        raise AssertionError(
            "Logging period can't smaller than validation period.")
    _check_series_collisions(runners, postprocess)

    paramstr = "Model has {} trainable parameters".format(trainer.n_parameters)
    if tf_manager.report_gpu_memory_consumption:
        paramstr += ", GPU memory usage: {}".format(gpu_memusage())

    log(paramstr)

    # TODO DOCUMENT_THIS
    if runners_batch_size is None:
        runners_batch_size = batch_size

    evaluators = [(e[0], e[0], e[1]) if len(e) == 2 else e for e in evaluators]

    main_metric = "{}/{}".format(evaluators[-1][0], evaluators[-1][-1].name)
    step = 0
    seen_instances = 0

    save_n_best_vars = tf_manager.saver_max_to_keep
    if save_n_best_vars < 1:
        raise Exception('save_n_best_vars must be greater than zero')

    if save_n_best_vars == 1:
        variables_files = [vars_prefix]
    elif save_n_best_vars > 1:
        variables_files = [
            '{}.{}'.format(vars_prefix, i) for i in range(save_n_best_vars)
        ]

    if minimize_metric:
        saved_scores = [np.inf for _ in range(save_n_best_vars)]
        best_score = np.inf
    else:
        saved_scores = [-np.inf for _ in range(save_n_best_vars)]
        best_score = -np.inf

    if initial_variables is None:
        # Assume we don't look at coder checkpoints when global
        # initial variables are supplied
        tf_manager.initialize_model_parts(runners + [trainer])  # type: ignore
        tf_manager.save(variables_files[0])
    else:
        tf_manager.restore(initial_variables)

    if os.path.islink(link_best_vars):
        # if overwriting output dir
        os.unlink(link_best_vars)
    os.symlink(os.path.basename(variables_files[0]), link_best_vars)

    if log_directory:
        log("Initializing TensorBoard summary writer.")
        tb_writer = tf.train.SummaryWriter(log_directory,
                                           tf_manager.sessions[0].graph)
        log("TensorBoard writer initialized.")

    best_score_epoch = 0
    best_score_batch_no = 0

    log("Starting training")
    try:
        for epoch_n in range(1, epochs + 1):
            log_print("")
            log("Epoch {} starts".format(epoch_n), color='red')

            train_dataset.shuffle()
            train_batched_datasets = train_dataset.batch_dataset(batch_size)

            if epoch_n == 1 and train_start_offset:
                if not isinstance(train_dataset, LazyDataset):
                    log(
                        "Warning: Not skipping training instances with "
                        "shuffled in-memory dataset",
                        color="red")
                else:
                    _skip_lines(train_start_offset, train_batched_datasets)

            for batch_n, batch_dataset in enumerate(train_batched_datasets):
                step += 1
                seen_instances += len(batch_dataset)
                if step % logging_period == logging_period - 1:
                    trainer_result = tf_manager.execute(batch_dataset,
                                                        [trainer],
                                                        train=True,
                                                        summaries=True)
                    train_results, train_outputs = run_on_dataset(
                        tf_manager,
                        runners,
                        batch_dataset,
                        postprocess,
                        write_out=False)
                    # ensure train outputs are iterable more than once
                    train_outputs = {
                        k: list(v)
                        for k, v in train_outputs.items()
                    }
                    train_evaluation = evaluation(evaluators, batch_dataset,
                                                  runners, train_results,
                                                  train_outputs)

                    _log_continuous_evaluation(tb_writer,
                                               tf_manager,
                                               main_metric,
                                               train_evaluation,
                                               seen_instances,
                                               epoch_n,
                                               epochs,
                                               trainer_result,
                                               train=True)
                else:
                    tf_manager.execute(batch_dataset, [trainer],
                                       train=True,
                                       summaries=False)

                if step % validation_period == validation_period - 1:
                    val_results, val_outputs = run_on_dataset(
                        tf_manager,
                        runners,
                        val_dataset,
                        postprocess,
                        write_out=False,
                        batch_size=runners_batch_size)
                    # ensure val outputs are iterable more than once
                    val_outputs = {k: list(v) for k, v in val_outputs.items()}
                    val_evaluation = evaluation(evaluators, val_dataset,
                                                runners, val_results,
                                                val_outputs)

                    this_score = val_evaluation[main_metric]

                    def is_better(score1, score2, minimize):
                        if minimize:
                            return score1 < score2
                        else:
                            return score1 > score2

                    def argworst(scores, minimize):
                        if minimize:
                            return np.argmax(scores)
                        else:
                            return np.argmin(scores)

                    if is_better(this_score, best_score, minimize_metric):
                        best_score = this_score
                        best_score_epoch = epoch_n
                        best_score_batch_no = batch_n

                    worst_index = argworst(saved_scores, minimize_metric)
                    worst_score = saved_scores[worst_index]

                    if is_better(this_score, worst_score, minimize_metric):
                        # we need to save this score instead the worst score
                        worst_var_file = variables_files[worst_index]
                        tf_manager.save(worst_var_file)
                        saved_scores[worst_index] = this_score
                        log("Variable file saved in {}".format(worst_var_file))

                        # update symlink
                        if best_score == this_score:
                            os.unlink(link_best_vars)
                            os.symlink(os.path.basename(worst_var_file),
                                       link_best_vars)

                        log("Best scores saved so far: {}".format(
                            saved_scores))

                    log("Validation (epoch {}, batch number {}):".format(
                        epoch_n, batch_n),
                        color='blue')

                    _log_continuous_evaluation(tb_writer,
                                               tf_manager,
                                               main_metric,
                                               val_evaluation,
                                               seen_instances,
                                               epoch_n,
                                               epochs,
                                               val_results,
                                               train=False)

                    if this_score == best_score:
                        best_score_str = colored("{:.4g}".format(best_score),
                                                 attrs=['bold'])
                    else:
                        best_score_str = "{:.4g}".format(best_score)

                    log("best {} on validation: {} (in epoch {}, "
                        "after batch number {})".format(
                            main_metric, best_score_str, best_score_epoch,
                            best_score_batch_no),
                        color='blue')

                    log_print("")
                    _print_examples(val_dataset, val_outputs,
                                    val_preview_input_series,
                                    val_preview_output_series,
                                    val_preview_num_examples)

    except KeyboardInterrupt:
        log("Training interrupted by user.")

    log("Training finished. Maximum {} on validation data: {:.4g}, epoch {}".
        format(main_metric, best_score, best_score_epoch))

    if test_datasets and os.path.islink(link_best_vars):
        tf_manager.restore(link_best_vars)

    for dataset in test_datasets:
        test_results, test_outputs = run_on_dataset(
            tf_manager,
            runners,
            dataset,
            postprocess,
            write_out=True,
            batch_size=runners_batch_size)
        # ensure test outputs are iterable more than once
        test_outputs = {k: list(v) for k, v in test_outputs.items()}
        eval_result = evaluation(evaluators, dataset, runners, test_results,
                                 test_outputs)
        print_final_evaluation(dataset.name, eval_result)

    log("Finished.")