Beispiel #1
0
    def test_nest_pack_as(self):
        self.assertEqual(self.n2, nest.pack_as(self.n2, nest.flatten(self.n2)))

        with self.assertRaisesRegex(ValueError, "didn't exhaust sequence"):
            nest.pack_as(self.n2, nest.flatten(self.n2) + [None])
        with self.assertRaisesRegex(ValueError, "Too few elements"):
            nest.pack_as(self.n2, nest.flatten(self.n2)[1:])
    def test_nest_flatten(self):
        t = torch.tensor(1)
        t2 = torch.tensor(2)
        n = (t, t2)
        d = {"hey": t}

        self.assertEqual(list(nest.flatten((t, t2))), [t, t2])
        self.assertEqual(list(nest.flatten(d)), [t])
        self.assertEqual(list(nest.flatten((d, t))), [t, t])
        self.assertEqual(list(nest.flatten((d, n, t))), [t, t, t2, t])

        self.assertEqual(list(nest.flatten(((t, t2), (t, t2)))),
                         [t, t2, t, t2])
        self.assertEqual(list(nest.flatten(self.n1)), ["Test", "More", 32, 4])
        self.assertEqual(list(nest.flatten(self.n2)),
                         ["Test", "More", 32, None, 43, 4])

        d2 = {"hey": t2, "there": d, "more": t2}
        # Nest uses "map" to store dicts. Therefore keys are sorted and for c++
        # dict looks like this:
        # {"hey": t2, "more": t2, "there": d}.
        self.assertEqual(list(nest.flatten(d2)), [t2, t2, t])

        self.assertEqual(list(nest.flatten(None)), [None])
        self.assertEqual(list(nest.flatten(self.n1)), ["Test", "More", 32, 4])
Beispiel #3
0
    def append(self, actor_ids, timesteps):
        assert len(actor_ids) == len(actor_ids.unique(
        )), f"Duplicate actor ids: {list(sorted(actor_ids))}"
        for s in nest.flatten(timesteps):
            assert s.shape[0] == actor_ids.shape[
                0], "Batch dimension don't match"

        curr_indices = self._index[actor_ids]

        for s, v in zip(nest.flatten(self._state), nest.flatten(timesteps)):
            s[actor_ids, curr_indices] = v

        self._index[actor_ids] += 1

        return self._complete_unrolls(actor_ids)
Beispiel #4
0
    def load_on_gpu():
        # TODO: Use CUDA streams?
        entries = 0
        next_batch = []

        while entries < flags.batch_size:
            # TODO: This isn't guaranteed to be exact if inference_batch_size
            # does not divide batch_size evenly.
            ids, *data = unroll_queue.get()
            next_batch.append(data)
            entries += ids.numel()

        # Batch.
        batch, initial_agent_state = nest.map_many(lambda d: torch.cat(d),
                                                   *next_batch)

        # Make time major (excluding agent states).
        # After this step, tensors in `batch` are of shape (T + 1, N, ...) with
        # `T` the unroll length and `N` the number of actors in the batch.
        for t in nest.flatten(batch):
            t.transpose_(0, 1)

        if not flags.learner_device.startswith("cuda"):
            return nest.map(lambda t: t.contiguous(),
                            (batch, initial_agent_state))
        return nest.map(
            lambda t: t.to(flags.learner_device,
                           memory_format=torch.contiguous_format),
            (batch, initial_agent_state),
        )
Beispiel #5
0
    def reset(self, actor_ids):
        j = self._num_overlapping_steps
        self._index.scatter_(0, actor_ids, j)

        for s in nest.flatten(self._state):
            # .zero_() doesn't work with tensor indexing?
            s[actor_ids, :j] = 0
Beispiel #6
0
def gradient_checkpointing(state,
                           body_fn,
                           total_iterations,
                           block_size=16,
                           checkpoint_last_iter=True):
    """
    checkpoint_last_iter: Indicates rather we checkpoint the final state (useful if more operations are done after)
    """
    if total_iterations == 0:
        return state
    if block_size == 0:
        # Skip gradient_checkpointing
        for _ in range(total_iterations):
            state = body_fn(state)
        return state
    structure = nest.map_structure(lambda x: None, state)
    state = nest.flatten(state)
    current_iteration = 0
    if total_iterations > block_size:
        for _ in range(int(total_iterations // block_size - 1)):
            state = GradientCheckpointBlock.apply(structure, block_size,
                                                  body_fn, *state)
            current_iteration += block_size

    if checkpoint_last_iter:
        state = GradientCheckpointBlock.apply(
            structure, total_iterations - current_iteration, body_fn, *state)
        current_iteration += total_iterations - current_iteration
        state = nest.pack_sequence_as(structure, state)
    else:
        state = nest.pack_sequence_as(structure, state)
        for _ in range(current_iteration, total_iterations):
            state = body_fn(state)
    return state
Beispiel #7
0
    def add_loss_op(self, logits, labels):
        '''

        :param logits: shape(b_sz, c_num) type(float)
        :param labels: shape(b_sz,) type(int)
        :return:
        '''

        self.prediction = tf.argmax(logits, axis=-1, output_type=labels.dtype)

        self.accuracy = tf.reduce_mean(
            tf.cast(tf.equal(self.prediction, labels), tf.float32))

        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                              labels=labels)
        ce_loss = tf.reduce_mean(entry_stop_gradients(loss, self.stop))
        # ce_loss = tf.reduce_mean(loss)

        exclude_vars = nest.flatten(
            [[v for v in tf.trainable_variables(o.name)]
             for o in self.EX_REG_SCOPE])
        exclude_vars_2 = [
            v for v in tf.trainable_variables() if '/bias:' in v.name
        ]
        exclude_vars = exclude_vars + exclude_vars_2

        reg_var_list = [
            v for v in tf.trainable_variables() if v not in exclude_vars
        ]
        reg_loss = tf.add_n([tf.nn.l2_loss(v) for v in reg_var_list])
        self.param_cnt = np.sum(
            [np.prod(v.get_shape().as_list()) for v in reg_var_list])

        print('===' * 20)
        print('total reg parameter count: %.3f M' %
              (self.param_cnt / 1000000.))
        print('excluded variables from regularization')
        print([v.name for v in exclude_vars])
        print('===' * 20)

        print('regularized variables')
        print([
            '%s:%.3fM' % (v.name, np.prod(v.get_shape().as_list()) / 1000000.)
            for v in reg_var_list
        ])
        print('===' * 20)
        '''shape(b_sz,)'''
        self.ce_loss = ce_loss
        self.w_loss = tf.reduce_mean(tf.multiply(loss, self.ph_sample_weights))
        reg = self.config.reg

        return self.ce_loss + reg * reg_loss
Beispiel #8
0
    def _complete_unrolls(self, actor_ids):
        """Obtain unrolls that have reached the desired length"""
        actor_indices = self._index[actor_ids]

        actor_ids = actor_ids[actor_indices == self._full_length]
        unrolls = nest.map(lambda s: s[actor_ids], self._state)

        # Reset state of completed actors to start from the end of the previous
        # ones (NB: since `unrolls` is a copy it is ok to do it in place).
        j = self._num_overlapping_steps + 1
        for s in nest.flatten(self._state):
            s[actor_ids, :j] = s[actor_ids, -j:]

        self._index.scatter_(0, actor_ids, 1 + self._num_overlapping_steps)

        return actor_ids, unrolls
Beispiel #9
0
 def forward(ctx, structure, block_size, body_fn, *state):
     ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*state)
     with torch.enable_grad():
         ctx.devices = [s.device for s in state]
         cpu_state = nest.map_structure(
             lambda x: x.to('cpu', non_blocking=True), state)
     ctx.save_for_backward(*cpu_state)
     ctx.structure = structure
     ctx.block_size = block_size
     ctx.body_fn = body_fn
     state = nest.pack_sequence_as(ctx.structure, state)
     ctx.fwd_cpu_state = torch.get_rng_state()
     with torch.no_grad():
         for _ in range(block_size):
             state = body_fn(state)
     state = nest.flatten(state)
     return tuple(state)
Beispiel #10
0
    def test_nest_map(self):
        t1 = torch.tensor(0)
        t2 = torch.tensor(1)
        d = {"hey": t2}

        n = nest.map(lambda t: t + 42, (t1, t2))

        self.assertSequenceEqual(n, [t1 + 42, t2 + 42])
        self.assertSequenceEqual(n, nest.flatten(n))

        n1 = (d, n, t1)
        n2 = nest.map(lambda t: t * 2, n1)

        self.assertEqual(n2[0], {"hey": torch.tensor(2)})
        self.assertEqual(n2[1], (torch.tensor(84), torch.tensor(86)))
        self.assertEqual(n2[2], torch.tensor(0))

        t = torch.tensor(42)

        # Doesn't work with pybind11/functional.h, but does with py::function.
        self.assertEqual(nest.map(t.add, t2), torch.tensor(43))
Beispiel #11
0
    def backward(ctx, *grad_output):
        with torch.enable_grad():
            detached_inputs = [
                detach_variable(v.to(device, non_blocking=True))
                for v, device in zip(ctx.saved_tensors, ctx.devices)
            ]
            state = nest.pack_sequence_as(ctx.structure, detached_inputs)
            next_state = state
            rng_devices = ctx.fwd_gpu_devices
            with torch.random.fork_rng(devices=rng_devices, enabled=True):
                torch.set_rng_state(ctx.fwd_cpu_state)
                set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
                for _ in range(ctx.block_size):
                    next_state = ctx.body_fn(next_state)
        next_state = nest.flatten(next_state)

        next_state, grad_output = zip(
            *
            [sg for sg in zip(next_state, grad_output) if sg[0].requires_grad])
        torch.autograd.backward(next_state, grad_output)

        return (None, None, None) + tuple(
            inp.grad if isinstance(inp, torch.Tensor) else None
            for inp in detached_inputs)
Beispiel #12
0
 def clear(self, ids):
     for s in nest.flatten(self._state):
         # .zero_() doesn't work with tensor indexing?
         s[ids] = 0
Beispiel #13
0
 def add(self, ids, values):
     for s, v in zip(nest.flatten(self._state), nest.flatten(values)):
         s[ids] += v
Beispiel #14
0
        def body(args):
            it, params, optimizer_state = args
            if training_schedule_backwards:
                x, y_one_hot = automl.generator(iterations - it - 1,
                                                *generator_args)
            else:
                x, y_one_hot = automl.generator(it, *generator_args)
            with torch.enable_grad():
                if use_intermediate_losses > 0 and (
                        it >= use_intermediate_losses
                        and it % use_intermediate_losses == 0):
                    params = SurrogateLoss.apply(intermediate_loss, it,
                                                 *nest.flatten(params))
                    params = nest.pack_sequence_as(initial_params, params[1:])
                params, buffers = params
                for p in params:
                    if not p.requires_grad:
                        p.requires_grad = True

                learner.model.set_parameters(
                    list(zip(names, split_params(params))))
                if buffer_names:
                    learner.model.set_buffers(
                        list(zip(buffer_names, split_buffer(buffers))))
                learner.model.train()
                output = learner.model(x)
                if isinstance(output, tuple):
                    output1, output2 = output
                    loss = -(output1 * y_one_hot).sum() * (1 /
                                                           output1.shape[0])
                    loss = loss - (output2 *
                                   y_one_hot).sum() * (1 / output2.shape[0])
                    pred = output1
                else:
                    loss = -(output * y_one_hot).sum() * (1 / output.shape[0])
                    pred = output
                if it.item() not in losses:
                    losses[it.item()] = loss.detach().cpu().item()
                    accuracies[it.item()] = (
                        pred.max(-1).indices == y_one_hot.max(-1).indices).to(
                            torch.float).mean().item()

                grads = grad(loss,
                             params,
                             create_graph=x.requires_grad,
                             allow_unused=True)
            # assert len(grads) == len(names)
            new_params, optimizer_state = learner.optimizer(
                it, params, grads, optimizer_state)
            buffers = list(learner.model.buffers())
            buffers = [torch.cat([b.flatten()
                                  for b in buffers])] if buffers else buffers
            if callback is not None:
                learner.model.set_parameters(
                    list(zip(names, split_params(params))))
                if buffer_names:
                    learner.model.set_buffers(
                        list(zip(buffer_names, split_buffer(buffers))))
                callback(learner)

            return (it + 1, (
                new_params,
                buffers,
            ), optimizer_state)
Beispiel #15
0
 def test_nest_flatten(self):
     self.assertEqual(nest.flatten(None), [None])
     self.assertEqual(nest.flatten(self.n1), ["Test", "More", 32, 4])
Beispiel #16
0
    def test_nest_flatten_no_asserts(self):
        t = torch.tensor(1)
        t2 = torch.tensor(2)
        n = (t, t2)
        d = {"hey": t}

        nest.flatten((t, t2))
        nest.flatten(d)
        nest.flatten((d, t))
        nest.flatten((d, n, t))

        nest.flatten(((t, t2), (t, t2)))

        nest.flatten(self.n1)
        nest.flatten(self.n2)

        d2 = {"hey": t2, "there": d, "more": t2}
        nest.flatten(d2)  # Careful here, order not necessarily as above.