def parse_serialized_simulation_example(example_proto, metadata):
    """Parses a serialized simulation tf.SequenceExample.

  Args:
    example_proto: A string encoding of the tf.SequenceExample proto.
    metadata: A dict of metadata for the dataset.

  Returns:
    context: A dict, with features that do not vary over the trajectory.
    parsed_features: A dict of tf.Tensors representing the parsed examples
      across time, where axis zero is the time axis.

  """
    if 'context_mean' in metadata:
        feature_description = _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT
    else:
        feature_description = _FEATURE_DESCRIPTION

    context, parsed_features = tf.io.parse_single_sequence_example(
        example_proto,
        context_features=_CONTEXT_FEATURES,
        sequence_features=feature_description)

    for feature_key, item in parsed_features.items():
        convert_fn = functools.partial(
            convert_to_tensor,
            encoded_dtype=_FEATURE_DTYPES[feature_key]['in'])
        parsed_features[feature_key] = tf.py_function(
            convert_fn,
            inp=[item.values],
            Tout=_FEATURE_DTYPES[feature_key]['out'])

    # There is an extra frame at the beginning so we can calculate pos change
    # for all frames used in the paper.
    position_shape = [metadata['sequence_length'] + 1, -1, metadata['dim']]
    # Reshape positions to correct dim:
    parsed_features['position'] = tf.reshape(parsed_features['position'],
                                             position_shape)
    # Set correct shapes of the remaining tensors.
    sequence_length = metadata['sequence_length'] + 1
    if 'context_mean' in metadata:
        context_feat_len = len(metadata['context_mean'])
        parsed_features['step_context'] = tf.reshape(
            parsed_features['step_context'],
            [sequence_length, context_feat_len])
    # Decode particle type explicitly
    context['particle_type'] = tf.py_function(
        functools.partial(convert_fn, encoded_dtype=np.int64),
        inp=[context['particle_type'].values],
        Tout=[tf.int64])

    context['particle_type'] = tf.reshape(context['particle_type'], [-1])
    return context, parsed_features
Example #2
0
    def receive_op(self, name, dtype):
        def func():
            assert self._current_iter_id is not None, "Bridge not started"
            x = self.receive(self._current_iter_id, name)
            return tf.convert_to_tensor(x, dtype=dtype)

        return tf.py_function(func=func, inp=[], Tout=[dtype])[0]
Example #3
0
    def send_op(self, name, x):
        def func(x):
            assert self._current_iter_id is not None, "Bridge not started"
            self.send(self._current_iter_id, name, x.numpy())

        out = tf.py_function(func=func, inp=[x], Tout=[], name='send_' + name)
        return out
 def tf_encode(x):
     result = tf.py_function(
         lambda s: tf.constant(encoder.encode(s.numpy())), [
             x,
         ], tf.int32)
     result.set_shape([None])
     return result
Example #5
0
    def receive_op(self, name, dtype):
        def func():
            return tf.convert_to_tensor(self.receive(name), dtype=dtype)

        return tf.py_function(func=func,
                              inp=[],
                              Tout=dtype,
                              name='recv_' + name)
    def decode_tf(self, ids):
        """Decode in TensorFlow.

    Args:
      ids: a 1d tf.Tensor with dtype tf.int32
    Returns:
      a tf Scalar with dtype tf.string
    """
        return tf.py_function(func=self.decode, inp=[ids], Tout=tf.string)
Example #7
0
def compute_connectivity_for_batch_pyfunc(positions, n_node, radius):
    """`_compute_connectivity_for_batch` wrapped in a pyfunc."""
    senders, receivers, n_edge = tf.py_function(
        _compute_connectivity_for_batch, [positions, n_node, radius],
        [tf.int32, tf.int32, tf.int32])
    senders.set_shape([None])
    receivers.set_shape([None])
    n_edge.set_shape(n_node.get_shape())
    return senders, receivers, n_edge
def tf_parse_function(im1_filename, im2_filename, flo_filename):
    [im1, im2, flo,
     mode] = tf.py_function(_parse_function,
                            inp=[im1_filename, im2_filename, flo_filename],
                            Tout=[tf.uint8, tf.uint8, tf.float32, tf.uint8])

    im1 = tf.cast(im1, tf.float32) / 255.0
    im2 = tf.cast(im2, tf.float32) / 255.0

    return im1, im2, flo, mode
Example #9
0
def compute_connectivity_for_batch_pyfunc(node_connections, node_locations,
                                          n_node):
    senders, receivers, n_edge = tf.py_function(
        _compute_connectivity_for_batch, [node_connections, node_locations],
        [tf.int32, tf.int32])
    senders.set_shape([None])
    receivers.set_shape([None])
    #TODO check n_node
    n_edge.set_shape(n_node.get_shape())
    return senders, receivers, n_edge
Example #10
0
    def gan_cap(x):
        print(x.shape)
        # x = self.generator_predict(x)["generator_output"]

        x = tf_v1.py_function(call_generator, [x], tf_v1.float32)
        # x = self.generator_predict(x)["generator_output"]
        x = tf_v1.reshape(x, [-1, 256, 256, 3])

        x = tf_v1.image.resize(x, (64, 64))
        x = tf_v1.reshape(x, [-1, 64, 64, 3])
        return x
Example #11
0
def compute_connectivity_for_batch_pyfunc(positions,
                                          n_node,
                                          radius,
                                          add_self_edges=True):
    """`_compute_connectivity_for_batch` wrapped in a pyfunc."""
    partial_fn = functools.partial(_compute_connectivity_for_batch,
                                   add_self_edges=add_self_edges)
    senders, receivers, n_edge = tf.py_function(partial_fn,
                                                [positions, n_node, radius],
                                                [tf.int32, tf.int32, tf.int32])
    senders.set_shape([None])
    receivers.set_shape([None])
    n_edge.set_shape(n_node.get_shape())
    return senders, receivers, n_edge
Example #12
0
def tf_plot_1d_signal(name, signals, labels, max_outputs=3, step=None):
    """Visualizes a list of 1d signals.

  Args:
    name: name of the summary.
    signals: a [batch, lines, steps] tensor, each line a 1d signal.
    labels: a [lines] list of labels for each signal.
    max_outputs: the maximum number of plots to add to summaries.
    step: an explicit step or None.

  Returns:
    the summary result.
  """
    image = tf.py_function(
        plot_1d_signals,
        (signals, labels, tf.math.minimum(max_outputs,
                                          tf.shape(signals)[0])), tf.uint8)
    return tfs.image(name, image, step, max_outputs=max_outputs)
Example #13
0
    def send_op(self, name, x):
        def func(x):
            self.send(name, x)

        return tf.py_function(func=func, inp=[x], Tout=[], name='send_' + name)
Example #14
0
 def _mapper(image):
     image_aug = tf.py_function(augment, [image], image.dtype)
     image_aug.set_shape(image.shape)
     return image_aug
Example #15
0
    def receive_op(self, name, dtype):
        def func():
            raise RuntimeError("Unexcepted call receive op")

        return tf.py_function(func=func, inp=[], Tout=[dtype])[0]
Example #16
0
    def send_op(self, name, x):
        def func(x):
            raise RuntimeError("Unexcepted call send op")

        out = tf.py_function(func=func, inp=[x], Tout=[], name='send_' + name)
        return out