Exemple #1
0
def main(argv):
    """
  Main entry.
  """
    arg_parser = argparse.ArgumentParser(
        description='Dump raw strings from dataset. Same format as in search.')
    arg_parser.add_argument(
        '--config',
        help="filename to config-file. will use dataset 'eval' from it")
    arg_parser.add_argument("--dataset", help="dataset, overwriting config")
    arg_parser.add_argument('--startseq',
                            type=int,
                            default=0,
                            help='start seq idx (inclusive) (default: 0)')
    arg_parser.add_argument('--endseq',
                            type=int,
                            default=-1,
                            help='end seq idx (inclusive) or -1 (default: -1)')
    arg_parser.add_argument(
        "--key",
        default="raw",
        help="data-key, e.g. 'data' or 'classes'. (default: 'raw')")
    arg_parser.add_argument("--verbosity",
                            default=4,
                            type=int,
                            help="5 for all seqs (default: 4)")
    arg_parser.add_argument("--out",
                            required=True,
                            help="out-file. py-format as in task=search")
    args = arg_parser.parse_args(argv[1:])
    assert args.config or args.dataset

    init(config_filename=args.config, log_verbosity=args.verbosity)
    if args.dataset:
        dataset = init_dataset(args.dataset)
    elif config.value("dump_data", "eval") in ["train", "dev", "eval"]:
        dataset = init_dataset(
            config.opt_typed_value(config.value("search_data", "eval")))
    else:
        dataset = init_dataset(config.opt_typed_value("wer_data"))
    dataset.init_seq_order(epoch=1)

    try:
        with generic_open(args.out, "w") as output_file:
            refs = get_raw_strings(dataset=dataset, options=args)
            output_file.write("{\n")
            for seq_tag, ref in refs:
                output_file.write("%r: %r,\n" % (seq_tag, ref))
            output_file.write("}\n")
        print("Done. Wrote to %r." % args.out)
    except KeyboardInterrupt:
        print("KeyboardInterrupt")
        sys.exit(1)
    finally:
        rnn.finalize()
Exemple #2
0
def load_data(config, cache_byte_size, files_config_key, **kwargs):
    """
  :param Config config:
  :param int cache_byte_size:
  :param str files_config_key: such as "train" or "dev"
  :param kwargs: passed on to init_dataset() or init_dataset_via_str()
  :rtype: (Dataset,int)
  :returns the dataset, and the cache byte size left over if we cache the whole dataset.
  """
    if not config.bool_or_other(files_config_key, None):
        return None, 0
    kwargs = kwargs.copy()
    kwargs.setdefault("name", files_config_key)
    if config.is_typed(files_config_key) and isinstance(
            config.typed_value(files_config_key), dict):
        config_opts = config.typed_value(files_config_key)
        assert isinstance(config_opts, dict)
        kwargs.update(config_opts)
        if 'cache_byte_size' not in config_opts:
            if kwargs.get('class', None) == 'HDFDataset':
                kwargs["cache_byte_size"] = cache_byte_size
        Dataset.kwargs_update_from_config(config, kwargs)
        data = init_dataset(kwargs)
    else:
        config_str = config.value(files_config_key, "")
        data = init_dataset_via_str(config_str,
                                    config=config,
                                    cache_byte_size=cache_byte_size,
                                    **kwargs)
    cache_leftover = 0
    if isinstance(data, HDFDataset):
        cache_leftover = data.definite_cache_leftover
    return data, cache_leftover
Exemple #3
0
def init(config_filename, cmd_line_opts, dataset_config_str):
    """
  :param str config_filename: global config for CRNN
  :param list[str] cmd_line_opts: options for init_config method
  :param str dataset_config_str: dataset via init_dataset_via_str()
  """
    rnn.init_better_exchook()
    rnn.init_thread_join_hack()
    if config_filename:
        rnn.init_config(config_filename, cmd_line_opts)
        rnn.init_log()
    else:
        log.initialize(verbosity=[5])
    print("Returnn hdf_dump starting up.", file=log.v3)
    rnn.init_faulthandler()
    if config_filename:
        rnn.init_data()
        rnn.print_task_properties()
        assert isinstance(rnn.train_data, Dataset)
        dataset = rnn.train_data
    else:
        assert dataset_config_str
        dataset = init_dataset(dataset_config_str)
    print("Source dataset:", dataset.len_info(), file=log.v3)
    return dataset
Exemple #4
0
def init(config_str, config_dataset, verbosity):
    """
  :param str config_str: either filename to config-file, or dict for dataset
  :param str|None config_dataset:
  :param int verbosity:
  """
    global dataset
    rnn.init_better_exchook()
    rnn.init_thread_join_hack()
    dataset_dict = None
    config_filename = None
    if config_str.strip().startswith("{"):
        print("Using dataset %s." % config_str)
        dataset_dict = eval(config_str.strip())
    elif config_str.endswith(".hdf"):
        dataset_dict = {"class": "HDFDataset", "files": [config_str]}
        print("Using dataset %r." % dataset_dict)
        assert os.path.exists(config_str)
    else:
        config_filename = config_str
        print("Using config file %r." % config_filename)
        assert os.path.exists(config_filename)
    rnn.init_config(config_filename=config_filename,
                    default_config={"cache_size": "0"})
    global config
    config = rnn.config
    config.set("log", None)
    config.set("log_verbosity", verbosity)
    if dataset_dict:
        assert not config_dataset
        dataset = init_dataset(dataset_dict)
    elif config_dataset and config_dataset != "train":
        print("Use dataset %r from config." % config_dataset)
        dataset = init_dataset("config:%s" % config_dataset)
    else:
        print("Use train dataset from config.")
        assert config.value("train", None)
        dataset = init_dataset("config:train")
    rnn.init_log()
    print("Returnn dump-dataset starting up.", file=log.v2)
    rnn.returnn_greeting()
    rnn.init_faulthandler()
    rnn.init_config_json_network()
    print("Dataset:", file=log.v2)
    print("  input:", dataset.num_inputs, "x", dataset.window, file=log.v2)
    print("  output:", dataset.num_outputs, file=log.v2)
    print(" ", dataset.len_info() or "no info", file=log.v2)
Exemple #5
0
def init(config_filename, command_line_options, args):
    """
  :param str config_filename:
  :param list[str] command_line_options:
  :param args: argparse.Namespace
  """
    global config, engine, dataset
    rnn.init(config_filename=config_filename,
             command_line_options=command_line_options,
             config_updates={
                 "log": None,
                 "need_data": False
             },
             extra_greeting="RETURNN dump-forward starting up.")
    config = rnn.config
    engine = rnn.engine

    dataset_str = args.dataset
    if dataset_str in {"train", "dev", "eval", "search_data"}:
        dataset_str = "config:%s" % dataset_str
    extra_dataset_kwargs = {}
    if args.reset_partition_epoch:
        print("NOTE: We are resetting partition epoch to %i." %
              (args.reset_partition_epoch, ))
        extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch
    if args.reset_seq_ordering:
        print("NOTE: We will use %r seq ordering." %
              (args.reset_seq_ordering, ))
        extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering
    if args.reset_epoch_wise_filter:
        extra_dataset_kwargs["epoch_wise_filter"] = eval(
            args.reset_epoch_wise_filter)
    dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs)
    if hasattr(dataset,
               "epoch_wise_filter") and args.reset_epoch_wise_filter is None:
        if dataset.epoch_wise_filter:
            print("NOTE: Resetting epoch_wise_filter to None.")
            dataset.epoch_wise_filter = None
    if args.reset_partition_epoch:
        assert dataset.partition_epoch == args.reset_partition_epoch
    if args.reset_seq_ordering:
        assert dataset.seq_ordering == args.reset_seq_ordering

    config.set("task", "eval")
    if args.load:
        config.set("load", args.load)

    epoch, model_epoch_filename = Engine.get_epoch_model(config)
    engine.pretrain = pretrain_from_config(config)
    engine.custom_get_net_dict = config.typed_value("get_network")
    net_dict = engine.get_net_dict_for_epoch(epoch)
    engine.make_tf_session()
    engine.network = TFNetwork(name="root")
    engine.network.construct_layer(net_dict, args.layer)
    print("Load model:", model_epoch_filename)
    engine.network.load_params_from_file(model_epoch_filename,
                                         session=engine.tf_session)
Exemple #6
0
def benchmark(lstm_unit, use_gpu):
    """
  :param str lstm_unit: e.g. "LSTMBlock", one of LstmCellTypes
  :param bool use_gpu:
  :return: runtime in seconds of the training itself, excluding initialization
  :rtype: float
  """
    device = {True: "GPU", False: "CPU"}[use_gpu]
    key = "%s:%s" % (device, lstm_unit)
    print(">>> Start benchmark for %s." % key)
    config = Config()
    config.update(make_config_dict(lstm_unit=lstm_unit, use_gpu=use_gpu))
    dataset_kwargs = config.typed_value("train")
    Dataset.kwargs_update_from_config(config, dataset_kwargs)
    dataset = init_dataset(dataset_kwargs)
    engine = Engine(config=config)
    engine.init_train_from_config(config=config, train_data=dataset)
    print(">>> Start training now for %s." % key)
    start_time = time.time()
    engine.train()
    runtime = time.time() - start_time
    print(">>> Runtime of %s: %s" % (key, hms_fraction(runtime)))
    engine.finalize()
    return runtime
Exemple #7
0
def execute_main_task():
    """
  Executes the main task (via config ``task`` option).
  """
    from returnn.util.basic import hms_fraction
    start_time = time.time()
    task = config.value('task', 'train')
    if config.is_true("dry_run"):
        print("Dry run, will not save anything.", file=log.v1)
    if task == 'train':
        assert train_data.have_seqs(
        ), "no train files specified, check train option: %s" % config.value(
            'train', None)
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        engine.train()
    elif task == "eval":
        epoch = config.int("epoch", -1)
        load_epoch = config.int("load_epoch", -1)
        if epoch >= 0:
            assert (load_epoch < 0) or (
                load_epoch == epoch), "epoch and load_epoch have to match"
            engine.epoch = epoch
            config.set('load_epoch', engine.epoch)
        else:
            assert load_epoch >= 0, "specify epoch or load_epoch"
            engine.epoch = load_epoch
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        print("Evaluate epoch", engine.epoch, file=log.v4)
        engine.eval_model(
            output_file=config.value("eval_output_file", None),
            output_per_seq_file=config.value("eval_output_file_per_seq", None),
            loss_name=config.value("loss_name", None),
            output_per_seq_format=config.list("output_per_seq_format",
                                              ["score"]),
            output_per_seq_file_format=config.value(
                "output_per_seq_file_format", "txt"))
    elif task in ['forward', 'hpx']:
        assert eval_data is not None, 'no eval data provided'
        combine_labels = config.value('combine_labels', '')
        engine.use_search_flag = config.bool("forward_use_search", False)
        if config.has("epoch"):
            config.set('load_epoch', config.int('epoch', 0))
        engine.init_network_from_config(config)
        output_file = config.value('output_file',
                                   'dump-fwd-epoch-%i.hdf' % engine.epoch)
        engine.forward_to_hdf(data=eval_data,
                              output_file=output_file,
                              combine_labels=combine_labels,
                              batch_size=config.int('forward_batch_size', 0))
    elif task == "search":
        engine.use_search_flag = True
        engine.use_eval_flag = config.bool("search_do_eval", True)
        engine.init_network_from_config(config)
        if config.value("search_data", "eval") in ["train", "dev", "eval"]:
            data = {
                "train": train_data,
                "dev": dev_data,
                "eval": eval_data
            }[config.value("search_data", "eval")]
            assert data, "set search_data"
        else:
            data = init_dataset(config.opt_typed_value("search_data"))
        engine.search(
            data,
            do_eval=config.bool("search_do_eval", True),
            output_layer_names=config.typed_value("search_output_layer",
                                                  "output"),
            output_file=config.value("search_output_file", ""),
            output_file_format=config.value("search_output_file_format",
                                            "txt"))
    elif task == 'compute_priors':
        assert train_data is not None, 'train data for priors should be provided'
        engine.init_network_from_config(config)
        engine.compute_priors(dataset=train_data, config=config)
    elif task == 'theano_graph':
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.printing
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.compile.io
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.compile.function_module
        engine.start_epoch = 1
        engine.init_network_from_config(config)
        for task in config.list('theano_graph.task', ['train']):
            func = engine.devices[-1].get_compute_func(task)
            prefix = config.value("theano_graph.prefix", "current") + ".task"
            print("dumping to %s.* ..." % prefix, file=log.v1)
            theano.printing.debugprint(func,
                                       file=open(
                                           "%s.optimized_func.txt" % prefix,
                                           "w"))
            assert isinstance(func.maker,
                              theano.compile.function_module.FunctionMaker)
            for inp in func.maker.inputs:
                assert isinstance(inp, theano.compile.io.In)
                if inp.update:
                    theano.printing.debugprint(
                        inp.update,
                        file=open(
                            "%s.unoptimized.var_%s_update.txt" %
                            (prefix, inp.name), "w"))
            theano.printing.pydotprint(func,
                                       format='png',
                                       var_with_name_simple=True,
                                       outfile="%s.png" % prefix)
    elif task == 'analyze':  # anything based on the network + Device
        statistics = config.list('statistics', None)
        engine.init_network_from_config(config)
        engine.analyze(data=eval_data or dev_data, statistics=statistics)
    elif task == "analyze_data":  # anything just based on the data
        analyze_data(config)
    elif task == "classify":
        assert eval_data is not None, 'no eval data provided'
        assert config.has('label_file'), 'no output file provided'
        label_file = config.value('label_file', '')
        engine.init_network_from_config(config)
        engine.classify(eval_data, label_file)
    elif task == "hyper_param_tuning":
        import returnn.tf.hyper_param_tuning
        tuner = returnn.tf.hyper_param_tuning.Optimization(
            config=config, train_data=train_data)
        tuner.work()
    elif task == "cleanup_old_models":
        engine.cleanup_old_models(ask_for_confirmation=True)
    elif task == "daemon":
        engine.init_network_from_config(config)
        engine.daemon(config)
    elif task == "server":
        print("Server Initiating", file=log.v1)
        server.run()
    elif task == "search_server":
        engine.use_search_flag = True
        engine.init_network_from_config(config)
        engine.web_server(port=config.int("web_server_port", 12380))
    elif task.startswith("config:"):
        action = config.typed_dict[task[len("config:"):]]
        print("Task: %r" % action, file=log.v1)
        assert callable(action)
        action()
    elif task.startswith("optional-config:"):
        action = config.typed_dict.get(task[len("optional-config:"):], None)
        if action is None:
            print("No task found for %r, so just quitting." % task,
                  file=log.v1)
        else:
            print("Task: %r" % action, file=log.v1)
            assert callable(action)
            action()
    elif task == "nop":
        print("Task: No-operation", file=log.v1)
    elif task == "nop_init_net_train":
        print(
            "Task: No-operation, despite initializing the network (for training)",
            file=log.v1)
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
    elif task == "initialize_model":
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        engine.save_model(config.value('model', 'dummy'))
    else:
        assert False, "unknown task: %s" % task

    print(("elapsed: %s" % hms_fraction(time.time() - start_time)),
          file=log.v3)
def main(argv):
    """
  Main entry.
  """
    argparser = argparse.ArgumentParser(description=__doc__)
    argparser.add_argument("config_file",
                           type=str,
                           help="RETURNN config, or model-dir")
    argparser.add_argument("--epoch", type=int)
    argparser.add_argument(
        '--data',
        default="train",
        help=
        "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'"
    )
    argparser.add_argument('--do_search', default=False, action='store_true')
    argparser.add_argument('--beam_size', default=12, type=int)
    argparser.add_argument('--dump_dir', help="for npy or png")
    argparser.add_argument("--output_file", help="hdf")
    argparser.add_argument("--device", help="gpu or cpu (default: automatic)")
    argparser.add_argument("--layers",
                           default=["att_weights"],
                           action="append",
                           help="Layer of subnet to grab")
    argparser.add_argument("--rec_layer",
                           default="output",
                           help="Subnet layer to grab from; decoder")
    argparser.add_argument("--enc_layer", default="encoder")
    argparser.add_argument("--batch_size", type=int, default=5000)
    argparser.add_argument("--seq_list",
                           default=[],
                           action="append",
                           help="predefined list of seqs")
    argparser.add_argument("--min_seq_len",
                           default="0",
                           help="can also be dict")
    argparser.add_argument("--num_seqs",
                           default=-1,
                           type=int,
                           help="stop after this many seqs")
    argparser.add_argument("--output_format",
                           default="npy",
                           help="npy, png or hdf")
    argparser.add_argument("--dropout",
                           default=None,
                           type=float,
                           help="if set, overwrites all dropout values")
    argparser.add_argument("--train_flag", action="store_true")
    argparser.add_argument("--reset_partition_epoch", type=int, default=1)
    argparser.add_argument("--reset_seq_ordering", default="sorted_reverse")
    argparser.add_argument("--reset_epoch_wise_filter", default=None)
    args = argparser.parse_args(argv[1:])

    layers = args.layers
    assert isinstance(layers, list)
    config_fn = args.config_file
    explicit_model_dir = None
    if os.path.isdir(config_fn):
        # Assume we gave a model dir.
        explicit_model_dir = config_fn
        train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn
        train_log_dir_configs = sorted(glob(train_log_dir_config_pattern))
        assert train_log_dir_configs
        config_fn = train_log_dir_configs[-1]
        print("Using this config via model dir:", config_fn)
    else:
        assert os.path.isfile(config_fn)
    model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1])

    init_returnn(config_fn=config_fn, args=args)
    if explicit_model_dir:
        config.set(
            "model", "%s/%s" %
            (explicit_model_dir, os.path.basename(config.value('model', ''))))
    print("Model file prefix:", config.value('model', ''))

    if args.do_search:
        raise NotImplementedError
    min_seq_length = NumbersDict(eval(args.min_seq_len))

    assert args.output_format in ["npy", "png", "hdf"]
    if args.output_format in ["npy", "png"]:
        assert args.dump_dir
        if not os.path.exists(args.dump_dir):
            os.makedirs(args.dump_dir)
    plt = ticker = None
    if args.output_format == "png":
        import matplotlib.pyplot as plt  # need to import early? https://stackoverflow.com/a/45582103/133374
        import matplotlib.ticker as ticker

    dataset_str = args.data
    if dataset_str in ["train", "dev", "eval"]:
        dataset_str = "config:%s" % dataset_str
    extra_dataset_kwargs = {}
    if args.reset_partition_epoch:
        print("NOTE: We are resetting partition epoch to %i." %
              (args.reset_partition_epoch, ))
        extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch
    if args.reset_seq_ordering:
        print("NOTE: We will use %r seq ordering." %
              (args.reset_seq_ordering, ))
        extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering
    if args.reset_epoch_wise_filter:
        extra_dataset_kwargs["epoch_wise_filter"] = eval(
            args.reset_epoch_wise_filter)
    dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs)
    if hasattr(dataset,
               "epoch_wise_filter") and args.reset_epoch_wise_filter is None:
        if dataset.epoch_wise_filter:
            print("NOTE: Resetting epoch_wise_filter to None.")
            dataset.epoch_wise_filter = None
    if args.reset_partition_epoch:
        assert dataset.partition_epoch == args.reset_partition_epoch
    if args.reset_seq_ordering:
        assert dataset.seq_ordering == args.reset_seq_ordering

    init_net(args, layers)
    network = rnn.engine.network

    hdf_writer = None
    if args.output_format == "hdf":
        assert args.output_file
        assert len(layers) == 1
        sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0]))
        from returnn.datasets.hdf import SimpleHDFWriter
        hdf_writer = SimpleHDFWriter(filename=args.output_file,
                                     dim=sub_layer.output.dim,
                                     ndim=sub_layer.output.ndim)

    extra_fetches = {
        "output":
        network.layers[args.rec_layer].output.get_placeholder_as_batch_major(),
        "output_len":
        network.layers[
            args.rec_layer].output.get_sequence_lengths(),  # decoder length
        "encoder_len":
        network.layers[
            args.enc_layer].output.get_sequence_lengths(),  # encoder length
        "seq_idx":
        network.get_extern_data("seq_idx"),
        "seq_tag":
        network.get_extern_data("seq_tag"),
        "target_data":
        network.get_extern_data(network.extern_data.default_input),
        "target_classes":
        network.get_extern_data(network.extern_data.default_target),
    }
    for layer in layers:
        sub_layer = rnn.engine.network.get_layer("%s/%s" %
                                                 (args.rec_layer, layer))
        extra_fetches[
            "rec_%s" %
            layer] = sub_layer.output.get_placeholder_as_batch_major()
    dataset.init_seq_order(
        epoch=1, seq_list=args.seq_list
        or None)  # use always epoch 1, such that we have same seqs
    dataset_batch = dataset.generate_batches(
        recurrent_net=network.recurrent,
        batch_size=args.batch_size,
        max_seqs=rnn.engine.max_seqs,
        max_seq_length=sys.maxsize,
        min_seq_length=min_seq_length,
        max_total_num_seqs=args.num_seqs,
        used_data_keys=network.used_data_keys)

    stats = {layer: Stats() for layer in layers}

    # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None
    def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output,
                       output_len, encoder_len, **kwargs):
        """
    :param list[int] seq_idx: len is n_batch
    :param list[str] seq_tag: len is n_batch
    :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...)
    :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...)
    :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...)
    :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,)
    :param numpy.ndarray encoder_len: encoder seq len, shape (B,)
    :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in
    """
        n_batch = len(seq_idx)
        for i in range(n_batch):
            # noinspection PyShadowingNames
            for layer in layers:
                att_weights = kwargs["rec_%s" % layer][i]
                stats[layer].collect(att_weights.flatten())
        if args.output_format == "npy":
            data = {}
            for i in range(n_batch):
                data[i] = {
                    'tag': seq_tag[i],
                    'data': target_data[i],
                    'classes': target_classes[i],
                    'output': output[i],
                    'output_len': output_len[i],
                    'encoder_len': encoder_len[i],
                }
                # noinspection PyShadowingNames
                for layer in [("rec_%s" % layer) for layer in layers]:
                    assert layer in kwargs
                    out = kwargs[layer][i]
                    assert out.ndim >= 2
                    assert out.shape[0] >= output_len[i] and out.shape[
                        1] >= encoder_len[i]
                    data[i][layer] = out[:output_len[i], :encoder_len[i]]
                fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % (
                    model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1])
                np.save(fname, data)
        elif args.output_format == "png":
            for i in range(n_batch):
                # noinspection PyShadowingNames
                for layer in layers:
                    extra_postfix = ""
                    if args.dropout is not None:
                        extra_postfix += "_dropout%.2f" % args.dropout
                    elif args.train_flag:
                        extra_postfix += "_train"
                    fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % (
                        model_name, rnn.engine.epoch, seq_idx[i], layer,
                        extra_postfix)
                    att_weights = kwargs["rec_%s" % layer][i]
                    att_weights = att_weights.squeeze(axis=2)  # (out,enc)
                    assert att_weights.shape[0] >= output_len[
                        i] and att_weights.shape[1] >= encoder_len[i]
                    att_weights = att_weights[:output_len[i], :encoder_len[i]]
                    print("Seq %i, %s: Dump att weights with shape %r to: %s" %
                          (seq_idx[i], seq_tag[i], att_weights.shape, fname))
                    plt.matshow(att_weights)
                    title = seq_tag[i]
                    if dataset.can_serialize_data(
                            network.extern_data.default_target):
                        title += "\n" + dataset.serialize_data(
                            network.extern_data.default_target,
                            target_classes[i][:output_len[i]])
                        ax = plt.gca()
                        tick_labels = [
                            dataset.serialize_data(
                                network.extern_data.default_target,
                                np.array([x], dtype=target_classes[i].dtype))
                            for x in target_classes[i][:output_len[i]]
                        ]
                        ax.set_yticklabels([''] + tick_labels, fontsize=8)
                        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
                    plt.title(title)
                    plt.savefig(fname)
                    plt.close()
        elif args.output_format == "hdf":
            assert len(layers) == 1
            att_weights = kwargs["rec_%s" % layers[0]]
            hdf_writer.insert_batch(inputs=att_weights,
                                    seq_len={
                                        0: output_len,
                                        1: encoder_len
                                    },
                                    seq_tag=seq_tag)
        else:
            raise Exception("output format %r" % args.output_format)

    runner = Runner(engine=rnn.engine,
                    dataset=dataset,
                    batches=dataset_batch,
                    train=False,
                    train_flag=bool(args.dropout) or args.train_flag,
                    extra_fetches=extra_fetches,
                    extra_fetches_callback=fetch_callback)
    runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch)
    for layer in layers:
        stats[layer].dump(stream_prefix="Layer %r " % layer)
    if not runner.finalized:
        print("Some error occured, not finalized.")
        sys.exit(1)

    if hdf_writer:
        hdf_writer.close()
    rnn.finalize()
def main(argv):
    """
  Main entry.
  """
    arg_parser = argparse.ArgumentParser(
        description='Dump search scores and other info to HDF file.')
    arg_parser.add_argument('config', help="filename to config-file")
    arg_parser.add_argument("--dataset", default="config:train")
    arg_parser.add_argument("--epoch",
                            type=int,
                            default=-1,
                            help="-1 for last epoch")
    arg_parser.add_argument("--output_file", help='hdf', required=True)
    arg_parser.add_argument("--rec_layer_name", default="output")
    arg_parser.add_argument("--cheating",
                            action="store_true",
                            help="add ground truth to the beam")
    arg_parser.add_argument("--att_weights",
                            action="store_true",
                            help="dump all softmax_over_spatial layers")
    arg_parser.add_argument("--verbosity",
                            default=4,
                            type=int,
                            help="5 for all seqs (default: 4)")
    arg_parser.add_argument("--seq_list",
                            nargs="+",
                            help="use only these seqs")
    args, remaining_args = arg_parser.parse_known_args(argv[1:])
    init(config_filename=args.config,
         log_verbosity=args.verbosity,
         remaining_args=remaining_args)

    dataset = init_dataset(args.dataset)
    print("Dataset:")
    pprint(dataset)
    if args.seq_list:
        dataset.seq_tags_filter = set(args.seq_list)
        dataset.partition_epoch = 1  # reset
        if isinstance(dataset, MetaDataset):
            for sub_dataset in dataset.datasets.values():
                dataset.seq_tags_filter = set(args.seq_list)
                sub_dataset.partition_epoch = 1
        dataset.finish_epoch()  # enforce reset
    if dataset.seq_tags_filter is not None:
        print("Using sequences:")
        pprint(dataset.seq_tags_filter)
    if args.epoch >= 1:
        config.set("load_epoch", args.epoch)

    def net_dict_post_proc(net_dict):
        """
    :param dict[str] net_dict:
    :return: net_dict
    :rtype: dict[str]
    """
        prepare_compile(rec_layer_name=args.rec_layer_name,
                        net_dict=net_dict,
                        cheating=args.cheating,
                        dump_att_weights=args.att_weights,
                        hdf_filename=args.output_file,
                        possible_labels=dataset.labels)
        return net_dict

    engine = Engine(config=config)
    engine.use_search_flag = True
    engine.init_network_from_config(config,
                                    net_dict_post_proc=net_dict_post_proc)
    engine.search(dataset,
                  do_eval=config.bool("search_do_eval", True),
                  output_layer_names=args.rec_layer_name)
    engine.finalize()
    print("Search finished.")
    assert os.path.exists(args.output_file), "hdf file not dumped?"
Exemple #10
0
def main():
    """
  Main entry.
  """
    arg_parser = ArgumentParser()
    arg_parser.add_argument("--action")
    arg_parser.add_argument("--print_seq", action='store_true')
    arg_parser.add_argument("--print_allos", action='store_true')
    arg_parser.add_argument("--print_targets", action='store_true')
    arg_parser.add_argument("--dataset")
    arg_parser.add_argument("--corpus")
    arg_parser.add_argument("--lexicon", help="filename")
    arg_parser.add_argument("--silence", type=int, help="index")
    arg_parser.add_argument("--context", default=1, type=int)
    arg_parser.add_argument("--hmm_states", default=3, type=int)
    arg_parser.add_argument("--state_tying_type", help="'monophone' or 'full'")
    arg_parser.add_argument("--state_tying_output", help="filename")
    arg_parser.add_argument("--allo_add_all", action="store_true")
    args = arg_parser.parse_args()

    dataset = init_dataset(args.dataset) if args.dataset else None
    corpus = dict(iter_bliss_orth(
        filename=args.corpus)) if args.corpus else None
    lexicon = Lexicon(filename=args.lexicon) if args.lexicon else None
    silence_label = args.silence

    if args.action == "show_corpus":
        pprint(corpus)
        return

    print("Num phones: %i" % len(lexicon.phonemes), file=log.v1)
    print("Phones: %r" % sorted(lexicon.phonemes.keys()), file=log.v1)

    orth_handler = OrthHandler(lexicon=lexicon,
                               allo_context_len=args.context,
                               allo_num_states=args.hmm_states)
    map_idx_to_allo = defaultdict(
        set)  # type: typing.Dict[int, typing.Set[AllophoneState]]
    map_allo_to_idx = {}  # type: typing.Dict[AllophoneState, int]
    if args.allo_add_all:
        orth_handler.allo_add_all = True

    print("Num HMM states: %i" % orth_handler.allo_num_states, file=log.v1)
    if args.state_tying_type == "monophone":
        print("Monophone state tying.", file=log.v1)
        num_labels = orth_handler.expected_num_labels_for_monophone_state_tying(
        )
        all_label_idx_are_used = True
    elif args.state_tying_type == "full":
        print("Full state tying.", file=log.v1)
        phone_idxs = {k: i + 1
                      for (i, k) in enumerate(lexicon.phoneme_list)
                      }  # +1 to keep 0 reserved as the term-symbol
        for phon in lexicon.phoneme_list:
            for allo in orth_handler.all_allophone_variations(
                    phon, all_boundary_variations=True):
                allo_idx = allo.index(
                    phone_idxs=phone_idxs,
                    num_states=orth_handler.allo_num_states,
                    context_length=orth_handler.allo_context_len)
                map_idx_to_allo[allo_idx].add(allo)
        num_labels = max(map_idx_to_allo.keys()) + 1
        all_label_idx_are_used = False
    else:
        raise Exception("invalid state tying type %r" % args.state_tying_type)
    print("Num labels: %i" % num_labels, file=log.v1)

    if dataset:
        count = 0
        for segment_name, targets in iter_dataset_targets(dataset):
            count += 1
            if silence_label is None or count == 1:
                likely_silence_label = collections.Counter(
                    targets).most_common(1)[0][0]
                if silence_label is None:
                    silence_label = likely_silence_label
                if silence_label != likely_silence_label:
                    print("warning: silence %i but likely %i" %
                          (silence_label, likely_silence_label),
                          file=log.v2)
                print("Silence label: %i" % silence_label, file=log.v1)
                orth_handler.si_label = silence_label
                # Monophone state tying:
                for allo in orth_handler.all_allophone_variations(
                        orth_handler.si_phone):
                    map_idx_to_allo[silence_label].add(allo)
                    map_allo_to_idx[allo] = silence_label
            assert segment_name in corpus
            orth = corpus[segment_name]
            allo_states = orth_handler.orth_to_allophone_states(orth=orth)
            if args.print_seq:
                print("%r %r" % (segment_name, orth))
            if args.print_allos:
                print("  allophone state seq: %r" % allo_states)
            tgt_seq = [t for t in uniq(targets) if t != silence_label]
            if args.print_targets:
                print("  target seq: %r" % (tgt_seq, ))
            assert len(allo_states) == len(tgt_seq), "check --hmm_states or so"
            for allo, t in zip(allo_states, tgt_seq):
                allo.boundary = 0  # do not differ between boundaries
                allos = map_idx_to_allo[t]
                if allo in map_allo_to_idx:
                    assert allo in allos, "bad mapping"
                else:
                    assert allo not in allos
                    allos.add(allo)
                    map_allo_to_idx[allo] = t
            if len(map_idx_to_allo) >= num_labels:
                assert len(map_idx_to_allo) == num_labels
                assert 0 in map_idx_to_allo
                assert num_labels - 1 in map_idx_to_allo
                print("Finished with uniq mapping after %i sequences." % count,
                      file=log.v1)
                break
            if count % 100 == 0:
                print("Have indices: %i (num labels: %i)" %
                      (len(map_idx_to_allo), num_labels),
                      file=log.v1)

        print("Finished. Have indices: %i (num labels: %i)" %
              (len(map_idx_to_allo), num_labels),
              file=log.v1)
        if len(map_idx_to_allo) < num_labels:
            found = []
            not_found = []
            for p in sorted(lexicon.phonemes.keys()):
                allo = AllophoneState(p, state=0)
                if allo in map_allo_to_idx:
                    found.append(p)
                else:
                    not_found.append(p)
            print("Phonemes found: %r" % found)
            print("Phonemes not found: %r" % not_found)

    if args.state_tying_output:
        assert not os.path.exists(args.state_tying_output)
        if all_label_idx_are_used:
            assert len(map_idx_to_allo) == num_labels
            assert 0 in map_idx_to_allo
            assert num_labels - 1 in map_idx_to_allo
        f = open(args.state_tying_output, "w")
        for i, allos in sorted(map_idx_to_allo.items()):
            for allo in allos:
                f.write("%s %i\n" % (allo.format(), i))
        f.close()
        print("Wrote state tying to %r." % args.state_tying_output,
              file=log.v1)

    print("The end.")
        debug_add_check_numerics_ops=True,
        model="/tmp/%s/returnn-demo-as-framework/model" % get_login_username(),
        cleanup_old_models=True,
        learning_rate_control="newbob_multi_epoch",
        learning_rate_control_relative_error_relative_lr=True,
        newbob_multi_num_epochs=3,
        newbob_multi_update_interval=1,
        newbob_learning_rate_decay=0.9,
        learning_rate_file="/tmp/%s/returnn-demo-as-framework/newbob.data" %
        get_login_username(),

        # log
        log_verbosity=3))

engine = Engine(config)

train_data = init_dataset({
    "class": "Task12AXDataset",
    "num_seqs": 1000,
    "name": "train"
})
dev_data = init_dataset({
    "class": "Task12AXDataset",
    "num_seqs": 100,
    "name": "dev",
    "fixed_random_seed": 1
})

engine.init_train_from_config(train_data=train_data, dev_data=dev_data)
engine.train()