Example #1
0
 def testCreateVariableSkipLpRegularization(self):
   layer_p = TestLayer.Params().Set(name='test', skip_lp_regularization=True)
   layer = layer_p.Instantiate()
   self.assertIn(layer.vars.w,
                 tf.get_collection(py_utils.SKIP_LP_REGULARIZATION))
   self.assertIn(layer.vars.b,
                 tf.get_collection(py_utils.SKIP_LP_REGULARIZATION))
Example #2
0
 def testCreateVariable(self):
     layer_p = TestLayer.Params().Set(name='test')
     layer = layer_p.Instantiate()
     self.assertEqual('test/w/var:0', layer.vars.w.name)
     self.assertEqual('test/b/var:0', layer.vars.b.name)
     self.assertNotIn(layer.vars.w,
                      tf.get_collection(py_utils.SKIP_LP_REGULARIZATION))
     # 'b' always skips Lp regularization.
     self.assertIn(layer.vars.b,
                   tf.get_collection(py_utils.SKIP_LP_REGULARIZATION))
Example #3
0
    def _testDecoderFPropFloatHelper(self,
                                     func_inline=False,
                                     num_decoder_layers=1,
                                     target_seq_len=5,
                                     residual_start=0):
        """Computes decoder from params and computes loss with random inputs."""
        cluster = cluster_factory.ForTestingWorker(add_summary=True)
        config = tf.ConfigProto(graph_options=tf.GraphOptions(
            optimizer_options=tf.OptimizerOptions(
                do_function_inlining=func_inline)))
        with cluster, self.session(use_gpu=False, config=config) as sess:
            tf.set_random_seed(8372749040)
            vn_config = py_utils.VariationalNoiseParams(None, False, False)
            p = self._DecoderParams(vn_config)
            p.rnn_layers = num_decoder_layers
            p.residual_start = residual_start
            p.target_seq_len = target_seq_len
            dec = p.Instantiate()
            src_seq_len = 5
            src_enc = tf.random_normal([src_seq_len, 2, 8], seed=9283748)
            src_enc_padding = tf.constant(
                [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
                dtype=tf.float32)
            encoder_outputs = py_utils.NestedMap(encoded=src_enc,
                                                 padding=src_enc_padding)
            target_ids = tf.transpose(
                tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15],
                             [5, 6, 7, 8], [10, 5, 2, 5]],
                            dtype=tf.int32))
            target_labels = tf.transpose(
                tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13],
                             [5, 7, 8, 10], [10, 5, 2, 4]],
                            dtype=tf.int32))
            target_paddings = tf.transpose(
                tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0],
                             [0, 1, 0, 0], [1, 1, 1, 1]],
                            dtype=tf.float32))
            target_transcripts = tf.constant(
                ['abcd', 'bcde', 'klmp', 'fghi', 'kfcf'])
            target_weights = 1.0 - target_paddings
            targets = py_utils.NestedMap({
                'ids': target_ids,
                'labels': target_labels,
                'weights': target_weights,
                'paddings': target_paddings,
                'transcripts': target_transcripts,
            })
            metrics = dec.FPropDefaultTheta(encoder_outputs, targets).metrics
            loss = metrics['loss'][0]
            correct_predicts = metrics['fraction_of_correct_next_step_preds'][
                0]
            summaries = tf.summary.merge(
                tf.get_collection(tf.GraphKeys.SUMMARIES))

            tf.global_variables_initializer().run()
            loss_v, _ = sess.run([loss, correct_predicts])

            summaries.eval()

            return loss_v
Example #4
0
    def _TpuEmbLookup(self) -> Dict[str, tf.Tensor]:
        """TPU Embedding lookup."""
        activations = self._tpu_embedding.get_activations()
        task = py_utils.GetTaskCallScope()
        # We expect either None (if this is the first call) or a single item in a
        # list.
        tpu_embedding_activations = tf.get_collection(
            py_utils.TPU_EMBEDDING_ACTIVATIONS)
        if not tpu_embedding_activations:
            # Create a dict from task -> activations dict.
            tpu_embedding_activations_dict = {}
            tpu_embedding_activations_dict[task] = activations
            tf.add_to_collection(py_utils.TPU_EMBEDDING_ACTIVATIONS,
                                 tpu_embedding_activations_dict)
        else:
            # This is a subsequent call, so the dictionary already exists.
            tpu_embedding_activations_dict = tpu_embedding_activations[0]
            tpu_embedding_activations_dict[task] = activations

        ret = py_utils.NestedMap()
        for k, v in activations.items():
            if k in self._sequence_features:
                ret[k] = v
            else:
                # Non-sequence embeddings, we fill the "time" dimension with 1.
                ret[k] = tf.expand_dims(v, axis=[1])
        return ret
Example #5
0
    def testBatchNormUpdatesWithUpdateUseGlobalStatsForTraining(self):
        tf.random.set_seed(398847392)
        np.random.seed(12345)
        params = layers.BatchNormLayer.Params()
        params.name = 'bn'
        params.dim = 3
        params.use_moving_avg_in_training = True
        params.params_init = py_utils.WeightInit.Gaussian(0.1)

        bn_layer = layers.BatchNormLayer(params)
        in_padding1 = tf.zeros([2, 8, 1], dtype=tf.float32)
        bn_in1 = tf.constant(np.random.normal(0.1, 0.5, [2, 8, 3]),
                             dtype=tf.float32)

        bn_out = bn_layer.FPropDefaultTheta(bn_in1, in_padding1)
        sig1 = tf.reduce_sum(bn_out)
        sig2 = tf.reduce_sum(bn_out * bn_out)

        # IMPORTANT: Keep these values consistent with the corresponding
        # test in layers_test.py
        self.assertAllClose(2.6575434, sig1, atol=1e-5)
        self.assertAllClose(15.473802, sig2)

        updates_collection = tf.get_collection(py_utils.BATCH_NORM_UPDATES)
        l1, l2 = py_utils.FindRelevantBatchNormUpdates(bn_out,
                                                       updates_collection)
        self.assertEqual(l1, [])
        self.assertEqual(l2, [])
    def _CreateLayerVariables(self):
        super()._CreateLayerVariables()
        p = self.params

        load_op_list = []
        retrieve_op_list = []

        # At the feature level, track which are associated
        # with "sequence embeddings".
        self._sequence_features = {}

        if py_utils.use_tpu():
            num_cores = self.cluster.params.worker.tpus_per_replica
            global_batch_size = (self.params.batch_size *
                                 self.cluster.num_splits_per_client)
            table_to_config_dict = {}
            feature_to_config_dict = {}
            for table in self.tables:
                table_to_config_dict[table.table_name] = table.table_config
                load_op_list += table.load_op_list
                retrieve_op_list += table.retrieve_op_list
                for feature in table.input_keys:
                    if table.max_sequence_length > 0:
                        self._sequence_features[feature] = True
                    feature_to_config_dict[
                        feature] = tpu_embedding_lib.FeatureConfig(
                            table.table_name,
                            max_sequence_length=table.max_sequence_length)
            tf.logging.info('adding load and retrieve ops to collection.')
            tf.add_to_collection(py_utils.TPU_EMBEDDING_LOAD_OPS, load_op_list)
            tf.add_to_collection(py_utils.TPU_EMBEDDING_RETRIEVE_OPS,
                                 retrieve_op_list)

            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            assert len(tpu_embedding_collection) <= 1
            if len(tpu_embedding_collection) == 1:
                tf.logging.info(
                    'TPUEmbedding API singleton already exists, reusing')
                self._tpu_embedding = tpu_embedding_collection[0]
            else:
                mode = tpu_embedding_lib.TRAINING
                device_config = tpu_embedding_lib.DeviceConfig(
                    num_cores=num_cores,
                    num_hosts=self.params.tables[0].num_tpu_hosts,
                    job_name=self.cluster.params.worker.name)
                self._tpu_embedding = tpu_embedding_lib.TPUEmbedding(
                    table_to_config_dict,
                    feature_to_config_dict,
                    global_batch_size,
                    mode,
                    master=None,
                    pipeline_execution_with_tensor_core=(
                        self.params.pipeline_execution_with_tensor_core),
                    partition_strategy=p.partition_strategy,
                    device_config=device_config)
                tf.add_to_collection(py_utils.TPU_EMBEDDING,
                                     self._tpu_embedding)
    def CreateTpuEmbeddingEnqueueOps(self):
        """Creates the TpuEmbedding enqueue ops on the host.

    Note that this must be called after the instantiation of the
    monolithic TPUEmbeddingLayer.
    """
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        enqueue_ops = []

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')
        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        if not tpu_embedding:
            return

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                enqueue_dict_per_core = [
                    {} for _ in range(tpu_embedding.num_cores_per_host)
                ]
                num_cores_per_host = tpu_embedding.num_cores_per_host
                for key in tpu_emb_input_keys:
                    feat = self._batch[key]
                    tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host)
                    for core, split in enumerate(tpu_emb_feat_splitted):
                        # Dense to sparse. Note the assumption of a padding id.
                        sample_indices = tf.where(tf.not_equal(split, -1))
                        embedding_indices = tf.gather_nd(split, sample_indices)
                        enqueue_data = tpu_embedding_lib.EnqueueData(
                            embedding_indices, sample_indices)
                        enqueue_dict_per_core[core][key] = enqueue_data
                enqueue_ops += tpu_embedding.generate_enqueue_ops(
                    enqueue_dict_per_core)
        self._tpu_infeed_op.append(tf.group(*enqueue_ops))
Example #8
0
 def variables_for_ema(self):
   p = self.params
   all_vars = set(tf.trainable_variables()) | set(
       tf.moving_average_variables())
   if p.train.ema_decay_moving_vars:
     all_vars |= set(tf.get_collection('moving_vars'))
   all_vars &= set(self.vars.Flatten())
   for var in all_vars:
     tf.logging.debug('variables_for_ema: %s', var.name)
   return all_vars
Example #9
0
 def Get(cls):
   """Returns the TpuEmbeddingCollection associated with the current graph."""
   emb_collection = tf.get_collection(cls.GRAPH_COLLECTION_NAME)
   assert len(emb_collection) <= 1
   if len(emb_collection) == 1:
     tf.logging.info(
         'TpuEmbeddingCollection singleton already exists, reusing')
     return emb_collection[0]
   else:
     singleton = cls()
     tf.add_to_collection(cls.GRAPH_COLLECTION_NAME, singleton)
     return singleton
Example #10
0
    def __init__(self, decoder_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._job_name = 'decoder_' + decoder_type
        self.params.cluster.do_eval = True
        self._cluster = cluster_factory.Cluster(self.params.cluster)
        self._decoder_dir = GetDecoderDir(self._logdir, self._job_name,
                                          self._model_task_name)
        tf.io.gfile.makedirs(self._decoder_dir)

        self._decode_path = None
        # Multitask params doesn't have 'task'.
        if 'task' in self.params:
            self._decode_path = checkpointer.GetSpecificCheckpoint(
                self.params.task.eval.load_checkpoint_from)

        self._should_report_metrics = self._job_name.startswith(
            self._cluster.reporting_job)

        with self._graph.as_default(), tf.container(self._container_id):
            self._summary_writer = self._CreateSummaryWriter(self._decoder_dir)
            self._CreateTF2SummaryWriter(self._decoder_dir)
            with self._cluster, tf.device(
                    self._cluster.GetPlacer()), self._TF2SummaryContext():
                self._model = self.params.Instantiate()
                self._params = self._model.params
                self._task = self._model.GetTask(self._model_task_name)
                # Note, different graphs are being constructed for different model
                # tasks, which may result in different node names being chosen.
                # Obviously, variable names has to be stay the same between train and
                # decode.
                cluster = self._cluster
                with tf.device(cluster.input_device):
                    input_batch = (
                        self._task.input_generator.GetPreprocessedInputBatch())

                self._dec_output = self._task.Decode(input_batch)
                self._summary_op = tf.summary.merge_all()
                self.checkpointer = self._CreateCheckpointer(
                    self._train_dir, self._model)
            self._CreateTF2SummaryOps()
            self._initialize_tables = tf.tables_initializer()
            self._initialize_local_vars = tf.local_variables_initializer()
            # No queues are allowed for decoder models.
            self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS)
            assert not self.enqueue_ops

        # Saves the graph def.
        self._WriteToLog(self.params.ToText(), self._decoder_dir, 'params.txt')
        if self.params.cluster.task == 0:
            tf.io.write_graph(self._graph.as_graph_def(), self._decoder_dir,
                              '%s.pbtxt' % self._job_name)
Example #11
0
 def testConv2DLayerConstruction(self):
   with self.session(use_gpu=True):
     tf.random.set_seed(398847392)
     np.random.seed(12345)
     params = conv_layers.Conv2DLayerWithPadding.Params()
     params.name = 'conv'
     params.filter_shape = [3, 3, 3, 32]
     params.filter_stride = [2, 2]
     params.params_init = py_utils.WeightInit.Gaussian(0.1)
     _ = params.Instantiate()
     conv_vars = tf.get_collection('Conv2DLayerWithPadding_vars')
     conv_var_names = [x.name for x in conv_vars]
     expected_var_names = ['conv/w/var:0']
     self.assertEqual(expected_var_names, conv_var_names)
Example #12
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._job_name = 'trainer'
        with self._graph.as_default(), tf.container(self._container_id):
            try:
                self._task_probs_summary_writers = []
                for task in self._model.task_schedule.tasks:
                    path = os.path.join(os.path.join(self._train_dir, task))
                    tf.io.gfile.makedirs(path)
                    self._task_probs_summary_writers.append(
                        self._CreateSummaryWriter(path))
            except AttributeError:
                tf.logging.info(
                    'AttributeError. Expected for single task models.')
                self._task_probs_summary_writers = []

            if self.params.cluster.task == 0:
                self._summary_writer = self._CreateSummaryWriter(
                    self._train_dir)
                self._CreateTF2SummaryWriter(self._train_dir)
            else:
                self._summary_writer = None

            with self._cluster, tf.device(
                    self._cluster.GetPlacer()), self._TF2SummaryContext():
                self._model = self.params.Instantiate()
                self._params = self._model.params
                self._model.ConstructFPropBPropGraph()
            self._CreateTF2SummaryOps()
            self._initialize_tables = tf.tables_initializer()
            self._initialize_local_vars = tf.local_variables_initializer()
            self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS)
            tf.logging.info('Trainer number of enqueue ops: %d',
                            len(self.enqueue_ops))

        self._step_rate_tracker = summary_utils.StepRateTracker()

        # Saves the graph def.
        if self.params.cluster.task == 0:
            self._WriteToLog(self.params.ToText(), self._train_dir,
                             'trainer_params.txt')
            tf.io.write_graph(self._graph.as_graph_def(), self._train_dir,
                              'train.pbtxt')
        worker_id = self.params.cluster.task
        self._start_up_delay_steps = (((worker_id + 1) * worker_id / 2) *
                                      self.params.train.start_up_delay_steps)
Example #13
0
 def ApplyExponentialMovingAverage(self, ema):
   """Wraps `self.train_op` with an op updating exponential moving average."""
   if (self._create_variables_status !=
       base_layer._CreateLayerVariablesStatus.COMPLETED):  # pylint: disable=protected-access
     raise ValueError(
         'ApplyExponentialMovingAverage called before InstantiateVariables!')
   # TODO(rpang): raise an exception if this is called in the eval mode.
   p = self.params
   # We need to apply EMA to trainable and moving average variable of this
   # Task, not just bprop vars, so that we create a shadow
   # '/ExponentialMovingAverage' variable for every trainable and moving
   # average variable.
   all_vars = set(tf.trainable_variables()) | set(
       tf.moving_average_variables())
   if p.train.ema_decay_moving_vars:
     all_vars |= set(tf.get_collection('moving_vars'))
   all_vars &= set(self.vars.Flatten())
   for var in all_vars:
     tf.logging.debug('ApplyExponentialMovingAverage: %s', var.name)
   with tf.name_scope('moving_average'):
     self._post_train_ops.append(ema.apply(all_vars))
Example #14
0
def AddToPruningCollections(weight,
                            mask,
                            threshold,
                            gradient=None,
                            old_weight=None,
                            old_old_weight=None):
    """Add mask, threshold, and weight vars to their respective collections."""
    if mask not in tf.get_collection(pruning.MASK_COLLECTION):
        tf.add_to_collection(pruning.WEIGHT_COLLECTION, weight)
        tf.add_to_collection(pruning.MASK_COLLECTION, mask)
        tf.add_to_collection(pruning.THRESHOLD_COLLECTION, threshold)

        # Add gradient, old_weight, and old_old_weight to collections approximating
        # gradient and hessian, where old_weight is the weight tensor one step
        # before and old_old_weight is the weight tensor two steps before.
        if gradient is not None:
            assert old_weight is not None
            assert old_old_weight is not None
            tf.add_to_collection(pruning.WEIGHT_GRADIENT_COLLECTION, gradient)
            tf.add_to_collection(pruning.OLD_WEIGHT_COLLECTION, old_weight)
            tf.add_to_collection(pruning.OLD_OLD_WEIGHT_COLLECTION,
                                 old_old_weight)
Example #15
0
    def __init__(self, train_dir, model):
        """Initialize Checkpointer.

    Args:
     train_dir: Training directory for saving checkpoints.
     model: Model.
    """
        self._train_dir = train_dir
        self._model = model
        self._params = model.params

        self._vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        self._uninitialized_vars = tf.report_uninitialized_variables(
            self._vars)
        self._initialize_vars = tf.global_variables_initializer()

        self._save_path = os.path.join(self._train_dir, 'ckpt')
        self._model_tasks = model.tasks

        tp = self._params.train
        self._save_interval_seconds = tp.save_interval_seconds
        self._next_checkpoint_seconds = 0
        self._saver = self._GetSaver()
Example #16
0
    def __init__(self, train_dir, model, train_params=None, save_only=False):
        """Initialize Checkpointer.

    Args:
     train_dir: Training directory for saving checkpoints.
     model: A BaseModel instance or None.
     train_params: If specified, use these training params instead of those
       in the `model`.
     save_only: This checkpointer is only intended for saving checkpoints.
    """
        self._train_dir = train_dir
        self._save_only = save_only

        self._vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        self._uninitialized_vars = tf.report_uninitialized_variables(
            self._vars)
        self._initialize_vars = tf.global_variables_initializer()

        self._save_path = os.path.join(self._train_dir, 'ckpt')

        if train_params:
            self._train_params = train_params
            self._model = None
        else:
            assert model
            self._train_params = model.params.train
            self._model = model

        if not self._save_only:
            self._params = model.params
            self._model_tasks = model.tasks
            self._model = model

        self._next_checkpoint_seconds = 0
        self._save_interval_seconds = self._train_params.save_interval_seconds
        self._saver = self._GetSaver()
Example #17
0
  def __init__(self, params):
    assert issubclass(params.cls, BaseTask)
    # Ensure global_step exists before calling super.
    py_utils.GetOrCreateGlobalStepVar()
    super(BaseTask, self).__init__(params)

    p = self.params

    if p.input:
      # TODO(zhifengc): Consider a simpler way to ensure the input
      # generator stops after one epoch.
      if p.is_eval and p.eval:
        seq_inp = issubclass(p.input.cls,
                             base_input_generator.BaseInputGeneratorFromFiles)
        if p.input.num_samples == 0:
          # Dataset size is unknown. Computes eval summary based on num_samples.
          assert p.eval.samples_per_summary > 0
        elif (p.eval.samples_per_summary == 0) or (p.input.num_samples <
                                                   p.eval.samples_per_summary):
          # If we know the dataset size and we want to evaluate the full
          # set, we need to coordinate the input generator to flush out
          # all samples so the evaler and decoder compute metrics on the
          # whole set for each summary step.
          if seq_inp:
            p.input.flush_every_n = p.input.num_samples
          p.eval.samples_per_summary = p.input.num_samples
        if seq_inp and p.input.num_batcher_threads > 1:
          tf.logging.warning('input.num_batcher_threads > 1 inside eval mode.  '
                             'The input generator may not iterate over exactly '
                             'one epoch per run')
      tf.logging.info('input_params: %s', p.input)
      input_params = self.cluster.PlaceInput(p.input)
      with py_utils.outside_all_rewrites():
        self.CreateChild('input', input_params)

    self._encoder = None
    self._online_encoder = None
    self._decoder = None

    self._loss = None
    self._num_predictions = None
    self._train_op = None
    self._eval_metrics = {}
    self._per_example = {}
    self._trainer_verbose_tensors = {}

    # Create the gradient mask,
    self._per_input_gradient_mask = None
    task_global_step_list = tf.get_collection('TASK_GLOBAL_STEP',
                                              '^%s_global_step' % p.name)
    if len(task_global_step_list) > 1:
      raise ValueError('Found multiple task_global_step for task %s' % p.name)
    self._global_step_var = (
        task_global_step_list[0] if len(task_global_step_list) == 1 else
        py_utils.GetOrCreateGlobalStepVar())
    self._global_step = tf.identity(
        self._global_step_var, name='global_step_tensor')
    tp = p.train
    # p.train can be None if this task is the teacher/student task in a
    # DistillationTask.
    if tp and self.cluster.job in ('worker', 'trainer', 'trainer_client',
                                   'controller', 'executor_tpu'):
      self._SetLearnerFromLegacyParams(tp)
      if tp.learner is not None:
        if isinstance(tp.learner, (list, tuple)):
          self.CreateChildren('learners', tp.learner)
        else:
          self.CreateChildren('learners', [tp.learner])
    self._UpdateVnConfig()
Example #18
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTPUFeeds num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            tf.logging.info('shards {}'.format(shards))

            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            if num_tpu_hosts > 1 and tpu_embedding is not None:
                if not p.use_per_host_infeed:
                    tf.logging.fatal(
                        'TPU Embedding must be used with per_host_infeed with multiple '
                        'TPU host topologies.')
            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    if p.use_partitioned_infeed_queue:
                        device_assignment = py_utils.GetTpuDeviceAssignment()

                        host_device = device_assignment.host_device(
                            replica=0, job=tf.flags.FLAGS.tf_master)
                        host_id = int(
                            host_device.split('/task:')[1].split('/device:')
                            [0])
                        tf.logging.info('host_id: {} host_device: {}'.format(
                            host_id, host_device))
                        q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                            number_of_tuple_elements=len(dtypes),
                            device_assignment=device_assignment,
                            host_id=host_id,
                            input_partition_dims=[[p.num_partitions, 1]
                                                  for _ in dtypes],
                            tuple_types=dtypes,
                            tuple_shapes=shapes)
                    else:
                        q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                                 tuple_shapes=shapes)
                        assert shards is not None
                        q.set_number_of_shards(shards)

                    queues.append(q)
                    tf.logging.info('q=%r', q)

                    if p.use_partitioned_infeed_queue:
                        input_ops = q.generate_enqueue_ops([batch.Flatten()])
                    elif p.use_per_host_infeed:
                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        self._tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Example #19
0
    def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
                 tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
        if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        super().__init__(train_cfg, model_task_name, logdir, tf_master,
                         **kwargs)

        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'trainer_params.txt')
        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        # BaseRunner legacy
        self.enqueue_ops = None

        train_cfg = self.params

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                # tpu.initialize_system() is called with None as embedding_config, as
                # embedding_config is not available yet. Later in _Loop, it is called
                # with the correct embedding_config. Since it cannot be called twice in
                # the same graph with different embedding_config, we use a dummy_graph
                # here.
                dummy_graph = tf.Graph()
                with dummy_graph.as_default():
                    tpu_initialize_system_op = tf.tpu.initialize_system(
                        embedding_config=None, job=job)

                with self._GetSession(graph=dummy_graph) as sess:
                    topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_device_order_mode is None:
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=py_utils.ComputationShape(
                            num_devices_per_split, topology),
                        num_replicas=data_parallelism)
                else:
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=py_utils.ComputationShape(
                            num_devices_per_split, topology),
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='init_start', value=None)
        if len(self._cluster.all_worker_names) > 1:
            for worker in self._cluster.all_worker_names:
                _WaitTillInit(worker)
        else:
            _WaitTillInit(None)

        shared_model = self._MaybeConstructSharedModel(train_cfg)

        self._program_schedule_dict = {}
        self._programs = []

        for task_string, program_schedule_params in ps_params_dict.items():
            program_schedule_params.logdir = logdir
            program_schedule_params.num_splits_per_client = data_parallelism
            program_schedule_params.task_name = task_string
            # If the model was created above, we'll inject it here as a shared_model.
            ps = program_schedule_params.Instantiate(shared_model=shared_model,
                                                     tf_master=self._tf_master)
            self._program_schedule_dict[task_string] = ps
            tf.logging.info('program_schedule_params: %s',
                            program_schedule_params.ToText())
            self._programs += ps.Programs()
            if program_schedule_params.ml_perf.benchmark_name is not None:
                self._ml_perf = program_schedule_params.ml_perf

        tf.logging.info('num_programs: %d', len(self._programs))

        with self._graph.as_default(), tf.container(self._container_id):
            with self._cluster, tf.device(self._cluster.GetPlacer()):
                with py_utils.VariableRenameScope(
                        self._variable_renaming_rules):
                    _ = py_utils.GetOrCreateGlobalStepVar()
                    for program in self._programs:
                        program.BuildTpuSubgraph()
                        py_utils.ClearTpuSummaryTensors()

                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()
                self._initialize_global_vars = tf.global_variables_initializer(
                )

                for program in self._programs:
                    program.SetStatusMessageFn(self._SetStatusMessage)
                    program.CreateCheckpointer(
                        init_op=self._initialize_global_vars)

                self.save_only_checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    model=None,
                    init_op=self._initialize_global_vars,
                    train_params=train_cfg.train,
                    save_only=True)

            self._load_ops = tf.get_collection(py_utils.TPU_EMBEDDING_LOAD_OPS)
            self._retrieve_ops = tf.get_collection(
                py_utils.TPU_EMBEDDING_RETRIEVE_OPS)
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            self._tpu_embedding = (tpu_embedding_collection[0]
                                   if tpu_embedding_collection else None)
            tf.io.write_graph(self._graph.as_graph_def(), self._checkpoint_dir,
                              'train.pbtxt')
Example #20
0
def get_masked_weights():
    return tf.get_collection(_MASKED_WEIGHT_COLLECTION)
Example #21
0
def get_masks():
    return tf.get_collection(_MASK_COLLECTION)
Example #22
0
  def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True):
    """Populates the train_ops dictionary in a backwards pass."""
    metrics = metrics or self._metrics

    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in self._per_input_gradient_mask.items()
      }
    all_losses = []
    for optimization in self.learners:
      learner_name = optimization.params.name
      (losses, train_ops['train/%s' % learner_name],
       eval_metrics) = optimization.Apply(
           metrics,
           vmap,
           gradient_mask=gradient_mask,
           gradient_adjuster=self.AdjustGradients)
      all_losses.extend(losses)
      if add_summary:
        for key, (value, weight) in eval_metrics.items():
          self.AddEvalMetric(key + '/' + learner_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    var_update_ops = [
        tf.group(*tf.nest.flatten(train_ops), name='var_update_ops')
    ]
    # Post training step update.
    with tf.control_dependencies(var_update_ops):
      post_step_op = self.PostTrainingStepUpdate()

    train_ops = {}
    with tf.control_dependencies([post_step_op]):
      # Get the op to update the weight masks and thresholds
      mask_update_op = self._GetMaskUpdateOp()
      train_ops['mask_updates'] = mask_update_op
      with tf.control_dependencies([mask_update_op]):
        true_global_step = py_utils.GetOrCreateGlobalStepVar()
        with tf.ops.colocate_with(true_global_step):
          if self.params.defer_global_step_update:
            increment_global_steps = true_global_step
          else:
            increment_global_steps = tf.assign_add(true_global_step, 1)
        if self._global_step_var != true_global_step:
          with tf.ops.colocate_with(self._global_step_var):
            increment_global_steps = tf.group(
                increment_global_steps, tf.assign_add(self._global_step_var, 1))
        train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    if tf.get_collection(py_utils.TPU_EMBEDDING):
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      sparse_grads = (
          tpu_embedding_gradient.get_gradients_through_dummy_table_variables(
              tpu_embedding))
      tpu_embedding_send_gradient_op = tpu_embedding.generate_send_gradients_op(
          sparse_grads, py_utils.GetGlobalStep())
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

      tpu_embedding_summary_tensors = tf.get_collection(
          py_utils.TPU_EMBEDDING_SUMMARY_TENSORS)
      if add_summary:
        for name, value, weight in tpu_embedding_summary_tensors:
          self.AddEvalMetric(name, value, weight, raise_if_already_added=False)

    for op_name, op in train_ops.items():
      assert op is not None, op_name
    return train_ops
    def CreateTpuEnqueueOps(self):
        """Create the host-side enqueue ops.

    This should be called in an outer non-TPU context.
    """
        assert not self._tpu_queues, (
            'CreateTpuEnqueueOps should only be called '
            'once.')
        self._tpu_queues = []
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTpuEnqueueOps num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        shards = (cluster.total_worker_devices //
                  num_infeed_hosts) // cluster.num_devices_per_split
        tf.logging.info('shards {}'.format(shards))

        input_ops_list = []
        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')

        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts)

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                self._batch = self.GetPreprocessedInputBatch()
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                for k, x in self._batch.FlattenItems():
                    assert x.shape.is_fully_defined(), (
                        'Shape must be fully defined: %s: %s' % (k, x))
                    # TODO(cwhipkey): if it's a string (or other type not supported on
                    # TPU), drop it from feeding and on the other end add in an op that
                    # fails if used.
                shapes = self._batch.Transform(lambda x: x.shape).Flatten()
                dtypes = self._batch.Transform(lambda x: x.dtype).Flatten()

                tf.logging.info('host_device: %s infeed shapes: %r',
                                host_device, shapes)
                tf.logging.info('host_device: %s infeed dtypes: %r',
                                host_device, dtypes)

                if p.use_partitioned_infeed_queue:
                    device_assignment = py_utils.GetTpuDeviceAssignment()

                    host_device = device_assignment.host_device(
                        replica=0, job=tf.flags.FLAGS.tf_master)
                    host_id = int(
                        host_device.split('/task:')[1].split('/device:')[0])
                    tf.logging.info('host_id: {} host_device: {}'.format(
                        host_id, host_device))
                    q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                        number_of_tuple_elements=len(dtypes),
                        device_assignment=device_assignment,
                        host_id=host_id,
                        input_partition_dims=[[p.num_partitions] + [1] *
                                              (len(s) - 1) for s in shapes],
                        tuple_types=dtypes,
                        tuple_shapes=shapes)
                else:
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                self._tpu_queues.append(q)

                if p.use_partitioned_infeed_queue:
                    input_ops = q.generate_enqueue_ops([self._batch.Flatten()])
                elif p.use_per_host_infeed:
                    # TODO(ylc/zhifengc): Add this to a policy module and test it.
                    def TPUOrdinalFunction(shard_index_in_host):
                        device_assignment = py_utils.GetTpuDeviceAssignment()
                        if device_assignment:
                            # We put both enqueue/dequeue ops at core 0 in each replica.
                            replica = device_assignment.lookup_replicas(
                                task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                            return device_assignment.tpu_ordinal(
                                replica=replica)
                        else:
                            return shard_index_in_host

                    input_ops = q.split_inputs_and_generate_enqueue_ops(
                        self._batch.Flatten(),
                        placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                        tpu_ordinal_function=TPUOrdinalFunction)
                else:
                    input_ops = q.split_inputs_and_generate_enqueue_ops(
                        self._batch.Flatten(),
                        device_assignment=py_utils.GetTpuDeviceAssignment())
                input_ops_list += input_ops

        tf.logging.info('input_ops_list %s', input_ops_list)
        grouped_infeed_op = tf.group(*input_ops_list)
        self._tpu_infeed_op = []
        for _ in range(p.tpu_infeed_parallelism):
            self._tpu_infeed_op.append(grouped_infeed_op)
Example #24
0
    def _CreateLayerVariables(self):
        super()._CreateLayerVariables()
        p = self.params

        def _BuildTpuEmbeddingApi():
            load_op_list = []
            retrieve_op_list = []

            num_cores = self.cluster.params.worker.tpus_per_replica
            global_batch_size = (self.params.batch_size *
                                 self.cluster.num_splits_per_client)
            table_to_config_dict = {}
            feature_to_config_dict = {}
            for table in self.tables:
                table_to_config_dict[table.table_name] = table.table_config
                load_op_list += table.load_op_list
                retrieve_op_list += table.retrieve_op_list
                for feature in table.input_keys:
                    feature_to_config_dict[
                        feature] = tpu_embedding_lib.FeatureConfig(
                            table.table_name,
                            max_sequence_length=table.max_sequence_length)

            mode = tpu_embedding_lib.TRAINING
            device_config = tpu_embedding_lib.DeviceConfig(
                num_cores=num_cores,
                num_hosts=self.params.tables[0].num_tpu_hosts,
                job_name=self.cluster.params.worker.name)
            tpu_embedding = tpu_embedding_lib.TPUEmbedding(
                table_to_config_dict,
                feature_to_config_dict,
                global_batch_size,
                mode,
                master=None,
                pipeline_execution_with_tensor_core=(
                    self.params.pipeline_execution_with_tensor_core),
                partition_strategy=p.partition_strategy,
                device_config=device_config)

            with tf.init_scope():
                dummy_variables, dummy_variables_init = (
                    tpu_embedding_gradient.create_dummy_table_variables(
                        tpu_embedding))
            load_op_list += [dummy_variables_init]

            tf.add_to_collection(py_utils.TPU_EMBEDDING, tpu_embedding)
            tf.add_to_collection(py_utils.TPU_EMBEDDING_DUMMY_VARS,
                                 dummy_variables)
            tf.add_to_collection(py_utils.TPU_EMBEDDING_LOAD_OPS, load_op_list)
            tf.add_to_collection(py_utils.TPU_EMBEDDING_RETRIEVE_OPS,
                                 retrieve_op_list)

        if py_utils.use_tpu():
            # At the feature level, track which are associated
            # with "sequence embeddings".
            self._sequence_features = {}
            for table in self.tables:
                for feature in table.input_keys:
                    if table.max_sequence_length > 0:
                        self._sequence_features[feature] = True

            # Multiple TPUEmbeddingLayers can be created but we must have a singleton
            # TPUEmbedding API object
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            assert len(tpu_embedding_collection) <= 1
            if not tpu_embedding_collection:
                _BuildTpuEmbeddingApi()
            else:
                tf.logging.info(
                    'TPUEmbedding API singleton already exists, reusing')

            self._tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
            self._dummy_variables = tf.get_collection(
                py_utils.TPU_EMBEDDING_DUMMY_VARS)[0]
            for k, v in self._dummy_variables.items():
                self._private_vars[k] = v
                self._private_theta[k] = v
Example #25
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if 'bucket_keys' in batch:
                        # Hack: bucket_keys are not needed on TPU.
                        del batch['bucket_keys']
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        # For executor-driven multiple programs, we need more fine-grained
        # access rather than using a single global graph collection.
        self.tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Example #26
0
  def _BPropForVariables(self, vmap):
    """Constructs the backward graph."""
    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in six.iteritems(self._per_input_gradient_mask)
      }
    all_losses = []
    for optimization in self.learners:
      loss_name = optimization.params.name
      metric = self._metrics.get(loss_name, None)
      if metric is None:
        raise ValueError('Loss %s not found in metrics %s' %
                         (loss_name, list(self._metrics.keys())))
      loss = metric[0]
      all_losses.append(loss)
      train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply(
          loss,
          vmap,
          gradient_mask=gradient_mask,
          gradient_adjuster=self.AdjustGradients)
      for key, (value, weight) in six.iteritems(eval_metrics):
        self.AddEvalMetric(key + '/' + loss_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    # Get the op to update the weight masks and thresholds
    train_ops['mask_updates'] = self._GetMaskUpdateOp()

    # Post training step update.
    train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step)

    with tf.control_dependencies(tf.nest.flatten(train_ops)):
      true_global_step = py_utils.GetOrCreateGlobalStepVar()
      with tf.colocate_with(true_global_step):
        increment_global_steps = tf.assign_add(true_global_step, 1)
      if self._global_step_var != true_global_step:
        with tf.colocate_with(self._global_step_var):
          increment_global_steps = tf.group(
              increment_global_steps, tf.assign_add(self._global_step_var, 1))
      train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    tpu_embedding_activations = tf.get_collection(
        py_utils.TPU_EMBEDDING_ACTIVATIONS)
    if tpu_embedding_activations:
      tpu_embedding_activations_dict = tpu_embedding_activations[0]
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients(
          self.loss, tpu_embedding_activations_dict, tpu_embedding)
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

    for op_name, op in six.iteritems(train_ops):
      assert op is not None, op_name

    # TODO(rpang): try to structure _train_op as:
    #   tf.cond(skip_step, <only update skip stats>, <all updates>)
    # so that we skip all other updates when a step is skipped.
    self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')
Example #27
0
  def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True):
    """Populates the train_ops dictionary in a backwards pass."""
    metrics = metrics or self._metrics

    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in self._per_input_gradient_mask.items()
      }
    all_losses = []
    for optimization in self.learners:
      learner_name = optimization.params.name
      loss_name = optimization.params.loss_name or learner_name
      metric = metrics.get(loss_name, None)
      if metric is None:
        raise ValueError('Loss %s not found in metrics %s' %
                         (loss_name, list(metrics.keys())))
      loss = metric[0]
      all_losses.append(loss)
      train_ops['train/%s' % learner_name], eval_metrics = optimization.Apply(
          loss,
          vmap,
          gradient_mask=gradient_mask,
          gradient_adjuster=self.AdjustGradients)
      if add_summary:
        for key, (value, weight) in eval_metrics.items():
          self.AddEvalMetric(key + '/' + learner_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    var_update_ops = [
        tf.group(*tf.nest.flatten(train_ops), name='var_update_ops')
    ]
    # Post training step update.
    with tf.control_dependencies(var_update_ops):
      post_step_op = self.PostTrainingStepUpdate(self.global_step)

    train_ops = {}
    with tf.control_dependencies([post_step_op]):
      # Get the op to update the weight masks and thresholds
      mask_update_op = self._GetMaskUpdateOp()
      train_ops['mask_updates'] = mask_update_op
      with tf.control_dependencies([mask_update_op]):
        true_global_step = py_utils.GetOrCreateGlobalStepVar()
        with tf.ops.colocate_with(true_global_step):
          increment_global_steps = tf.assign_add(true_global_step, 1)
        if self._global_step_var != true_global_step:
          with tf.ops.colocate_with(self._global_step_var):
            increment_global_steps = tf.group(
                increment_global_steps, tf.assign_add(self._global_step_var, 1))
        train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    tpu_embedding_activations = tf.get_collection(
        py_utils.TPU_EMBEDDING_ACTIVATIONS)
    if tpu_embedding_activations:
      tpu_embedding_activations_dict = tpu_embedding_activations[0]
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients(
          self.loss, tpu_embedding_activations_dict, tpu_embedding)
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

    for op_name, op in train_ops.items():
      assert op is not None, op_name
    return train_ops
Example #28
0
def get_thresholds():
    return tf.get_collection(_THRESHOLD_COLLECTION)
Example #29
0
def get_weights():
    return tf.get_collection(_WEIGHT_COLLECTION)