예제 #1
0
def _prepare_forwarding():
    assert engine
    assert config
    # Should already be set via setTargetMode().
    assert config.list('extract') == [
        "posteriors"
    ], ("You need to have extract = posteriors in your RETURNN config. You have: %s"
        % config.list('extract'))

    # Load network.
    engine.init_network_from_config(config)

    # Copy over net params.
    if BackendEngine.is_theano_selected():
        engine.devices[0].prepare(engine.network)
예제 #2
0
def execute_main_task():
    """
  Executes the main task (via config ``task`` option).
  """
    from returnn.util.basic import hms_fraction
    start_time = time.time()
    task = config.value('task', 'train')
    if config.is_true("dry_run"):
        print("Dry run, will not save anything.", file=log.v1)
    if task == 'train':
        assert train_data.have_seqs(
        ), "no train files specified, check train option: %s" % config.value(
            'train', None)
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        engine.train()
    elif task == "eval":
        epoch = config.int("epoch", -1)
        load_epoch = config.int("load_epoch", -1)
        if epoch >= 0:
            assert (load_epoch < 0) or (
                load_epoch == epoch), "epoch and load_epoch have to match"
            engine.epoch = epoch
            config.set('load_epoch', engine.epoch)
        else:
            assert load_epoch >= 0, "specify epoch or load_epoch"
            engine.epoch = load_epoch
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        print("Evaluate epoch", engine.epoch, file=log.v4)
        engine.eval_model(
            output_file=config.value("eval_output_file", None),
            output_per_seq_file=config.value("eval_output_file_per_seq", None),
            loss_name=config.value("loss_name", None),
            output_per_seq_format=config.list("output_per_seq_format",
                                              ["score"]),
            output_per_seq_file_format=config.value(
                "output_per_seq_file_format", "txt"))
    elif task in ['forward', 'hpx']:
        assert eval_data is not None, 'no eval data provided'
        combine_labels = config.value('combine_labels', '')
        engine.use_search_flag = config.bool("forward_use_search", False)
        if config.has("epoch"):
            config.set('load_epoch', config.int('epoch', 0))
        engine.init_network_from_config(config)
        output_file = config.value('output_file',
                                   'dump-fwd-epoch-%i.hdf' % engine.epoch)
        engine.forward_to_hdf(data=eval_data,
                              output_file=output_file,
                              combine_labels=combine_labels,
                              batch_size=config.int('forward_batch_size', 0))
    elif task == "search":
        engine.use_search_flag = True
        engine.use_eval_flag = config.bool("search_do_eval", True)
        engine.init_network_from_config(config)
        if config.value("search_data", "eval") in ["train", "dev", "eval"]:
            data = {
                "train": train_data,
                "dev": dev_data,
                "eval": eval_data
            }[config.value("search_data", "eval")]
            assert data, "set search_data"
        else:
            data = init_dataset(config.opt_typed_value("search_data"))
        engine.search(
            data,
            do_eval=config.bool("search_do_eval", True),
            output_layer_names=config.typed_value("search_output_layer",
                                                  "output"),
            output_file=config.value("search_output_file", ""),
            output_file_format=config.value("search_output_file_format",
                                            "txt"))
    elif task == 'compute_priors':
        assert train_data is not None, 'train data for priors should be provided'
        engine.init_network_from_config(config)
        engine.compute_priors(dataset=train_data, config=config)
    elif task == 'theano_graph':
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.printing
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.compile.io
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.compile.function_module
        engine.start_epoch = 1
        engine.init_network_from_config(config)
        for task in config.list('theano_graph.task', ['train']):
            func = engine.devices[-1].get_compute_func(task)
            prefix = config.value("theano_graph.prefix", "current") + ".task"
            print("dumping to %s.* ..." % prefix, file=log.v1)
            theano.printing.debugprint(func,
                                       file=open(
                                           "%s.optimized_func.txt" % prefix,
                                           "w"))
            assert isinstance(func.maker,
                              theano.compile.function_module.FunctionMaker)
            for inp in func.maker.inputs:
                assert isinstance(inp, theano.compile.io.In)
                if inp.update:
                    theano.printing.debugprint(
                        inp.update,
                        file=open(
                            "%s.unoptimized.var_%s_update.txt" %
                            (prefix, inp.name), "w"))
            theano.printing.pydotprint(func,
                                       format='png',
                                       var_with_name_simple=True,
                                       outfile="%s.png" % prefix)
    elif task == 'analyze':  # anything based on the network + Device
        statistics = config.list('statistics', None)
        engine.init_network_from_config(config)
        engine.analyze(data=eval_data or dev_data, statistics=statistics)
    elif task == "analyze_data":  # anything just based on the data
        analyze_data(config)
    elif task == "classify":
        assert eval_data is not None, 'no eval data provided'
        assert config.has('label_file'), 'no output file provided'
        label_file = config.value('label_file', '')
        engine.init_network_from_config(config)
        engine.classify(eval_data, label_file)
    elif task == "hyper_param_tuning":
        import returnn.tf.hyper_param_tuning
        tuner = returnn.tf.hyper_param_tuning.Optimization(
            config=config, train_data=train_data)
        tuner.work()
    elif task == "cleanup_old_models":
        engine.cleanup_old_models(ask_for_confirmation=True)
    elif task == "daemon":
        engine.init_network_from_config(config)
        engine.daemon(config)
    elif task == "server":
        print("Server Initiating", file=log.v1)
        server.run()
    elif task == "search_server":
        engine.use_search_flag = True
        engine.init_network_from_config(config)
        engine.web_server(port=config.int("web_server_port", 12380))
    elif task.startswith("config:"):
        action = config.typed_dict[task[len("config:"):]]
        print("Task: %r" % action, file=log.v1)
        assert callable(action)
        action()
    elif task.startswith("optional-config:"):
        action = config.typed_dict.get(task[len("optional-config:"):], None)
        if action is None:
            print("No task found for %r, so just quitting." % task,
                  file=log.v1)
        else:
            print("Task: %r" % action, file=log.v1)
            assert callable(action)
            action()
    elif task == "nop":
        print("Task: No-operation", file=log.v1)
    elif task == "nop_init_net_train":
        print(
            "Task: No-operation, despite initializing the network (for training)",
            file=log.v1)
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
    elif task == "initialize_model":
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        engine.save_model(config.value('model', 'dummy'))
    else:
        assert False, "unknown task: %s" % task

    print(("elapsed: %s" % hms_fraction(time.time() - start_time)),
          file=log.v3)