def testDynamicUpdateSlice(self): a = array_ops.placeholder(np.float32, shape=(2, 3, 4)) upd = array_ops.placeholder(np.float32, shape=(1, 2, 3)) start_indices = array_ops.placeholder(np.int32, shape=(3, )) res = xla.dynamic_update_slice(a, upd, start_indices) self.assertEqual(res.shape.as_list(), [2, 3, 4]) a = array_ops.placeholder(np.float32, shape=(None, 3, None)) res = xla.dynamic_update_slice(a, upd, start_indices) self.assertEqual(res.shape.as_list(), [None, 3, None])
def dlrm_llr_model_fn(params, feature_config, features, labels, is_training, eval_step_num=None, predictions=None): """Model fn. Args: params: Params dict for the model. feature_config: Configuration of features. features: Features dict for the model. labels: Labels tensor. Not used for this model. is_training: Boolean, True if training. eval_step_num: Int tensor, representing the batch number during eval. predictions: [num_batches, batch_size, 2] tensor holding all predictions. Returns: [train_op, predictions] """ assert labels is None, "Labels should be None. Reconfigure." labels = features["clicked"] preds, _ = logits_fn(features, params, feature_config) global_step = tf.train.get_or_create_global_step() if is_training: bce_func = tf.keras.losses.BinaryCrossentropy( from_logits=False, reduction=tf.compat.v2.keras.losses.Reduction.NONE) loss = tf.reduce_mean(bce_func(labels, preds)) learning_rate = utils.lr_fn(params, global_step) optimizer = ConditionalOptimizer(params, learning_rate, global_step) optimizer = tf.tpu.CrossShardOptimizer(optimizer) train_op = contrib_layers.optimize_loss( name="training", loss=loss, global_step=global_step, learning_rate=learning_rate, optimizer=optimizer, colocate_gradients_with_ops=True) return train_op, None else: # TODO(tayo): Consider adding a local key-value sort. new_preds = tf.concat([preds, tf.cast(labels, tf.float32)], axis=1) predictions = xla.dynamic_update_slice( predictions, tf.expand_dims(new_preds, axis=0), tf.stack([eval_step_num, tf.constant(0), tf.constant(0)])) return None, dict(results=predictions)
def _dynamic_update_slice(operand, update, *start_indices): return tfxla.dynamic_update_slice(*promote_types(operand, update), tf.stack(start_indices))
def _dynamic_update_slice(operand, update, *start_indices): return tfxla.dynamic_update_slice(operand, update, tf.stack(start_indices))