def cache_fn(batch_size): mems = [] if FLAGS.mem_len > 0: for _ in range(FLAGS.n_layer): if FLAGS.chunk_len is not None: zeros = tf.zeros( [batch_size, mem_len, FLAGS.chunk_len, FLAGS.d_model], dtype=tf_float) else: zeros = tf.zeros([batch_size, mem_len, FLAGS.d_model], dtype=tf_float) mems.append(zeros) return mems
def cache_fn(batch_size): """Call back function used to create initial cache in TPUEstimator.""" mems = [] for obj_len in [FLAGS.seq_len]: if obj_len > 0: for _ in range(FLAGS.n_layer + int(FLAGS.use_extra_layer)): zeros = tf.zeros([batch_size, mem_len, FLAGS.d_model], dtype=tf_float) mems.append(zeros) return mems
def _transform_images(self, params, features, labels=None): """Transforms images.""" images = features['images'] batch_size, _, _, c = images.get_shape().as_list() if params['conv0_space_to_depth_block_size'] != 0: # Transforms (space-to-depth) images for TPU performance. def _fused_transform(images, image_size): return spatial_transform.fused_transpose_and_space_to_depth( images, image_size, params['conv0_space_to_depth_block_size'], params['transpose_input']) images = tf.cond( tf.less(features['image_info'][0, 3], features['image_info'][0, 4]), lambda: _fused_transform(images, params['image_size']), lambda: _fused_transform(images, params['image_size'][::-1])) else: # Transposes images for TPU performance. image_area = params['image_size'][0] * params['image_size'][1] if params['transpose_input']: images = tf.transpose(images, [1, 2, 0, 3]) # Flattens spatial dimensions so that the image tensor has a static # shape. images = tf.reshape(images, [image_area, batch_size, c]) else: images = tf.reshape(images, [batch_size, image_area, c]) if params['use_bfloat16']: images = tf.cast(images, dtype=tf.bfloat16) features['images'] = images if labels is not None: return features, labels else: return features, tf.zeros([batch_size])
def zeroed_box_fn(): return tf.zeros([0, self._ori_height, self._ori_width, 1])
def parser(record): """function used to parse tfrecord.""" record_spec = { "input": tf.FixedLenFeature([seq_len], tf.int64), "target": tf.FixedLenFeature([seq_len], tf.int64), "seg_id": tf.FixedLenFeature([seq_len], tf.int64), "label": tf.FixedLenFeature([1], tf.int64), "is_masked": tf.FixedLenFeature([seq_len], tf.int64), } # retrieve serialized example example = tf.parse_single_example( serialized=record, features=record_spec) inputs = example.pop("input") target = example.pop("target") is_masked = tf.cast(example.pop("is_masked"), tf.bool) non_reuse_len = seq_len - reuse_len assert perm_size <= reuse_len and perm_size <= non_reuse_len perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm( inputs[:reuse_len], target[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len) perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm( inputs[reuse_len:], target[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len) perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) target = tf.concat([target_0, target_1], axis=0) target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) input_k = tf.concat([input_k_0, input_k_1], axis=0) input_q = tf.concat([input_q_0, input_q_1], axis=0) if num_predict is not None: indices = tf.range(seq_len, dtype=tf.int64) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[0] pad_len = num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) example["target_mapping"] = tf.reshape(target_mapping, [num_predict, seq_len]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) example["target"] = tf.reshape(target, [num_predict]) ##### target mask target_mask = tf.concat( [tf.ones([actual_num_predict], dtype=tf.float32), tf.zeros([pad_len], dtype=tf.float32)], axis=0) example["target_mask"] = tf.reshape(target_mask, [num_predict]) else: example["target"] = tf.reshape(target, [seq_len]) example["target_mask"] = tf.reshape(target_mask, [seq_len]) # reshape back to fixed shape example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) example["input_k"] = tf.reshape(input_k, [seq_len]) example["input_q"] = tf.reshape(input_q, [seq_len]) _convert_example(example, use_bfloat16) for k, v in example.items(): tf.logging.info("%s: %s", k, v) return example