示例#1
0
def test_align_func():
    def my_func():
        return [0, 1, 2, 3]

    x = utils.align_func(tf.float32)(my_func)()
    assert x.shape == (4, )
    assert x.dtype == np.float32
    assert np.allclose(x, [0, 1, 2, 3])

    x = utils.align_func(np.int64)(my_func)()
    assert x.dtype == np.int64
    assert np.allclose(x, [[0, 1, 2, 3]])
示例#2
0
def test_align_func():
    def my_func():
        return [0, 1, 2, 3]

    x = utils.align_func((4,), tf.float32)(my_func)()
    assert x.shape == (4,)
    assert x.dtype == np.float32
    assert np.allclose(x, [0, 1, 2, 3])

    x = utils.align_func((2, 2), np.int64)(my_func)()
    assert x.shape == (2, 2)
    assert x.dtype == np.int64
    assert np.allclose(x, [[0, 1], [2, 3]])
示例#3
0
        def merged_func(time, inputs):  # pragma: no cover
            outputs = []
            offset = 0
            for op in ops:
                if op.output is None:
                    func = op.fn
                else:
                    func = utils.align_func(op.output.shape,
                                            self.output_dtype)(op.fn)

                func_input = inputs[offset:offset + op.x.shape[0]]
                offset += op.x.shape[0]

                mini_out = []
                for j in range(signals.minibatch_size):
                    if op.t is None:
                        func_out = func(func_input[..., j])
                    else:
                        func_out = func(time, func_input[..., j])

                    if op.output is None:
                        # just return time as a noop (since we need to
                        # return something)
                        func_out = time
                    mini_out += [func_out]
                outputs += [np.stack(mini_out, axis=-1)]

            return np.concatenate(outputs, axis=0)
示例#4
0
        def merged_func(time, inputs):  # pragma: no cover
            outputs = []
            offset = 0
            for op in ops:
                if op.output is None:
                    func = op.fn
                else:
                    func = utils.align_func(
                        op.output.shape, self.output_dtype)(op.fn)

                func_input = inputs[offset:offset + op.x.shape[0]]
                offset += op.x.shape[0]

                mini_out = []
                for j in range(signals.minibatch_size):
                    if op.t is None:
                        func_out = func(func_input[..., j])
                    else:
                        func_out = func(time, func_input[..., j])

                    if op.output is None:
                        # just return time as a noop (since we need to
                        # return something)
                        func_out = time
                    mini_out += [func_out]
                outputs += [np.stack(mini_out, axis=-1)]

            return np.concatenate(outputs, axis=0)
示例#5
0
    def build_post(self):
        """
        Executes post-build processes for operators (after the graph has
        been constructed and whenever Simulator is reset).
        """

        rng = np.random.RandomState(self.seed)

        # build input functions (we need to do this here, because in the case
        # of processes these functions need to be be rebuilt on reset)
        self.input_funcs = {}
        for n, output in self.invariant_inputs.items():
            if isinstance(output, np.ndarray):
                self.input_funcs[n] = output
            elif isinstance(output, Process):
                state = output.make_state((n.size_in,), (n.size_out,), self.dt)
                self.input_funcs[n] = [
                    output.make_step(
                        (n.size_in,),
                        (n.size_out,),
                        self.dt,
                        output.get_rng(rng),
                        state,
                    )
                    for _ in range(self.minibatch_size)
                ]
            elif n.size_out > 0:
                self.input_funcs[n] = [utils.align_func(self.dtype)(output)]
            else:
                # a node with no inputs and no outputs, but it can still
                # have side effects
                self.input_funcs[n] = [output]

        # execute build_post on all the op builders
        self.op_builder.build_post(self.signals)
示例#6
0
    def build_post(self, sess, rng):
        """
        Executes post-build processes for operators (after the graph has
        been constructed and session/variables initialized).

        Note that unlike other build functions, this is called every time
        the simulator is reset.

        Parameters
        ----------
        sess : ``tf.Session``
            The TensorFlow session for the simulator
        rng : :class:`~numpy:numpy.random.RandomState`
            Seeded random number generator
        """

        # build input functions (we need to do this here, because in the case
        # of processes these functions depend on the rng, and need to be be
        # rebuilt on reset)
        self.input_funcs = {}
        for n, output in self.invariant_inputs.items():
            if isinstance(output, np.ndarray):
                self.input_funcs[n] = output
            elif isinstance(output, Process):
                self.input_funcs[n] = [
                    output.make_step((n.size_in, ), (n.size_out, ), self.dt,
                                     output.get_rng(rng))
                    for _ in range(self.minibatch_size)
                ]
            elif n.size_out > 0:
                self.input_funcs[n] = [
                    utils.align_func((n.size_out, ), self.dtype)(output)
                ]
            else:
                # a node with no inputs and no outputs, but it can still
                # have side effects
                self.input_funcs[n] = [output]

        # call build_post on all the op builders
        for ops, built_ops in self.op_builds.items():
            built_ops.build_post(ops, self.signals, sess, rng)
示例#7
0
    def build_post(self, sess, rng):
        """
        Executes post-build processes for operators (after the graph has
        been constructed and session/variables initialized).

        Note that unlike other build functions, this is called every time
        the simulator is reset.

        Parameters
        ----------
        sess : ``tf.Session``
            The TensorFlow session for the simulator
        rng : `~numpy.random.RandomState`
            Seeded random number generator
        """

        # build input functions (we need to do this here, because in the case
        # of processes these functions depend on the rng, and need to be be
        # rebuilt on reset)
        self.input_funcs = {}
        for n, output in self.invariant_inputs.items():
            if isinstance(output, np.ndarray):
                self.input_funcs[n] = output
            elif isinstance(output, Process):
                self.input_funcs[n] = [
                    output.make_step(
                        (n.size_in,), (n.size_out,), self.dt,
                        output.get_rng(rng))
                    for _ in range(self.minibatch_size)]
            elif n.size_out > 0:
                self.input_funcs[n] = [
                    utils.align_func((n.size_out,), self.dtype)(output)]
            else:
                # a node with no inputs and no outputs, but it can still
                # have side effects
                self.input_funcs[n] = [output]

        # execute post_build on all the op builders
        self.op_builder.post_build(sess, rng)
示例#8
0
    def build_post(self, signals):
        # generate state for each op
        step_states = [
            op.process.make_state(
                op.input.shape if op.input is not None else (0, ),
                op.output.shape,
                signals.dt_val,
            ) for op in self.ops
        ]

        # build all the states into combined array with shape
        # (n_states, n_ops, *state_d)
        combined_states = [[None for _ in self.ops]
                           for _ in range(len(self.ops[0].state))]
        for i, op in enumerate(self.ops):
            # note: we iterate over op.state so that the order is always based on that
            # dict's order (which is what we used to set up self.state_data)
            for j, name in enumerate(op.state):
                combined_states[j][i] = step_states[i][name]

        # combine op states, giving shape
        # (n_states, n_ops * state_d[0], *state_d[1:])
        # (keeping track of the offset of where each op's state lies in the
        # combined array)
        offsets = [[s.shape[0] for s in state] for state in combined_states]
        offsets = np.cumsum(offsets, axis=-1)
        self.step_states = [
            np.concatenate(state, axis=0) for state in combined_states
        ]

        # cast to appropriate dtype
        for i, state in enumerate(self.state_data):
            self.step_states[i] = self.step_states[i].astype(state.dtype)

        # duplicate state for each minibatch, giving shape
        # (n_states, minibatch_size, n_ops * state_d[0], *state_d[1:])
        assert all(s.minibatched for op in self.ops for s in op.state.values())
        for i, state in enumerate(self.step_states):
            self.step_states[i] = np.tile(
                state[None,
                      ...], (signals.minibatch_size, ) + (1, ) * state.ndim)

        # build the step functions
        self.step_fs = [[None for _ in range(signals.minibatch_size)]
                        for _ in self.ops]
        for i, op in enumerate(self.ops):
            for j in range(signals.minibatch_size):
                # pass each make_step function a view into the combined state
                state = {}
                for k, name in enumerate(op.state):
                    start = 0 if i == 0 else offsets[k][i - 1]
                    stop = offsets[k][i]

                    state[name] = self.step_states[k][j, start:stop]

                    assert np.allclose(state[name], step_states[i][name])

                self.step_fs[i][j] = op.process.make_step(
                    op.input.shape if op.input is not None else (0, ),
                    op.output.shape,
                    signals.dt_val,
                    op.process.get_rng(self.config.rng),
                    state,
                )

                self.step_fs[i][j] = utils.align_func(self.output_data.dtype)(
                    self.step_fs[i][j])