Exemplo n.º 1
0
def parse_tfexample_sequence(example_proto,
                             img_height=48,
                             img_width=48,
                             action_size=None,
                             episode_length=16):
    """Parse TFExamples saved by episode_to_transitions.

  Args:
    example_proto: tf.String tensor representing a serialized protobuf.
    img_height: Height of parsed image tensors.
    img_width: Width of parsed image tensors.
    action_size: Size of continuous actions. If None, actions are assumed to be
      integer-encoded discrete actions.
    episode_length: Intended length of each episode.
  Returns:
    NamedTuple of type SARSTransition containing unbatched Tensors.
  """
    if action_size is None:
        # Is discrete.
        action_feature_spec = tf.FixedLenFeature((episode_length, ), tf.int64)
    else:
        # Vector-encoded float feature.
        action_feature_spec = tf.FixedLenFeature((episode_length, action_size),
                                                 tf.float32)

    features = {
        'S/img': tf.FixedLenFeature((episode_length, ), tf.string),
        'A': action_feature_spec,
        'R': tf.FixedLenFeature((episode_length, ), tf.float32),
        'S_p1/img': tf.FixedLenFeature((episode_length, ), tf.string),
        'done': tf.FixedLenFeature((episode_length, ), tf.int64),
        't': tf.FixedLenFeature((episode_length, ), tf.int64)
    }
    parsed_features = tf.parse_single_example(example_proto, features)
    # Decode the jpeg-encoded images into numeric tensors.
    states = []
    for key in 'S/img', 'S_p1/img':
        state = tf.stack([
            tf.image.decode_jpeg(img, channels=3)
            for img in tf.unstack(parsed_features[key], num=episode_length)
        ])
        state.set_shape([episode_length, img_height, img_width, 3])
        states.append(tf.cast(state, tf.float32))

    action = parsed_features['A']
    reward = parsed_features['R']
    done = tf.cast(parsed_features['done'], tf.float32)
    step = tf.cast(parsed_features['t'], tf.int32)
    aux = {'step': step}

    return SARSTransition((states[0], step), action, reward,
                          (states[1], step + 1), done, aux)
Exemplo n.º 2
0
  def parser(record):
    """function used to parse tfrecord."""

    record_spec = {
        "input": tf.FixedLenFeature([seq_len], tf.int64),
        "target": tf.FixedLenFeature([seq_len], tf.int64),
        "seg_id": tf.FixedLenFeature([seq_len], tf.int64),
        "label": tf.FixedLenFeature([1], tf.int64),
        "is_masked": tf.FixedLenFeature([seq_len], tf.int64),
    }

    # retrieve serialized example
    example = tf.parse_single_example(
        serialized=record,
        features=record_spec)

    inputs = example.pop("input")
    target = example.pop("target")
    is_masked = tf.cast(example.pop("is_masked"), tf.bool)

    non_reuse_len = seq_len - reuse_len
    assert perm_size <= reuse_len and perm_size <= non_reuse_len

    perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
        inputs[:reuse_len],
        target[:reuse_len],
        is_masked[:reuse_len],
        perm_size,
        reuse_len)

    perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
        inputs[reuse_len:],
        target[reuse_len:],
        is_masked[reuse_len:],
        perm_size,
        non_reuse_len)

    perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
                            axis=1)
    perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
                            axis=1)
    perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
    target = tf.concat([target_0, target_1], axis=0)
    target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
    input_k = tf.concat([input_k_0, input_k_1], axis=0)
    input_q = tf.concat([input_q_0, input_q_1], axis=0)

    if num_predict is not None:
      indices = tf.range(seq_len, dtype=tf.int64)
      bool_target_mask = tf.cast(target_mask, tf.bool)
      indices = tf.boolean_mask(indices, bool_target_mask)

      ##### extra padding due to CLS/SEP introduced after prepro
      actual_num_predict = tf.shape(indices)[0]
      pad_len = num_predict - actual_num_predict

      ##### target_mapping
      target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
      paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
      target_mapping = tf.concat([target_mapping, paddings], axis=0)
      example["target_mapping"] = tf.reshape(target_mapping,
                                             [num_predict, seq_len])

      ##### target
      target = tf.boolean_mask(target, bool_target_mask)
      paddings = tf.zeros([pad_len], dtype=target.dtype)
      target = tf.concat([target, paddings], axis=0)
      example["target"] = tf.reshape(target, [num_predict])

      ##### target mask
      target_mask = tf.concat(
          [tf.ones([actual_num_predict], dtype=tf.float32),
           tf.zeros([pad_len], dtype=tf.float32)],
          axis=0)
      example["target_mask"] = tf.reshape(target_mask, [num_predict])
    else:
      example["target"] = tf.reshape(target, [seq_len])
      example["target_mask"] = tf.reshape(target_mask, [seq_len])

    # reshape back to fixed shape
    example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
    example["input_k"] = tf.reshape(input_k, [seq_len])
    example["input_q"] = tf.reshape(input_q, [seq_len])

    _convert_example(example, use_bfloat16)

    for k, v in example.items():
      tf.logging.info("%s: %s", k, v)

    return example
Exemplo n.º 3
0
def Read(record_file):
    keys_to_features = {
        'view1/image/encoded':
        tf.FixedLenFeature((), dtype=tf.string, default_value=''),
        'view1/image/format':
        tf.FixedLenFeature([], dtype=tf.string, default_value='png'),
        'view1/image/height':
        tf.FixedLenFeature([1], dtype=tf.int64, default_value=64),
        'view1/image/width':
        tf.FixedLenFeature([1], dtype=tf.int64, default_value=64),
        'view2/image/encoded':
        tf.FixedLenFeature((), dtype=tf.string, default_value=''),
        'view2/image/format':
        tf.FixedLenFeature([], dtype=tf.string, default_value='png'),
        'view2/image/height':
        tf.FixedLenFeature([1], dtype=tf.int64, default_value=64),
        'view2/image/width':
        tf.FixedLenFeature([1], dtype=tf.int64, default_value=64),
        'image/encoded':
        tf.FixedLenFeature([2], dtype=tf.string, default_value=['', '']),
        'same_object':
        tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1),
        'relative_pos':
        tf.FixedLenFeature([3], dtype=tf.float32),
    }
    with tf.Graph().as_default():
        filename_queue = tf.train.string_input_producer([record_file],
                                                        capacity=10)
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        example = tf.parse_single_example(serialized_example, keys_to_features)
        #png1 = example['view1/image/encoded']
        #png2 = example['view2/image/encoded']
        png = example['image/encoded']
        coord = tf.train.Coordinator()
        print 'Reading images:'
        with tf.Session() as sess:
            queue_threads = tf.start_queue_runners(sess=sess, coord=coord)
            #image1, image2 = sess.run([png1, png2])
            image1, image2 = sess.run([png[0], png[1]])
            publish.image(encoded_image=image1, width=20)
            publish.image(encoded_image=image2, width=20)
        coord.request_stop()
        coord.join(queue_threads)