示例#1
0
 def _init_network(self, net_desc, epoch=None):
   if epoch is None:
     epoch = self.epoch
   self._close_tf_session()
   self._reset_graph()
   # The new session will by default use the newly created default graph.
   self._make_tf_session()
   tf.set_random_seed(42)
   from TFUtil import get_global_train_flag_placeholder
   if self.use_dynamic_train_flag:
     train_flag = get_global_train_flag_placeholder()
   else:
     train_flag = False
   if False:  # TODO ...
     extern_data = ExternData()
     extern_data.init_from_config(self.config)
     # TODO...
   network = TFNetwork(
     name="root",
     config=self.config,
     rnd_seed=epoch,
     train_flag=train_flag,
     search_flag=self.use_search_flag)
   network.construct_from_dict(net_desc)
   network.initialize_params(session=self.tf_session)
   network.layers_desc = net_desc
   self.network = network
   if self.train_data:
     # Need to create new Updater because it has the learning_rate var which must be in the current graph.
     self.updater = Updater(config=self.config, tf_session=self.tf_session, network=network)
     self.updater.set_trainable_vars(network.get_trainable_params())
   network.print_network_info()
示例#2
0
def main(argv):
    argparser = argparse.ArgumentParser(description='Compile some op')
    argparser.add_argument('config', help="filename to config-file")
    argparser.add_argument('--train',
                           type=int,
                           default=0,
                           help='0 disable (default), 1 enable, -1 dynamic')
    argparser.add_argument(
        '--eval',
        type=int,
        default=0,
        help='calculate losses. 0 disable (default), 1 enable')
    argparser.add_argument('--search',
                           type=int,
                           default=0,
                           help='beam search. 0 disable (default), 1 enable')
    argparser.add_argument("--verbosity",
                           default=4,
                           type=int,
                           help="5 for all seqs (default: 4)")
    argparser.add_argument("--summaries_tensor_name",
                           help="create Tensor for tf.summary.merge_all()")
    argparser.add_argument(
        "--rec_step_by_step",
        help="make step-by-step graph for this rec layer (eg. 'output')")
    argparser.add_argument("--rec_step_by_step_output_file",
                           help="store meta info for rec_step_by_step (JSON)")
    argparser.add_argument(
        "--output_file",
        help='allowed extensions: pb, pbtxt, meta, metatxt, logdir')
    argparser.add_argument("--output_file_model_params_list",
                           help="line-based, names of model params")
    argparser.add_argument("--output_file_state_vars_list",
                           help="line-based, name of state vars")
    args = argparser.parse_args(argv[1:])
    assert args.train in [0, 1, -1
                          ] and args.eval in [0, 1] and args.search in [0, 1]
    init(config_filename=args.config, log_verbosity=args.verbosity)
    assert 'network' in config.typed_dict
    net_dict = config.typed_dict["network"]
    if args.rec_step_by_step:
        RecStepByStepLayer.prepare_compile(
            rec_layer_name=args.rec_step_by_step, net_dict=net_dict)
    with tf.Graph().as_default() as graph:
        assert isinstance(graph, tf.Graph)
        print("Create graph...")
        # See :func:`Engine._init_network`.
        tf.set_random_seed(42)
        if args.train < 0:
            from TFUtil import get_global_train_flag_placeholder
            train_flag = get_global_train_flag_placeholder()
        else:
            train_flag = bool(args.train)
        eval_flag = bool(args.eval)
        search_flag = bool(args.search)
        network = create_graph(train_flag=train_flag,
                               eval_flag=eval_flag,
                               search_flag=search_flag,
                               net_dict=net_dict)

        if args.rec_step_by_step:
            RecStepByStepLayer.post_compile(
                rec_layer_name=args.rec_step_by_step,
                network=network,
                output_file_name=args.rec_step_by_step_output_file)

        from TFNetworkLayer import LayerBase
        for layer in network.layers.values():
            assert isinstance(layer, LayerBase)
            if layer.output.time_dim_axis is None:
                continue
            with layer.cls_layer_scope(layer.name):
                tf.identity(layer.output.get_placeholder_as_batch_major(),
                            name="output_batch_major")

        tf.group(*network.get_post_control_dependencies(),
                 name="post_control_dependencies")

        # Do some cleanup of collections which do not contain tensors or operations,
        # because the tf.train.import_meta_graph code might fail otherwise.
        tf.get_collection_ref(CollectionKeys.RETURNN_LAYERS).clear()

        if args.summaries_tensor_name:
            summaries_tensor = tf.summary.merge_all()
            assert isinstance(summaries_tensor,
                              tf.Tensor), "no summaries in the graph?"
            tf.identity(summaries_tensor, name=args.summaries_tensor_name)

        if args.output_file and os.path.splitext(
                args.output_file)[1] in [".meta", ".metatxt"]:
            # https://www.tensorflow.org/api_guides/python/meta_graph
            saver = tf.train.Saver(var_list=network.get_saveable_params_list(),
                                   max_to_keep=2**31 - 1)
            graph_def = saver.export_meta_graph()
        else:
            graph_def = graph.as_graph_def(add_shapes=True)

        print("Graph collection keys:", graph.get_all_collection_keys())
        print("Graph num operations:", len(graph.get_operations()))
        print("Graph def size:", Util.human_bytes_size(graph_def.ByteSize()))

        if args.output_file:
            filename = args.output_file
            _, ext = os.path.splitext(filename)
            if ext == ".logdir":
                print("Write TF events to logdir:", filename)
                writer = tf.summary.FileWriter(logdir=filename)
                writer.add_graph(graph)
                writer.flush()
            else:
                assert ext in [".pb", ".pbtxt", ".meta", ".metatxt"
                               ], 'filename %r extension invalid' % filename
                print("Write graph to file:", filename)
                graph_io.write_graph(graph_def,
                                     logdir=os.path.dirname(filename),
                                     name=os.path.basename(filename),
                                     as_text=ext.endswith("txt"))
        else:
            print("Use --output_file if you want to store the graph.")

        if args.output_file_model_params_list:
            print("Write model param list to:",
                  args.output_file_model_params_list)
            with open(args.output_file_model_params_list, "w") as f:
                for param in network.get_params_list():
                    assert param.name[-2:] == ":0"
                    f.write("%s\n" % param.name[:-2])

        if args.output_file_state_vars_list:
            print("Write state var list to:", args.output_file_state_vars_list)
            with open(args.output_file_state_vars_list, "w") as f:
                for param in tf.get_collection(CollectionKeys.STATE_VARS):
                    assert param.name[-2:] == ":0"
                    f.write("%s\n" % param.name[:-2])
示例#3
0
def main(argv):
    argparser = argparse.ArgumentParser(description='Compile some op')
    argparser.add_argument('config', help="filename to config-file")
    argparser.add_argument('--train',
                           type=int,
                           default=0,
                           help='0 disable (default), 1 enable, -1 dynamic')
    argparser.add_argument(
        '--eval',
        type=int,
        default=0,
        help='calculate losses. 0 disable (default), 1 enable')
    argparser.add_argument('--search',
                           type=int,
                           default=0,
                           help='beam search. 0 disable (default), 1 enable')
    argparser.add_argument("--verbosity",
                           default=4,
                           type=int,
                           help="5 for all seqs (default: 4)")
    argparser.add_argument("--summaries_tensor_name")
    argparser.add_argument("--output_file", help='output pb or pbtxt file')
    argparser.add_argument("--output_file_model_params_list",
                           help="line-based, names of model params")
    argparser.add_argument("--output_file_state_vars_list",
                           help="line-based, name of state vars")
    args = argparser.parse_args(argv[1:])
    assert args.train in [0, 1, 2
                          ] and args.eval in [0, 1] and args.search in [0, 1]
    init(config_filename=args.config, log_verbosity=args.verbosity)
    with tf.Graph().as_default() as graph:
        assert isinstance(graph, tf.Graph)
        print("Create graph...")
        # See :func:`Engine._init_network`.
        tf.set_random_seed(42)
        if args.train < 0:
            from TFUtil import get_global_train_flag_placeholder
            train_flag = get_global_train_flag_placeholder()
        else:
            train_flag = bool(args.train)
        eval_flag = bool(args.eval)
        search_flag = bool(args.search)
        network = create_graph(train_flag=train_flag,
                               eval_flag=eval_flag,
                               search_flag=search_flag)

        from TFNetworkLayer import LayerBase
        for layer in network.layers.values():
            assert isinstance(layer, LayerBase)
            if layer.output.time_dim_axis is None:
                continue
            with layer.cls_layer_scope(layer.name):
                tf.identity(layer.output.get_placeholder_as_batch_major(),
                            name="output_batch_major")

        tf.group(*network.post_control_dependencies,
                 name="post_control_dependencies")

        if args.summaries_tensor_name:
            summaries_tensor = tf.summary.merge_all()
            assert isinstance(summaries_tensor,
                              tf.Tensor), "no summaries in the graph?"
            tf.identity(summaries_tensor, name=args.summaries_tensor_name)

        print("Graph collection keys:", graph.get_all_collection_keys())
        print("Graph num operations:", len(graph.get_operations()))
        graph_def = graph.as_graph_def(add_shapes=True)
        print("Graph def size:", Util.human_bytes_size(graph_def.ByteSize()))

        if args.output_file:
            filename = args.output_file
            _, ext = os.path.splitext(filename)
            assert ext in [
                ".pb", ".pbtxt"
            ], 'filename %r extension should be pb or pbtxt' % filename
            print("Write graph to file:", filename)
            graph_io.write_graph(graph_def,
                                 logdir=os.path.dirname(filename),
                                 name=os.path.basename(filename),
                                 as_text=(ext == ".pbtxt"))
        else:
            print("Use --output_file if you want to store the graph.")

        if args.output_file_model_params_list:
            print("Write model param list to:",
                  args.output_file_model_params_list)
            with open(args.output_file_model_params_list, "w") as f:
                for param in network.get_params_list():
                    assert param.name[-2:] == ":0"
                    f.write("%s\n" % param.name[:-2])

        if args.output_file_state_vars_list:
            print("Write state var list to:", args.output_file_state_vars_list)
            from TFUtil import CollectionKeys
            with open(args.output_file_state_vars_list, "w") as f:
                for param in tf.get_collection(CollectionKeys.STATE_VARS):
                    assert param.name[-2:] == ":0"
                    f.write("%s\n" % param.name[:-2])
示例#4
0
def main(argv):
  argparser = argparse.ArgumentParser(description='Compile some op')
  argparser.add_argument('config', help="filename to config-file")
  argparser.add_argument('--train', type=int, default=0, help='0 disable (default), 1 enable, -1 dynamic')
  argparser.add_argument('--eval', type=int, default=0, help='calculate losses. 0 disable (default), 1 enable')
  argparser.add_argument('--search', type=int, default=0, help='beam search. 0 disable (default), 1 enable')
  argparser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)")
  argparser.add_argument("--summaries_tensor_name")
  argparser.add_argument("--output_file", help='output pb, pbtxt or meta, metatxt file')
  argparser.add_argument("--output_file_model_params_list", help="line-based, names of model params")
  argparser.add_argument("--output_file_state_vars_list", help="line-based, name of state vars")
  args = argparser.parse_args(argv[1:])
  assert args.train in [0, 1, 2] and args.eval in [0, 1] and args.search in [0, 1]
  init(config_filename=args.config, log_verbosity=args.verbosity)
  with tf.Graph().as_default() as graph:
    assert isinstance(graph, tf.Graph)
    print("Create graph...")
    # See :func:`Engine._init_network`.
    tf.set_random_seed(42)
    if args.train < 0:
      from TFUtil import get_global_train_flag_placeholder
      train_flag = get_global_train_flag_placeholder()
    else:
      train_flag = bool(args.train)
    eval_flag = bool(args.eval)
    search_flag = bool(args.search)
    network = create_graph(train_flag=train_flag, eval_flag=eval_flag, search_flag=search_flag)

    from TFNetworkLayer import LayerBase
    for layer in network.layers.values():
      assert isinstance(layer, LayerBase)
      if layer.output.time_dim_axis is None:
        continue
      with layer.cls_layer_scope(layer.name):
        tf.identity(layer.output.get_placeholder_as_batch_major(), name="output_batch_major")

    tf.group(*network.get_post_control_dependencies(), name="post_control_dependencies")

    if args.summaries_tensor_name:
      summaries_tensor = tf.summary.merge_all()
      assert isinstance(summaries_tensor, tf.Tensor), "no summaries in the graph?"
      tf.identity(summaries_tensor, name=args.summaries_tensor_name)

    if args.output_file and os.path.splitext(args.output_file)[1] in [".meta", ".metatxt"]:
      # https://www.tensorflow.org/api_guides/python/meta_graph
      saver = tf.train.Saver(
        var_list=network.get_saveable_params_list(), max_to_keep=2 ** 31 - 1)
      graph_def = saver.export_meta_graph()
    else:
      graph_def = graph.as_graph_def(add_shapes=True)

    print("Graph collection keys:", graph.get_all_collection_keys())
    print("Graph num operations:", len(graph.get_operations()))
    print("Graph def size:", Util.human_bytes_size(graph_def.ByteSize()))

    if args.output_file:
      filename = args.output_file
      _, ext = os.path.splitext(filename)
      assert ext in [".pb", ".pbtxt", ".meta", ".metatxt"], 'filename %r extension invalid' % filename
      print("Write graph to file:", filename)
      graph_io.write_graph(
        graph_def,
        logdir=os.path.dirname(filename),
        name=os.path.basename(filename),
        as_text=ext.endswith("txt"))
    else:
      print("Use --output_file if you want to store the graph.")

    if args.output_file_model_params_list:
      print("Write model param list to:", args.output_file_model_params_list)
      with open(args.output_file_model_params_list, "w") as f:
        for param in network.get_params_list():
          assert param.name[-2:] == ":0"
          f.write("%s\n" % param.name[:-2])

    if args.output_file_state_vars_list:
      print("Write state var list to:", args.output_file_state_vars_list)
      from TFUtil import CollectionKeys
      with open(args.output_file_state_vars_list, "w") as f:
        for param in tf.get_collection(CollectionKeys.STATE_VARS):
          assert param.name[-2:] == ":0"
          f.write("%s\n" % param.name[:-2])
示例#5
0
  def __init__(self, extern_data, capacity=100, seed=1, with_batch=False, enqueue_data=None):
    """
    :param ExternData extern_data: this specifies the data keys
    :param int capacity:
    :param int seed:
    :param bool with_batch: whether we have the batch-dim in input/output
    :param dict[str,tf.Tensor] enqueue_data: if provided, will be the input
    """
    self.extern_data = extern_data
    self.data_keys = extern_data.data.keys()
    self.with_batch = with_batch
    self.enqueue_data = enqueue_data

    # http://stackoverflow.com/questions/41187745/tensorflow-how-can-i-evaluate-a-validation-data-queue-multiple-times-during-tra/44067467#44067467
    # I.e. we need two separate queues, one for training (RandomShuffleQueue) and one for eval (FIFOQueue),
    # and switch between the dequeue via tf.cond.
    from TFUtil import cond, get_global_train_flag_placeholder
    self.train_flag = get_global_train_flag_placeholder()
    self.names = list(self.data_keys)
    self.dtypes = [self.extern_data.data[key].dtype for key in self.data_keys]
    self.shapes = {
      key: data.batch_shape if with_batch else data.shape
      for (key, data) in self.extern_data.data.items()}
    for key, data in self.extern_data.data.items():
      for axis in data.get_axes_with_size():
        self.shapes["%s/size%i" % (key, axis)] = (None,) if with_batch else ()

    self.enqueue_placeholders = None
    if not self.enqueue_data:
      self.enqueue_placeholders = {
        key: tf.placeholder(**self.extern_data.data[key].get_placeholder_kwargs(with_batch=with_batch))
        for key in self.data_keys}
      for key in self.data_keys:
        for axis in self.extern_data.data[key].get_axes_with_size():
          name = "%s/size%i" % (key, axis)
          self.names += [name]
          self.dtypes += [self.extern_data.data[key].size_dtype]
          self.enqueue_placeholders[name] = tf.placeholder(
            **self.extern_data.data[key].get_size_placeholder_kwargs(axis, with_batch=with_batch))
      self.enqueue_data = self.enqueue_placeholders

    # TF recommendation: capacity = min_after_dequeue + (num_threads + a small safety margin) * batch_size
    self.capacity = capacity
    self.train_queue_min_after_dequeue = int(capacity * 0.8)
    self.train_queue = tf.RandomShuffleQueue(
      capacity=self.capacity, min_after_dequeue=self.train_queue_min_after_dequeue,
      names=self.names, dtypes=self.dtypes,
      seed=seed, name="train_queue")
    self.eval_queue = tf.FIFOQueue(
      capacity=self.capacity, names=self.names, dtypes=self.dtypes, name="eval_queue")
    self.train_queue_size = self.train_queue.size()
    self.eval_queue_size = self.eval_queue.size()
    self.dequeue_size_op = cond(
      self.train_flag,
      lambda: self.train_queue.size() - self.train_queue_min_after_dequeue,
      lambda: self.eval_queue.size())
    self.have_more_op = tf.greater(self.dequeue_size_op, 0, name="queue_have_more")
    self.one_more_enqueue_is_enough_op = tf.greater(self.dequeue_size_op, -1, name="queue_have_more")
    self.enqueue_op = cond(
      self.train_flag,
      lambda: self.train_queue.enqueue(self.enqueue_data),
      lambda: self.eval_queue.enqueue(self.enqueue_data),
      name="queue_enqueue")