def __init__(self,
                 hparams,
                 train_iterations,
                 eval_steps,
                 per_host_v1=False):
        tf.logging.info("TrainLowLevelRunner: constructor")

        self.feature_structure = {}
        self.eval_feature_structure = {}
        self.loss = None
        self.infeed_queue = []
        self.eval_infeed_queue = []
        self.enqueue_ops = []
        self.eval_enqueue_ops = []
        self.dataset_initializer = []
        self.eval_dataset_initializer = []
        self.is_local = ((hparams.master == "") and (hparams.tpu_name is None))
        self.per_host_v1 = per_host_v1
        self.iterations = train_iterations
        self.eval_steps = eval_steps
        self.outfeed_tensors = []
        self.outfeed_names = []
        self.dequeue_ops = []
        self.predictions = {}
        self.sess = None
        self.graph = tf.Graph()
        self.hparams = hparams
        self.num_hosts = hparams.num_shards // hparams.num_shards_per_host
        with self.graph.as_default():
            self.tpu_init = [tpu.initialize_system()]
            self.tpu_shutdown = tpu.shutdown_system()

        self.resolver = get_resolver(hparams)
        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            isolate_session_state=True,
            operation_timeout_in_ms=600 * 60 * 1000,
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True)))

        if self.hparams.tpu_name is None:
            master = self.hparams.master
        else:
            cluster_spec = self.resolver.cluster_spec()
            if cluster_spec:
                session_config.cluster_def.CopyFrom(
                    cluster_spec.as_cluster_def())
            master = self.resolver.get_master()
        self.sess = tf.Session(master, graph=self.graph, config=session_config)
        self.sess.run(self.tpu_init)
예제 #2
0
  def __init__(self, sparse_features_key, embedding, **kwargs):
    """Initializes the runner."""
    super(DLRMEmbeddingRunner, self).__init__(**kwargs, do_initialize=False)
    self.embedding = embedding
    self.embedding_config = embedding.config_proto
    self.features_key = sparse_features_key
    self.embed_vars_and_ops = None
    self.retrieve_ops = None
    self.enqueue_datas_list = {True: [], False: []}
    self.dummy_variables = None
    self.dummy_variables_init = None
    self.num_outfeeds = 1

    with self.graph.as_default():
      self.embed_vars_and_ops = self.embedding.create_variables_and_ops()
      self.dummy_variables, self.dummy_variables_init = (
          tpu_embedding_gradient.create_dummy_table_variables(self.embedding))
    self.device_topology = tf.Session(
        self.master, config=self.config).run(
            tpu.initialize_system(embedding_config=self.embedding_config))
예제 #3
0
    def __init__(self,
                 iterations_per_loop,
                 train_steps,
                 eval_steps,
                 num_replicas,
                 eval_dataset_repeats=True,
                 do_initialize=True):
        self.feature_structure = {}
        self.infeed_op = {}
        self.num_replicas = num_replicas
        self.eval_dataset_repeats = eval_dataset_repeats
        # Set number of input graphs to number of hosts up to a maximum of 32.
        self.num_input_graphs = min(
            32, self.num_replicas // FLAGS.replicas_per_host)
        # Following data has separated copies for training and eval, thus
        # represented as a map from is_train(boolean) to actual data
        self.dataset_initializer = {True: [], False: []}
        self.input_graph = {True: [], False: []}
        self.input_sess = {True: [], False: []}
        self.enqueue_ops = {True: [], False: []}
        for _ in range(self.num_input_graphs):
            self.input_graph[True].append(tf.Graph())
            self.input_graph[False].append(tf.Graph())
            self.dataset_initializer[True].append([])
            self.dataset_initializer[False].append([])
            self.enqueue_ops[True].append([])
            self.enqueue_ops[False].append([])
            self.input_sess[True].append([])
            self.input_sess[False].append([])
        # dequeue_ops is only for eval
        self.dequeue_ops = []
        self.iterations_per_loop = iterations_per_loop
        self.sess = None
        self.output_sess = None
        self.train_eval_thread = None
        self.graph = tf.Graph()
        if iterations_per_loop != 0 and train_steps % iterations_per_loop != 0:
            train_steps = iterations_per_loop * int(
                math.ceil(train_steps / iterations_per_loop))
        self.train_steps = train_steps
        if iterations_per_loop == 0:
            self.max_train_iterations = 1
        else:
            self.max_train_iterations = train_steps // iterations_per_loop
        self.eval_steps = int(eval_steps)
        self.train_batch_size = 0
        self.eval_batch_size = 0
        self.eval_has_labels = 0
        self.model_fn = None
        self.num_outfeeds = self.eval_steps
        self.config = tf.ConfigProto(
            operation_timeout_in_ms=600 * 60 * 1000,
            allow_soft_placement=True,
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True)),
            isolate_session_state=True)

        if FLAGS.enable_mlir_bridge:
            self.config.experimental.enable_mlir_bridge = True

        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.master,
            zone=FLAGS.tpu_zone,
            project=FLAGS.gcp_project,
            job_name="tpu_worker")
        self.master = tpu_cluster_resolver.get_master()
        self.job_name = tpu_cluster_resolver.get_job_name() or "tpu_worker"
        self.embedding_config = None
        self.device_topology = None
        if do_initialize:
            self.device_topology = tf.Session(
                self.master, config=self.config).run(tpu.initialize_system())