示例#1
0
def placeholder(shape,
                dtype=np.float32,
                basename=None,
                item_condition=struct.VARIABLES):
    if struct.isstruct(dtype):

        def placeholder_map(trace):
            shape, dtype = trace.value
            return tf.placeholder(dtype, shape, _tf_name(trace, basename))

        zipped = struct.zip([shape, dtype],
                            leaf_condition=is_static_shape,
                            item_condition=item_condition)
        return struct.map(placeholder_map,
                          zipped,
                          leaf_condition=is_static_shape,
                          trace=True,
                          item_condition=item_condition)
    else:
        f = lambda trace: tf.placeholder(dtype, trace.value,
                                         _tf_name(trace, basename))
        return struct.map(f,
                          shape,
                          leaf_condition=is_static_shape,
                          trace=True,
                          item_condition=item_condition)
示例#2
0
def argument_assembler(args, kwargs):
    structs, keymap = build_keymap(args, kwargs)
    assert len(structs) > 0
    if len(structs) == 1:
        obj = structs[0]
    else:
        obj = struct.zip(structs)

    def assemble_arguments(items):
        args = []
        kwargs = {}
        i = 0
        while i in keymap:
            is_struct, value = keymap[i]
            if is_struct:
                value = items[value]
            args.append(value)
            i += 1
        for key, (is_struct, value) in keymap.items():
            if not isinstance(key, int):
                if is_struct:
                    value = items[value]
                kwargs[key] = value
        return args, kwargs

    return obj, assemble_arguments
示例#3
0
 def _get_batch(self, indices):
     data_list = self._cache.get(indices, self._load, add_to_cache=True)
     data = list_swap_axes(data_list)
     data_map = {
         self.streams[i]: data[i]
         for i in range(len(self._streams))
     }
     return struct.map(lambda x, is_stream: data_map[x] if is_stream else x,
                       struct.zip([self._fields, self.stream_mask]),
                       content_type=struct.INVALID)
示例#4
0
def placeholder(shape, dtype=np.float32, basename='Placeholder'):
    if struct.isstruct(dtype):
        def placeholder_map(trace):
            shape, dtype = trace.value
            return tf.placeholder(dtype, shape, _tf_name(trace, basename))
        zipped = struct.zip([shape, dtype], leaf_condition=is_static_shape)
        return struct.map(placeholder_map, zipped, leaf_condition=is_static_shape, trace=True)
    else:
        def f(trace): return tf.placeholder(dtype, trace.value, _tf_name(trace, basename))
        return struct.map(f, shape, leaf_condition=is_static_shape, trace=True)
示例#5
0
def placeholder(shape, dtype=None, basename='Placeholder'):
    if struct.isstruct(dtype):

        def placeholder_map(trace):
            shape, dtype = trace.value
            return tf.placeholder(dtype, shape, _tf_name(trace, basename))

        zipped = struct.zip([shape, dtype], leaf_condition=is_static_shape)
        return struct.map(placeholder_map,
                          zipped,
                          leaf_condition=is_static_shape,
                          trace=True)
    else:

        def f(trace):
            return tf.placeholder(
                TF_BACKEND.precision_dtype if dtype is None else dtype,
                trace.value, _tf_name(trace, basename))

        return struct.map(f, shape, leaf_condition=is_static_shape, trace=True)
示例#6
0
    def run(self,
            fetches,
            feed_dict=None,
            summary_key=None,
            time=None,
            merged_summary=None,
            item_condition=struct.ALL_ITEMS):
        if isinstance(fetches, np.ndarray):
            return fetches
        if fetches is None:
            return None

        tensor_feed_dict = None
        if feed_dict is not None:
            tensor_feed_dict = {}
            for (key, value) in feed_dict.items():
                pairs = struct.zip([key, value],
                                   item_condition=item_condition,
                                   zip_parents_if_incompatible=True)

                def add_to_dict(key_tensor, value_tensor):
                    if isplaceholder(key_tensor):
                        tensor_feed_dict[key_tensor] = value_tensor
                    return None

                struct.map(add_to_dict,
                           pairs,
                           item_condition=item_condition,
                           content_type=struct.INVALID)

        tensor_fetches = struct.flatten(fetches, item_condition=item_condition)
        if isinstance(fetches, (tuple, list)):

            def is_fetch(x):
                return istensor(x) or _identity_in(x, fetches)
        else:

            def is_fetch(x):
                return istensor(x) or x is fetches

        tensor_fetches = tuple(filter(is_fetch, tensor_fetches))

        # Handle tracing
        trace = _trace_stack.get_default(raise_error=False)
        if trace:
            options = trace.timeliner.options
            run_metadata = trace.timeliner.run_metadata
        else:
            options = None
            run_metadata = None

        # Summary
        if summary_key is not None and merged_summary is not None:
            tensor_fetches = (merged_summary, ) + tensor_fetches

        result_fetches = self._session.run(tensor_fetches, tensor_feed_dict,
                                           options, run_metadata)
        result_dict = {
            fetch: result
            for fetch, result in zip(tensor_fetches, result_fetches)
        }

        if summary_key:
            summary_buffer = result_fetches[0]
            result_fetches = result_fetches[1:]
            if summary_key in self.summary_writers:
                summary_writer = self.summary_writers[summary_key]
            else:
                summary_writer = tf.summary.FileWriter(
                    os.path.join(self.summary_directory, str(summary_key)),
                    self.graph)
                self.summary_writers[summary_key] = summary_writer
            summary_writer.add_summary(summary_buffer, time)
            summary_writer.flush()

        if trace:
            trace.timeliner.add_run()

        def replace_tensor_with_value(fetch):
            try:
                if fetch in result_dict:
                    return result_dict[fetch]
                else:
                    return fetch
            except TypeError:  # not hashable
                return fetch

        result = struct.map(replace_tensor_with_value,
                            fetches,
                            item_condition=item_condition)
        return result
示例#7
0
 def test_zip(self):
     with struct.unsafe():
         a = CenteredGrid('a')
         b = CenteredGrid('b')
         stacked = struct.map(lambda *x: x, struct.zip([a, b]))
         numpy.testing.assert_equal(stacked.data, ('a', 'b'))