Exemplo n.º 1
0
    def train(self) -> None:
        """Train model specified by this experiment.

        This function is one of the main functions (entrypoints) called on
        the experiment. It builds the model (if needed) and runs the training
        procedure.

        Raises:
            `RuntimeError` when the experiment is not intended for training.
        """
        if not self.train_mode:
            raise RuntimeError("train() was called, but the experiment was "
                               "created with train_mode=False")
        if not self._model_built:
            self.build_model()

        self.cont_index += 1

        # Initialize the experiment directory.
        self.config.save_file(self.get_path("experiment.ini"))
        shutil.copyfile(self._config_path, self.get_path("original.ini"))
        save_git_info(self.get_path("git_commit"), self.get_path("git_diff"))
        Logging.set_log_file(self.get_path("experiment.log"))

        Logging.print_header(self.model.name, self.model.output)

        with self.graph.as_default():
            self.model.tf_manager.init_saving(self.get_path("variables.data"))

            training_loop(
                tf_manager=self.model.tf_manager,
                epochs=self.model.epochs,
                trainers=self.model.trainers,
                batching_scheme=self.model.batching_scheme,
                runners_batching_scheme=self.model.runners_batching_scheme,
                log_directory=self.model.output,
                evaluators=self.model.evaluation,
                main_metric=self.model.main_metric,
                runners=self.model.runners,
                train_dataset=self.model.train_dataset,
                val_datasets=self.model.val_datasets,
                test_datasets=self.model.test_datasets,
                log_timer=self.model.log_timer,
                val_timer=self.model.val_timer,
                val_preview_input_series=self.model.val_preview_input_series,
                val_preview_output_series=self.model.val_preview_output_series,
                val_preview_num_examples=self.model.val_preview_num_examples,
                postprocess=self.model.postprocess,
                train_start_offset=self.model.train_start_offset,
                initial_variables=self.model.initial_variables,
                final_variables=self.get_path("variables.data.final"))

            self._vars_loaded = True
Exemplo n.º 2
0
    def train(self) -> None:
        """Train model specified by this experiment.

        This function is one of the main functions (entrypoints) called on
        the experiment. It builds the model (if needed) and runs the training
        procedure.

        Raises:
            `RuntimeError` when the experiment is not intended for training.
        """
        if not self.train_mode:
            raise RuntimeError("train() was called, but the experiment was "
                               "created with train_mode=False")
        if not self._model_built:
            self.build_model()

        self.cont_index += 1

        # Initialize the experiment directory.
        self.config.save_file(self.get_path("experiment.ini"))
        shutil.copyfile(self._config_path, self.get_path("original.ini"))
        save_git_info(self.get_path("git_commit"), self.get_path("git_diff"))
        Logging.set_log_file(self.get_path("experiment.log"))

        Logging.print_header(self.model.name, self.model.output)

        with self.graph.as_default():
            self.model.tf_manager.init_saving(self.get_path("variables.data"))

            training_loop(cfg=self.model)

            final_variables = self.get_path("variables.data.final")
            log("Saving final variables in {}".format(final_variables))
            self.model.tf_manager.save(final_variables)

            if self.model.test_datasets:
                if self.model.tf_manager.best_score_index is not None:
                    self.model.tf_manager.restore_best_vars()

                for test_id, dataset in enumerate(self.model.test_datasets):
                    self.evaluate(dataset,
                                  write_out=True,
                                  name="test_{}".format(test_id))

            log("Finished.")
            self._vars_loaded = True
Exemplo n.º 3
0
    def train(self) -> None:
        """Train model specified by this experiment.

        This function is one of the main functions (entrypoints) called on
        the experiment. It builds the model (if needed) and runs the training
        procedure.

        Raises:
            `RuntimeError` when the experiment is not intended for training.
        """
        if not self.train_mode:
            raise RuntimeError("train() was called, but the experiment was "
                               "created with train_mode=False")
        if not self._model_built:
            self.build_model()

        self.cont_index += 1

        # Initialize the experiment directory.
        self.config.save_file(self.get_path("experiment.ini"))
        shutil.copyfile(self._config_path, self.get_path("original.ini"))
        save_git_info(self.get_path("git_commit"), self.get_path("git_diff"))
        Logging.set_log_file(self.get_path("experiment.log"))

        Logging.print_header(self.model.name, self.model.output)

        with self.graph.as_default():
            self.model.tf_manager.init_saving(self.get_path("variables.data"))

            training_loop(cfg=self.model)

            final_variables = self.get_path("variables.data.final")
            log("Saving final variables in {}".format(final_variables))
            self.model.tf_manager.save(final_variables)

            if self.model.test_datasets:
                if self.model.tf_manager.best_score_index is not None:
                    self.model.tf_manager.restore_best_vars()

                for test_id, dataset in enumerate(self.model.test_datasets):
                    self.evaluate(dataset, write_out=True,
                                  name="test_{}".format(test_id))

            log("Finished.")
            self._vars_loaded = True
Exemplo n.º 4
0
    def train(self) -> None:
        if not self.train_mode:
            raise RuntimeError("train() was called, but the experiment was "
                               "created with train_mode=False")
        if not self._model_built:
            self.build_model()

        self.cont_index += 1

        # Initialize the experiment directory.
        self.config.save_file(self.get_path("experiment.ini"))
        shutil.copyfile(self._config_path, self.get_path("original.ini"))
        save_git_info(self.get_path("git_commit"), self.get_path("git_diff"))
        Logging.set_log_file(self.get_path("experiment.log"))

        Logging.print_header(self.model.name, self.model.output)

        with self.graph.as_default():
            self.model.tf_manager.init_saving(self.get_path("variables.data"))

            training_loop(
                tf_manager=self.model.tf_manager,
                epochs=self.model.epochs,
                trainer=self.model.trainer,
                batch_size=self.model.batch_size,
                batching_scheme=self.model.batching_scheme,
                log_directory=self.model.output,
                evaluators=self.model.evaluation,
                runners=self.model.runners,
                train_dataset=self.model.train_dataset,
                val_dataset=self.model.val_dataset,
                test_datasets=self.model.test_datasets,
                logging_period=self.model.logging_period,
                validation_period=self.model.validation_period,
                val_preview_input_series=self.model.val_preview_input_series,
                val_preview_output_series=self.model.val_preview_output_series,
                val_preview_num_examples=self.model.val_preview_num_examples,
                postprocess=self.model.postprocess,
                train_start_offset=self.model.train_start_offset,
                runners_batch_size=self.model.runners_batch_size,
                initial_variables=self.model.initial_variables,
                final_variables=self.get_path("variables.data.final"))

            self._vars_loaded = True
Exemplo n.º 5
0
def main() -> None:
    if len(sys.argv) != 2:
        print("Usage: train.py <ini_file>")
        exit(1)

    # define valid parameters and defaults
    cfg = create_config()
    # load the params from the config file, getting also the simple arguments
    cfg.load_file(sys.argv[1])
    # various things like randseed or summarywriter should be set up here
    # so that graph building can be recorded
    # build all the objects specified in the config

    if cfg.args.random_seed is None:
        cfg.args.random_seed = 2574600
    random.seed(cfg.args.random_seed)
    np.random.seed(cfg.args.random_seed)
    tf.set_random_seed(cfg.args.random_seed)

    # pylint: disable=no-member
    if (os.path.isdir(cfg.args.output) and os.path.exists(
            os.path.join(cfg.args.output, "experiment.ini"))):
        if cfg.args.overwrite_output_dir:
            # we do not want to delete the directory contents
            log("Directory with experiment.ini '{}' exists, "
                "overwriting enabled, proceeding.".format(cfg.args.output))
        else:
            log("Directory with experiment.ini '{}' exists, "
                "overwriting disabled.".format(cfg.args.output),
                color='red')
            exit(1)

    # pylint: disable=broad-except
    if not os.path.isdir(cfg.args.output):
        try:
            os.mkdir(cfg.args.output)
        except Exception as exc:
            log("Failed to create experiment directory: {}. Exception: {}".
                format(cfg.args.output, exc),
                color='red')
            exit(1)

    log_file = "{}/experiment.log".format(cfg.args.output)
    ini_file = "{}/experiment.ini".format(cfg.args.output)
    git_commit_file = "{}/git_commit".format(cfg.args.output)
    git_diff_file = "{}/git_diff".format(cfg.args.output)
    variables_file_prefix = "{}/variables.data".format(cfg.args.output)

    cont_index = 0

    while (os.path.exists(log_file) or os.path.exists(ini_file)
           or os.path.exists(git_commit_file) or os.path.exists(git_diff_file)
           or os.path.exists(variables_file_prefix)
           or os.path.exists("{}.0".format(variables_file_prefix))):
        cont_index += 1

        log_file = "{}/experiment.log.cont-{}".format(cfg.args.output,
                                                      cont_index)
        ini_file = "{}/experiment.ini.cont-{}".format(cfg.args.output,
                                                      cont_index)
        git_commit_file = "{}/git_commit.cont-{}".format(
            cfg.args.output, cont_index)
        git_diff_file = "{}/git_diff.cont-{}".format(cfg.args.output,
                                                     cont_index)
        variables_file_prefix = "{}/variables.data.cont-{}".format(
            cfg.args.output, cont_index)

    copyfile(sys.argv[1], ini_file)
    Logging.set_log_file(log_file)

    # this points inside the neuralmonkey/ dir inside the repo, but
    # it does not matter for git.
    repodir = os.path.dirname(os.path.realpath(__file__))

    # we need to execute the git log command in subshell, because if
    # the log file is specified via relative path, we need to do the
    # redirection of the git-log output to the right file
    os.system("(cd {}; git log -1 --format=%H) > {}".format(
        repodir, git_commit_file))

    os.system("(cd {}; git --no-pager diff --color=always) > {}".format(
        repodir, git_diff_file))

    link_best_vars = "{}.best".format(variables_file_prefix)

    cfg.build_model(warn_unused=True)

    try:
        check_dataset_and_coders(cfg.model.train_dataset, cfg.model.runners)
        check_dataset_and_coders(cfg.model.val_dataset, cfg.model.runners)
    except CheckingException as exc:
        log(str(exc), color='red')
        exit(1)

    Logging.print_header(cfg.model.name)

    # runners_batch_size must be set to avoid problems on GPU
    if cfg.model.runners_batch_size is None:
        cfg.model.runners_batch_size = cfg.model.batch_size

    training_loop(
        tf_manager=cfg.model.tf_manager,
        epochs=cfg.model.epochs,
        trainer=cfg.model.trainer,
        batch_size=cfg.model.batch_size,
        train_dataset=cfg.model.train_dataset,
        val_dataset=cfg.model.val_dataset,
        log_directory=cfg.model.output,
        evaluators=cfg.model.evaluation,
        runners=cfg.model.runners,
        test_datasets=cfg.model.test_datasets,
        link_best_vars=link_best_vars,
        vars_prefix=variables_file_prefix,
        logging_period=cfg.model.logging_period,
        validation_period=cfg.model.validation_period,
        val_preview_input_series=cfg.model.val_preview_input_series,
        val_preview_output_series=cfg.model.val_preview_output_series,
        val_preview_num_examples=cfg.model.val_preview_num_examples,
        postprocess=cfg.model.postprocess,
        train_start_offset=cfg.model.train_start_offset,
        runners_batch_size=cfg.model.runners_batch_size,
        initial_variables=cfg.model.initial_variables,
        minimize_metric=cfg.model.minimize)
Exemplo n.º 6
0
def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("config",
                        metavar="INI-FILE",
                        help="the configuration file for the experiment")
    parser.add_argument("-s",
                        "--set",
                        type=str,
                        metavar="SETTING",
                        action="append",
                        dest="config_changes",
                        help="override an option in the configuration; the "
                        "syntax is [section.]option=value")
    parser.add_argument("-i",
                        "--init",
                        dest="init_only",
                        action="store_true",
                        help="initialize the experiment directory and exit "
                        "without building the model")
    parser.add_argument("-f",
                        "--overwrite",
                        action="store_true",
                        help="force overwriting the output directory; can be "
                        "used to start an experiment created with --init")
    args = parser.parse_args()

    # define valid parameters and defaults
    cfg = create_config()
    # load the params from the config file, getting also the simple arguments
    cfg.load_file(args.config, changes=args.config_changes)
    # various things like randseed or summarywriter should be set up here
    # so that graph building can be recorded
    # build all the objects specified in the config

    if cfg.args.random_seed is None:
        cfg.args.random_seed = 2574600
    random.seed(cfg.args.random_seed)
    np.random.seed(cfg.args.random_seed)
    tf.set_random_seed(cfg.args.random_seed)

    # pylint: disable=no-member
    if (os.path.isdir(cfg.args.output) and os.path.exists(
            os.path.join(cfg.args.output, "experiment.ini"))):
        if cfg.args.overwrite_output_dir or args.overwrite:
            # we do not want to delete the directory contents
            log("Directory with experiment.ini '{}' exists, "
                "overwriting enabled, proceeding.".format(cfg.args.output))
        else:
            log("Directory with experiment.ini '{}' exists, "
                "overwriting disabled.".format(cfg.args.output),
                color="red")
            exit(1)

    # pylint: disable=broad-except
    if not os.path.isdir(cfg.args.output):
        try:
            os.mkdir(cfg.args.output)
        except Exception as exc:
            log("Failed to create experiment directory: {}. Exception: {}".
                format(cfg.args.output, exc),
                color="red")
            exit(1)

    args_file = "{}/args".format(cfg.args.output)
    log_file = "{}/experiment.log".format(cfg.args.output)
    ini_file = "{}/experiment.ini".format(cfg.args.output)
    orig_ini_file = "{}/original.ini".format(cfg.args.output)
    git_commit_file = "{}/git_commit".format(cfg.args.output)
    git_diff_file = "{}/git_diff".format(cfg.args.output)
    variables_file_prefix = "{}/variables.data".format(cfg.args.output)

    cont_index = 0

    while (os.path.exists(log_file) or os.path.exists(ini_file)
           or os.path.exists(git_commit_file) or os.path.exists(git_diff_file)
           or os.path.exists(variables_file_prefix)
           or os.path.exists("{}.0".format(variables_file_prefix))):
        cont_index += 1

        args_file = "{}/args.cont-{}".format(cfg.args.output, cont_index)
        log_file = "{}/experiment.log.cont-{}".format(cfg.args.output,
                                                      cont_index)
        ini_file = "{}/experiment.ini.cont-{}".format(cfg.args.output,
                                                      cont_index)
        orig_ini_file = "{}/original.ini.cont-{}".format(
            cfg.args.output, cont_index)
        git_commit_file = "{}/git_commit.cont-{}".format(
            cfg.args.output, cont_index)
        git_diff_file = "{}/git_diff.cont-{}".format(cfg.args.output,
                                                     cont_index)
        variables_file_prefix = "{}/variables.data.cont-{}".format(
            cfg.args.output, cont_index)

    with open(args_file, "w") as file:
        print(" ".join(shlex.quote(a) for a in sys.argv), file=file)

    cfg.save_file(ini_file)
    copyfile(args.config, orig_ini_file)

    if args.init_only:
        log("Experiment directory initialized.")

        cmd = [os.path.basename(sys.argv[0]), "-f", ini_file]
        log("To start experiment, run: {}".format(" ".join(
            shlex.quote(a) for a in cmd)))
        exit(0)

    Logging.set_log_file(log_file)

    # this points inside the neuralmonkey/ dir inside the repo, but
    # it does not matter for git.
    repodir = os.path.dirname(os.path.realpath(__file__))

    # we need to execute the git log command in subshell, because if
    # the log file is specified via relative path, we need to do the
    # redirection of the git-log output to the right file
    os.system("(cd {}; git log -1 --format=%H) > {}".format(
        repodir, git_commit_file))

    os.system("(cd {}; git --no-pager diff --color=always) > {}".format(
        repodir, git_diff_file))

    cfg.build_model(warn_unused=True)

    cfg.model.tf_manager.init_saving(variables_file_prefix)

    try:
        check_dataset_and_coders(cfg.model.train_dataset, cfg.model.runners)
        if isinstance(cfg.model.val_dataset, Dataset):
            check_dataset_and_coders(cfg.model.val_dataset, cfg.model.runners)
        else:
            for val_dataset in cfg.model.val_dataset:
                check_dataset_and_coders(val_dataset, cfg.model.runners)
    except CheckingException as exc:
        log(str(exc), color="red")
        exit(1)

    if cfg.model.visualize_embeddings:

        tb_projector = projector.ProjectorConfig()

        for sequence in cfg.model.visualize_embeddings:
            # TODO this check should be done when abstract class of embedded
            # sequences will be created, not only EmbeddedFactorSequence
            if not isinstance(sequence, EmbeddedFactorSequence):
                raise ValueError("Visualization must be embedded sequence.")
            sequence.tb_embedding_visualization(cfg.model.output, tb_projector)

        summary_writer = tf.summary.FileWriter(cfg.model.output)
        projector.visualize_embeddings(summary_writer, tb_projector)

    Logging.print_header(cfg.model.name, cfg.args.output)

    # runners_batch_size must be set to avoid problems on GPU
    if cfg.model.runners_batch_size is None:
        cfg.model.runners_batch_size = cfg.model.batch_size

    training_loop(
        tf_manager=cfg.model.tf_manager,
        epochs=cfg.model.epochs,
        trainer=cfg.model.trainer,
        batch_size=cfg.model.batch_size,
        log_directory=cfg.model.output,
        evaluators=cfg.model.evaluation,
        runners=cfg.model.runners,
        train_dataset=cfg.model.train_dataset,
        val_dataset=cfg.model.val_dataset,
        test_datasets=cfg.model.test_datasets,
        logging_period=cfg.model.logging_period,
        validation_period=cfg.model.validation_period,
        val_preview_input_series=cfg.model.val_preview_input_series,
        val_preview_output_series=cfg.model.val_preview_output_series,
        val_preview_num_examples=cfg.model.val_preview_num_examples,
        postprocess=cfg.model.postprocess,
        train_start_offset=cfg.model.train_start_offset,
        runners_batch_size=cfg.model.runners_batch_size,
        initial_variables=cfg.model.initial_variables)
Exemplo n.º 7
0
def main():
    if len(sys.argv) != 2:
        print("Usage: train.py <ini_file>")
        exit(1)

    args = create_config(sys.argv[1])

    print("")

    #pylint: disable=no-member,broad-except
    if args.random_seed is not None:
        tf.set_random_seed(args.random_seed)

    if os.path.isdir(args.output) and \
            os.path.exists(os.path.join(args.output, "experiment.ini")):
        if args.overwrite_output_dir:
            # we do not want to delete the directory contents
            log("Directory with experiment.ini '{}' exists, "
                "overwriting enabled, proceeding."
                .format(args.output))
        else:
            log("Directory with experiment.ini '{}' exists, "
                "overwriting disabled."
                .format(args.output), color='red')
            exit(1)

    try:
        check_dataset_and_coders(args.train_dataset,
                                 args.encoders + [args.decoder])
        check_dataset_and_coders(args.val_dataset,
                                 args.encoders + [args.decoder])
        for test in args.test_datasets:
            check_dataset_and_coders(test, args.encoders)
    except Exception as exc:
        log(exc.message, color='red')
        exit(1)

    if not os.path.isdir(args.output):
        try:
            os.mkdir(args.output)
        except Exception as exc:
            log("Failed to create experiment directory: {}. Exception: {}"
                .format(args.output, exc), color='red')
            exit(1)

    log_file = "{}/experiment.log".format(args.output)
    ini_file = "{}/experiment.ini".format(args.output)
    git_commit_file = "{}/git_commit".format(args.output)
    git_diff_file = "{}/git_diff".format(args.output)
    variables_file_prefix = "{}/variables.data".format(args.output)

    cont_index = 0

    while (os.path.exists(log_file)
           or os.path.exists(ini_file)
           or os.path.exists(git_commit_file)
           or os.path.exists(git_diff_file)
           or os.path.exists(variables_file_prefix)
           or os.path.exists("{}.0".format(variables_file_prefix))):
        cont_index += 1

        log_file = "{}/experiment.log.cont-{}".format(args.output, cont_index)
        ini_file = "{}/experiment.ini.cont-{}".format(args.output, cont_index)
        git_commit_file = "{}/git_commit.cont-{}".format(
            args.output, cont_index)
        git_diff_file = "{}/git_diff.cont-{}".format(args.output, cont_index)
        variables_file_prefix = "{}/variables.data.cont-{}".format(
            args.output, cont_index)

    copyfile(sys.argv[1], ini_file)
    Logging.set_log_file(log_file)
    Logging.print_header(args.name)

    os.system("git log -1 --format=%H > {}".format(git_commit_file))
    os.system("git --no-pager diff --color=always > {}".format(git_diff_file))

    link_best_vars = "{}.best".format(variables_file_prefix)

    sess, saver = initialize_tf(args.initial_variables, args.threads)
    training_loop(sess, saver, args.epochs, args.trainer,
                  args.encoders + [args.decoder], args.decoder,
                  args.batch_size, args.train_dataset, args.val_dataset,
                  args.output, args.evaluation, args.runner,
                  test_datasets=args.test_datasets,
                  save_n_best_vars=args.save_n_best,
                  link_best_vars=link_best_vars,
                  vars_prefix=variables_file_prefix,
                  logging_period=args.logging_period,
                  validation_period=args.validation_period,
                  postprocess=args.postprocess,
                  minimize_metric=args.minimize)
Exemplo n.º 8
0
def main():
    if len(sys.argv) != 2:
        print("Usage: train.py <ini_file>")
        exit(1)

    args = create_config(sys.argv[1])

    print("")

    #pylint: disable=no-member,broad-except
    if args.random_seed is not None:
        tf.set_random_seed(args.random_seed)

    if os.path.isdir(args.output) and \
            os.path.exists(os.path.join(args.output, "experiment.ini")):
        if args.overwrite_output_dir:
            # we do not want to delete the directory contents
            log("Directory with experiment.ini '{}' exists, "
                "overwriting enabled, proceeding.".format(args.output))
        else:
            log("Directory with experiment.ini '{}' exists, "
                "overwriting disabled.".format(args.output),
                color='red')
            exit(1)

    try:
        check_dataset_and_coders(args.train_dataset,
                                 args.encoders + [args.decoder])
        check_dataset_and_coders(args.val_dataset,
                                 args.encoders + [args.decoder])
        for test in args.test_datasets:
            check_dataset_and_coders(test, args.encoders)
    except Exception as exc:
        log(exc.message, color='red')
        exit(1)

    if not os.path.isdir(args.output):
        try:
            os.mkdir(args.output)
        except Exception as exc:
            log("Failed to create experiment directory: {}. Exception: {}".
                format(args.output, exc),
                color='red')
            exit(1)

    log_file = "{}/experiment.log".format(args.output)
    ini_file = "{}/experiment.ini".format(args.output)
    git_commit_file = "{}/git_commit".format(args.output)
    git_diff_file = "{}/git_diff".format(args.output)
    variables_file_prefix = "{}/variables.data".format(args.output)

    cont_index = 0

    while (os.path.exists(log_file) or os.path.exists(ini_file)
           or os.path.exists(git_commit_file) or os.path.exists(git_diff_file)
           or os.path.exists(variables_file_prefix)
           or os.path.exists("{}.0".format(variables_file_prefix))):
        cont_index += 1

        log_file = "{}/experiment.log.cont-{}".format(args.output, cont_index)
        ini_file = "{}/experiment.ini.cont-{}".format(args.output, cont_index)
        git_commit_file = "{}/git_commit.cont-{}".format(
            args.output, cont_index)
        git_diff_file = "{}/git_diff.cont-{}".format(args.output, cont_index)
        variables_file_prefix = "{}/variables.data.cont-{}".format(
            args.output, cont_index)

    copyfile(sys.argv[1], ini_file)
    Logging.set_log_file(log_file)
    Logging.print_header(args.name)

    # this points inside the neuralmonkey/ dir inside the repo, but
    # it does not matter for git.
    repodir = os.path.dirname(os.path.realpath(__file__))

    os.system("cd {}; git log -1 --format=%H > {}".format(
        repodir, git_commit_file))

    os.system("cd {}; git --no-pager diff --color=always > {}".format(
        repodir, git_diff_file))

    link_best_vars = "{}.best".format(variables_file_prefix)

    sess, saver = initialize_tf(args.initial_variables, args.threads)
    training_loop(sess,
                  saver,
                  args.epochs,
                  args.trainer,
                  args.encoders + [args.decoder],
                  args.decoder,
                  args.batch_size,
                  args.train_dataset,
                  args.val_dataset,
                  args.output,
                  args.evaluation,
                  args.runner,
                  test_datasets=args.test_datasets,
                  save_n_best_vars=args.save_n_best,
                  link_best_vars=link_best_vars,
                  vars_prefix=variables_file_prefix,
                  logging_period=args.logging_period,
                  validation_period=args.validation_period,
                  postprocess=args.postprocess,
                  minimize_metric=args.minimize)
Exemplo n.º 9
0
def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('config',
                        metavar='INI-FILE',
                        help='the configuration file for the experiment')
    parser.add_argument('-s',
                        '--set',
                        type=str,
                        metavar='SETTING',
                        action='append',
                        dest='config_changes',
                        help='override an option in the configuration; the '
                        'syntax is [section.]option=value')
    args = parser.parse_args()

    # define valid parameters and defaults
    cfg = create_config()
    # load the params from the config file, getting also the simple arguments
    cfg.load_file(args.config, changes=args.config_changes)
    # various things like randseed or summarywriter should be set up here
    # so that graph building can be recorded
    # build all the objects specified in the config

    if cfg.args.random_seed is None:
        cfg.args.random_seed = 2574600
    random.seed(cfg.args.random_seed)
    np.random.seed(cfg.args.random_seed)
    tf.set_random_seed(cfg.args.random_seed)

    # pylint: disable=no-member
    if (os.path.isdir(cfg.args.output) and os.path.exists(
            os.path.join(cfg.args.output, "experiment.ini"))):
        if cfg.args.overwrite_output_dir:
            # we do not want to delete the directory contents
            log("Directory with experiment.ini '{}' exists, "
                "overwriting enabled, proceeding.".format(cfg.args.output))
        else:
            log("Directory with experiment.ini '{}' exists, "
                "overwriting disabled.".format(cfg.args.output),
                color='red')
            exit(1)

    # pylint: disable=broad-except
    if not os.path.isdir(cfg.args.output):
        try:
            os.mkdir(cfg.args.output)
        except Exception as exc:
            log("Failed to create experiment directory: {}. Exception: {}".
                format(cfg.args.output, exc),
                color='red')
            exit(1)

    args_file = "{}/args".format(cfg.args.output)
    log_file = "{}/experiment.log".format(cfg.args.output)
    ini_file = "{}/experiment.ini".format(cfg.args.output)
    orig_ini_file = "{}/original.ini".format(cfg.args.output)
    git_commit_file = "{}/git_commit".format(cfg.args.output)
    git_diff_file = "{}/git_diff".format(cfg.args.output)
    variables_file_prefix = "{}/variables.data".format(cfg.args.output)

    cont_index = 0

    while (os.path.exists(log_file) or os.path.exists(ini_file)
           or os.path.exists(git_commit_file) or os.path.exists(git_diff_file)
           or os.path.exists(variables_file_prefix)
           or os.path.exists("{}.0".format(variables_file_prefix))):
        cont_index += 1

        args_file = "{}/args.cont-{}".format(cfg.args.output, cont_index)
        log_file = "{}/experiment.log.cont-{}".format(cfg.args.output,
                                                      cont_index)
        ini_file = "{}/experiment.ini.cont-{}".format(cfg.args.output,
                                                      cont_index)
        orig_ini_file = "{}/original.ini.cont-{}".format(
            cfg.args.output, cont_index)
        git_commit_file = "{}/git_commit.cont-{}".format(
            cfg.args.output, cont_index)
        git_diff_file = "{}/git_diff.cont-{}".format(cfg.args.output,
                                                     cont_index)
        variables_file_prefix = "{}/variables.data.cont-{}".format(
            cfg.args.output, cont_index)

    with open(args_file, 'w') as file:
        print(' '.join(shlex.quote(a) for a in sys.argv), file=file)

    cfg.save_file(ini_file)
    copyfile(args.config, orig_ini_file)

    Logging.set_log_file(log_file)

    # this points inside the neuralmonkey/ dir inside the repo, but
    # it does not matter for git.
    repodir = os.path.dirname(os.path.realpath(__file__))

    # we need to execute the git log command in subshell, because if
    # the log file is specified via relative path, we need to do the
    # redirection of the git-log output to the right file
    os.system("(cd {}; git log -1 --format=%H) > {}".format(
        repodir, git_commit_file))

    os.system("(cd {}; git --no-pager diff --color=always) > {}".format(
        repodir, git_diff_file))

    cfg.build_model(warn_unused=True)

    cfg.model.tf_manager.init_saving(variables_file_prefix)

    try:
        check_dataset_and_coders(cfg.model.train_dataset, cfg.model.runners)
        check_dataset_and_coders(cfg.model.val_dataset, cfg.model.runners)
    except CheckingException as exc:
        log(str(exc), color='red')
        exit(1)

    Logging.print_header(cfg.model.name)

    # runners_batch_size must be set to avoid problems on GPU
    if cfg.model.runners_batch_size is None:
        cfg.model.runners_batch_size = cfg.model.batch_size

    training_loop(
        tf_manager=cfg.model.tf_manager,
        epochs=cfg.model.epochs,
        trainer=cfg.model.trainer,
        batch_size=cfg.model.batch_size,
        train_dataset=cfg.model.train_dataset,
        val_dataset=cfg.model.val_dataset,
        log_directory=cfg.model.output,
        evaluators=cfg.model.evaluation,
        runners=cfg.model.runners,
        test_datasets=cfg.model.test_datasets,
        logging_period=cfg.model.logging_period,
        validation_period=cfg.model.validation_period,
        val_preview_input_series=cfg.model.val_preview_input_series,
        val_preview_output_series=cfg.model.val_preview_output_series,
        val_preview_num_examples=cfg.model.val_preview_num_examples,
        postprocess=cfg.model.postprocess,
        train_start_offset=cfg.model.train_start_offset,
        runners_batch_size=cfg.model.runners_batch_size,
        initial_variables=cfg.model.initial_variables)