def _prepare_forwarding(): assert engine assert config # Should already be set via setTargetMode(). assert config.list('extract') == [ "posteriors" ], ("You need to have extract = posteriors in your RETURNN config. You have: %s" % config.list('extract')) # Load network. engine.init_network_from_config(config) # Copy over net params. if BackendEngine.is_theano_selected(): engine.devices[0].prepare(engine.network)
def execute_main_task(): """ Executes the main task (via config ``task`` option). """ from returnn.util.basic import hms_fraction start_time = time.time() task = config.value('task', 'train') if config.is_true("dry_run"): print("Dry run, will not save anything.", file=log.v1) if task == 'train': assert train_data.have_seqs( ), "no train files specified, check train option: %s" % config.value( 'train', None) engine.init_train_from_config(config, train_data, dev_data, eval_data) engine.train() elif task == "eval": epoch = config.int("epoch", -1) load_epoch = config.int("load_epoch", -1) if epoch >= 0: assert (load_epoch < 0) or ( load_epoch == epoch), "epoch and load_epoch have to match" engine.epoch = epoch config.set('load_epoch', engine.epoch) else: assert load_epoch >= 0, "specify epoch or load_epoch" engine.epoch = load_epoch engine.init_train_from_config(config, train_data, dev_data, eval_data) print("Evaluate epoch", engine.epoch, file=log.v4) engine.eval_model( output_file=config.value("eval_output_file", None), output_per_seq_file=config.value("eval_output_file_per_seq", None), loss_name=config.value("loss_name", None), output_per_seq_format=config.list("output_per_seq_format", ["score"]), output_per_seq_file_format=config.value( "output_per_seq_file_format", "txt")) elif task in ['forward', 'hpx']: assert eval_data is not None, 'no eval data provided' combine_labels = config.value('combine_labels', '') engine.use_search_flag = config.bool("forward_use_search", False) if config.has("epoch"): config.set('load_epoch', config.int('epoch', 0)) engine.init_network_from_config(config) output_file = config.value('output_file', 'dump-fwd-epoch-%i.hdf' % engine.epoch) engine.forward_to_hdf(data=eval_data, output_file=output_file, combine_labels=combine_labels, batch_size=config.int('forward_batch_size', 0)) elif task == "search": engine.use_search_flag = True engine.use_eval_flag = config.bool("search_do_eval", True) engine.init_network_from_config(config) if config.value("search_data", "eval") in ["train", "dev", "eval"]: data = { "train": train_data, "dev": dev_data, "eval": eval_data }[config.value("search_data", "eval")] assert data, "set search_data" else: data = init_dataset(config.opt_typed_value("search_data")) engine.search( data, do_eval=config.bool("search_do_eval", True), output_layer_names=config.typed_value("search_output_layer", "output"), output_file=config.value("search_output_file", ""), output_file_format=config.value("search_output_file_format", "txt")) elif task == 'compute_priors': assert train_data is not None, 'train data for priors should be provided' engine.init_network_from_config(config) engine.compute_priors(dataset=train_data, config=config) elif task == 'theano_graph': # noinspection PyPackageRequirements,PyUnresolvedReferences import theano.printing # noinspection PyPackageRequirements,PyUnresolvedReferences import theano.compile.io # noinspection PyPackageRequirements,PyUnresolvedReferences import theano.compile.function_module engine.start_epoch = 1 engine.init_network_from_config(config) for task in config.list('theano_graph.task', ['train']): func = engine.devices[-1].get_compute_func(task) prefix = config.value("theano_graph.prefix", "current") + ".task" print("dumping to %s.* ..." % prefix, file=log.v1) theano.printing.debugprint(func, file=open( "%s.optimized_func.txt" % prefix, "w")) assert isinstance(func.maker, theano.compile.function_module.FunctionMaker) for inp in func.maker.inputs: assert isinstance(inp, theano.compile.io.In) if inp.update: theano.printing.debugprint( inp.update, file=open( "%s.unoptimized.var_%s_update.txt" % (prefix, inp.name), "w")) theano.printing.pydotprint(func, format='png', var_with_name_simple=True, outfile="%s.png" % prefix) elif task == 'analyze': # anything based on the network + Device statistics = config.list('statistics', None) engine.init_network_from_config(config) engine.analyze(data=eval_data or dev_data, statistics=statistics) elif task == "analyze_data": # anything just based on the data analyze_data(config) elif task == "classify": assert eval_data is not None, 'no eval data provided' assert config.has('label_file'), 'no output file provided' label_file = config.value('label_file', '') engine.init_network_from_config(config) engine.classify(eval_data, label_file) elif task == "hyper_param_tuning": import returnn.tf.hyper_param_tuning tuner = returnn.tf.hyper_param_tuning.Optimization( config=config, train_data=train_data) tuner.work() elif task == "cleanup_old_models": engine.cleanup_old_models(ask_for_confirmation=True) elif task == "daemon": engine.init_network_from_config(config) engine.daemon(config) elif task == "server": print("Server Initiating", file=log.v1) server.run() elif task == "search_server": engine.use_search_flag = True engine.init_network_from_config(config) engine.web_server(port=config.int("web_server_port", 12380)) elif task.startswith("config:"): action = config.typed_dict[task[len("config:"):]] print("Task: %r" % action, file=log.v1) assert callable(action) action() elif task.startswith("optional-config:"): action = config.typed_dict.get(task[len("optional-config:"):], None) if action is None: print("No task found for %r, so just quitting." % task, file=log.v1) else: print("Task: %r" % action, file=log.v1) assert callable(action) action() elif task == "nop": print("Task: No-operation", file=log.v1) elif task == "nop_init_net_train": print( "Task: No-operation, despite initializing the network (for training)", file=log.v1) engine.init_train_from_config(config, train_data, dev_data, eval_data) elif task == "initialize_model": engine.init_train_from_config(config, train_data, dev_data, eval_data) engine.save_model(config.value('model', 'dummy')) else: assert False, "unknown task: %s" % task print(("elapsed: %s" % hms_fraction(time.time() - start_time)), file=log.v3)