Esempio n. 1
0
def dump_dataset(dataset, options):
    """
  :type dataset: Dataset.Dataset
  :param options: argparse.Namespace
  """
    print("Epoch: %i" % options.epoch, file=log.v3)
    dataset.init_seq_order(epoch=options.epoch)
    print("Dataset keys:", dataset.get_data_keys(), file=log.v3)
    print("Dataset target keys:", dataset.get_target_list(), file=log.v3)
    assert options.key in dataset.get_data_keys()

    if options.get_num_seqs:
        print("Get num seqs.")
        print("estimated_num_seqs: %r" % dataset.estimated_num_seqs)
        try:
            print("num_seqs: %r" % dataset.num_seqs)
        except Exception as exc:
            print("num_seqs exception %r, which is valid, so we count." % exc)
            seq_idx = 0
            if dataset.get_target_list():
                default_target = dataset.get_target_list()[0]
            else:
                default_target = None
            while dataset.is_less_than_num_seqs(seq_idx):
                dataset.load_seqs(seq_idx, seq_idx + 1)
                if seq_idx % 10000 == 0:
                    if default_target:
                        targets = dataset.get_targets(default_target, seq_idx)
                        postfix = " (targets = %r...)" % (targets[:10], )
                    else:
                        postfix = ""
                    print("%i ...%s" % (seq_idx, postfix))
                seq_idx += 1
            print("accumulated num seqs: %i" % seq_idx)
        print("Done.")
        return

    dump_file = None
    if options.type == "numpy":
        print("Dump files: %r*%r" %
              (options.dump_prefix, options.dump_postfix),
              file=log.v3)
    elif options.type == "stdout":
        print("Dump to stdout", file=log.v3)
        if options.stdout_limit is not None:
            util.set_pretty_print_default_limit(options.stdout_limit)
            numpy.set_printoptions(
                threshold=sys.maxsize if options.stdout_limit ==
                float("inf") else int(options.stdout_limit))
        if options.stdout_as_bytes:
            util.set_pretty_print_as_bytes(options.stdout_as_bytes)
    elif options.type == "print_tag":
        print("Dump seq tag to stdout", file=log.v3)
    elif options.type == "dump_tag":
        dump_file = open("%sseq-tags.txt" % options.dump_prefix, "w")
        print("Dump seq tag to file: %s" % (dump_file.name, ), file=log.v3)
    elif options.type == "dump_seq_len":
        dump_file = open("%sseq-lens.txt" % options.dump_prefix, "w")
        print("Dump seq lens to file: %s" % (dump_file.name, ), file=log.v3)
        dump_file.write("{\n")
    elif options.type == "print_shape":
        print("Dump shape to stdout", file=log.v3)
    elif options.type == "plot":
        print("Plot.", file=log.v3)
    elif options.type == "interactive":
        print("Interactive debug shell.", file=log.v3)
    elif options.type == "null":
        if options.dump_stats:
            print("No dump (except stats).")
        else:
            print("No dump.")
    else:
        raise Exception("unknown dump option type %r" % options.type)

    start_time = time.time()
    stats = Stats() if (options.stats or options.dump_stats) else None
    seq_len_stats = {key: Stats() for key in dataset.get_data_keys()}
    seq_idx = options.startseq
    if options.endseq < 0:
        options.endseq = float("inf")
    while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq:
        dataset.load_seqs(seq_idx, seq_idx + 1)
        complete_frac = dataset.get_complete_frac(seq_idx)
        start_elapsed = time.time() - start_time
        try:
            num_seqs_s = str(dataset.num_seqs)
        except NotImplementedError:
            try:
                num_seqs_s = "~%i" % dataset.estimated_num_seqs
            except TypeError:  # a number is required, not NoneType
                num_seqs_s = "?"
        progress_prefix = "%i/%s" % (seq_idx, num_seqs_s)
        progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
        data = None
        if complete_frac > 0:
            total_time_estimated = start_elapsed / complete_frac
            remaining_estimated = total_time_estimated - start_elapsed
            progress += " (%s)" % hms(remaining_estimated)
        if options.type == "print_tag":
            print(
                "seq %s tag:" %
                (progress if log.verbose[2] else progress_prefix),
                dataset.get_tag(seq_idx))
        elif options.type == "dump_tag":
            print(
                "seq %s tag:" %
                (progress if log.verbose[2] else progress_prefix),
                dataset.get_tag(seq_idx))
            dump_file.write("%s\n" % dataset.get_tag(seq_idx))
        elif options.type == "dump_seq_len":
            seq_len = dataset.get_seq_length(seq_idx)[options.key]
            print(
                "seq %s tag:" %
                (progress if log.verbose[2] else progress_prefix),
                dataset.get_tag(seq_idx), "%r len:" % options.key, seq_len)
            dump_file.write("%r: %r,\n" % (dataset.get_tag(seq_idx), seq_len))
        else:
            data = dataset.get_data(seq_idx, options.key)
            if options.type == "numpy":
                numpy.savetxt(
                    "%s%i.data%s" %
                    (options.dump_prefix, seq_idx, options.dump_postfix), data)
            elif options.type == "stdout":
                print("seq %s tag:" % progress, dataset.get_tag(seq_idx))
                print("seq %s data:" % progress, pretty_print(data))
            elif options.type == "print_shape":
                print("seq %s data shape:" % progress, data.shape)
            elif options.type == "plot":
                plot(data)
            for target in dataset.get_target_list():
                targets = dataset.get_targets(target, seq_idx)
                if options.type == "numpy":
                    numpy.savetxt("%s%i.targets.%s%s" %
                                  (options.dump_prefix, seq_idx, target,
                                   options.dump_postfix),
                                  targets,
                                  fmt='%i')
                elif options.type == "stdout":
                    extra = ""
                    if target in dataset.labels and len(
                            dataset.labels[target]) > 1:
                        assert dataset.can_serialize_data(target)
                        extra += " (%r)" % dataset.serialize_data(key=target,
                                                                  data=targets)
                    print("seq %i target %r: %s%s" %
                          (seq_idx, target, pretty_print(targets), extra))
                elif options.type == "print_shape":
                    print("seq %i target %r shape:" % (seq_idx, target),
                          targets.shape)
            if options.type == "interactive":
                from returnn.util.debug import debug_shell
                debug_shell(locals())
        seq_len = dataset.get_seq_length(seq_idx)
        for key in dataset.get_data_keys():
            seq_len_stats[key].collect([seq_len[key]])
        if stats:
            stats.collect(data)
        if options.type == "null":
            util.progress_bar_with_time(complete_frac, prefix=progress_prefix)

        seq_idx += 1

    print("Done. Total time %s. More seqs which we did not dumped: %s" %
          (hms_fraction(time.time() - start_time),
           dataset.is_less_than_num_seqs(seq_idx)),
          file=log.v2)
    for key in dataset.get_data_keys():
        seq_len_stats[key].dump(stream_prefix="Seq-length %r " % key,
                                stream=log.v2)
    if stats:
        stats.dump(output_file_prefix=options.dump_stats,
                   stream_prefix="Data %r " % options.key,
                   stream=log.v1)
    if options.type == "dump_seq_len":
        dump_file.write("}\n")
    if dump_file:
        print("Dumped to file:", dump_file.name, file=log.v2)
        dump_file.close()
Esempio n. 2
0
def execute_main_task():
    """
  Executes the main task (via config ``task`` option).
  """
    from returnn.util.basic import hms_fraction
    start_time = time.time()
    task = config.value('task', 'train')
    if config.is_true("dry_run"):
        print("Dry run, will not save anything.", file=log.v1)
    if task == 'train':
        assert train_data.have_seqs(
        ), "no train files specified, check train option: %s" % config.value(
            'train', None)
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        engine.train()
    elif task == "eval":
        epoch = config.int("epoch", -1)
        load_epoch = config.int("load_epoch", -1)
        if epoch >= 0:
            assert (load_epoch < 0) or (
                load_epoch == epoch), "epoch and load_epoch have to match"
            engine.epoch = epoch
            config.set('load_epoch', engine.epoch)
        else:
            assert load_epoch >= 0, "specify epoch or load_epoch"
            engine.epoch = load_epoch
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        print("Evaluate epoch", engine.epoch, file=log.v4)
        engine.eval_model(
            output_file=config.value("eval_output_file", None),
            output_per_seq_file=config.value("eval_output_file_per_seq", None),
            loss_name=config.value("loss_name", None),
            output_per_seq_format=config.list("output_per_seq_format",
                                              ["score"]),
            output_per_seq_file_format=config.value(
                "output_per_seq_file_format", "txt"))
    elif task in ['forward', 'hpx']:
        assert eval_data is not None, 'no eval data provided'
        combine_labels = config.value('combine_labels', '')
        engine.use_search_flag = config.bool("forward_use_search", False)
        if config.has("epoch"):
            config.set('load_epoch', config.int('epoch', 0))
        engine.init_network_from_config(config)
        output_file = config.value('output_file',
                                   'dump-fwd-epoch-%i.hdf' % engine.epoch)
        engine.forward_to_hdf(data=eval_data,
                              output_file=output_file,
                              combine_labels=combine_labels,
                              batch_size=config.int('forward_batch_size', 0))
    elif task == "search":
        engine.use_search_flag = True
        engine.use_eval_flag = config.bool("search_do_eval", True)
        engine.init_network_from_config(config)
        if config.value("search_data", "eval") in ["train", "dev", "eval"]:
            data = {
                "train": train_data,
                "dev": dev_data,
                "eval": eval_data
            }[config.value("search_data", "eval")]
            assert data, "set search_data"
        else:
            data = init_dataset(config.opt_typed_value("search_data"))
        engine.search(
            data,
            do_eval=config.bool("search_do_eval", True),
            output_layer_names=config.typed_value("search_output_layer",
                                                  "output"),
            output_file=config.value("search_output_file", ""),
            output_file_format=config.value("search_output_file_format",
                                            "txt"))
    elif task == 'compute_priors':
        assert train_data is not None, 'train data for priors should be provided'
        engine.init_network_from_config(config)
        engine.compute_priors(dataset=train_data, config=config)
    elif task == 'theano_graph':
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.printing
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.compile.io
        # noinspection PyPackageRequirements,PyUnresolvedReferences
        import theano.compile.function_module
        engine.start_epoch = 1
        engine.init_network_from_config(config)
        for task in config.list('theano_graph.task', ['train']):
            func = engine.devices[-1].get_compute_func(task)
            prefix = config.value("theano_graph.prefix", "current") + ".task"
            print("dumping to %s.* ..." % prefix, file=log.v1)
            theano.printing.debugprint(func,
                                       file=open(
                                           "%s.optimized_func.txt" % prefix,
                                           "w"))
            assert isinstance(func.maker,
                              theano.compile.function_module.FunctionMaker)
            for inp in func.maker.inputs:
                assert isinstance(inp, theano.compile.io.In)
                if inp.update:
                    theano.printing.debugprint(
                        inp.update,
                        file=open(
                            "%s.unoptimized.var_%s_update.txt" %
                            (prefix, inp.name), "w"))
            theano.printing.pydotprint(func,
                                       format='png',
                                       var_with_name_simple=True,
                                       outfile="%s.png" % prefix)
    elif task == 'analyze':  # anything based on the network + Device
        statistics = config.list('statistics', None)
        engine.init_network_from_config(config)
        engine.analyze(data=eval_data or dev_data, statistics=statistics)
    elif task == "analyze_data":  # anything just based on the data
        analyze_data(config)
    elif task == "classify":
        assert eval_data is not None, 'no eval data provided'
        assert config.has('label_file'), 'no output file provided'
        label_file = config.value('label_file', '')
        engine.init_network_from_config(config)
        engine.classify(eval_data, label_file)
    elif task == "hyper_param_tuning":
        import returnn.tf.hyper_param_tuning
        tuner = returnn.tf.hyper_param_tuning.Optimization(
            config=config, train_data=train_data)
        tuner.work()
    elif task == "cleanup_old_models":
        engine.cleanup_old_models(ask_for_confirmation=True)
    elif task == "daemon":
        engine.init_network_from_config(config)
        engine.daemon(config)
    elif task == "server":
        print("Server Initiating", file=log.v1)
        server.run()
    elif task == "search_server":
        engine.use_search_flag = True
        engine.init_network_from_config(config)
        engine.web_server(port=config.int("web_server_port", 12380))
    elif task.startswith("config:"):
        action = config.typed_dict[task[len("config:"):]]
        print("Task: %r" % action, file=log.v1)
        assert callable(action)
        action()
    elif task.startswith("optional-config:"):
        action = config.typed_dict.get(task[len("optional-config:"):], None)
        if action is None:
            print("No task found for %r, so just quitting." % task,
                  file=log.v1)
        else:
            print("Task: %r" % action, file=log.v1)
            assert callable(action)
            action()
    elif task == "nop":
        print("Task: No-operation", file=log.v1)
    elif task == "nop_init_net_train":
        print(
            "Task: No-operation, despite initializing the network (for training)",
            file=log.v1)
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
    elif task == "initialize_model":
        engine.init_train_from_config(config, train_data, dev_data, eval_data)
        engine.save_model(config.value('model', 'dummy'))
    else:
        assert False, "unknown task: %s" % task

    print(("elapsed: %s" % hms_fraction(time.time() - start_time)),
          file=log.v3)
Esempio n. 3
0
def main():
    """
  Main entry.
  """
    global LstmCellTypes
    print("Benchmarking LSTMs.")
    better_exchook.install()
    print("Args:", " ".join(sys.argv))
    arg_parser = ArgumentParser()
    arg_parser.add_argument("cfg",
                            nargs="*",
                            help="opt=value, opt in %r" %
                            sorted(base_settings.keys()))
    arg_parser.add_argument("--no-cpu", action="store_true")
    arg_parser.add_argument("--no-gpu", action="store_true")
    arg_parser.add_argument("--selected",
                            help="comma-separated list from %r" %
                            LstmCellTypes)
    arg_parser.add_argument("--no-setup-tf-thread-pools", action="store_true")
    args = arg_parser.parse_args()
    for opt in args.cfg:
        key, value = opt.split("=", 1)
        assert key in base_settings
        value_type = type(base_settings[key])
        base_settings[key] = value_type(value)
    print("Settings:")
    pprint(base_settings)

    log.initialize(verbosity=[4])
    print("Returnn:", describe_returnn_version(), file=log.v3)
    print("TensorFlow:", describe_tensorflow_version(), file=log.v3)
    print("Python:", sys.version.replace("\n", ""), sys.platform)
    if not args.no_setup_tf_thread_pools:
        setup_tf_thread_pools(log_file=log.v2)
    else:
        print(
            "Not setting up the TF thread pools. Will be done automatically by TF to number of CPU cores."
        )
    if args.no_gpu:
        print("GPU will not be used.")
    else:
        print("GPU available: %r" % is_gpu_available())
    print_available_devices()

    if args.selected:
        LstmCellTypes = args.selected.split(",")
    benchmarks = {}
    if not args.no_gpu and is_gpu_available():
        for lstm_unit in LstmCellTypes:
            benchmarks["GPU:" + lstm_unit] = benchmark(lstm_unit=lstm_unit,
                                                       use_gpu=True)
    if not args.no_cpu:
        for lstm_unit in LstmCellTypes:
            if lstm_unit in GpuOnlyCellTypes:
                continue
            benchmarks["CPU:" + lstm_unit] = benchmark(lstm_unit=lstm_unit,
                                                       use_gpu=False)

    print("-" * 20)
    print("Settings:")
    pprint(base_settings)
    print("Final results:")
    for t, lstm_unit in sorted([
        (t, lstm_unit) for (lstm_unit, t) in sorted(benchmarks.items())
    ]):
        print("  %s: %s" % (lstm_unit, hms_fraction(t)))
    print("Done.")