def testCreateVariableSkipLpRegularization(self): layer_p = TestLayer.Params().Set(name='test', skip_lp_regularization=True) layer = layer_p.Instantiate() self.assertIn(layer.vars.w, tf.get_collection(py_utils.SKIP_LP_REGULARIZATION)) self.assertIn(layer.vars.b, tf.get_collection(py_utils.SKIP_LP_REGULARIZATION))
def testCreateVariable(self): layer_p = TestLayer.Params().Set(name='test') layer = layer_p.Instantiate() self.assertEqual('test/w/var:0', layer.vars.w.name) self.assertEqual('test/b/var:0', layer.vars.b.name) self.assertNotIn(layer.vars.w, tf.get_collection(py_utils.SKIP_LP_REGULARIZATION)) # 'b' always skips Lp regularization. self.assertIn(layer.vars.b, tf.get_collection(py_utils.SKIP_LP_REGULARIZATION))
def _testDecoderFPropFloatHelper(self, func_inline=False, num_decoder_layers=1, target_seq_len=5, residual_start=0): """Computes decoder from params and computes loss with random inputs.""" cluster = cluster_factory.ForTestingWorker(add_summary=True) config = tf.ConfigProto(graph_options=tf.GraphOptions( optimizer_options=tf.OptimizerOptions( do_function_inlining=func_inline))) with cluster, self.session(use_gpu=False, config=config) as sess: tf.set_random_seed(8372749040) vn_config = py_utils.VariationalNoiseParams(None, False, False) p = self._DecoderParams(vn_config) p.rnn_layers = num_decoder_layers p.residual_start = residual_start p.target_seq_len = target_seq_len dec = p.Instantiate() src_seq_len = 5 src_enc = tf.random_normal([src_seq_len, 2, 8], seed=9283748) src_enc_padding = tf.constant( [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=tf.float32) encoder_outputs = py_utils.NestedMap(encoded=src_enc, padding=src_enc_padding) target_ids = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15], [5, 6, 7, 8], [10, 5, 2, 5]], dtype=tf.int32)) target_labels = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13], [5, 7, 8, 10], [10, 5, 2, 4]], dtype=tf.int32)) target_paddings = tf.transpose( tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [1, 1, 1, 1]], dtype=tf.float32)) target_transcripts = tf.constant( ['abcd', 'bcde', 'klmp', 'fghi', 'kfcf']) target_weights = 1.0 - target_paddings targets = py_utils.NestedMap({ 'ids': target_ids, 'labels': target_labels, 'weights': target_weights, 'paddings': target_paddings, 'transcripts': target_transcripts, }) metrics = dec.FPropDefaultTheta(encoder_outputs, targets).metrics loss = metrics['loss'][0] correct_predicts = metrics['fraction_of_correct_next_step_preds'][ 0] summaries = tf.summary.merge( tf.get_collection(tf.GraphKeys.SUMMARIES)) tf.global_variables_initializer().run() loss_v, _ = sess.run([loss, correct_predicts]) summaries.eval() return loss_v
def _TpuEmbLookup(self) -> Dict[str, tf.Tensor]: """TPU Embedding lookup.""" activations = self._tpu_embedding.get_activations() task = py_utils.GetTaskCallScope() # We expect either None (if this is the first call) or a single item in a # list. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if not tpu_embedding_activations: # Create a dict from task -> activations dict. tpu_embedding_activations_dict = {} tpu_embedding_activations_dict[task] = activations tf.add_to_collection(py_utils.TPU_EMBEDDING_ACTIVATIONS, tpu_embedding_activations_dict) else: # This is a subsequent call, so the dictionary already exists. tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding_activations_dict[task] = activations ret = py_utils.NestedMap() for k, v in activations.items(): if k in self._sequence_features: ret[k] = v else: # Non-sequence embeddings, we fill the "time" dimension with 1. ret[k] = tf.expand_dims(v, axis=[1]) return ret
def testBatchNormUpdatesWithUpdateUseGlobalStatsForTraining(self): tf.random.set_seed(398847392) np.random.seed(12345) params = layers.BatchNormLayer.Params() params.name = 'bn' params.dim = 3 params.use_moving_avg_in_training = True params.params_init = py_utils.WeightInit.Gaussian(0.1) bn_layer = layers.BatchNormLayer(params) in_padding1 = tf.zeros([2, 8, 1], dtype=tf.float32) bn_in1 = tf.constant(np.random.normal(0.1, 0.5, [2, 8, 3]), dtype=tf.float32) bn_out = bn_layer.FPropDefaultTheta(bn_in1, in_padding1) sig1 = tf.reduce_sum(bn_out) sig2 = tf.reduce_sum(bn_out * bn_out) # IMPORTANT: Keep these values consistent with the corresponding # test in layers_test.py self.assertAllClose(2.6575434, sig1, atol=1e-5) self.assertAllClose(15.473802, sig2) updates_collection = tf.get_collection(py_utils.BATCH_NORM_UPDATES) l1, l2 = py_utils.FindRelevantBatchNormUpdates(bn_out, updates_collection) self.assertEqual(l1, []) self.assertEqual(l2, [])
def _CreateLayerVariables(self): super()._CreateLayerVariables() p = self.params load_op_list = [] retrieve_op_list = [] # At the feature level, track which are associated # with "sequence embeddings". self._sequence_features = {} if py_utils.use_tpu(): num_cores = self.cluster.params.worker.tpus_per_replica global_batch_size = (self.params.batch_size * self.cluster.num_splits_per_client) table_to_config_dict = {} feature_to_config_dict = {} for table in self.tables: table_to_config_dict[table.table_name] = table.table_config load_op_list += table.load_op_list retrieve_op_list += table.retrieve_op_list for feature in table.input_keys: if table.max_sequence_length > 0: self._sequence_features[feature] = True feature_to_config_dict[ feature] = tpu_embedding_lib.FeatureConfig( table.table_name, max_sequence_length=table.max_sequence_length) tf.logging.info('adding load and retrieve ops to collection.') tf.add_to_collection(py_utils.TPU_EMBEDDING_LOAD_OPS, load_op_list) tf.add_to_collection(py_utils.TPU_EMBEDDING_RETRIEVE_OPS, retrieve_op_list) tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) assert len(tpu_embedding_collection) <= 1 if len(tpu_embedding_collection) == 1: tf.logging.info( 'TPUEmbedding API singleton already exists, reusing') self._tpu_embedding = tpu_embedding_collection[0] else: mode = tpu_embedding_lib.TRAINING device_config = tpu_embedding_lib.DeviceConfig( num_cores=num_cores, num_hosts=self.params.tables[0].num_tpu_hosts, job_name=self.cluster.params.worker.name) self._tpu_embedding = tpu_embedding_lib.TPUEmbedding( table_to_config_dict, feature_to_config_dict, global_batch_size, mode, master=None, pipeline_execution_with_tensor_core=( self.params.pipeline_execution_with_tensor_core), partition_strategy=p.partition_strategy, device_config=device_config) tf.add_to_collection(py_utils.TPU_EMBEDDING, self._tpu_embedding)
def CreateTpuEmbeddingEnqueueOps(self): """Creates the TpuEmbedding enqueue ops on the host. Note that this must be called after the instantiation of the monolithic TPUEmbeddingLayer. """ p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) enqueue_ops = [] if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) if not tpu_embedding: return for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = self._batch[key] tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host) for core, split in enumerate(tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where(tf.not_equal(split, -1)) embedding_indices = tf.gather_nd(split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data enqueue_ops += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) self._tpu_infeed_op.append(tf.group(*enqueue_ops))
def variables_for_ema(self): p = self.params all_vars = set(tf.trainable_variables()) | set( tf.moving_average_variables()) if p.train.ema_decay_moving_vars: all_vars |= set(tf.get_collection('moving_vars')) all_vars &= set(self.vars.Flatten()) for var in all_vars: tf.logging.debug('variables_for_ema: %s', var.name) return all_vars
def Get(cls): """Returns the TpuEmbeddingCollection associated with the current graph.""" emb_collection = tf.get_collection(cls.GRAPH_COLLECTION_NAME) assert len(emb_collection) <= 1 if len(emb_collection) == 1: tf.logging.info( 'TpuEmbeddingCollection singleton already exists, reusing') return emb_collection[0] else: singleton = cls() tf.add_to_collection(cls.GRAPH_COLLECTION_NAME, singleton) return singleton
def __init__(self, decoder_type, *args, **kwargs): super().__init__(*args, **kwargs) self._job_name = 'decoder_' + decoder_type self.params.cluster.do_eval = True self._cluster = cluster_factory.Cluster(self.params.cluster) self._decoder_dir = GetDecoderDir(self._logdir, self._job_name, self._model_task_name) tf.io.gfile.makedirs(self._decoder_dir) self._decode_path = None # Multitask params doesn't have 'task'. if 'task' in self.params: self._decode_path = checkpointer.GetSpecificCheckpoint( self.params.task.eval.load_checkpoint_from) self._should_report_metrics = self._job_name.startswith( self._cluster.reporting_job) with self._graph.as_default(), tf.container(self._container_id): self._summary_writer = self._CreateSummaryWriter(self._decoder_dir) self._CreateTF2SummaryWriter(self._decoder_dir) with self._cluster, tf.device( self._cluster.GetPlacer()), self._TF2SummaryContext(): self._model = self.params.Instantiate() self._params = self._model.params self._task = self._model.GetTask(self._model_task_name) # Note, different graphs are being constructed for different model # tasks, which may result in different node names being chosen. # Obviously, variable names has to be stay the same between train and # decode. cluster = self._cluster with tf.device(cluster.input_device): input_batch = ( self._task.input_generator.GetPreprocessedInputBatch()) self._dec_output = self._task.Decode(input_batch) self._summary_op = tf.summary.merge_all() self.checkpointer = self._CreateCheckpointer( self._train_dir, self._model) self._CreateTF2SummaryOps() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() # No queues are allowed for decoder models. self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS) assert not self.enqueue_ops # Saves the graph def. self._WriteToLog(self.params.ToText(), self._decoder_dir, 'params.txt') if self.params.cluster.task == 0: tf.io.write_graph(self._graph.as_graph_def(), self._decoder_dir, '%s.pbtxt' % self._job_name)
def testConv2DLayerConstruction(self): with self.session(use_gpu=True): tf.random.set_seed(398847392) np.random.seed(12345) params = conv_layers.Conv2DLayerWithPadding.Params() params.name = 'conv' params.filter_shape = [3, 3, 3, 32] params.filter_stride = [2, 2] params.params_init = py_utils.WeightInit.Gaussian(0.1) _ = params.Instantiate() conv_vars = tf.get_collection('Conv2DLayerWithPadding_vars') conv_var_names = [x.name for x in conv_vars] expected_var_names = ['conv/w/var:0'] self.assertEqual(expected_var_names, conv_var_names)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._job_name = 'trainer' with self._graph.as_default(), tf.container(self._container_id): try: self._task_probs_summary_writers = [] for task in self._model.task_schedule.tasks: path = os.path.join(os.path.join(self._train_dir, task)) tf.io.gfile.makedirs(path) self._task_probs_summary_writers.append( self._CreateSummaryWriter(path)) except AttributeError: tf.logging.info( 'AttributeError. Expected for single task models.') self._task_probs_summary_writers = [] if self.params.cluster.task == 0: self._summary_writer = self._CreateSummaryWriter( self._train_dir) self._CreateTF2SummaryWriter(self._train_dir) else: self._summary_writer = None with self._cluster, tf.device( self._cluster.GetPlacer()), self._TF2SummaryContext(): self._model = self.params.Instantiate() self._params = self._model.params self._model.ConstructFPropBPropGraph() self._CreateTF2SummaryOps() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS) tf.logging.info('Trainer number of enqueue ops: %d', len(self.enqueue_ops)) self._step_rate_tracker = summary_utils.StepRateTracker() # Saves the graph def. if self.params.cluster.task == 0: self._WriteToLog(self.params.ToText(), self._train_dir, 'trainer_params.txt') tf.io.write_graph(self._graph.as_graph_def(), self._train_dir, 'train.pbtxt') worker_id = self.params.cluster.task self._start_up_delay_steps = (((worker_id + 1) * worker_id / 2) * self.params.train.start_up_delay_steps)
def ApplyExponentialMovingAverage(self, ema): """Wraps `self.train_op` with an op updating exponential moving average.""" if (self._create_variables_status != base_layer._CreateLayerVariablesStatus.COMPLETED): # pylint: disable=protected-access raise ValueError( 'ApplyExponentialMovingAverage called before InstantiateVariables!') # TODO(rpang): raise an exception if this is called in the eval mode. p = self.params # We need to apply EMA to trainable and moving average variable of this # Task, not just bprop vars, so that we create a shadow # '/ExponentialMovingAverage' variable for every trainable and moving # average variable. all_vars = set(tf.trainable_variables()) | set( tf.moving_average_variables()) if p.train.ema_decay_moving_vars: all_vars |= set(tf.get_collection('moving_vars')) all_vars &= set(self.vars.Flatten()) for var in all_vars: tf.logging.debug('ApplyExponentialMovingAverage: %s', var.name) with tf.name_scope('moving_average'): self._post_train_ops.append(ema.apply(all_vars))
def AddToPruningCollections(weight, mask, threshold, gradient=None, old_weight=None, old_old_weight=None): """Add mask, threshold, and weight vars to their respective collections.""" if mask not in tf.get_collection(pruning.MASK_COLLECTION): tf.add_to_collection(pruning.WEIGHT_COLLECTION, weight) tf.add_to_collection(pruning.MASK_COLLECTION, mask) tf.add_to_collection(pruning.THRESHOLD_COLLECTION, threshold) # Add gradient, old_weight, and old_old_weight to collections approximating # gradient and hessian, where old_weight is the weight tensor one step # before and old_old_weight is the weight tensor two steps before. if gradient is not None: assert old_weight is not None assert old_old_weight is not None tf.add_to_collection(pruning.WEIGHT_GRADIENT_COLLECTION, gradient) tf.add_to_collection(pruning.OLD_WEIGHT_COLLECTION, old_weight) tf.add_to_collection(pruning.OLD_OLD_WEIGHT_COLLECTION, old_old_weight)
def __init__(self, train_dir, model): """Initialize Checkpointer. Args: train_dir: Training directory for saving checkpoints. model: Model. """ self._train_dir = train_dir self._model = model self._params = model.params self._vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) self._uninitialized_vars = tf.report_uninitialized_variables( self._vars) self._initialize_vars = tf.global_variables_initializer() self._save_path = os.path.join(self._train_dir, 'ckpt') self._model_tasks = model.tasks tp = self._params.train self._save_interval_seconds = tp.save_interval_seconds self._next_checkpoint_seconds = 0 self._saver = self._GetSaver()
def __init__(self, train_dir, model, train_params=None, save_only=False): """Initialize Checkpointer. Args: train_dir: Training directory for saving checkpoints. model: A BaseModel instance or None. train_params: If specified, use these training params instead of those in the `model`. save_only: This checkpointer is only intended for saving checkpoints. """ self._train_dir = train_dir self._save_only = save_only self._vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) self._uninitialized_vars = tf.report_uninitialized_variables( self._vars) self._initialize_vars = tf.global_variables_initializer() self._save_path = os.path.join(self._train_dir, 'ckpt') if train_params: self._train_params = train_params self._model = None else: assert model self._train_params = model.params.train self._model = model if not self._save_only: self._params = model.params self._model_tasks = model.tasks self._model = model self._next_checkpoint_seconds = 0 self._save_interval_seconds = self._train_params.save_interval_seconds self._saver = self._GetSaver()
def __init__(self, params): assert issubclass(params.cls, BaseTask) # Ensure global_step exists before calling super. py_utils.GetOrCreateGlobalStepVar() super(BaseTask, self).__init__(params) p = self.params if p.input: # TODO(zhifengc): Consider a simpler way to ensure the input # generator stops after one epoch. if p.is_eval and p.eval: seq_inp = issubclass(p.input.cls, base_input_generator.BaseInputGeneratorFromFiles) if p.input.num_samples == 0: # Dataset size is unknown. Computes eval summary based on num_samples. assert p.eval.samples_per_summary > 0 elif (p.eval.samples_per_summary == 0) or (p.input.num_samples < p.eval.samples_per_summary): # If we know the dataset size and we want to evaluate the full # set, we need to coordinate the input generator to flush out # all samples so the evaler and decoder compute metrics on the # whole set for each summary step. if seq_inp: p.input.flush_every_n = p.input.num_samples p.eval.samples_per_summary = p.input.num_samples if seq_inp and p.input.num_batcher_threads > 1: tf.logging.warning('input.num_batcher_threads > 1 inside eval mode. ' 'The input generator may not iterate over exactly ' 'one epoch per run') tf.logging.info('input_params: %s', p.input) input_params = self.cluster.PlaceInput(p.input) with py_utils.outside_all_rewrites(): self.CreateChild('input', input_params) self._encoder = None self._online_encoder = None self._decoder = None self._loss = None self._num_predictions = None self._train_op = None self._eval_metrics = {} self._per_example = {} self._trainer_verbose_tensors = {} # Create the gradient mask, self._per_input_gradient_mask = None task_global_step_list = tf.get_collection('TASK_GLOBAL_STEP', '^%s_global_step' % p.name) if len(task_global_step_list) > 1: raise ValueError('Found multiple task_global_step for task %s' % p.name) self._global_step_var = ( task_global_step_list[0] if len(task_global_step_list) == 1 else py_utils.GetOrCreateGlobalStepVar()) self._global_step = tf.identity( self._global_step_var, name='global_step_tensor') tp = p.train # p.train can be None if this task is the teacher/student task in a # DistillationTask. if tp and self.cluster.job in ('worker', 'trainer', 'trainer_client', 'controller', 'executor_tpu'): self._SetLearnerFromLegacyParams(tp) if tp.learner is not None: if isinstance(tp.learner, (list, tuple)): self.CreateChildren('learners', tp.learner) else: self.CreateChildren('learners', [tp.learner]) self._UpdateVnConfig()
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTPUFeeds num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts tf.logging.info('shards {}'.format(shards)) input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if isinstance(batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. batch = batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:') [0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions, 1] for _ in dtypes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) queues.append(q) tf.logging.info('q=%r', q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) self._tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir, tf_master, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. model_task_name: An override for multi-task models, currently unused. logdir: String path to the log directory to output to. tf_master: String path to the master job, e.g. 'local'. **kwargs: keyword args to pass through to BaseRunner. """ super().__init__(train_cfg, model_task_name, logdir, tf_master, **kwargs) data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'trainer_params.txt') if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False # BaseRunner legacy self.enqueue_ops = None train_cfg = self.params @py_utils.RetryOnTransientTfError() def _WaitTillInit(job=None): """Wait until the model is ready.""" try: # tpu.initialize_system() is called with None as embedding_config, as # embedding_config is not available yet. Later in _Loop, it is called # with the correct embedding_config. Since it cannot be called twice in # the same graph with different embedding_config, we use a dummy_graph # here. dummy_graph = tf.Graph() with dummy_graph.as_default(): tpu_initialize_system_op = tf.tpu.initialize_system( embedding_config=None, job=job) with self._GetSession(graph=dummy_graph) as sess: topology = sess.run(tpu_initialize_system_op) if train_cfg.train.tpu_device_order_mode is None: device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split, topology), num_replicas=data_parallelism) else: device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split, topology), num_replicas=data_parallelism, device_order_mode=train_cfg.train.tpu_device_order_mode ) py_utils.SetTpuDeviceAssignment(device_assignment, job) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='init_start', value=None) if len(self._cluster.all_worker_names) > 1: for worker in self._cluster.all_worker_names: _WaitTillInit(worker) else: _WaitTillInit(None) shared_model = self._MaybeConstructSharedModel(train_cfg) self._program_schedule_dict = {} self._programs = [] for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string # If the model was created above, we'll inject it here as a shared_model. ps = program_schedule_params.Instantiate(shared_model=shared_model, tf_master=self._tf_master) self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf tf.logging.info('num_programs: %d', len(self._programs)) with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device(self._cluster.GetPlacer()): with py_utils.VariableRenameScope( self._variable_renaming_rules): _ = py_utils.GetOrCreateGlobalStepVar() for program in self._programs: program.BuildTpuSubgraph() py_utils.ClearTpuSummaryTensors() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self._initialize_global_vars = tf.global_variables_initializer( ) for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) program.CreateCheckpointer( init_op=self._initialize_global_vars) self.save_only_checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, model=None, init_op=self._initialize_global_vars, train_params=train_cfg.train, save_only=True) self._load_ops = tf.get_collection(py_utils.TPU_EMBEDDING_LOAD_OPS) self._retrieve_ops = tf.get_collection( py_utils.TPU_EMBEDDING_RETRIEVE_OPS) tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) self._tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tf.io.write_graph(self._graph.as_graph_def(), self._checkpoint_dir, 'train.pbtxt')
def get_masked_weights(): return tf.get_collection(_MASKED_WEIGHT_COLLECTION)
def get_masks(): return tf.get_collection(_MASK_COLLECTION)
def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True): """Populates the train_ops dictionary in a backwards pass.""" metrics = metrics or self._metrics bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in self._per_input_gradient_mask.items() } all_losses = [] for optimization in self.learners: learner_name = optimization.params.name (losses, train_ops['train/%s' % learner_name], eval_metrics) = optimization.Apply( metrics, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) all_losses.extend(losses) if add_summary: for key, (value, weight) in eval_metrics.items(): self.AddEvalMetric(key + '/' + learner_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates var_update_ops = [ tf.group(*tf.nest.flatten(train_ops), name='var_update_ops') ] # Post training step update. with tf.control_dependencies(var_update_ops): post_step_op = self.PostTrainingStepUpdate() train_ops = {} with tf.control_dependencies([post_step_op]): # Get the op to update the weight masks and thresholds mask_update_op = self._GetMaskUpdateOp() train_ops['mask_updates'] = mask_update_op with tf.control_dependencies([mask_update_op]): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.ops.colocate_with(true_global_step): if self.params.defer_global_step_update: increment_global_steps = true_global_step else: increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.ops.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. if tf.get_collection(py_utils.TPU_EMBEDDING): tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] sparse_grads = ( tpu_embedding_gradient.get_gradients_through_dummy_table_variables( tpu_embedding)) tpu_embedding_send_gradient_op = tpu_embedding.generate_send_gradients_op( sparse_grads, py_utils.GetGlobalStep()) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op tpu_embedding_summary_tensors = tf.get_collection( py_utils.TPU_EMBEDDING_SUMMARY_TENSORS) if add_summary: for name, value, weight in tpu_embedding_summary_tensors: self.AddEvalMetric(name, value, weight, raise_if_already_added=False) for op_name, op in train_ops.items(): assert op is not None, op_name return train_ops
def CreateTpuEnqueueOps(self): """Create the host-side enqueue ops. This should be called in an outer non-TPU context. """ assert not self._tpu_queues, ( 'CreateTpuEnqueueOps should only be called ' 'once.') self._tpu_queues = [] p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTpuEnqueueOps num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 shards = (cluster.total_worker_devices // num_infeed_hosts) // cluster.num_devices_per_split tf.logging.info('shards {}'.format(shards)) input_ops_list = [] tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts) for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): self._batch = self.GetPreprocessedInputBatch() if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) for k, x in self._batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = self._batch.Transform(lambda x: x.shape).Flatten() dtypes = self._batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:')[0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions] + [1] * (len(s) - 1) for s in shapes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) self._tpu_queues.append(q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([self._batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment() if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( self._batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( self._batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment()) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) grouped_infeed_op = tf.group(*input_ops_list) self._tpu_infeed_op = [] for _ in range(p.tpu_infeed_parallelism): self._tpu_infeed_op.append(grouped_infeed_op)
def _CreateLayerVariables(self): super()._CreateLayerVariables() p = self.params def _BuildTpuEmbeddingApi(): load_op_list = [] retrieve_op_list = [] num_cores = self.cluster.params.worker.tpus_per_replica global_batch_size = (self.params.batch_size * self.cluster.num_splits_per_client) table_to_config_dict = {} feature_to_config_dict = {} for table in self.tables: table_to_config_dict[table.table_name] = table.table_config load_op_list += table.load_op_list retrieve_op_list += table.retrieve_op_list for feature in table.input_keys: feature_to_config_dict[ feature] = tpu_embedding_lib.FeatureConfig( table.table_name, max_sequence_length=table.max_sequence_length) mode = tpu_embedding_lib.TRAINING device_config = tpu_embedding_lib.DeviceConfig( num_cores=num_cores, num_hosts=self.params.tables[0].num_tpu_hosts, job_name=self.cluster.params.worker.name) tpu_embedding = tpu_embedding_lib.TPUEmbedding( table_to_config_dict, feature_to_config_dict, global_batch_size, mode, master=None, pipeline_execution_with_tensor_core=( self.params.pipeline_execution_with_tensor_core), partition_strategy=p.partition_strategy, device_config=device_config) with tf.init_scope(): dummy_variables, dummy_variables_init = ( tpu_embedding_gradient.create_dummy_table_variables( tpu_embedding)) load_op_list += [dummy_variables_init] tf.add_to_collection(py_utils.TPU_EMBEDDING, tpu_embedding) tf.add_to_collection(py_utils.TPU_EMBEDDING_DUMMY_VARS, dummy_variables) tf.add_to_collection(py_utils.TPU_EMBEDDING_LOAD_OPS, load_op_list) tf.add_to_collection(py_utils.TPU_EMBEDDING_RETRIEVE_OPS, retrieve_op_list) if py_utils.use_tpu(): # At the feature level, track which are associated # with "sequence embeddings". self._sequence_features = {} for table in self.tables: for feature in table.input_keys: if table.max_sequence_length > 0: self._sequence_features[feature] = True # Multiple TPUEmbeddingLayers can be created but we must have a singleton # TPUEmbedding API object tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) assert len(tpu_embedding_collection) <= 1 if not tpu_embedding_collection: _BuildTpuEmbeddingApi() else: tf.logging.info( 'TPUEmbedding API singleton already exists, reusing') self._tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] self._dummy_variables = tf.get_collection( py_utils.TPU_EMBEDDING_DUMMY_VARS)[0] for k, v in self._dummy_variables.items(): self._private_vars[k] = v self._private_theta[k] = v
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if 'bucket_keys' in batch: # Hack: bucket_keys are not needed on TPU. del batch['bucket_keys'] tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) # For executor-driven multiple programs, we need more fine-grained # access rather than using a single global graph collection. self.tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def _BPropForVariables(self, vmap): """Constructs the backward graph.""" bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in six.iteritems(self._per_input_gradient_mask) } all_losses = [] for optimization in self.learners: loss_name = optimization.params.name metric = self._metrics.get(loss_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (loss_name, list(self._metrics.keys()))) loss = metric[0] all_losses.append(loss) train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply( loss, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) for key, (value, weight) in six.iteritems(eval_metrics): self.AddEvalMetric(key + '/' + loss_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates # Get the op to update the weight masks and thresholds train_ops['mask_updates'] = self._GetMaskUpdateOp() # Post training step update. train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step) with tf.control_dependencies(tf.nest.flatten(train_ops)): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.colocate_with(true_global_step): increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if tpu_embedding_activations: tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients( self.loss, tpu_embedding_activations_dict, tpu_embedding) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op for op_name, op in six.iteritems(train_ops): assert op is not None, op_name # TODO(rpang): try to structure _train_op as: # tf.cond(skip_step, <only update skip stats>, <all updates>) # so that we skip all other updates when a step is skipped. self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')
def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True): """Populates the train_ops dictionary in a backwards pass.""" metrics = metrics or self._metrics bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in self._per_input_gradient_mask.items() } all_losses = [] for optimization in self.learners: learner_name = optimization.params.name loss_name = optimization.params.loss_name or learner_name metric = metrics.get(loss_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (loss_name, list(metrics.keys()))) loss = metric[0] all_losses.append(loss) train_ops['train/%s' % learner_name], eval_metrics = optimization.Apply( loss, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) if add_summary: for key, (value, weight) in eval_metrics.items(): self.AddEvalMetric(key + '/' + learner_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates var_update_ops = [ tf.group(*tf.nest.flatten(train_ops), name='var_update_ops') ] # Post training step update. with tf.control_dependencies(var_update_ops): post_step_op = self.PostTrainingStepUpdate(self.global_step) train_ops = {} with tf.control_dependencies([post_step_op]): # Get the op to update the weight masks and thresholds mask_update_op = self._GetMaskUpdateOp() train_ops['mask_updates'] = mask_update_op with tf.control_dependencies([mask_update_op]): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.ops.colocate_with(true_global_step): increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.ops.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if tpu_embedding_activations: tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients( self.loss, tpu_embedding_activations_dict, tpu_embedding) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op for op_name, op in train_ops.items(): assert op is not None, op_name return train_ops
def get_thresholds(): return tf.get_collection(_THRESHOLD_COLLECTION)
def get_weights(): return tf.get_collection(_WEIGHT_COLLECTION)