Example #1
0
def test_assign_dev_data():
  config = Config()
  config.update(dummyconfig_dict)
  device = DummyDevice(config=config)
  dataset = DummyDataset(input_dim=config.int("num_inputs", 0),
                         output_dim=config.int("num_outputs", 0),
                         num_seqs=10)
  batches = [generate_batch(0, dataset), generate_batch(1, dataset)]
  success, num_batches = assign_dev_data(device, dataset, batches)
  assert_true(success)
  assert_equal(num_batches, len(batches))
Example #2
0
def main():
    global configFile, archiverExec, inputDim, outputDim
    parser = argparse.ArgumentParser()
    parser.add_argument('--sprintLoadParams',
                        required=True,
                        help='Sprint NN params path prefix')
    parser.add_argument('--sprintFirstLayer',
                        default=1,
                        type=int,
                        help='Sprint NN params first layer (default 1)')
    parser.add_argument('--crnnSaveEpoch',
                        type=int,
                        required=True,
                        help='save this train epoch number in RETURNN model')
    parser.add_argument('--crnnConfigFile',
                        required=True,
                        help='RETURNN (CRNN) config file')
    parser.add_argument('--sprintArchiverExec',
                        default=archiverExec,
                        help='path to Sprint/RASR archiver executable')
    parser.add_argument('--floatType',
                        default="f32",
                        help='float type (f32/f64)')
    args = parser.parse_args()

    configFile = args.crnnConfigFile
    assert os.path.exists(configFile), "RETURNN config file not found"
    archiverExec = args.sprintArchiverExec
    assert os.path.exists(archiverExec), "Sprint archiver not found"
    assert args.crnnSaveEpoch >= 1

    from returnn.config import Config
    global config
    config = Config()
    config.load_file(configFile)

    inputDim = config.int('num_inputs', None)
    outputDim = config.int('num_outputs', None)
    assert inputDim and outputDim

    layers = loadSprintNetwork(params_prefix_path=args.sprintLoadParams,
                               first_layer=args.sprintFirstLayer,
                               float_type=args.floatType)
    saveCrnnNetwork(epoch=args.crnnSaveEpoch, layers=layers)

    print("Done.")
Example #3
0
def init_config(config_filename=None,
                command_line_options=(),
                default_config=None,
                extra_updates=None):
    """
  :param str|None config_filename:
  :param list[str]|tuple[str] command_line_options: e.g. ``sys.argv[1:]``
  :param dict[str]|None default_config:
  :param dict[str]|None extra_updates:

  Initializes the global config.
  There are multiple sources which are used to init the config:

    * ``configFilename``, and maybe first item of ``commandLineOptions`` interpret as config filename
    * other options via ``commandLineOptions``
    * ``extra_updates``

  Note about the order/priority of these:

    * ``extra_updates``
    * options from ``commandLineOptions``
    * ``configFilename``
    * config filename from ``commandLineOptions[0]``
    * ``extra_updates``
    * options from ``commandLineOptions``

  ``extra_updates`` and ``commandLineOptions`` are used twice so that they are available
  when the config is loaded, which thus has access to them, and can e.g. use them via Python code.
  However, the purpose is that they overwrite any option from the config;
  that is why we apply them again in the end.

  ``commandLineOptions`` is applied after ``extra_updates`` so that the user has still the possibility
  to overwrite anything set by ``extra_updates``.
  """
    global config
    config = Config()

    config_filenames_by_cmd_line = []
    if command_line_options:
        # Assume that the first argument prefixed with "+" or "-" and all following is not a config file.
        i = 0
        for arg in command_line_options:
            if arg[:1] in "-+":
                break
            config_filenames_by_cmd_line.append(arg)
            i += 1
        command_line_options = command_line_options[i:]

    if default_config:
        config.update(default_config)
    if extra_updates:
        config.update(extra_updates)
    if command_line_options:
        config.parse_cmd_args(command_line_options)
    if config_filename:
        config.load_file(config_filename)
    for fn in config_filenames_by_cmd_line:
        config.load_file(fn)
    if extra_updates:
        config.update(extra_updates)
    if command_line_options:
        config.parse_cmd_args(command_line_options)

    # I really don't know where to put this otherwise:
    if config.bool("EnableAutoNumpySharedMemPickling", False):
        import returnn.util.task_system
        returnn.util.task_system.SharedMemNumpyConfig["enabled"] = True
    # Server default options
    if config.value('task', 'train') == 'server':
        config.set('num_inputs', 2)
        config.set('num_outputs', 1)

    BehaviorVersion.set(config.int('behavior_version', None))