def test_enc_dec1_init(): config = Config() config.load_file(StringIO(config_enc_dec1_json)) network_json = LayerNetwork.json_from_config(config) assert_true(network_json) network = LayerNetwork.from_json_and_config(network_json, config) assert_true(network)
def create_first_epoch(config_filename): config = Config() config.load_file(config_filename) engine = Engine([]) engine.init_train_from_config(config=config, train_data=None) engine.epoch = 1 engine.save_model(engine.get_epoch_model_filename(), epoch=engine.epoch) Engine._epoch_model = None
def is_crnn_config(filename): if filename.endswith(".gz"): return False try: config = Config() config.load_file(filename) return True except Exception: pass return False
def cleanup_tmp_models(config_filename): assert os.path.exists(config_filename) from returnn.config import Config config = Config() config.load_file(config_filename) model_filename = config.value('model', '') assert model_filename # Remove existing models assert model_filename.startswith("/tmp/") for f in glob(model_filename + ".*"): os.remove(f)
def is_returnn_config(filename): """ :param str filename: :rtype: bool """ if filename.endswith(".gz"): return False # noinspection PyBroadException try: config = Config() config.load_file(filename) return True except Exception: pass return False
def main(): global configFile, archiverExec, inputDim, outputDim parser = argparse.ArgumentParser() parser.add_argument('--sprintLoadParams', required=True, help='Sprint NN params path prefix') parser.add_argument('--sprintFirstLayer', default=1, type=int, help='Sprint NN params first layer (default 1)') parser.add_argument('--crnnSaveEpoch', type=int, required=True, help='save this train epoch number in RETURNN model') parser.add_argument('--crnnConfigFile', required=True, help='RETURNN (CRNN) config file') parser.add_argument('--sprintArchiverExec', default=archiverExec, help='path to Sprint/RASR archiver executable') parser.add_argument('--floatType', default="f32", help='float type (f32/f64)') args = parser.parse_args() configFile = args.crnnConfigFile assert os.path.exists(configFile), "RETURNN config file not found" archiverExec = args.sprintArchiverExec assert os.path.exists(archiverExec), "Sprint archiver not found" assert args.crnnSaveEpoch >= 1 from returnn.config import Config global config config = Config() config.load_file(configFile) inputDim = config.int('num_inputs', None) outputDim = config.int('num_outputs', None) assert inputDim and outputDim layers = loadSprintNetwork(params_prefix_path=args.sprintLoadParams, first_layer=args.sprintFirstLayer, float_type=args.floatType) saveCrnnNetwork(epoch=args.crnnSaveEpoch, layers=layers) print("Done.")
def test_enc_dec1_hdf(): filename = tempfile.mktemp(prefix="returnn-model-test") model = h5py.File(filename, "w") config = Config() config.load_file(StringIO(config_enc_dec1_json)) network_json = LayerNetwork.json_from_config(config) assert_true(network_json) network = LayerNetwork.from_json_and_config(network_json, config) assert_true(network) network.save_hdf(model, epoch=42) model.close() loaded_model = h5py.File(filename, "r") loaded_net = LayerNetwork.from_hdf_model_topology(loaded_model) assert_true(loaded_net) assert_equal(sorted(network.hidden.keys()), sorted(loaded_net.hidden.keys())) assert_equal(sorted(network.y.keys()), sorted(loaded_net.y.keys())) assert_equal(sorted(network.j.keys()), sorted(loaded_net.j.keys())) os.remove(filename)
def init_config(config_filename=None, command_line_options=(), default_config=None, extra_updates=None): """ :param str|None config_filename: :param list[str]|tuple[str] command_line_options: e.g. ``sys.argv[1:]`` :param dict[str]|None default_config: :param dict[str]|None extra_updates: Initializes the global config. There are multiple sources which are used to init the config: * ``configFilename``, and maybe first item of ``commandLineOptions`` interpret as config filename * other options via ``commandLineOptions`` * ``extra_updates`` Note about the order/priority of these: * ``extra_updates`` * options from ``commandLineOptions`` * ``configFilename`` * config filename from ``commandLineOptions[0]`` * ``extra_updates`` * options from ``commandLineOptions`` ``extra_updates`` and ``commandLineOptions`` are used twice so that they are available when the config is loaded, which thus has access to them, and can e.g. use them via Python code. However, the purpose is that they overwrite any option from the config; that is why we apply them again in the end. ``commandLineOptions`` is applied after ``extra_updates`` so that the user has still the possibility to overwrite anything set by ``extra_updates``. """ global config config = Config() config_filenames_by_cmd_line = [] if command_line_options: # Assume that the first argument prefixed with "+" or "-" and all following is not a config file. i = 0 for arg in command_line_options: if arg[:1] in "-+": break config_filenames_by_cmd_line.append(arg) i += 1 command_line_options = command_line_options[i:] if default_config: config.update(default_config) if extra_updates: config.update(extra_updates) if command_line_options: config.parse_cmd_args(command_line_options) if config_filename: config.load_file(config_filename) for fn in config_filenames_by_cmd_line: config.load_file(fn) if extra_updates: config.update(extra_updates) if command_line_options: config.parse_cmd_args(command_line_options) # I really don't know where to put this otherwise: if config.bool("EnableAutoNumpySharedMemPickling", False): import returnn.util.task_system returnn.util.task_system.SharedMemNumpyConfig["enabled"] = True # Server default options if config.value('task', 'train') == 'server': config.set('num_inputs', 2) config.set('num_outputs', 1)
def demo(): """ Demo. """ print("SprintDataset demo.") from argparse import ArgumentParser from returnn.util.basic import progress_bar_with_time from returnn.log import log from returnn.config import Config from returnn.datasets.basic import init_dataset arg_parser = ArgumentParser() arg_parser.add_argument("--config", help="config with ExternSprintDataset", required=True) arg_parser.add_argument("--sprint_cache_dataset", help="kwargs dict for SprintCacheDataset", required=True) arg_parser.add_argument("--max_num_seqs", default=sys.maxsize, type=int) arg_parser.add_argument("--action", default="compare", help="compare or benchmark") args = arg_parser.parse_args() log.initialize(verbosity=[4]) sprint_cache_dataset_kwargs = eval(args.sprint_cache_dataset) assert isinstance(sprint_cache_dataset_kwargs, dict) sprint_cache_dataset = SprintCacheDataset(**sprint_cache_dataset_kwargs) print("SprintCacheDataset: %r" % sprint_cache_dataset) config = Config() config.load_file(args.config) dataset = init_dataset(config.typed_value("train")) print("Dataset via config: %r" % dataset) assert sprint_cache_dataset.num_inputs == dataset.num_inputs assert tuple(sprint_cache_dataset.num_outputs["classes"]) == tuple(dataset.num_outputs["classes"]) sprint_cache_dataset.init_seq_order(epoch=1) if args.action == "compare": print("Iterating through dataset...") seq_idx = 0 dataset.init_seq_order(epoch=1) while seq_idx < args.max_num_seqs: if not dataset.is_less_than_num_seqs(seq_idx): break dataset.load_seqs(seq_idx, seq_idx + 1) tag = dataset.get_tag(seq_idx) assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs" dataset_seq = sprint_cache_dataset.get_dataset_seq_for_name(tag) data = dataset.get_data(seq_idx, "data") targets = dataset.get_data(seq_idx, "classes") assert data.shape == dataset_seq.features["data"].shape assert targets.shape == dataset_seq.features["classes"].shape assert numpy.allclose(data, dataset_seq.features["data"]) assert numpy.allclose(targets, dataset_seq.features["classes"]) seq_idx += 1 progress_bar_with_time(dataset.get_complete_frac(seq_idx)) print("Finished through dataset. Num seqs: %i" % seq_idx) print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs) elif args.action == "benchmark": print("Iterating through dataset...") start_time = time.time() seq_tags = [] seq_idx = 0 dataset.init_seq_order(epoch=1) while seq_idx < args.max_num_seqs: if not dataset.is_less_than_num_seqs(seq_idx): break dataset.load_seqs(seq_idx, seq_idx + 1) tag = dataset.get_tag(seq_idx) assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs" seq_tags.append(tag) dataset.get_data(seq_idx, "data") dataset.get_data(seq_idx, "classes") seq_idx += 1 progress_bar_with_time(dataset.get_complete_frac(seq_idx)) print("Finished through dataset. Num seqs: %i, time: %f" % (seq_idx, time.time() - start_time)) print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs) if hasattr(dataset, "exit_handler"): dataset.exit_handler() else: print("No way to stop any background tasks.") del dataset start_time = time.time() print("Iterating through SprintCacheDataset...") for i, tag in enumerate(seq_tags): sprint_cache_dataset.get_dataset_seq_for_name(tag) progress_bar_with_time(float(i) / len(seq_tags)) print("Finished through SprintCacheDataset. time: %f" % (time.time() - start_time,)) else: raise Exception("invalid action: %r" % args.action)
def main(): """ Main entry point. """ arg_parser = argparse.ArgumentParser() arg_parser.add_argument("--config", help="RETURNN config") arg_parser.add_argument( "--learning-rate-file", help="The learning rate file contains scores / errors per epoch.") arg_parser.add_argument("--key", help="key to use, e.g. 'dev_error'") arg_parser.add_argument("--n", type=int, default=5, help="print best N epochs") args = arg_parser.parse_args() if bool(args.config) == bool(args.learning_rate_file): print("Error: provide either --config or --learning-rate-file") arg_parser.print_help() sys.exit(1) if args.config: config = Config() config.load_file(args.config) lr = LearningRateControl.load_initial_from_config(config) elif args.learning_rate_file: lr = LearningRateControl( default_learning_rate=1, # default lr not relevant filename=args.learning_rate_file) else: assert False, "should not get here with %r" % args epochs = sorted(lr.epoch_data.keys()) if not epochs: print("Error: no epochs found") sys.exit(1) print("Loaded epochs", epochs[0], "..", epochs[-1]) if args.key: key = args.key print("Using key %s" % key) else: last_epoch_with_error_info = None for ep in reversed(epochs): if lr.epoch_data[ep].error: last_epoch_with_error_info = ep break if last_epoch_with_error_info is None: print("Error: no scores/errors found") sys.exit(1) key = lr.get_error_key(last_epoch_with_error_info) print("Using key %s (auto via epoch %i)" % (key, last_epoch_with_error_info)) epochs_ = [] missing_epochs = [] for ep in epochs: errors = lr.epoch_data[ep].error if key in errors: epochs_.append((errors[key], ep)) else: missing_epochs.append(ep) if len(epochs_) == len(epochs): print("All epochs have the key.") else: print("Epochs missing the key:", missing_epochs) assert epochs_ epochs_.sort() for value, ep in epochs_[:args.n]: errors = lr.epoch_data[ep].error print(", ".join( ["Epoch %i" % ep, "%r %r" % (key, value)] + ["%r %r" % (k, v) for k, v in errors.items() if k != key]))