def parse_tfexample_sequence(example_proto, img_height=48, img_width=48, action_size=None, episode_length=16): """Parse TFExamples saved by episode_to_transitions. Args: example_proto: tf.String tensor representing a serialized protobuf. img_height: Height of parsed image tensors. img_width: Width of parsed image tensors. action_size: Size of continuous actions. If None, actions are assumed to be integer-encoded discrete actions. episode_length: Intended length of each episode. Returns: NamedTuple of type SARSTransition containing unbatched Tensors. """ if action_size is None: # Is discrete. action_feature_spec = tf.FixedLenFeature((episode_length, ), tf.int64) else: # Vector-encoded float feature. action_feature_spec = tf.FixedLenFeature((episode_length, action_size), tf.float32) features = { 'S/img': tf.FixedLenFeature((episode_length, ), tf.string), 'A': action_feature_spec, 'R': tf.FixedLenFeature((episode_length, ), tf.float32), 'S_p1/img': tf.FixedLenFeature((episode_length, ), tf.string), 'done': tf.FixedLenFeature((episode_length, ), tf.int64), 't': tf.FixedLenFeature((episode_length, ), tf.int64) } parsed_features = tf.parse_single_example(example_proto, features) # Decode the jpeg-encoded images into numeric tensors. states = [] for key in 'S/img', 'S_p1/img': state = tf.stack([ tf.image.decode_jpeg(img, channels=3) for img in tf.unstack(parsed_features[key], num=episode_length) ]) state.set_shape([episode_length, img_height, img_width, 3]) states.append(tf.cast(state, tf.float32)) action = parsed_features['A'] reward = parsed_features['R'] done = tf.cast(parsed_features['done'], tf.float32) step = tf.cast(parsed_features['t'], tf.int32) aux = {'step': step} return SARSTransition((states[0], step), action, reward, (states[1], step + 1), done, aux)
def Read(record_file): keys_to_features = { 'view1/image/encoded': tf.FixedLenFeature((), dtype=tf.string, default_value=''), 'view1/image/format': tf.FixedLenFeature([], dtype=tf.string, default_value='png'), 'view1/image/height': tf.FixedLenFeature([1], dtype=tf.int64, default_value=64), 'view1/image/width': tf.FixedLenFeature([1], dtype=tf.int64, default_value=64), 'view2/image/encoded': tf.FixedLenFeature((), dtype=tf.string, default_value=''), 'view2/image/format': tf.FixedLenFeature([], dtype=tf.string, default_value='png'), 'view2/image/height': tf.FixedLenFeature([1], dtype=tf.int64, default_value=64), 'view2/image/width': tf.FixedLenFeature([1], dtype=tf.int64, default_value=64), 'image/encoded': tf.FixedLenFeature([2], dtype=tf.string, default_value=['', '']), 'same_object': tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1), 'relative_pos': tf.FixedLenFeature([3], dtype=tf.float32), } with tf.Graph().as_default(): filename_queue = tf.train.string_input_producer([record_file], capacity=10) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) example = tf.parse_single_example(serialized_example, keys_to_features) #png1 = example['view1/image/encoded'] #png2 = example['view2/image/encoded'] png = example['image/encoded'] coord = tf.train.Coordinator() print 'Reading images:' with tf.Session() as sess: queue_threads = tf.start_queue_runners(sess=sess, coord=coord) #image1, image2 = sess.run([png1, png2]) image1, image2 = sess.run([png[0], png[1]]) publish.image(encoded_image=image1, width=20) publish.image(encoded_image=image2, width=20) coord.request_stop() coord.join(queue_threads)
def parser(record): """function used to parse tfrecord.""" record_spec = { "input": tf.FixedLenFeature([seq_len], tf.int64), "target": tf.FixedLenFeature([seq_len], tf.int64), "seg_id": tf.FixedLenFeature([seq_len], tf.int64), "label": tf.FixedLenFeature([1], tf.int64), "is_masked": tf.FixedLenFeature([seq_len], tf.int64), } # retrieve serialized example example = tf.parse_single_example( serialized=record, features=record_spec) inputs = example.pop("input") target = example.pop("target") is_masked = tf.cast(example.pop("is_masked"), tf.bool) non_reuse_len = seq_len - reuse_len assert perm_size <= reuse_len and perm_size <= non_reuse_len perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm( inputs[:reuse_len], target[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len) perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm( inputs[reuse_len:], target[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len) perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) target = tf.concat([target_0, target_1], axis=0) target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) input_k = tf.concat([input_k_0, input_k_1], axis=0) input_q = tf.concat([input_q_0, input_q_1], axis=0) if num_predict is not None: indices = tf.range(seq_len, dtype=tf.int64) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[0] pad_len = num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) example["target_mapping"] = tf.reshape(target_mapping, [num_predict, seq_len]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) example["target"] = tf.reshape(target, [num_predict]) ##### target mask target_mask = tf.concat( [tf.ones([actual_num_predict], dtype=tf.float32), tf.zeros([pad_len], dtype=tf.float32)], axis=0) example["target_mask"] = tf.reshape(target_mask, [num_predict]) else: example["target"] = tf.reshape(target, [seq_len]) example["target_mask"] = tf.reshape(target_mask, [seq_len]) # reshape back to fixed shape example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) example["input_k"] = tf.reshape(input_k, [seq_len]) example["input_q"] = tf.reshape(input_q, [seq_len]) _convert_example(example, use_bfloat16) for k, v in example.items(): tf.logging.info("%s: %s", k, v) return example