Ejemplo n.º 1
0
    def step(self, batch, initial_states):
        bm_gamestate, restarting = batch

        # reset initial_states where necessary
        restarting = tf.expand_dims(restarting, -1)
        initial_states = tf.nest.map_structure(
            lambda x, y: tf.where(restarting, x, y),
            self.policy.initial_state(restarting.shape[0]), initial_states)

        # switch axes to time-major
        tm_gamestate = tf.nest.map_structure(to_time_major, bm_gamestate)
        gamestate, action_repeat = tm_gamestate

        p1_controller = get_p1_controller(gamestate, action_repeat)
        next_action = tf.nest.map_structure(lambda t: t[1:], p1_controller)

        loss, final_states, distances = self.policy.loss(
            tm_gamestate, initial_states)
        mean_loss = tf.reduce_mean(loss)
        stats = dict(loss=mean_loss, distances=distances)

        controller_samples, _ = utils.dynamic_rnn(self.sample, tm_gamestate,
                                                  initial_states)

        return stats, DebugResult(loss, next_action, controller_samples,
                                  distances, tm_gamestate, final_states)
Ejemplo n.º 2
0
    def forward(self, inputs, length_minor, length_major):
        inputs_to_minor = inputs.view((-1, ) + inputs.shape[2:])

        _, minor_codes = dynamic_rnn(self.minor_encoder,
                                     inputs_to_minor,
                                     length_minor.view(-1),
                                     batch_first=True)

        inputs_to_major = minor_codes.view(inputs.shape[:2] +
                                           minor_codes.shape[1:])

        if self.major_encoder is None:
            return _, inputs_to_major.squeeze(1)

        _, major_codes = dynamic_rnn(self.major_encoder,
                                     inputs_to_major,
                                     length_major,
                                     batch_first=True)

        return _, major_codes
Ejemplo n.º 3
0
    def code_candidates(self, candidates):
        inputs = candidates['text_ids']
        length = candidates['length']

        shape = inputs.shape[:2]
        inputs = self.embedding(inputs)

        _, codes = dynamic_rnn(
            self.candidates_encoder,
            inputs.view(-1, *inputs.shape[2:]), 
            length.view(-1, *length.shape[2:])) 
        return torch.matmul(
            codes.view(*shape, *codes.shape[1:]), self.transform_mat)
Ejemplo n.º 4
0
    def test_dynamic_rnn(self):
        def nested_core(input_, state):
            output = tf.nest.map_structure(lambda t: t + state, input_)
            return output, state

        unroll_length = 8
        batch_size = 4
        initial_state = tf.constant(1.0)

        inputs = dict(
            a=tf.random.uniform([unroll_length, batch_size]),
            b=tf.random.uniform([unroll_length, batch_size]),
        )

        static_outputs, _ = static_rnn(nested_core, inputs, initial_state)
        dynamic_outputs, _ = utils.dynamic_rnn(nested_core, inputs,
                                               initial_state)

        tf.nest.map_structure(assert_tensors_close, static_outputs,
                              dynamic_outputs)
Ejemplo n.º 5
0
 def unroll(self, inputs, initial_state):
     return utils.dynamic_rnn(self.step, inputs, initial_state)
Ejemplo n.º 6
0
 def unroll(self, inputs, prev_state):
     flat_inputs = process_inputs(inputs)
     return utils.dynamic_rnn(self._gru, flat_inputs, prev_state)
Ejemplo n.º 7
0
 def unroll(self, inputs, prev_state):
     flat_inputs = process_inputs(inputs)
     flat_inputs = self.encoder(flat_inputs)
     return utils.dynamic_rnn(self.deep_rnn, flat_inputs, prev_state)
Ejemplo n.º 8
0
 def unroll(self, inputs, prev_state):
     flat_inputs = process_inputs(inputs)
     flat_inputs = self._resnet(flat_inputs)
     return utils.dynamic_rnn(self._lstm, flat_inputs, prev_state)