예제 #1
0
def main():
    # pylint: disable=no-member,broad-except
    if len(sys.argv) != 3:
        print("Usage: run.py <run_ini_file> <test_datasets>")
        exit(1)

    test_datasets = Configuration()
    test_datasets.add_argument('test_datasets')

    args, sess = initialize_for_running(sys.argv[1])

    datasets_args = test_datasets.load_file(sys.argv[2])
    print("")

    try:
        for dataset in datasets_args.test_datasets:
            check_dataset_and_coders(dataset, args.encoders)
    except Exception as exc:
        log(str(exc), color='red')
        exit(1)

    for dataset in datasets_args.test_datasets:
        _, _, evaluation = run_on_dataset(sess,
                                          args.runner,
                                          args.encoders + [args.decoder],
                                          args.decoder,
                                          dataset,
                                          args.evaluation,
                                          args.postprocess,
                                          write_out=True)
        if evaluation:
            print_dataset_evaluation(dataset.name, evaluation)
예제 #2
0
def main():
    # pylint: disable=no-member,broad-except
    if len(sys.argv) != 3:
        print("Usage: run.py <run_ini_file> <test_datasets>")
        exit(1)

    test_datasets = Configuration()
    test_datasets.add_argument('test_datasets')

    args, sess = initialize_for_running(sys.argv[1])

    datasets_args = test_datasets.load_file(sys.argv[2])
    print("")

    try:
        for dataset in datasets_args.test_datasets:
            check_dataset_and_coders(dataset, args.encoders)
    except Exception as exc:
        log(exc.message, color='red')
        exit(1)

    for dataset in datasets_args.test_datasets:
        _, _, evaluation = run_on_dataset(
            sess, args.runner, args.encoders + [args.decoder], args.decoder,
            dataset, args.evaluation, args.postprocess, write_out=True)
        if evaluation:
            print_dataset_evaluation(dataset.name, evaluation)
예제 #3
0
def post_request():
    start_time = datetime.datetime.now()
    request_data = request.get_json()

    if request_data is None:
        response_data = {"error": "No data were provided."}
        code = 400
    else:
        args = APP.config['args']
        sess = APP.config['sess']

        try:
            dataset = Dataset("request", request_data, {})
            check_dataset_and_coders(dataset, args.encoders)

            result, _, _ = run_on_dataset(
                sess, args.runner, args.encoders + [args.decoder], args.decoder,
                dataset, args.evaluation, args.postprocess, write_out=True)
            response_data = {args.decoder.data_id: result}
            code = 200
        #pylint: disable=broad-except
        except Exception as exc:
            response_data = {'error': str(exc)}
            code = 400

    response_data['duration'] = (datetime.datetime.now() - start_time).total_seconds()
    json_response = json.dumps(response_data)
    response = flask.Response(json_response,
                              content_type='application/json; charset=utf-8')
    response.headers.add('content-length', len(json_response.encode('utf-8')))
    response.status_code = code
    return response
예제 #4
0
def post_request():
    start_time = datetime.datetime.now()
    request_data = request.get_json()

    if request_data is None:
        response_data = {"error": "No data were provided."}
        code = 400
    else:
        args = APP.config['args']
        sess = APP.config['sess']

        try:
            dataset = Dataset("request", request_data, {})
            check_dataset_and_coders(dataset, args.encoders)

            result, _, _ = run_on_dataset(sess,
                                          args.runner,
                                          args.encoders + [args.decoder],
                                          args.decoder,
                                          dataset,
                                          args.evaluation,
                                          args.postprocess,
                                          write_out=True)
            response_data = {args.decoder.data_id: result}
            code = 200
        #pylint: disable=broad-except
        except Exception as exc:
            response_data = {'error': str(exc)}
            code = 400

    response_data['duration'] = (datetime.datetime.now() -
                                 start_time).total_seconds()
    json_response = json.dumps(response_data)
    response = flask.Response(json_response,
                              content_type='application/json; charset=utf-8')
    response.headers.add('content-length', len(json_response.encode('utf-8')))
    response.status_code = code
    return response
예제 #5
0
    def build_model(self) -> None:
        """Build the configuration and the computational graph.

        This function is invoked by all of the main entrypoints of the
        `Experiment` class (`train`, `evaluate`, `run`). It manages the
        building of the TensorFlow graph.

        The bulding procedure is executed as follows:
        1. Random seeds are set.
        2. Configuration is built (instantiated) and normalized.
        3. TODO(tf-data) tf.data.Dataset instance is created and registered
            in the model parts. (This is not implemented yet!)
        4. Graph executors are "blessed". This causes the rest of the TF Graph
            to be built.
        5. Sessions are initialized using the TF Manager object.

        Raises:
            `RuntimeError` when the model is already built.
        """
        if self._model_built:
            raise RuntimeError("build_model() called twice")

        random.seed(self.config.args.random_seed)
        np.random.seed(self.config.args.random_seed)

        with self.graph.as_default():
            tf.set_random_seed(self.config.args.random_seed)

            # Enable the created model parts to find this experiment.
            type(self)._current_experiment = self  # type: ignore

            self.config.build_model(warn_unused=self.train_mode)
            normalize_configuration(self.config.model, self.train_mode)

            self._model = self.config.model
            self._model_built = True

            self._bless_graph_executors()
            self.model.tf_manager.initialize_sessions()

            type(self)._current_experiment = None

            if self.train_mode:
                check_dataset_and_coders(self.model.train_dataset,
                                         self.model.runners)
                if isinstance(self.model.val_dataset, Dataset):
                    check_dataset_and_coders(self.model.val_dataset,
                                             self.model.runners)
                else:
                    for val_dataset in self.model.val_dataset:
                        check_dataset_and_coders(val_dataset,
                                                 self.model.runners)

            if self.train_mode and self.model.visualize_embeddings:
                visualize_embeddings(self.model.visualize_embeddings,
                                     self.model.output)

        self._check_unused_initializers()
예제 #6
0
    def build_model(self) -> None:
        if self._model_built:
            raise RuntimeError("build_model() called twice")

        random.seed(self.config.args.random_seed)
        np.random.seed(self.config.args.random_seed)

        with self.graph.as_default():
            tf.set_random_seed(self.config.args.random_seed)

            # Enable the created model parts to find this experiment.
            type(self)._current_experiment = self  # type: ignore
            self.config.build_model(warn_unused=self.train_mode)
            type(self)._current_experiment = None

            self._model = self.config.model
            self._model_built = True

            if self.model.runners_batch_size is None:
                self.model.runners_batch_size = self.model.batch_size

            if self.model.tf_manager is None:
                self.model.tf_manager = get_default_tf_manager()

            if self.train_mode:
                check_dataset_and_coders(self.model.train_dataset,
                                         self.model.runners)
                if isinstance(self.model.val_dataset, Dataset):
                    check_dataset_and_coders(self.model.val_dataset,
                                             self.model.runners)
                else:
                    for val_dataset in self.model.val_dataset:
                        check_dataset_and_coders(val_dataset,
                                                 self.model.runners)

            if self.train_mode and self.model.visualize_embeddings:
                visualize_embeddings(self.model.visualize_embeddings,
                                     self.model.output)

        self._check_unused_initializers()
예제 #7
0
def main() -> None:
    if len(sys.argv) != 2:
        print("Usage: train.py <ini_file>")
        exit(1)

    # define valid parameters and defaults
    cfg = create_config()
    # load the params from the config file, getting also the simple arguments
    cfg.load_file(sys.argv[1])
    # various things like randseed or summarywriter should be set up here
    # so that graph building can be recorded
    # build all the objects specified in the config

    if cfg.args.random_seed is None:
        cfg.args.random_seed = 2574600
    random.seed(cfg.args.random_seed)
    np.random.seed(cfg.args.random_seed)
    tf.set_random_seed(cfg.args.random_seed)

    # pylint: disable=no-member
    if (os.path.isdir(cfg.args.output) and os.path.exists(
            os.path.join(cfg.args.output, "experiment.ini"))):
        if cfg.args.overwrite_output_dir:
            # we do not want to delete the directory contents
            log("Directory with experiment.ini '{}' exists, "
                "overwriting enabled, proceeding.".format(cfg.args.output))
        else:
            log("Directory with experiment.ini '{}' exists, "
                "overwriting disabled.".format(cfg.args.output),
                color='red')
            exit(1)

    # pylint: disable=broad-except
    if not os.path.isdir(cfg.args.output):
        try:
            os.mkdir(cfg.args.output)
        except Exception as exc:
            log("Failed to create experiment directory: {}. Exception: {}".
                format(cfg.args.output, exc),
                color='red')
            exit(1)

    log_file = "{}/experiment.log".format(cfg.args.output)
    ini_file = "{}/experiment.ini".format(cfg.args.output)
    git_commit_file = "{}/git_commit".format(cfg.args.output)
    git_diff_file = "{}/git_diff".format(cfg.args.output)
    variables_file_prefix = "{}/variables.data".format(cfg.args.output)

    cont_index = 0

    while (os.path.exists(log_file) or os.path.exists(ini_file)
           or os.path.exists(git_commit_file) or os.path.exists(git_diff_file)
           or os.path.exists(variables_file_prefix)
           or os.path.exists("{}.0".format(variables_file_prefix))):
        cont_index += 1

        log_file = "{}/experiment.log.cont-{}".format(cfg.args.output,
                                                      cont_index)
        ini_file = "{}/experiment.ini.cont-{}".format(cfg.args.output,
                                                      cont_index)
        git_commit_file = "{}/git_commit.cont-{}".format(
            cfg.args.output, cont_index)
        git_diff_file = "{}/git_diff.cont-{}".format(cfg.args.output,
                                                     cont_index)
        variables_file_prefix = "{}/variables.data.cont-{}".format(
            cfg.args.output, cont_index)

    copyfile(sys.argv[1], ini_file)
    Logging.set_log_file(log_file)

    # this points inside the neuralmonkey/ dir inside the repo, but
    # it does not matter for git.
    repodir = os.path.dirname(os.path.realpath(__file__))

    # we need to execute the git log command in subshell, because if
    # the log file is specified via relative path, we need to do the
    # redirection of the git-log output to the right file
    os.system("(cd {}; git log -1 --format=%H) > {}".format(
        repodir, git_commit_file))

    os.system("(cd {}; git --no-pager diff --color=always) > {}".format(
        repodir, git_diff_file))

    link_best_vars = "{}.best".format(variables_file_prefix)

    cfg.build_model(warn_unused=True)

    try:
        check_dataset_and_coders(cfg.model.train_dataset, cfg.model.runners)
        check_dataset_and_coders(cfg.model.val_dataset, cfg.model.runners)
    except CheckingException as exc:
        log(str(exc), color='red')
        exit(1)

    Logging.print_header(cfg.model.name)

    # runners_batch_size must be set to avoid problems on GPU
    if cfg.model.runners_batch_size is None:
        cfg.model.runners_batch_size = cfg.model.batch_size

    training_loop(
        tf_manager=cfg.model.tf_manager,
        epochs=cfg.model.epochs,
        trainer=cfg.model.trainer,
        batch_size=cfg.model.batch_size,
        train_dataset=cfg.model.train_dataset,
        val_dataset=cfg.model.val_dataset,
        log_directory=cfg.model.output,
        evaluators=cfg.model.evaluation,
        runners=cfg.model.runners,
        test_datasets=cfg.model.test_datasets,
        link_best_vars=link_best_vars,
        vars_prefix=variables_file_prefix,
        logging_period=cfg.model.logging_period,
        validation_period=cfg.model.validation_period,
        val_preview_input_series=cfg.model.val_preview_input_series,
        val_preview_output_series=cfg.model.val_preview_output_series,
        val_preview_num_examples=cfg.model.val_preview_num_examples,
        postprocess=cfg.model.postprocess,
        train_start_offset=cfg.model.train_start_offset,
        runners_batch_size=cfg.model.runners_batch_size,
        initial_variables=cfg.model.initial_variables,
        minimize_metric=cfg.model.minimize)
예제 #8
0
def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("config",
                        metavar="INI-FILE",
                        help="the configuration file for the experiment")
    parser.add_argument("-s",
                        "--set",
                        type=str,
                        metavar="SETTING",
                        action="append",
                        dest="config_changes",
                        help="override an option in the configuration; the "
                        "syntax is [section.]option=value")
    parser.add_argument("-i",
                        "--init",
                        dest="init_only",
                        action="store_true",
                        help="initialize the experiment directory and exit "
                        "without building the model")
    parser.add_argument("-f",
                        "--overwrite",
                        action="store_true",
                        help="force overwriting the output directory; can be "
                        "used to start an experiment created with --init")
    args = parser.parse_args()

    # define valid parameters and defaults
    cfg = create_config()
    # load the params from the config file, getting also the simple arguments
    cfg.load_file(args.config, changes=args.config_changes)
    # various things like randseed or summarywriter should be set up here
    # so that graph building can be recorded
    # build all the objects specified in the config

    if cfg.args.random_seed is None:
        cfg.args.random_seed = 2574600
    random.seed(cfg.args.random_seed)
    np.random.seed(cfg.args.random_seed)
    tf.set_random_seed(cfg.args.random_seed)

    # pylint: disable=no-member
    if (os.path.isdir(cfg.args.output) and os.path.exists(
            os.path.join(cfg.args.output, "experiment.ini"))):
        if cfg.args.overwrite_output_dir or args.overwrite:
            # we do not want to delete the directory contents
            log("Directory with experiment.ini '{}' exists, "
                "overwriting enabled, proceeding.".format(cfg.args.output))
        else:
            log("Directory with experiment.ini '{}' exists, "
                "overwriting disabled.".format(cfg.args.output),
                color="red")
            exit(1)

    # pylint: disable=broad-except
    if not os.path.isdir(cfg.args.output):
        try:
            os.mkdir(cfg.args.output)
        except Exception as exc:
            log("Failed to create experiment directory: {}. Exception: {}".
                format(cfg.args.output, exc),
                color="red")
            exit(1)

    args_file = "{}/args".format(cfg.args.output)
    log_file = "{}/experiment.log".format(cfg.args.output)
    ini_file = "{}/experiment.ini".format(cfg.args.output)
    orig_ini_file = "{}/original.ini".format(cfg.args.output)
    git_commit_file = "{}/git_commit".format(cfg.args.output)
    git_diff_file = "{}/git_diff".format(cfg.args.output)
    variables_file_prefix = "{}/variables.data".format(cfg.args.output)

    cont_index = 0

    while (os.path.exists(log_file) or os.path.exists(ini_file)
           or os.path.exists(git_commit_file) or os.path.exists(git_diff_file)
           or os.path.exists(variables_file_prefix)
           or os.path.exists("{}.0".format(variables_file_prefix))):
        cont_index += 1

        args_file = "{}/args.cont-{}".format(cfg.args.output, cont_index)
        log_file = "{}/experiment.log.cont-{}".format(cfg.args.output,
                                                      cont_index)
        ini_file = "{}/experiment.ini.cont-{}".format(cfg.args.output,
                                                      cont_index)
        orig_ini_file = "{}/original.ini.cont-{}".format(
            cfg.args.output, cont_index)
        git_commit_file = "{}/git_commit.cont-{}".format(
            cfg.args.output, cont_index)
        git_diff_file = "{}/git_diff.cont-{}".format(cfg.args.output,
                                                     cont_index)
        variables_file_prefix = "{}/variables.data.cont-{}".format(
            cfg.args.output, cont_index)

    with open(args_file, "w") as file:
        print(" ".join(shlex.quote(a) for a in sys.argv), file=file)

    cfg.save_file(ini_file)
    copyfile(args.config, orig_ini_file)

    if args.init_only:
        log("Experiment directory initialized.")

        cmd = [os.path.basename(sys.argv[0]), "-f", ini_file]
        log("To start experiment, run: {}".format(" ".join(
            shlex.quote(a) for a in cmd)))
        exit(0)

    Logging.set_log_file(log_file)

    # this points inside the neuralmonkey/ dir inside the repo, but
    # it does not matter for git.
    repodir = os.path.dirname(os.path.realpath(__file__))

    # we need to execute the git log command in subshell, because if
    # the log file is specified via relative path, we need to do the
    # redirection of the git-log output to the right file
    os.system("(cd {}; git log -1 --format=%H) > {}".format(
        repodir, git_commit_file))

    os.system("(cd {}; git --no-pager diff --color=always) > {}".format(
        repodir, git_diff_file))

    cfg.build_model(warn_unused=True)

    cfg.model.tf_manager.init_saving(variables_file_prefix)

    try:
        check_dataset_and_coders(cfg.model.train_dataset, cfg.model.runners)
        if isinstance(cfg.model.val_dataset, Dataset):
            check_dataset_and_coders(cfg.model.val_dataset, cfg.model.runners)
        else:
            for val_dataset in cfg.model.val_dataset:
                check_dataset_and_coders(val_dataset, cfg.model.runners)
    except CheckingException as exc:
        log(str(exc), color="red")
        exit(1)

    if cfg.model.visualize_embeddings:

        tb_projector = projector.ProjectorConfig()

        for sequence in cfg.model.visualize_embeddings:
            # TODO this check should be done when abstract class of embedded
            # sequences will be created, not only EmbeddedFactorSequence
            if not isinstance(sequence, EmbeddedFactorSequence):
                raise ValueError("Visualization must be embedded sequence.")
            sequence.tb_embedding_visualization(cfg.model.output, tb_projector)

        summary_writer = tf.summary.FileWriter(cfg.model.output)
        projector.visualize_embeddings(summary_writer, tb_projector)

    Logging.print_header(cfg.model.name, cfg.args.output)

    # runners_batch_size must be set to avoid problems on GPU
    if cfg.model.runners_batch_size is None:
        cfg.model.runners_batch_size = cfg.model.batch_size

    training_loop(
        tf_manager=cfg.model.tf_manager,
        epochs=cfg.model.epochs,
        trainer=cfg.model.trainer,
        batch_size=cfg.model.batch_size,
        log_directory=cfg.model.output,
        evaluators=cfg.model.evaluation,
        runners=cfg.model.runners,
        train_dataset=cfg.model.train_dataset,
        val_dataset=cfg.model.val_dataset,
        test_datasets=cfg.model.test_datasets,
        logging_period=cfg.model.logging_period,
        validation_period=cfg.model.validation_period,
        val_preview_input_series=cfg.model.val_preview_input_series,
        val_preview_output_series=cfg.model.val_preview_output_series,
        val_preview_num_examples=cfg.model.val_preview_num_examples,
        postprocess=cfg.model.postprocess,
        train_start_offset=cfg.model.train_start_offset,
        runners_batch_size=cfg.model.runners_batch_size,
        initial_variables=cfg.model.initial_variables)
예제 #9
0
def main():
    if len(sys.argv) != 2:
        print("Usage: train.py <ini_file>")
        exit(1)

    args = create_config(sys.argv[1])

    print("")

    #pylint: disable=no-member,broad-except
    if args.random_seed is not None:
        tf.set_random_seed(args.random_seed)

    if os.path.isdir(args.output) and \
            os.path.exists(os.path.join(args.output, "experiment.ini")):
        if args.overwrite_output_dir:
            # we do not want to delete the directory contents
            log("Directory with experiment.ini '{}' exists, "
                "overwriting enabled, proceeding."
                .format(args.output))
        else:
            log("Directory with experiment.ini '{}' exists, "
                "overwriting disabled."
                .format(args.output), color='red')
            exit(1)

    try:
        check_dataset_and_coders(args.train_dataset,
                                 args.encoders + [args.decoder])
        check_dataset_and_coders(args.val_dataset,
                                 args.encoders + [args.decoder])
        for test in args.test_datasets:
            check_dataset_and_coders(test, args.encoders)
    except Exception as exc:
        log(exc.message, color='red')
        exit(1)

    if not os.path.isdir(args.output):
        try:
            os.mkdir(args.output)
        except Exception as exc:
            log("Failed to create experiment directory: {}. Exception: {}"
                .format(args.output, exc), color='red')
            exit(1)

    log_file = "{}/experiment.log".format(args.output)
    ini_file = "{}/experiment.ini".format(args.output)
    git_commit_file = "{}/git_commit".format(args.output)
    git_diff_file = "{}/git_diff".format(args.output)
    variables_file_prefix = "{}/variables.data".format(args.output)

    cont_index = 0

    while (os.path.exists(log_file)
           or os.path.exists(ini_file)
           or os.path.exists(git_commit_file)
           or os.path.exists(git_diff_file)
           or os.path.exists(variables_file_prefix)
           or os.path.exists("{}.0".format(variables_file_prefix))):
        cont_index += 1

        log_file = "{}/experiment.log.cont-{}".format(args.output, cont_index)
        ini_file = "{}/experiment.ini.cont-{}".format(args.output, cont_index)
        git_commit_file = "{}/git_commit.cont-{}".format(
            args.output, cont_index)
        git_diff_file = "{}/git_diff.cont-{}".format(args.output, cont_index)
        variables_file_prefix = "{}/variables.data.cont-{}".format(
            args.output, cont_index)

    copyfile(sys.argv[1], ini_file)
    Logging.set_log_file(log_file)
    Logging.print_header(args.name)

    os.system("git log -1 --format=%H > {}".format(git_commit_file))
    os.system("git --no-pager diff --color=always > {}".format(git_diff_file))

    link_best_vars = "{}.best".format(variables_file_prefix)

    sess, saver = initialize_tf(args.initial_variables, args.threads)
    training_loop(sess, saver, args.epochs, args.trainer,
                  args.encoders + [args.decoder], args.decoder,
                  args.batch_size, args.train_dataset, args.val_dataset,
                  args.output, args.evaluation, args.runner,
                  test_datasets=args.test_datasets,
                  save_n_best_vars=args.save_n_best,
                  link_best_vars=link_best_vars,
                  vars_prefix=variables_file_prefix,
                  logging_period=args.logging_period,
                  validation_period=args.validation_period,
                  postprocess=args.postprocess,
                  minimize_metric=args.minimize)
예제 #10
0
def main():
    if len(sys.argv) != 2:
        print("Usage: train.py <ini_file>")
        exit(1)

    args = create_config(sys.argv[1])

    print("")

    #pylint: disable=no-member,broad-except
    if args.random_seed is not None:
        tf.set_random_seed(args.random_seed)

    if os.path.isdir(args.output) and \
            os.path.exists(os.path.join(args.output, "experiment.ini")):
        if args.overwrite_output_dir:
            # we do not want to delete the directory contents
            log("Directory with experiment.ini '{}' exists, "
                "overwriting enabled, proceeding.".format(args.output))
        else:
            log("Directory with experiment.ini '{}' exists, "
                "overwriting disabled.".format(args.output),
                color='red')
            exit(1)

    try:
        check_dataset_and_coders(args.train_dataset,
                                 args.encoders + [args.decoder])
        check_dataset_and_coders(args.val_dataset,
                                 args.encoders + [args.decoder])
        for test in args.test_datasets:
            check_dataset_and_coders(test, args.encoders)
    except Exception as exc:
        log(exc.message, color='red')
        exit(1)

    if not os.path.isdir(args.output):
        try:
            os.mkdir(args.output)
        except Exception as exc:
            log("Failed to create experiment directory: {}. Exception: {}".
                format(args.output, exc),
                color='red')
            exit(1)

    log_file = "{}/experiment.log".format(args.output)
    ini_file = "{}/experiment.ini".format(args.output)
    git_commit_file = "{}/git_commit".format(args.output)
    git_diff_file = "{}/git_diff".format(args.output)
    variables_file_prefix = "{}/variables.data".format(args.output)

    cont_index = 0

    while (os.path.exists(log_file) or os.path.exists(ini_file)
           or os.path.exists(git_commit_file) or os.path.exists(git_diff_file)
           or os.path.exists(variables_file_prefix)
           or os.path.exists("{}.0".format(variables_file_prefix))):
        cont_index += 1

        log_file = "{}/experiment.log.cont-{}".format(args.output, cont_index)
        ini_file = "{}/experiment.ini.cont-{}".format(args.output, cont_index)
        git_commit_file = "{}/git_commit.cont-{}".format(
            args.output, cont_index)
        git_diff_file = "{}/git_diff.cont-{}".format(args.output, cont_index)
        variables_file_prefix = "{}/variables.data.cont-{}".format(
            args.output, cont_index)

    copyfile(sys.argv[1], ini_file)
    Logging.set_log_file(log_file)
    Logging.print_header(args.name)

    # this points inside the neuralmonkey/ dir inside the repo, but
    # it does not matter for git.
    repodir = os.path.dirname(os.path.realpath(__file__))

    os.system("cd {}; git log -1 --format=%H > {}".format(
        repodir, git_commit_file))

    os.system("cd {}; git --no-pager diff --color=always > {}".format(
        repodir, git_diff_file))

    link_best_vars = "{}.best".format(variables_file_prefix)

    sess, saver = initialize_tf(args.initial_variables, args.threads)
    training_loop(sess,
                  saver,
                  args.epochs,
                  args.trainer,
                  args.encoders + [args.decoder],
                  args.decoder,
                  args.batch_size,
                  args.train_dataset,
                  args.val_dataset,
                  args.output,
                  args.evaluation,
                  args.runner,
                  test_datasets=args.test_datasets,
                  save_n_best_vars=args.save_n_best,
                  link_best_vars=link_best_vars,
                  vars_prefix=variables_file_prefix,
                  logging_period=args.logging_period,
                  validation_period=args.validation_period,
                  postprocess=args.postprocess,
                  minimize_metric=args.minimize)
예제 #11
0
def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('config',
                        metavar='INI-FILE',
                        help='the configuration file for the experiment')
    parser.add_argument('-s',
                        '--set',
                        type=str,
                        metavar='SETTING',
                        action='append',
                        dest='config_changes',
                        help='override an option in the configuration; the '
                        'syntax is [section.]option=value')
    args = parser.parse_args()

    # define valid parameters and defaults
    cfg = create_config()
    # load the params from the config file, getting also the simple arguments
    cfg.load_file(args.config, changes=args.config_changes)
    # various things like randseed or summarywriter should be set up here
    # so that graph building can be recorded
    # build all the objects specified in the config

    if cfg.args.random_seed is None:
        cfg.args.random_seed = 2574600
    random.seed(cfg.args.random_seed)
    np.random.seed(cfg.args.random_seed)
    tf.set_random_seed(cfg.args.random_seed)

    # pylint: disable=no-member
    if (os.path.isdir(cfg.args.output) and os.path.exists(
            os.path.join(cfg.args.output, "experiment.ini"))):
        if cfg.args.overwrite_output_dir:
            # we do not want to delete the directory contents
            log("Directory with experiment.ini '{}' exists, "
                "overwriting enabled, proceeding.".format(cfg.args.output))
        else:
            log("Directory with experiment.ini '{}' exists, "
                "overwriting disabled.".format(cfg.args.output),
                color='red')
            exit(1)

    # pylint: disable=broad-except
    if not os.path.isdir(cfg.args.output):
        try:
            os.mkdir(cfg.args.output)
        except Exception as exc:
            log("Failed to create experiment directory: {}. Exception: {}".
                format(cfg.args.output, exc),
                color='red')
            exit(1)

    args_file = "{}/args".format(cfg.args.output)
    log_file = "{}/experiment.log".format(cfg.args.output)
    ini_file = "{}/experiment.ini".format(cfg.args.output)
    orig_ini_file = "{}/original.ini".format(cfg.args.output)
    git_commit_file = "{}/git_commit".format(cfg.args.output)
    git_diff_file = "{}/git_diff".format(cfg.args.output)
    variables_file_prefix = "{}/variables.data".format(cfg.args.output)

    cont_index = 0

    while (os.path.exists(log_file) or os.path.exists(ini_file)
           or os.path.exists(git_commit_file) or os.path.exists(git_diff_file)
           or os.path.exists(variables_file_prefix)
           or os.path.exists("{}.0".format(variables_file_prefix))):
        cont_index += 1

        args_file = "{}/args.cont-{}".format(cfg.args.output, cont_index)
        log_file = "{}/experiment.log.cont-{}".format(cfg.args.output,
                                                      cont_index)
        ini_file = "{}/experiment.ini.cont-{}".format(cfg.args.output,
                                                      cont_index)
        orig_ini_file = "{}/original.ini.cont-{}".format(
            cfg.args.output, cont_index)
        git_commit_file = "{}/git_commit.cont-{}".format(
            cfg.args.output, cont_index)
        git_diff_file = "{}/git_diff.cont-{}".format(cfg.args.output,
                                                     cont_index)
        variables_file_prefix = "{}/variables.data.cont-{}".format(
            cfg.args.output, cont_index)

    with open(args_file, 'w') as file:
        print(' '.join(shlex.quote(a) for a in sys.argv), file=file)

    cfg.save_file(ini_file)
    copyfile(args.config, orig_ini_file)

    Logging.set_log_file(log_file)

    # this points inside the neuralmonkey/ dir inside the repo, but
    # it does not matter for git.
    repodir = os.path.dirname(os.path.realpath(__file__))

    # we need to execute the git log command in subshell, because if
    # the log file is specified via relative path, we need to do the
    # redirection of the git-log output to the right file
    os.system("(cd {}; git log -1 --format=%H) > {}".format(
        repodir, git_commit_file))

    os.system("(cd {}; git --no-pager diff --color=always) > {}".format(
        repodir, git_diff_file))

    cfg.build_model(warn_unused=True)

    cfg.model.tf_manager.init_saving(variables_file_prefix)

    try:
        check_dataset_and_coders(cfg.model.train_dataset, cfg.model.runners)
        check_dataset_and_coders(cfg.model.val_dataset, cfg.model.runners)
    except CheckingException as exc:
        log(str(exc), color='red')
        exit(1)

    Logging.print_header(cfg.model.name)

    # runners_batch_size must be set to avoid problems on GPU
    if cfg.model.runners_batch_size is None:
        cfg.model.runners_batch_size = cfg.model.batch_size

    training_loop(
        tf_manager=cfg.model.tf_manager,
        epochs=cfg.model.epochs,
        trainer=cfg.model.trainer,
        batch_size=cfg.model.batch_size,
        train_dataset=cfg.model.train_dataset,
        val_dataset=cfg.model.val_dataset,
        log_directory=cfg.model.output,
        evaluators=cfg.model.evaluation,
        runners=cfg.model.runners,
        test_datasets=cfg.model.test_datasets,
        logging_period=cfg.model.logging_period,
        validation_period=cfg.model.validation_period,
        val_preview_input_series=cfg.model.val_preview_input_series,
        val_preview_output_series=cfg.model.val_preview_output_series,
        val_preview_num_examples=cfg.model.val_preview_num_examples,
        postprocess=cfg.model.postprocess,
        train_start_offset=cfg.model.train_start_offset,
        runners_batch_size=cfg.model.runners_batch_size,
        initial_variables=cfg.model.initial_variables)