def main(argv):
  arg_parser = argparse.ArgumentParser(description='Dump search scores and other info to HDF file.')
  arg_parser.add_argument('config', help="filename to config-file")
  arg_parser.add_argument("--dataset", default="config:train")
  arg_parser.add_argument("--epoch", type=int, default=-1, help="-1 for last epoch")
  arg_parser.add_argument("--output_file", help='hdf', required=True)
  arg_parser.add_argument("--rec_layer_name", default="output")
  arg_parser.add_argument("--cheating", action="store_true", help="add ground truth to the beam")
  arg_parser.add_argument("--att_weights", action="store_true", help="dump all softmax_over_spatial layers")
  arg_parser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)")
  arg_parser.add_argument("--seq_list", nargs="+", help="use only these seqs")
  args, remaining_args = arg_parser.parse_known_args(argv[1:])
  init(config_filename=args.config, log_verbosity=args.verbosity, remaining_args=remaining_args)

  dataset = init_dataset(args.dataset)
  print("Dataset:")
  pprint(dataset)
  if args.seq_list:
    dataset.seq_tags_filter = set(args.seq_list)
    dataset.partition_epoch = 1  # reset
    if isinstance(dataset, MetaDataset):
      for sub_dataset in dataset.datasets.values():
        dataset.seq_tags_filter = set(args.seq_list)
        sub_dataset.partition_epoch = 1
    dataset.finish_epoch()  # enforce reset
  if dataset.seq_tags_filter is not None:
    print("Using sequences:")
    pprint(dataset.seq_tags_filter)
  if args.epoch >= 1:
    config.set("load_epoch", args.epoch)

  def net_dict_post_proc(net_dict):
    """
    :param dict[str] net_dict:
    :return: net_dict
    :rtype: dict[str]
    """
    prepare_compile(
      rec_layer_name=args.rec_layer_name, net_dict=net_dict,
      cheating=args.cheating, dump_att_weights=args.att_weights,
      hdf_filename=args.output_file, possible_labels=dataset.labels)
    return net_dict

  engine = Engine(config=config)
  engine.use_search_flag = True
  engine.init_network_from_config(config, net_dict_post_proc=net_dict_post_proc)
  engine.search(
    dataset,
    do_eval=config.bool("search_do_eval", True),
    output_layer_names=args.rec_layer_name)
  engine.finalize()
  print("Search finished.")
  assert os.path.exists(args.output_file), "hdf file not dumped?"
Exemplo n.º 2
0
def benchmark(lstm_unit, use_gpu):
  """
  :param str lstm_unit: e.g. "LSTMBlock", one of LstmCellTypes
  :param bool use_gpu:
  :return: runtime in seconds of the training itself, excluding initialization
  :rtype: float
  """
  device = {True: "GPU", False: "CPU"}[use_gpu]
  key = "%s:%s" % (device, lstm_unit)
  print(">>> Start benchmark for %s." % key)
  config = Config()
  config.update(make_config_dict(lstm_unit=lstm_unit, use_gpu=use_gpu))
  dataset_kwargs = config.typed_value("train")
  Dataset.kwargs_update_from_config(config, dataset_kwargs)
  dataset = init_dataset(dataset_kwargs)
  engine = Engine(config=config)
  engine.init_train_from_config(config=config, train_data=dataset)
  print(">>> Start training now for %s." % key)
  start_time = time.time()
  engine.train()
  runtime = time.time() - start_time
  print(">>> Runtime of %s: %s" % (key, hms_fraction(runtime)))
  engine.finalize()
  return runtime
Exemplo n.º 3
0
def benchmark(lstm_unit, use_gpu):
  """
  :param str lstm_unit: e.g. "LSTMBlock", one of LstmCellTypes
  :param bool use_gpu:
  :return: runtime in seconds of the training itself, excluding initialization
  :rtype: float
  """
  device = {True: "GPU", False: "CPU"}[use_gpu]
  key = "%s:%s" % (device, lstm_unit)
  print(">>> Start benchmark for %s." % key)
  config = Config()
  config.update(make_config_dict(lstm_unit=lstm_unit, use_gpu=use_gpu))
  dataset_kwargs = config.typed_value("train")
  Dataset.kwargs_update_from_config(config, dataset_kwargs)
  dataset = init_dataset(dataset_kwargs)
  engine = Engine(config=config)
  engine.init_train_from_config(config=config, train_data=dataset)
  print(">>> Start training now for %s." % key)
  start_time = time.time()
  engine.train()
  runtime = time.time() - start_time
  print(">>> Runtime of %s: %s" % (key, hms_fraction(runtime)))
  engine.finalize()
  return runtime