예제 #1
0
 def staggered_shape(self, batch_size=1, name=None, extrapolation=None, age=0.0):
     with struct.unsafe():
         grids = []
         for axis in range(self.rank):
             shape = _extend1(tensor_shape(batch_size, self.resolution, 1), axis)
             box = staggered_component_box(self.resolution, axis, self.box)
             grid = CenteredGrid(shape, box, age=age, extrapolation=extrapolation, name=None, batch_size=batch_size)
             grids.append(grid)
         return StaggeredGrid(grids, age=age, box=self.box, name=name, batch_size=batch_size, extrapolation=extrapolation)
예제 #2
0
def broadcast_function(backend, func, args, kwargs):
    backend_func = getattr(backend, func)
    obj, build_arguments = argument_assembler(args, kwargs)

    def f(*values):
        args, kwargs = build_arguments(values)
        result = backend_func(*args, **kwargs)
        return result
    with struct.unsafe():
        return struct.map(f, obj)
예제 #3
0
def _transform_for_writing(obj):
    def f(value):
        if isinstance(value, field.StaggeredGrid):
            return value.staggered_tensor()
        if isinstance(value, field.CenteredGrid):
            return value.data
        else:
            return value
    with struct.unsafe():
        data = struct.map(f, obj, lambda x: isinstance(x, (field.StaggeredGrid, field.CenteredGrid)))
    return data
예제 #4
0
 def test_identity(self):
     for obj in generate_test_structs():
         with struct.unsafe():
             obj2 = struct.map(lambda s: s, obj, recursive=False)
             self.assertEqual(obj, obj2)
             obj3 = struct.map(lambda t: t, obj, recursive=True)
             self.assertEqual(obj, obj3)
             obj4 = struct.map(lambda t: t,
                               obj,
                               item_condition=struct.ALL_ITEMS)
             self.assertEqual(obj, obj4)
예제 #5
0
 def _add_default_fields(self):
     def add_default_field(trace):
         field = trace.value
         if isinstance(field, (CenteredGrid, StaggeredGrid)):
             def field_generator():
                 world_state = self.world.state
                 return trace.find_in(world_state)
             self.add_field(field.name[0].upper() + field.name[1:], field_generator)
         return None
     with struct.unsafe():
         struct.map(add_default_field, self.world.state, leaf_condition=lambda x: isinstance(x, (CenteredGrid, StaggeredGrid)), trace=True)
예제 #6
0
 def shape(self):
     with struct.unsafe():
         if math.ndims(self.data) > 0:
             data_shape = (self._batch_size, self._point_count,
                           self.component_count)
         else:
             data_shape = ()
         return self.copied_with(data=data_shape,
                                 sample_points=(self._batch_size,
                                                self._point_count,
                                                self.rank))
예제 #7
0
 def test_struct_placeholders(self):
     bounds = box[0:1]  # outside unsafe
     with struct.unsafe():
         obj = ([4], CenteredGrid([1, 4, 1], bounds), ([9], [8, 2]))
     tensorflow.reset_default_graph()
     p = placeholder(obj)
     self.assertEqual(p[0].name, '0:0')
     self.assertEqual(p[1].data.name, '1/data:0')
     self.assertIsInstance(p, tuple)
     p2 = placeholder_like(p)
     self.assertIsInstance(p2, tuple)
     numpy.testing.assert_equal(p2[1].data.shape.as_list(), [1, 4, 1])
예제 #8
0
def load_state(state):
    if isinstance(state, StateProxy):
        state = state.state
    assert isinstance(state, State)
    state = _transform_for_writing(state)
    names = struct.names(state)
    with struct.unsafe():
        placeholders = placeholder(state.shape)
    state_in = struct.map(
        lambda x: x,
        placeholders)  # validates fields, splits staggered tensors
    return state_in, {placeholders: names}
예제 #9
0
    def test_copy(self):
        with struct.unsafe():
            fluid = Fluid(Domain([4]), density='Density', velocity='Velocity')
            v = fluid.copied_with(velocity='V2')
            self.assertEqual(v.velocity, 'V2')
            self.assertEqual(v.density, 'Density')

            try:
                fluid.copied_with(velocity='D2')
                self.fail()
            except AssertionError:
                pass
예제 #10
0
    def shape(self):
        """
Similar to phi.math.shape(self) but respects unknown dimensions.
        """
        def tensorshape(tensor):
            if tensor is None: return None
            default_batched_shape = staticshape(tensor)
            if len(default_batched_shape) >= 2:
                return (self._batch_size,) + default_batched_shape[1:]
            else:
                return default_batched_shape
        with struct.unsafe():
            return struct.map(tensorshape, self, item_condition=struct.VARIABLES)
예제 #11
0
파일: domain.py 프로젝트: syyunn/PhiFlow
 def centered_shape(self,
                    components=1,
                    batch_size=1,
                    name=None,
                    extrapolation=None,
                    age=0.0):
     warnings.warn(
         "Domain.centered_shape and Domain.centered_grid are deprecated. Use CenteredGrid.sample() instead.",
         DeprecationWarning)
     with struct.unsafe():
         from phi.physics.field import CenteredGrid
         return CenteredGrid(tensor_shape(batch_size, self.resolution,
                                          components),
                             age=age,
                             box=self.box,
                             extrapolation=extrapolation,
                             name=name,
                             batch_size=batch_size,
                             flags=())
예제 #12
0
파일: session.py 프로젝트: xyuan/PhiFlow
    def run(self,
            fetches,
            feed_dict=None,
            summary_key=None,
            time=None,
            merged_summary=None):
        if isinstance(fetches, np.ndarray):
            return fetches
        if fetches is None:
            return None

        tensor_feed_dict = None
        if feed_dict is not None:
            tensor_feed_dict = {}
            for (key, value) in feed_dict.items():
                pairs = struct.zip([key, value],
                                   item_condition=struct.ALL_ITEMS,
                                   zip_parents_if_incompatible=True)

                def add_to_dict(key_tensor, value_tensor):
                    if isplaceholder(key_tensor):
                        tensor_feed_dict[key_tensor] = value_tensor
                    return None

                with struct.unsafe():
                    struct.map(add_to_dict,
                               pairs,
                               item_condition=struct.ALL_ITEMS)

        tensor_fetches = struct.flatten(fetches,
                                        item_condition=struct.ALL_ITEMS)
        if isinstance(fetches, (tuple, list)):
            is_fetch = lambda x: istensor(x) or x in fetches
        else:
            is_fetch = lambda x: istensor(x) or x is fetches
        tensor_fetches = tuple(filter(is_fetch, tensor_fetches))

        # Handle tracing
        trace = _trace_stack.get_default(raise_error=False)
        if trace:
            options = trace.timeliner.options
            run_metadata = trace.timeliner.run_metadata
        else:
            options = None
            run_metadata = None

        # Summary
        if summary_key is not None and merged_summary is not None:
            tensor_fetches = (merged_summary, ) + tensor_fetches

        result_fetches = self._session.run(tensor_fetches, tensor_feed_dict,
                                           options, run_metadata)
        result_dict = {
            fetch: result
            for fetch, result in zip(tensor_fetches, result_fetches)
        }

        if summary_key:
            summary_buffer = result_fetches[0]
            result_fetches = result_fetches[1:]
            if summary_key in self.summary_writers:
                summary_writer = self.summary_writers[summary_key]
            else:
                summary_writer = tf.summary.FileWriter(
                    os.path.join(self.summary_directory, str(summary_key)),
                    self.graph)
                self.summary_writers[summary_key] = summary_writer
            summary_writer.add_summary(summary_buffer, time)
            summary_writer.flush()

        if trace:
            trace.timeliner.add_run()

        def replace_tensor_with_value(fetch):
            try:
                if fetch in result_dict:
                    return result_dict[fetch]
                else:
                    return fetch
            except TypeError:  # not hashable
                return fetch

        result = struct.map(replace_tensor_with_value,
                            fetches,
                            item_condition=struct.ALL_ITEMS)
        return result
예제 #13
0
 def _get_batch(self, indices):
     data_list = self._cache.get(indices, self._load, add_to_cache=True)
     data = list_swap_axes(data_list)
     data_map = {self.streams[i]: data[i] for i in range(len(self._streams))}
     with struct.unsafe():
         return struct.map(lambda stream: data_map[stream], self._fields)
예제 #14
0
 def centered_shape(self, components=1, batch_size=1, name=None, extrapolation=None, age=0.0):
     with struct.unsafe():
         return CenteredGrid(tensor_shape(batch_size, self.resolution, components), age=age, box=self.box, extrapolation=extrapolation, name=name, batch_size=batch_size)
예제 #15
0
 def test_zip(self):
     with struct.unsafe():
         a = CenteredGrid('a')
         b = CenteredGrid('b')
         stacked = struct.map(lambda *x: x, struct.zip([a, b]))
         numpy.testing.assert_equal(stacked.data, ('a', 'b'))