def main(argv): """ Main entry. """ arg_parser = argparse.ArgumentParser( description='Dump raw strings from dataset. Same format as in search.') arg_parser.add_argument( '--config', help="filename to config-file. will use dataset 'eval' from it") arg_parser.add_argument("--dataset", help="dataset, overwriting config") arg_parser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') arg_parser.add_argument('--endseq', type=int, default=-1, help='end seq idx (inclusive) or -1 (default: -1)') arg_parser.add_argument( "--key", default="raw", help="data-key, e.g. 'data' or 'classes'. (default: 'raw')") arg_parser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)") arg_parser.add_argument("--out", required=True, help="out-file. py-format as in task=search") args = arg_parser.parse_args(argv[1:]) assert args.config or args.dataset init(config_filename=args.config, log_verbosity=args.verbosity) if args.dataset: dataset = init_dataset(args.dataset) elif config.value("dump_data", "eval") in ["train", "dev", "eval"]: dataset = init_dataset( config.opt_typed_value(config.value("search_data", "eval"))) else: dataset = init_dataset(config.opt_typed_value("wer_data")) dataset.init_seq_order(epoch=1) try: with generic_open(args.out, "w") as output_file: refs = get_raw_strings(dataset=dataset, options=args) output_file.write("{\n") for seq_tag, ref in refs: output_file.write("%r: %r,\n" % (seq_tag, ref)) output_file.write("}\n") print("Done. Wrote to %r." % args.out) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def load_data(config, cache_byte_size, files_config_key, **kwargs): """ :param Config config: :param int cache_byte_size: :param str files_config_key: such as "train" or "dev" :param kwargs: passed on to init_dataset() or init_dataset_via_str() :rtype: (Dataset,int) :returns the dataset, and the cache byte size left over if we cache the whole dataset. """ if not config.bool_or_other(files_config_key, None): return None, 0 kwargs = kwargs.copy() kwargs.setdefault("name", files_config_key) if config.is_typed(files_config_key) and isinstance( config.typed_value(files_config_key), dict): config_opts = config.typed_value(files_config_key) assert isinstance(config_opts, dict) kwargs.update(config_opts) if 'cache_byte_size' not in config_opts: if kwargs.get('class', None) == 'HDFDataset': kwargs["cache_byte_size"] = cache_byte_size Dataset.kwargs_update_from_config(config, kwargs) data = init_dataset(kwargs) else: config_str = config.value(files_config_key, "") data = init_dataset_via_str(config_str, config=config, cache_byte_size=cache_byte_size, **kwargs) cache_leftover = 0 if isinstance(data, HDFDataset): cache_leftover = data.definite_cache_leftover return data, cache_leftover
def init(config_filename, cmd_line_opts, dataset_config_str): """ :param str config_filename: global config for CRNN :param list[str] cmd_line_opts: options for init_config method :param str dataset_config_str: dataset via init_dataset_via_str() """ rnn.init_better_exchook() rnn.init_thread_join_hack() if config_filename: rnn.init_config(config_filename, cmd_line_opts) rnn.init_log() else: log.initialize(verbosity=[5]) print("Returnn hdf_dump starting up.", file=log.v3) rnn.init_faulthandler() if config_filename: rnn.init_data() rnn.print_task_properties() assert isinstance(rnn.train_data, Dataset) dataset = rnn.train_data else: assert dataset_config_str dataset = init_dataset(dataset_config_str) print("Source dataset:", dataset.len_info(), file=log.v3) return dataset
def init(config_str, config_dataset, verbosity): """ :param str config_str: either filename to config-file, or dict for dataset :param str|None config_dataset: :param int verbosity: """ global dataset rnn.init_better_exchook() rnn.init_thread_join_hack() dataset_dict = None config_filename = None if config_str.strip().startswith("{"): print("Using dataset %s." % config_str) dataset_dict = eval(config_str.strip()) elif config_str.endswith(".hdf"): dataset_dict = {"class": "HDFDataset", "files": [config_str]} print("Using dataset %r." % dataset_dict) assert os.path.exists(config_str) else: config_filename = config_str print("Using config file %r." % config_filename) assert os.path.exists(config_filename) rnn.init_config(config_filename=config_filename, default_config={"cache_size": "0"}) global config config = rnn.config config.set("log", None) config.set("log_verbosity", verbosity) if dataset_dict: assert not config_dataset dataset = init_dataset(dataset_dict) elif config_dataset and config_dataset != "train": print("Use dataset %r from config." % config_dataset) dataset = init_dataset("config:%s" % config_dataset) else: print("Use train dataset from config.") assert config.value("train", None) dataset = init_dataset("config:train") rnn.init_log() print("Returnn dump-dataset starting up.", file=log.v2) rnn.returnn_greeting() rnn.init_faulthandler() rnn.init_config_json_network() print("Dataset:", file=log.v2) print(" input:", dataset.num_inputs, "x", dataset.window, file=log.v2) print(" output:", dataset.num_outputs, file=log.v2) print(" ", dataset.len_info() or "no info", file=log.v2)
def init(config_filename, command_line_options, args): """ :param str config_filename: :param list[str] command_line_options: :param args: argparse.Namespace """ global config, engine, dataset rnn.init(config_filename=config_filename, command_line_options=command_line_options, config_updates={ "log": None, "need_data": False }, extra_greeting="RETURNN dump-forward starting up.") config = rnn.config engine = rnn.engine dataset_str = args.dataset if dataset_str in {"train", "dev", "eval", "search_data"}: dataset_str = "config:%s" % dataset_str extra_dataset_kwargs = {} if args.reset_partition_epoch: print("NOTE: We are resetting partition epoch to %i." % (args.reset_partition_epoch, )) extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch if args.reset_seq_ordering: print("NOTE: We will use %r seq ordering." % (args.reset_seq_ordering, )) extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering if args.reset_epoch_wise_filter: extra_dataset_kwargs["epoch_wise_filter"] = eval( args.reset_epoch_wise_filter) dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs) if hasattr(dataset, "epoch_wise_filter") and args.reset_epoch_wise_filter is None: if dataset.epoch_wise_filter: print("NOTE: Resetting epoch_wise_filter to None.") dataset.epoch_wise_filter = None if args.reset_partition_epoch: assert dataset.partition_epoch == args.reset_partition_epoch if args.reset_seq_ordering: assert dataset.seq_ordering == args.reset_seq_ordering config.set("task", "eval") if args.load: config.set("load", args.load) epoch, model_epoch_filename = Engine.get_epoch_model(config) engine.pretrain = pretrain_from_config(config) engine.custom_get_net_dict = config.typed_value("get_network") net_dict = engine.get_net_dict_for_epoch(epoch) engine.make_tf_session() engine.network = TFNetwork(name="root") engine.network.construct_layer(net_dict, args.layer) print("Load model:", model_epoch_filename) engine.network.load_params_from_file(model_epoch_filename, session=engine.tf_session)
def benchmark(lstm_unit, use_gpu): """ :param str lstm_unit: e.g. "LSTMBlock", one of LstmCellTypes :param bool use_gpu: :return: runtime in seconds of the training itself, excluding initialization :rtype: float """ device = {True: "GPU", False: "CPU"}[use_gpu] key = "%s:%s" % (device, lstm_unit) print(">>> Start benchmark for %s." % key) config = Config() config.update(make_config_dict(lstm_unit=lstm_unit, use_gpu=use_gpu)) dataset_kwargs = config.typed_value("train") Dataset.kwargs_update_from_config(config, dataset_kwargs) dataset = init_dataset(dataset_kwargs) engine = Engine(config=config) engine.init_train_from_config(config=config, train_data=dataset) print(">>> Start training now for %s." % key) start_time = time.time() engine.train() runtime = time.time() - start_time print(">>> Runtime of %s: %s" % (key, hms_fraction(runtime))) engine.finalize() return runtime
def execute_main_task(): """ Executes the main task (via config ``task`` option). """ from returnn.util.basic import hms_fraction start_time = time.time() task = config.value('task', 'train') if config.is_true("dry_run"): print("Dry run, will not save anything.", file=log.v1) if task == 'train': assert train_data.have_seqs( ), "no train files specified, check train option: %s" % config.value( 'train', None) engine.init_train_from_config(config, train_data, dev_data, eval_data) engine.train() elif task == "eval": epoch = config.int("epoch", -1) load_epoch = config.int("load_epoch", -1) if epoch >= 0: assert (load_epoch < 0) or ( load_epoch == epoch), "epoch and load_epoch have to match" engine.epoch = epoch config.set('load_epoch', engine.epoch) else: assert load_epoch >= 0, "specify epoch or load_epoch" engine.epoch = load_epoch engine.init_train_from_config(config, train_data, dev_data, eval_data) print("Evaluate epoch", engine.epoch, file=log.v4) engine.eval_model( output_file=config.value("eval_output_file", None), output_per_seq_file=config.value("eval_output_file_per_seq", None), loss_name=config.value("loss_name", None), output_per_seq_format=config.list("output_per_seq_format", ["score"]), output_per_seq_file_format=config.value( "output_per_seq_file_format", "txt")) elif task in ['forward', 'hpx']: assert eval_data is not None, 'no eval data provided' combine_labels = config.value('combine_labels', '') engine.use_search_flag = config.bool("forward_use_search", False) if config.has("epoch"): config.set('load_epoch', config.int('epoch', 0)) engine.init_network_from_config(config) output_file = config.value('output_file', 'dump-fwd-epoch-%i.hdf' % engine.epoch) engine.forward_to_hdf(data=eval_data, output_file=output_file, combine_labels=combine_labels, batch_size=config.int('forward_batch_size', 0)) elif task == "search": engine.use_search_flag = True engine.use_eval_flag = config.bool("search_do_eval", True) engine.init_network_from_config(config) if config.value("search_data", "eval") in ["train", "dev", "eval"]: data = { "train": train_data, "dev": dev_data, "eval": eval_data }[config.value("search_data", "eval")] assert data, "set search_data" else: data = init_dataset(config.opt_typed_value("search_data")) engine.search( data, do_eval=config.bool("search_do_eval", True), output_layer_names=config.typed_value("search_output_layer", "output"), output_file=config.value("search_output_file", ""), output_file_format=config.value("search_output_file_format", "txt")) elif task == 'compute_priors': assert train_data is not None, 'train data for priors should be provided' engine.init_network_from_config(config) engine.compute_priors(dataset=train_data, config=config) elif task == 'theano_graph': # noinspection PyPackageRequirements,PyUnresolvedReferences import theano.printing # noinspection PyPackageRequirements,PyUnresolvedReferences import theano.compile.io # noinspection PyPackageRequirements,PyUnresolvedReferences import theano.compile.function_module engine.start_epoch = 1 engine.init_network_from_config(config) for task in config.list('theano_graph.task', ['train']): func = engine.devices[-1].get_compute_func(task) prefix = config.value("theano_graph.prefix", "current") + ".task" print("dumping to %s.* ..." % prefix, file=log.v1) theano.printing.debugprint(func, file=open( "%s.optimized_func.txt" % prefix, "w")) assert isinstance(func.maker, theano.compile.function_module.FunctionMaker) for inp in func.maker.inputs: assert isinstance(inp, theano.compile.io.In) if inp.update: theano.printing.debugprint( inp.update, file=open( "%s.unoptimized.var_%s_update.txt" % (prefix, inp.name), "w")) theano.printing.pydotprint(func, format='png', var_with_name_simple=True, outfile="%s.png" % prefix) elif task == 'analyze': # anything based on the network + Device statistics = config.list('statistics', None) engine.init_network_from_config(config) engine.analyze(data=eval_data or dev_data, statistics=statistics) elif task == "analyze_data": # anything just based on the data analyze_data(config) elif task == "classify": assert eval_data is not None, 'no eval data provided' assert config.has('label_file'), 'no output file provided' label_file = config.value('label_file', '') engine.init_network_from_config(config) engine.classify(eval_data, label_file) elif task == "hyper_param_tuning": import returnn.tf.hyper_param_tuning tuner = returnn.tf.hyper_param_tuning.Optimization( config=config, train_data=train_data) tuner.work() elif task == "cleanup_old_models": engine.cleanup_old_models(ask_for_confirmation=True) elif task == "daemon": engine.init_network_from_config(config) engine.daemon(config) elif task == "server": print("Server Initiating", file=log.v1) server.run() elif task == "search_server": engine.use_search_flag = True engine.init_network_from_config(config) engine.web_server(port=config.int("web_server_port", 12380)) elif task.startswith("config:"): action = config.typed_dict[task[len("config:"):]] print("Task: %r" % action, file=log.v1) assert callable(action) action() elif task.startswith("optional-config:"): action = config.typed_dict.get(task[len("optional-config:"):], None) if action is None: print("No task found for %r, so just quitting." % task, file=log.v1) else: print("Task: %r" % action, file=log.v1) assert callable(action) action() elif task == "nop": print("Task: No-operation", file=log.v1) elif task == "nop_init_net_train": print( "Task: No-operation, despite initializing the network (for training)", file=log.v1) engine.init_train_from_config(config, train_data, dev_data, eval_data) elif task == "initialize_model": engine.init_train_from_config(config, train_data, dev_data, eval_data) engine.save_model(config.value('model', 'dummy')) else: assert False, "unknown task: %s" % task print(("elapsed: %s" % hms_fraction(time.time() - start_time)), file=log.v3)
def main(argv): """ Main entry. """ argparser = argparse.ArgumentParser(description=__doc__) argparser.add_argument("config_file", type=str, help="RETURNN config, or model-dir") argparser.add_argument("--epoch", type=int) argparser.add_argument( '--data', default="train", help= "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'" ) argparser.add_argument('--do_search', default=False, action='store_true') argparser.add_argument('--beam_size', default=12, type=int) argparser.add_argument('--dump_dir', help="for npy or png") argparser.add_argument("--output_file", help="hdf") argparser.add_argument("--device", help="gpu or cpu (default: automatic)") argparser.add_argument("--layers", default=["att_weights"], action="append", help="Layer of subnet to grab") argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from; decoder") argparser.add_argument("--enc_layer", default="encoder") argparser.add_argument("--batch_size", type=int, default=5000) argparser.add_argument("--seq_list", default=[], action="append", help="predefined list of seqs") argparser.add_argument("--min_seq_len", default="0", help="can also be dict") argparser.add_argument("--num_seqs", default=-1, type=int, help="stop after this many seqs") argparser.add_argument("--output_format", default="npy", help="npy, png or hdf") argparser.add_argument("--dropout", default=None, type=float, help="if set, overwrites all dropout values") argparser.add_argument("--train_flag", action="store_true") argparser.add_argument("--reset_partition_epoch", type=int, default=1) argparser.add_argument("--reset_seq_ordering", default="sorted_reverse") argparser.add_argument("--reset_epoch_wise_filter", default=None) args = argparser.parse_args(argv[1:]) layers = args.layers assert isinstance(layers, list) config_fn = args.config_file explicit_model_dir = None if os.path.isdir(config_fn): # Assume we gave a model dir. explicit_model_dir = config_fn train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn train_log_dir_configs = sorted(glob(train_log_dir_config_pattern)) assert train_log_dir_configs config_fn = train_log_dir_configs[-1] print("Using this config via model dir:", config_fn) else: assert os.path.isfile(config_fn) model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1]) init_returnn(config_fn=config_fn, args=args) if explicit_model_dir: config.set( "model", "%s/%s" % (explicit_model_dir, os.path.basename(config.value('model', '')))) print("Model file prefix:", config.value('model', '')) if args.do_search: raise NotImplementedError min_seq_length = NumbersDict(eval(args.min_seq_len)) assert args.output_format in ["npy", "png", "hdf"] if args.output_format in ["npy", "png"]: assert args.dump_dir if not os.path.exists(args.dump_dir): os.makedirs(args.dump_dir) plt = ticker = None if args.output_format == "png": import matplotlib.pyplot as plt # need to import early? https://stackoverflow.com/a/45582103/133374 import matplotlib.ticker as ticker dataset_str = args.data if dataset_str in ["train", "dev", "eval"]: dataset_str = "config:%s" % dataset_str extra_dataset_kwargs = {} if args.reset_partition_epoch: print("NOTE: We are resetting partition epoch to %i." % (args.reset_partition_epoch, )) extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch if args.reset_seq_ordering: print("NOTE: We will use %r seq ordering." % (args.reset_seq_ordering, )) extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering if args.reset_epoch_wise_filter: extra_dataset_kwargs["epoch_wise_filter"] = eval( args.reset_epoch_wise_filter) dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs) if hasattr(dataset, "epoch_wise_filter") and args.reset_epoch_wise_filter is None: if dataset.epoch_wise_filter: print("NOTE: Resetting epoch_wise_filter to None.") dataset.epoch_wise_filter = None if args.reset_partition_epoch: assert dataset.partition_epoch == args.reset_partition_epoch if args.reset_seq_ordering: assert dataset.seq_ordering == args.reset_seq_ordering init_net(args, layers) network = rnn.engine.network hdf_writer = None if args.output_format == "hdf": assert args.output_file assert len(layers) == 1 sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0])) from returnn.datasets.hdf import SimpleHDFWriter hdf_writer = SimpleHDFWriter(filename=args.output_file, dim=sub_layer.output.dim, ndim=sub_layer.output.ndim) extra_fetches = { "output": network.layers[args.rec_layer].output.get_placeholder_as_batch_major(), "output_len": network.layers[ args.rec_layer].output.get_sequence_lengths(), # decoder length "encoder_len": network.layers[ args.enc_layer].output.get_sequence_lengths(), # encoder length "seq_idx": network.get_extern_data("seq_idx"), "seq_tag": network.get_extern_data("seq_tag"), "target_data": network.get_extern_data(network.extern_data.default_input), "target_classes": network.get_extern_data(network.extern_data.default_target), } for layer in layers: sub_layer = rnn.engine.network.get_layer("%s/%s" % (args.rec_layer, layer)) extra_fetches[ "rec_%s" % layer] = sub_layer.output.get_placeholder_as_batch_major() dataset.init_seq_order( epoch=1, seq_list=args.seq_list or None) # use always epoch 1, such that we have same seqs dataset_batch = dataset.generate_batches( recurrent_net=network.recurrent, batch_size=args.batch_size, max_seqs=rnn.engine.max_seqs, max_seq_length=sys.maxsize, min_seq_length=min_seq_length, max_total_num_seqs=args.num_seqs, used_data_keys=network.used_data_keys) stats = {layer: Stats() for layer in layers} # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs): """ :param list[int] seq_idx: len is n_batch :param list[str] seq_tag: len is n_batch :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...) :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...) :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...) :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,) :param numpy.ndarray encoder_len: encoder seq len, shape (B,) :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in """ n_batch = len(seq_idx) for i in range(n_batch): # noinspection PyShadowingNames for layer in layers: att_weights = kwargs["rec_%s" % layer][i] stats[layer].collect(att_weights.flatten()) if args.output_format == "npy": data = {} for i in range(n_batch): data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i], 'output_len': output_len[i], 'encoder_len': encoder_len[i], } # noinspection PyShadowingNames for layer in [("rec_%s" % layer) for layer in layers]: assert layer in kwargs out = kwargs[layer][i] assert out.ndim >= 2 assert out.shape[0] >= output_len[i] and out.shape[ 1] >= encoder_len[i] data[i][layer] = out[:output_len[i], :encoder_len[i]] fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % ( model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1]) np.save(fname, data) elif args.output_format == "png": for i in range(n_batch): # noinspection PyShadowingNames for layer in layers: extra_postfix = "" if args.dropout is not None: extra_postfix += "_dropout%.2f" % args.dropout elif args.train_flag: extra_postfix += "_train" fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % ( model_name, rnn.engine.epoch, seq_idx[i], layer, extra_postfix) att_weights = kwargs["rec_%s" % layer][i] att_weights = att_weights.squeeze(axis=2) # (out,enc) assert att_weights.shape[0] >= output_len[ i] and att_weights.shape[1] >= encoder_len[i] att_weights = att_weights[:output_len[i], :encoder_len[i]] print("Seq %i, %s: Dump att weights with shape %r to: %s" % (seq_idx[i], seq_tag[i], att_weights.shape, fname)) plt.matshow(att_weights) title = seq_tag[i] if dataset.can_serialize_data( network.extern_data.default_target): title += "\n" + dataset.serialize_data( network.extern_data.default_target, target_classes[i][:output_len[i]]) ax = plt.gca() tick_labels = [ dataset.serialize_data( network.extern_data.default_target, np.array([x], dtype=target_classes[i].dtype)) for x in target_classes[i][:output_len[i]] ] ax.set_yticklabels([''] + tick_labels, fontsize=8) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.title(title) plt.savefig(fname) plt.close() elif args.output_format == "hdf": assert len(layers) == 1 att_weights = kwargs["rec_%s" % layers[0]] hdf_writer.insert_batch(inputs=att_weights, seq_len={ 0: output_len, 1: encoder_len }, seq_tag=seq_tag) else: raise Exception("output format %r" % args.output_format) runner = Runner(engine=rnn.engine, dataset=dataset, batches=dataset_batch, train=False, train_flag=bool(args.dropout) or args.train_flag, extra_fetches=extra_fetches, extra_fetches_callback=fetch_callback) runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch) for layer in layers: stats[layer].dump(stream_prefix="Layer %r " % layer) if not runner.finalized: print("Some error occured, not finalized.") sys.exit(1) if hdf_writer: hdf_writer.close() rnn.finalize()
def main(argv): """ Main entry. """ arg_parser = argparse.ArgumentParser( description='Dump search scores and other info to HDF file.') arg_parser.add_argument('config', help="filename to config-file") arg_parser.add_argument("--dataset", default="config:train") arg_parser.add_argument("--epoch", type=int, default=-1, help="-1 for last epoch") arg_parser.add_argument("--output_file", help='hdf', required=True) arg_parser.add_argument("--rec_layer_name", default="output") arg_parser.add_argument("--cheating", action="store_true", help="add ground truth to the beam") arg_parser.add_argument("--att_weights", action="store_true", help="dump all softmax_over_spatial layers") arg_parser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)") arg_parser.add_argument("--seq_list", nargs="+", help="use only these seqs") args, remaining_args = arg_parser.parse_known_args(argv[1:]) init(config_filename=args.config, log_verbosity=args.verbosity, remaining_args=remaining_args) dataset = init_dataset(args.dataset) print("Dataset:") pprint(dataset) if args.seq_list: dataset.seq_tags_filter = set(args.seq_list) dataset.partition_epoch = 1 # reset if isinstance(dataset, MetaDataset): for sub_dataset in dataset.datasets.values(): dataset.seq_tags_filter = set(args.seq_list) sub_dataset.partition_epoch = 1 dataset.finish_epoch() # enforce reset if dataset.seq_tags_filter is not None: print("Using sequences:") pprint(dataset.seq_tags_filter) if args.epoch >= 1: config.set("load_epoch", args.epoch) def net_dict_post_proc(net_dict): """ :param dict[str] net_dict: :return: net_dict :rtype: dict[str] """ prepare_compile(rec_layer_name=args.rec_layer_name, net_dict=net_dict, cheating=args.cheating, dump_att_weights=args.att_weights, hdf_filename=args.output_file, possible_labels=dataset.labels) return net_dict engine = Engine(config=config) engine.use_search_flag = True engine.init_network_from_config(config, net_dict_post_proc=net_dict_post_proc) engine.search(dataset, do_eval=config.bool("search_do_eval", True), output_layer_names=args.rec_layer_name) engine.finalize() print("Search finished.") assert os.path.exists(args.output_file), "hdf file not dumped?"
def main(): """ Main entry. """ arg_parser = ArgumentParser() arg_parser.add_argument("--action") arg_parser.add_argument("--print_seq", action='store_true') arg_parser.add_argument("--print_allos", action='store_true') arg_parser.add_argument("--print_targets", action='store_true') arg_parser.add_argument("--dataset") arg_parser.add_argument("--corpus") arg_parser.add_argument("--lexicon", help="filename") arg_parser.add_argument("--silence", type=int, help="index") arg_parser.add_argument("--context", default=1, type=int) arg_parser.add_argument("--hmm_states", default=3, type=int) arg_parser.add_argument("--state_tying_type", help="'monophone' or 'full'") arg_parser.add_argument("--state_tying_output", help="filename") arg_parser.add_argument("--allo_add_all", action="store_true") args = arg_parser.parse_args() dataset = init_dataset(args.dataset) if args.dataset else None corpus = dict(iter_bliss_orth( filename=args.corpus)) if args.corpus else None lexicon = Lexicon(filename=args.lexicon) if args.lexicon else None silence_label = args.silence if args.action == "show_corpus": pprint(corpus) return print("Num phones: %i" % len(lexicon.phonemes), file=log.v1) print("Phones: %r" % sorted(lexicon.phonemes.keys()), file=log.v1) orth_handler = OrthHandler(lexicon=lexicon, allo_context_len=args.context, allo_num_states=args.hmm_states) map_idx_to_allo = defaultdict( set) # type: typing.Dict[int, typing.Set[AllophoneState]] map_allo_to_idx = {} # type: typing.Dict[AllophoneState, int] if args.allo_add_all: orth_handler.allo_add_all = True print("Num HMM states: %i" % orth_handler.allo_num_states, file=log.v1) if args.state_tying_type == "monophone": print("Monophone state tying.", file=log.v1) num_labels = orth_handler.expected_num_labels_for_monophone_state_tying( ) all_label_idx_are_used = True elif args.state_tying_type == "full": print("Full state tying.", file=log.v1) phone_idxs = {k: i + 1 for (i, k) in enumerate(lexicon.phoneme_list) } # +1 to keep 0 reserved as the term-symbol for phon in lexicon.phoneme_list: for allo in orth_handler.all_allophone_variations( phon, all_boundary_variations=True): allo_idx = allo.index( phone_idxs=phone_idxs, num_states=orth_handler.allo_num_states, context_length=orth_handler.allo_context_len) map_idx_to_allo[allo_idx].add(allo) num_labels = max(map_idx_to_allo.keys()) + 1 all_label_idx_are_used = False else: raise Exception("invalid state tying type %r" % args.state_tying_type) print("Num labels: %i" % num_labels, file=log.v1) if dataset: count = 0 for segment_name, targets in iter_dataset_targets(dataset): count += 1 if silence_label is None or count == 1: likely_silence_label = collections.Counter( targets).most_common(1)[0][0] if silence_label is None: silence_label = likely_silence_label if silence_label != likely_silence_label: print("warning: silence %i but likely %i" % (silence_label, likely_silence_label), file=log.v2) print("Silence label: %i" % silence_label, file=log.v1) orth_handler.si_label = silence_label # Monophone state tying: for allo in orth_handler.all_allophone_variations( orth_handler.si_phone): map_idx_to_allo[silence_label].add(allo) map_allo_to_idx[allo] = silence_label assert segment_name in corpus orth = corpus[segment_name] allo_states = orth_handler.orth_to_allophone_states(orth=orth) if args.print_seq: print("%r %r" % (segment_name, orth)) if args.print_allos: print(" allophone state seq: %r" % allo_states) tgt_seq = [t for t in uniq(targets) if t != silence_label] if args.print_targets: print(" target seq: %r" % (tgt_seq, )) assert len(allo_states) == len(tgt_seq), "check --hmm_states or so" for allo, t in zip(allo_states, tgt_seq): allo.boundary = 0 # do not differ between boundaries allos = map_idx_to_allo[t] if allo in map_allo_to_idx: assert allo in allos, "bad mapping" else: assert allo not in allos allos.add(allo) map_allo_to_idx[allo] = t if len(map_idx_to_allo) >= num_labels: assert len(map_idx_to_allo) == num_labels assert 0 in map_idx_to_allo assert num_labels - 1 in map_idx_to_allo print("Finished with uniq mapping after %i sequences." % count, file=log.v1) break if count % 100 == 0: print("Have indices: %i (num labels: %i)" % (len(map_idx_to_allo), num_labels), file=log.v1) print("Finished. Have indices: %i (num labels: %i)" % (len(map_idx_to_allo), num_labels), file=log.v1) if len(map_idx_to_allo) < num_labels: found = [] not_found = [] for p in sorted(lexicon.phonemes.keys()): allo = AllophoneState(p, state=0) if allo in map_allo_to_idx: found.append(p) else: not_found.append(p) print("Phonemes found: %r" % found) print("Phonemes not found: %r" % not_found) if args.state_tying_output: assert not os.path.exists(args.state_tying_output) if all_label_idx_are_used: assert len(map_idx_to_allo) == num_labels assert 0 in map_idx_to_allo assert num_labels - 1 in map_idx_to_allo f = open(args.state_tying_output, "w") for i, allos in sorted(map_idx_to_allo.items()): for allo in allos: f.write("%s %i\n" % (allo.format(), i)) f.close() print("Wrote state tying to %r." % args.state_tying_output, file=log.v1) print("The end.")
debug_add_check_numerics_ops=True, model="/tmp/%s/returnn-demo-as-framework/model" % get_login_username(), cleanup_old_models=True, learning_rate_control="newbob_multi_epoch", learning_rate_control_relative_error_relative_lr=True, newbob_multi_num_epochs=3, newbob_multi_update_interval=1, newbob_learning_rate_decay=0.9, learning_rate_file="/tmp/%s/returnn-demo-as-framework/newbob.data" % get_login_username(), # log log_verbosity=3)) engine = Engine(config) train_data = init_dataset({ "class": "Task12AXDataset", "num_seqs": 1000, "name": "train" }) dev_data = init_dataset({ "class": "Task12AXDataset", "num_seqs": 100, "name": "dev", "fixed_random_seed": 1 }) engine.init_train_from_config(train_data=train_data, dev_data=dev_data) engine.train()