예제 #1
0
  def test_autoregressive_sample_reformer2_lsh_attn_quality(self):
    gin.add_config_file_search_path(_CONFIG_DIR)
    max_len = 32  # 32 is the max length we trained the checkpoint for.
    test_lengths = [8, 16, 32]
    vocab_size = 13
    # The checkpoint is correct on ~90% sequences, set random seed to deflake.
    np.random.seed(0)
    for test_len in test_lengths:
      gin.clear_config()
      gin.parse_config_file('reformer2_copy.gin')
      gin.bind_parameter('LSHSelfAttention.predict_mem_len', 2 * max_len)
      gin.bind_parameter('LSHSelfAttention.predict_drop_len', 2 * max_len)

      pred_model = models.Reformer2(mode='predict')

      shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
      shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)

      model_path = os.path.join(_TESTDATA, 'reformer2_copy_lsh_attn.pkl.gz')
      pred_model.init_from_file(model_path, weights_only=True,
                                input_signature=(shape1l, shape11))
      initial_state = pred_model.state

      for _ in range(2):  # Set low to make the test run reasonably fast.
        # Pick a length in [1, test_len] at random.
        inp_len = np.random.randint(low=1, high=test_len + 1)
        inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, inp_len))
        inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)],
                        mode='constant', constant_values=0)
        s = decoding.autoregressive_sample(
            pred_model, inputs=inputs, eos_id=-1, max_length=inp_len,
            temperature=0.0)
        np.testing.assert_equal(s[0], inputs[0, :inp_len])
        pred_model.state = initial_state
    gin.clear_config()  # Make sure to not affect other tests.
예제 #2
0
 def test_autoregressive_sample_transformerlm(self):
     model = models.TransformerLM(10,
                                  d_model=32,
                                  d_ff=64,
                                  n_layers=1,
                                  n_heads=2,
                                  mode='predict')
     model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32))
     s1 = trainer_lib.autoregressive_sample(model,
                                            batch_size=1,
                                            eos_id=-1,
                                            max_length=10)
     self.assertEqual(s1.shape[0], 1)
     self.assertEqual(s1.shape[1], 10)
     batch_per_device = 2 // fastmath.device_count()
     model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=jnp.int32))
     s2 = trainer_lib.autoregressive_sample(model,
                                            batch_size=2,
                                            max_length=10)
     self.assertEqual(s2.shape[0], 2)
     self.assertLess(s2.shape[1], 11)
     model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32))
     prefix = jnp.array([[1, 2, 3]])
     s3 = trainer_lib.autoregressive_sample(model,
                                            eos_id=-1,
                                            max_length=10,
                                            batch_size=1,
                                            prefix=prefix)
     self.assertEqual(s3.shape[0], 1)
     self.assertEqual(int(s3[0][0]), 1)
     self.assertEqual(int(s3[0][1]), 2)
     self.assertEqual(int(s3[0][2]), 3)
예제 #3
0
    def test_can_predict_with_trained_model(self):
        model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2)))
        train_tasks, eval_tasks = [], []
        for output_dim in [1, 2]:
            # The head we select from the model: 0 for output_dim 1 and 1 for 2.
            head_index = output_dim - 1
            train_tasks.append(
                training.TrainTask(
                    _very_simple_data(output_dim),
                    tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss()),
                    optimizers.SGD(.01)))
            eval_tasks.append(
                training.EvalTask(
                    _very_simple_data(
                        output_dim),  # deliberately re-use training data
                    [tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss())]))
        tmp_dir = self.create_tempdir().full_path
        training_session = training.Loop(
            model,
            tasks=train_tasks,
            eval_tasks=eval_tasks,
            checkpoint_at=lambda step_n: step_n == 1,
            output_dir=tmp_dir,
            which_task=lambda step_n: step_n % 2,
        )
        training_session.run(n_steps=2)

        trained_model = training_session.eval_model
        inp = next(_very_simple_data())[0]
        out = trained_model(inp)
        self.assertEqual(
            shapes.signature(out),
            (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))),
        )
예제 #4
0
 def test_autoregressive_sample_transformerlm_tfnp(self):
     with fastmath.use_backend(fastmath.Backend.TFNP):
         model = models.TransformerLM(10,
                                      d_model=32,
                                      d_ff=64,
                                      n_layers=1,
                                      n_heads=2,
                                      mode='predict')
         model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
         s1 = decoding.autoregressive_sample(model,
                                             batch_size=1,
                                             eos_id=-1,
                                             max_length=10)
         self.assertEqual(s1.shape[0], 1)
         self.assertEqual(s1.shape[1], 10)
         batch_per_device = 2 // fastmath.device_count()
         model.init(shapes.ShapeDtype((batch_per_device, 1),
                                      dtype=np.int32))
         s2 = decoding.autoregressive_sample(model,
                                             batch_size=2,
                                             max_length=10)
         self.assertEqual(s2.shape[0], 2)
         self.assertLess(s2.shape[1], 11)
         model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
         prefix = np.array([[1, 2, 3]])
         s3 = decoding.autoregressive_sample(model,
                                             prefix,
                                             eos_id=-1,
                                             max_length=10,
                                             batch_size=1)
         self.assertEqual(s3.shape[0], 1)
         self.assertEqual(s3.shape[1], 10)
예제 #5
0
    def test_can_predict_with_trained_model(self):
        model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2)))
        tasks = tuple(
            training.TrainTask(  # pylint: disable=g-complex-comprehension
                _very_simple_data(output_dim),
                tl.L2Loss(),
                optimizers.SGD(.01),
            ) for output_dim in (1, 2))
        eval_tasks = tuple([
            training.EvalTask(  # pylint: disable=g-complex-comprehension
                # deliberately re-using training data
                _very_simple_data(output_dim),
                [tl.L2Loss()],
            )
        ] for output_dim in (1, 2))
        tmp_dir = self.create_tempdir().full_path
        training_session = training.Loop(
            model,
            tasks=tasks,
            eval_tasks=eval_tasks,
            checkpoint_at=lambda step_n: step_n == 1,
            output_dir=tmp_dir,
            which_task=lambda step_n: step_n % 2,
        )
        training_session.run(n_steps=2)

        trained_model = training_session.eval_model
        inp = next(_very_simple_data())[0]
        out = trained_model(inp)
        self.assertEqual(
            shapes.signature(out),
            (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))),
        )
예제 #6
0
    def test_extract_inner_model(self):
        vocab_size = 3

        inner_model = transformer.TransformerLM(vocab_size=vocab_size,
                                                d_model=2,
                                                d_ff=2,
                                                n_layers=0)
        obs_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            inner_model,
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        obs_sig = shapes.ShapeDtype((1, 2))
        act_sig = shapes.ShapeDtype((1, 1))
        (weights,
         state) = serialized_model.init(input_signature=(obs_sig, act_sig,
                                                         obs_sig, obs_sig), )
        (inner_weights,
         inner_state) = map(serialization_utils.extract_inner_model,
                            (weights, state))
        inner_model(np.array([[0]]), weights=inner_weights, state=inner_state)
예제 #7
0
    def test_serialized_model_extracts_seq_model_weights_and_state(self):
        vocab_size = 3

        seq_model_fn = functools.partial(
            transformer.TransformerLM,
            vocab_size=vocab_size,
            d_model=2,
            d_ff=2,
            n_layers=0,
        )
        seq_model = seq_model_fn(mode='eval')
        obs_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            seq_model_fn,
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        obs_sig = shapes.ShapeDtype((1, 2))
        act_sig = shapes.ShapeDtype((1, 1))
        serialized_model.init(input_signature=(obs_sig, act_sig, obs_sig,
                                               obs_sig))
        seq_model.weights = serialized_model.seq_model_weights
        seq_model.state = serialized_model.seq_model_state
        # Run the model to check if the weights and state have correct structure.
        seq_model(jnp.array([[0]]))
예제 #8
0
 def _make_schedule(
         self,
         history,
         control_configs,
         observation_metrics=(('eval', 'metrics/accuracy'), ),
         action_multipliers=(1.0, ),
         vocab_size=None,
 ):
     policy_and_value_model = functools.partial(
         transformer.TransformerDecoder,
         d_model=2,
         d_ff=2,
         n_layers=0,
         vocab_size=vocab_size,
     )
     observation_space = gym.spaces.Box(
         shape=(len(observation_metrics), ),
         low=0.0,
         high=1.0,
     )
     action_space = gym.spaces.MultiDiscrete(
         nvec=(len(action_multipliers), ) * len(control_configs))
     (net, _) = policy_based_utils.policy_and_value_net(
         bottom_layers_fn=policy_and_value_model,
         observation_space=observation_space,
         action_space=action_space,
         vocab_size=vocab_size,
         two_towers=False,
     )
     input_signature = (
         shapes.ShapeDtype((1, 2) + observation_space.shape,
                           observation_space.dtype),
         shapes.ShapeDtype((1, 1) + action_space.shape, action_space.dtype),
     )
     (params, state) = net.init(input_signature)
     policy_dir = self.get_temp_dir()
     # Optimizer slots and parameters should not be used for anything.
     slots = None
     opt_params = None
     opt_state = (params, slots, opt_params)
     policy_based_utils.save_opt_state(policy_dir,
                                       opt_state,
                                       state,
                                       epoch=0,
                                       total_opt_step=0,
                                       history=history)
     return lr_schedules.PolicySchedule(
         history,
         observation_metrics=observation_metrics,
         include_controls_in_observation=False,
         action_multipliers=action_multipliers,
         control_configs=control_configs,
         policy_and_value_model=policy_and_value_model,
         policy_and_value_two_towers=False,
         policy_and_value_vocab_size=vocab_size,
         policy_dir=policy_dir,
     )
예제 #9
0
    def _test_sparse_fast_inference(self, length):
        with fastmath.use_backend(fastmath.Backend.JAX):
            vocab_size = 16
            d_model = 4

            encoder_decoder_attention_type = functools.partial(
                tl.MultiplicativeConvCausalAttention,
                sparsity=2,
                length_kernel_size=1,
            )

            model_fn = functools.partial(
                ct.ConfigurableTransformer,
                input_vocab_size=vocab_size,
                d_model=d_model,
                d_ff=8,
                n_encoder_layers=2,
                n_decoder_layers=2,
                n_heads=2,
                loss_sparsity=2,
                ff_sparsity=2,
                encoder_decoder_attention_type=encoder_decoder_attention_type,

                # SRU currently doesn't work for second token and further.
                # ff_use_sru=(1, 4),
            )

            model_slow = model_fn(mode='eval')
            model_fast = model_fn(mode='predict')
            rng = fastmath.random.get_prng(0)
            batch_size = 2
            input_signature = (shapes.ShapeDtype(
                (batch_size, length),
                np.int32), shapes.ShapeDtype((batch_size, 1), np.int32))
            model_slow.init(input_signature)
            model_fast.init(input_signature)
            model_slow.save_to_file('/tmp/unique_weights')
            model_fast.init_from_file('/tmp/unique_weights',
                                      weights_only=True,
                                      input_signature=input_signature)

            inp = np.random.randint(vocab_size, size=(batch_size, length))
            buf = np.zeros((batch_size, length), dtype=np.int32)
            next_sym = np.zeros((batch_size, 1), dtype=np.int32)

            for index in range(length):
                logits_slow = model_slow((inp, buf), rng=rng)[0]
                logits_fast = model_fast((inp, next_sym), rng=rng)[0]
                np.testing.assert_array_almost_equal(
                    logits_slow[:, index, :],
                    logits_fast[:, 0, :],
                    decimal=5,
                    err_msg='Error on token {} out of {}.'.format(
                        index + 1, length))
                next_sym = np.random.randint(vocab_size, size=(batch_size, 1))
                buf[:, index] = next_sym[:, 0]
예제 #10
0
    def test_input_signatures(self):
        layer = tl.Serial(DivideBy(2.0), DivideBy(5.0))
        self.assertIsNone(layer.input_signature)

        layer._set_input_signature_recursive(shapes.ShapeDtype((3, 2)))
        self.assertEqual(layer.input_signature, shapes.ShapeDtype((3, 2)))
        self.assertLen(layer.sublayers, 2)
        for sublayer in layer.sublayers:
            self.assertEqual(sublayer.input_signature, shapes.ShapeDtype(
                (3, 2)))
예제 #11
0
    def test_autoregressive_sample_reformer2_timing(self):
        max_len = 16

        def _self_attention_fn():
            return functools.partial(layers.SelfAttention,
                                     predict_drop_len=2 * max_len,
                                     predict_mem_len=2 * max_len)

        def _causal_attention_fn():
            return functools.partial(layers.ModularCausalAttention,
                                     n_modules=64,
                                     max_inference_length=2 * max_len)

        pred_model = models.Reformer2(
            mode='predict',
            d_model=8 * 1024,
            d_ff=64 * 1024,
            dropout=0.05,
            max_len=max_len,
            n_heads=64,
            n_encoder_layers=2,
            n_decoder_layers=2,
            encoder_attention_type=_self_attention_fn(),
            encoder_decoder_attention_type=_causal_attention_fn(),
            input_vocab_size=4,
            ff_sparsity=256,
            axial_pos_shape=None,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
        shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
        pred_model.init(input_signature=(shape1l, shape11))
        inputs = np.arange(16, dtype=np.int32).reshape(1, 16)

        # This is decoding.autoregressive_sample but simplified and with timing.
        result, counter, start_time, total_time = [], 0, time.time(), 0.0
        for sample in decoding.autoregressive_sample_stream(
                pred_model, inputs, temperature=0.0):  # accelerate=False):
            elapsed_time = time.time() - start_time
            start_time = time.time()
            if counter > 3:
                total_time += elapsed_time
            result.append(sample[:, None])
            counter += 1
            if counter >= 14:
                break

        # We print 5* time for 10 tokens, @2 layers this is ~1 token @ 100 layers.
        print('\n\nTime for 5x10 tokens (~1tok @100): %.4fs\n\n\n' %
              (5 * total_time))
        self.assertLess(total_time, 20.0)  # If it's > 20s, it's some bug.
        # Check resulting shapes.
        s = np.concatenate(result, axis=1)
        self.assertEqual(s.shape[0], 1)
        self.assertEqual(s.shape[1], 14)
예제 #12
0
 def test_new_rngs_deterministic(self):
   inputs1 = shapes.ShapeDtype((2, 3, 5))
   inputs2 = (shapes.ShapeDtype((2, 3, 5)), shapes.ShapeDtype((2, 3, 5)))
   layer1 = base.Layer()
   layer2 = base.Layer(n_in=2, n_out=2)
   _, _ = layer1.init(inputs1)
   _, _ = layer2.init(inputs2)
   rng1, rng2 = layer1.new_rngs(2)
   rng3, rng4 = layer2.new_rngs(2)
   self.assertEqual(rng1.tolist(), rng3.tolist())
   self.assertEqual(rng2.tolist(), rng4.tolist())
예제 #13
0
    def test_output_signature_no_weights(self):
        shape_2_3_5 = shapes.ShapeDtype((2, 3, 5))
        input_signature = (shape_2_3_5, shape_2_3_5)
        layer = tl.Fn('2in1out', lambda x, y: x + y)
        output_signature = layer.output_signature(input_signature)
        self.assertEqual(output_signature, shape_2_3_5)

        shape_5_7 = shapes.ShapeDtype((5, 7))
        input_signature = shape_5_7
        layer = tl.Fn('1in3out', lambda x: (x, 2 * x, 3 * x), n_out=3)
        output_signature = layer.output_signature(input_signature)
        self.assertEqual(output_signature, (shape_5_7, shape_5_7, shape_5_7))
예제 #14
0
    def test_run_reversible_large_weights(self):
        """Runs the reversible trainer with a lot of weights to test memory use."""
        # This test requires > 18GB RAM, only run on TPUs. It does pass on GPU
        # and CPU when you run it locally, but it's too big for unit-testing.
        ram_limited = True  # Set to False to run this test locally.
        if fastmath.global_device_count() == 1 and ram_limited:
            return

        # Create inputs and rngs.
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        first_layer = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup())
        rng_init = fastmath.random.get_prng(12)
        rng_step = fastmath.random.get_prng(13)

        # Initialize layers.
        first_layer.init(labeled_batch, rng=rng_init)
        n_layers = 18  # 18 layers each 16K x 16K = 256M weights ~= 1GB, 18GB ram
        rev_layers = []
        int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32)
        shape = shapes.ShapeDtype((2, 4, 16 * 1024))
        sig = (shape, shape)
        for _ in range(n_layers):
            layer = tl.ReversibleHalfResidual(tl.Dense(16 * 1024))
            layer.init(sig, rng=rng_init)
            layer.weights = tl.on_cpu(
                layer.weights)  # store weights in cpu memory
            rev_layers.append(layer)
            rev_layers.append(tl.ReversibleSwap())
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(),
                               tl.CrossEntropyLoss())
        loss_layer.init((shape, shape, int_shape, int_shape))
        optimizer_fn = optimizers.Adafactor

        # Make a step with reversible trainer.
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer, rev_layers)], loss_layer, optimizer_fn)
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        # Set to true to run again, e.g., for profiling.
        run_twice = False
        if run_twice:
            t = time.time()
            loss, _ = trainer.one_step(labeled_batch, rng_step)
            self.assertLess(float(loss.sum()),
                            10000.0)  # Just to get the loss.
            print('Took %.3f seconds to run, loss %s' %
                  (time.time() - t, loss))
예제 #15
0
    def test_autoregressive_sample_reformer2_copy_self_attn_quality(self):
        max_len = 32

        def _self_attention_fn():
            return functools.partial(
                tl.SelfAttention,
                predict_drop_len=2 * max_len,
                predict_mem_len=2 * max_len,
            )

        pred_model = models.Reformer2(
            mode='predict',
            d_model=256,
            d_ff=512,
            dropout=0.05,
            max_len=max_len,
            n_heads=4,
            n_encoder_layers=3,
            n_decoder_layers=3,
            ff_use_sru=1,
            d_attention_key=64,
            d_attention_value=64,
            encoder_attention_type=_self_attention_fn(),
            encoder_decoder_attention_type=_self_attention_fn(),
            n_decoder_attention_layers=1,
            input_vocab_size=13,
            axial_pos_shape=None,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
        shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)

        model_path = os.path.join(_TESTDATA, 'reformer2_copy_self_attn.pkl.gz')
        pred_model.init_from_file(model_path,
                                  weights_only=True,
                                  input_signature=(shape1l, shape11))

        inputs = np.array([[11, 5, 1, 2, 3, 4]], dtype=np.int32)
        inp_len = inputs.shape[1]
        inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)],
                        mode='constant',
                        constant_values=0)
        s = decoding.autoregressive_sample(pred_model,
                                           inputs=inputs,
                                           eos_id=-1,
                                           max_length=inp_len,
                                           temperature=0.0)

        np.testing.assert_equal(s[0], inputs[0, :inp_len])
예제 #16
0
 def get_slice_for_val(x):
     if isinstance(x, shapes.ShapeDtype):
         return shapes.ShapeDtype(shape=x.shape[:1] + (1, ) +
                                  x.shape[2:],
                                  dtype=x.dtype)
     else:
         return x[:, i:i + 1]
예제 #17
0
    def _test_fast_inference(self, length):
        with fastmath.use_backend(fastmath.Backend.JAX):
            vocab_size = 16
            model_fn = functools.partial(
                configurable_transformer.ConfigurableTransformerLM,
                vocab_size=vocab_size,
                d_model=4,
                d_ff=8,
                n_layers=2,
                n_heads=2,
            )
            model_slow = model_fn(mode='eval')
            model_fast = model_fn(mode='predict')
            rng = fastmath.random.get_prng(0)
            batch_size = 2
            input_signature = shapes.ShapeDtype((batch_size, 1), np.int32)
            # Given the same rng, both models initialize with the same parameters.
            model_slow.init(input_signature, rng)
            model_fast.init(input_signature, rng)

            buf = np.zeros((batch_size, length), dtype=np.int32)
            next_sym = np.zeros((batch_size, 1), dtype=np.int32)

            for index in range(length):
                logits_slow = model_slow(buf, rng=rng)
                logits_fast = model_fast(next_sym, rng=rng)
                np.testing.assert_array_almost_equal(
                    logits_slow[:, index, :],
                    logits_fast[:, 0, :],
                    decimal=5,
                )
                next_sym = np.random.randint(vocab_size, size=(batch_size, 1))
                buf[:, index] = next_sym[:, 0]
예제 #18
0
    def test_custom_id_grad(self):
        class IdWithIdGrad(base.Layer):
            def forward(self, x, weights):
                return x

            @property
            def has_backward(self):
                return True

            def backward(self, inputs, output, grad, weights, state, new_state,
                         rng):
                return (inputs, ())

        layer = IdWithIdGrad()
        rng = math.random.get_prng(0)
        input_signature = shapes.ShapeDtype((9, 17))
        random_input = math.random.uniform(rng,
                                           input_signature.shape,
                                           minval=-1.0,
                                           maxval=1.0)
        layer.init(input_signature)
        f = lambda x: jnp.mean(layer(x))
        grad = math.grad(f)(random_input)
        self.assertEqual(grad.shape, (9, 17))  # Gradient for each input.
        self.assertEqual(sum(sum(grad)),
                         sum(sum(random_input)))  # Same as input.
예제 #19
0
    def test_autoregressive_sample_reformerlm(self):
        lsh_self_attention = self._lsh_self_attention_fn()
        timebin_self_attention = self._timebin_self_attention_fn()

        model = models.ReformerLM(
            vocab_size=256,
            d_model=256,
            d_ff=512,
            d_attention_key=128,
            d_attention_value=128,
            n_layers=2,
            n_heads=2,
            dropout=0.05,
            max_len=65536,
            attention_type=[timebin_self_attention, lsh_self_attention],
            pos_axial_shape=(256, 256),
            pos_d_axial_embs=(128, 128),
            ff_activation=tl.Relu,
            ff_use_sru=0,
            mode='predict',
        )
        model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
        s1 = decoding.autoregressive_sample(model,
                                            batch_size=1,
                                            eos_id=-1,
                                            max_length=10)
        self.assertEqual(s1.shape[0], 1)
        self.assertEqual(s1.shape[1], 10)
예제 #20
0
    def test_custom_zero_grad(self, backend):
        class IdWithZeroGrad(tl.Layer):
            def forward(self, x):
                return x

            @property
            def has_backward(self):
                return True

            def backward(self, inputs, output, grad, weights, state, new_state,
                         rng):
                return (jnp.zeros_like(grad), ())

        with fastmath.use_backend(backend):
            layer = IdWithZeroGrad()
            rng = fastmath.random.get_prng(0)
            input_signature = shapes.ShapeDtype((9, 17))
            random_input = fastmath.random.uniform(rng,
                                                   input_signature.shape,
                                                   minval=-1.0,
                                                   maxval=1.0)
            layer.init(input_signature)
            f = lambda x: jnp.mean(layer(x))
            grad = fastmath.grad(f)(random_input)
            self.assertEqual(grad.shape, (9, 17))  # Gradient for each input.
            self.assertEqual(sum(sum(grad * grad)), 0.0)  # Each one is 0.
예제 #21
0
파일: training.py 프로젝트: yliu45/trax
 def _value_model_signature(self):
   obs_sig = shapes.signature(self._task.observation_space)
   target_sig = mask_sig = shapes.ShapeDtype(
       shape=(1, 1, self._task.action_space),
   )
   inputs_sig = obs_sig.replace(shape=(1, 1) + obs_sig.shape)
   return (inputs_sig, target_sig, mask_sig)
예제 #22
0
def test_eval_equals_predict_discrete(
    model_fn, vocab_size=10, length=5, batch_size=3
):
  """Tests the equivalence of eval and predict modes for discrete models."""
  with fastmath.use_backend(fastmath.Backend.JAX):
    model_slow = model_fn(mode='eval', vocab_size=vocab_size)
    model_fast = model_fn(mode='predict', vocab_size=vocab_size)
    rng = fastmath.random.get_prng(0)
    input_signature = shapes.ShapeDtype((batch_size, 1), np.int32)
    # Given the same rng, both models initialize with the same parameters.
    model_slow.init(input_signature, rng)
    model_fast.init(input_signature, rng)

    buf = np.zeros((batch_size, length), dtype=np.int32)
    next_sym = np.zeros((batch_size, 1), dtype=np.int32)

    for index in range(length):
      logits_slow = model_slow(buf, rng=rng)
      logits_fast = model_fast(next_sym, rng=rng)
      np.testing.assert_array_almost_equal(
          logits_slow[:, index, :], logits_fast[:, 0, :],
          decimal=5,
      )
      next_sym = np.random.randint(vocab_size, size=(batch_size, 1))
      buf[:, index] = next_sym[:, 0]
예제 #23
0
    def test_autoregressive_sample_terraformer_pure_lsh_attn_quality(self):
        gin.add_config_file_search_path(_CONFIG_DIR)
        max_len = 32  # 32 is the max length we trained the checkpoint for.
        test_lengths = [8, 16, 32]
        vocab_size = 13
        # The checkpoint is correct on ~90% sequences, set random seed to deflake.
        np.random.seed(0)
        for test_len in test_lengths:
            gin.clear_config()
            gin.parse_config_file('terraformer_purelsh_copy.gin')
            gin.bind_parameter('PureLSHSelfAttention.predict_mem_len',
                               2 * max_len)
            gin.bind_parameter('PureLSHSelfAttention.predict_drop_len',
                               2 * max_len)
            gin.bind_parameter('PureLSHSelfAttentionWrapper.bias', False)
            gin.bind_parameter('PureLSHSelfAttentionWrapper.num_weights', 2)

            pred_model = models.ConfigurableTerraformer(mode='predict')

            shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
            shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)

            model_path = os.path.join(_TESTDATA,
                                      'terraformer_purelsh_copy.pkl.gz')
            pred_model.init_from_file(model_path,
                                      weights_only=True,
                                      input_signature=(shape1l, shape11))
            initial_state = pred_model.state

            for _ in range(2):  # Set low to make the test run reasonably fast.
                # Pick a length in [1, test_len] at random.
                inp_len = np.random.randint(low=1, high=test_len + 1)
                inputs = np.random.randint(low=1,
                                           high=vocab_size - 1,
                                           size=(1, max_len))
                # TODO(jaszczur): properly fix padding in terraformer predict mode,
                # and add a test here.
                s = decoding.autoregressive_sample(pred_model,
                                                   inputs=inputs,
                                                   eos_id=-1,
                                                   max_length=inp_len,
                                                   temperature=0.0)

                np.testing.assert_equal(s[0], inputs[0, :inp_len])
                pred_model.state = initial_state
        gin.clear_config()  # Make sure to not affect other tests.
예제 #24
0
    def test_train_with_mixed_lsh_attention(self,
                                            backend=fastmath.Backend.JAX):
        with fastmath.use_backend(backend):
            # Prepare model and inputs

            def model(mode='train'):
                return models.ConfigurableTerraformer(
                    mode=mode,
                    d_model=16,
                    d_ff=16,
                    n_heads=2,
                    dropout=0.05,
                    n_decoder_layers=1,
                    n_encoder_layers=1,
                    input_vocab_size=256,
                    encoder_attention_type=_mixed_lsh_self_attention_fn(),
                    encoder_decoder_attention_type=_mixed_lsh_self_attention_fn(
                    ),
                )

            max_len = 128
            inputs = _test_inputs_lm(vocab_size=256, seq_len=max_len)

            steps = 1
            eval_steps = 1

            # Train and evaluate
            output_dir = self.create_tempdir().full_path
            trainer_lib.train(output_dir,
                              model=model,
                              inputs=inputs,
                              steps=steps,
                              eval_steps=eval_steps,
                              eval_frequency=1)

            # Read checkpoint
            model_file = os.path.join(output_dir, 'model.pkl.gz')

            shape11 = trax_shapes.ShapeDtype((1, 1), dtype=jnp.int32)
            shape1l = trax_shapes.ShapeDtype((1, max_len), dtype=jnp.int32)

            model_predict = model(mode='predict')
            model_predict.init_from_file(model_file,
                                         weights_only=True,
                                         input_signature=(shape1l, shape11))
예제 #25
0
    def test_run_reversible_weights_trainsfer_xprof(self):
        """Runs the reversible trainer and profiles weight transfer stats."""
        run_this_test = False  # We only run this test manually.
        if not run_this_test or fastmath.global_device_count(
        ) == 1:  # TPU only
            return

        # Create inputs and rngs.
        inputs_batch = np.ones((1024, 128), dtype=np.int32)
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup())
        rng_init = fastmath.random.get_prng(12)
        rng_step = fastmath.random.get_prng(13)

        # Initialize layers.
        first_layer.init(labeled_batch, rng=rng_init)
        n_layers = 6
        rev_layers = []
        int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32)
        shape = shapes.ShapeDtype((1024, 128, 1024))
        sig = (shape, shape)
        for _ in range(n_layers):
            layer = tl.ReversibleHalfResidual(tl.Dense(1024))
            layer.init(sig, rng=rng_init)
            layer.weights = tl.on_cpu(
                layer.weights)  # store weights in cpu memory
            rev_layers.append(layer)
            rev_layers.append(tl.ReversibleSwap())
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(),
                               tl.CrossEntropyLoss())
        loss_layer.init((shape, shape, int_shape, int_shape))
        optimizer_fn = optimizers.SGD

        # Make a step with reversible trainer.
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer, rev_layers)], loss_layer, optimizer_fn)
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        # We profile here.
        t = time.time()
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss))
예제 #26
0
 def test_new_rng_deterministic(self):
   input_signature = shapes.ShapeDtype((2, 3, 5))
   layer1 = base.Layer()
   layer2 = base.Layer(n_in=2, n_out=2)
   _, _ = layer1.init(input_signature)
   _, _ = layer2.init(input_signature)
   rng1 = layer1.new_rng()
   rng2 = layer2.new_rng()
   self.assertEqual(rng1.tolist(), rng2.tolist())
예제 #27
0
  def test_loss_layer_timing(self):
    all_settings = [
        # The first run is sometimes slower, less reliable.
        {'output': 32000, 'input': 2048, 'prob': None,
         'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': False},

        {'output': 32000, 'input': 2048, 'prob': None,
         'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': False},
        {'output': 32000, 'input': 2048, 'prob': None,
         'type': 'einsum', 'sparsity': 0, 'lowrank': 0, 'use_bias': False},
        {'output': 32000, 'input': 2048, 'prob': None,
         'type': 'mult', 'sparsity': 2, 'lowrank': 0, 'use_bias': False},

        {'output': 32000, 'input': 2048, 'prob': None,
         'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': True},
        {'output': 32000, 'input': 2048, 'prob': None,
         'type': 'einsum', 'sparsity': 0, 'lowrank': 0, 'use_bias': True},
        {'output': 32000, 'input': 2048, 'prob': None,
         'type': 'mult', 'sparsity': 2, 'lowrank': 0, 'use_bias': True},
    ]

    messages = []
    for settings in all_settings:
      pred_model = tl.SparseDenseWithOptions(
          n_units=settings['output'],
          d_input=settings['input'],
          sparsity_type=settings['type'],
          sparsity=settings['sparsity'],
          d_lowrank=settings['lowrank'],
          prob_sparse=settings['prob'],
          use_bias=settings['use_bias'],
          mode='predict',
          )
      pred_model = tl.Accelerate(pred_model)

      shape1l = shapes.ShapeDtype((1, settings['input']))
      pred_model.init(input_signature=shape1l)
      inputs = np.ones((1, settings['input']))

      total_time = 0.0
      for counter in range(-50, 100):
        start_time = time.time()
        y = pred_model(inputs)
        self.assertEqual(y.shape, (1, settings['output']))
        elapsed_time = time.time() - start_time
        if counter >= 0:
          total_time += elapsed_time

      message = (
          '\n\nParams: %d Settings: %s\nTime for 100 tokens: %.4f s\n\n\n'
          % (_size_of_model(pred_model), settings, total_time))
      messages.append(message)
      print(message)

    print('Final results (recap):')
    for message in messages:
      print(message)
예제 #28
0
 def test_new_rng_new_value_each_call(self):
   input_signature = shapes.ShapeDtype((2, 3, 5))
   layer = base.Layer()
   _, _ = layer.init(input_signature)
   rng1 = layer.new_rng()
   rng2 = layer.new_rng()
   rng3 = layer.new_rng()
   self.assertNotEqual(rng1.tolist(), rng2.tolist())
   self.assertNotEqual(rng2.tolist(), rng3.tolist())
예제 #29
0
 def test_autoregressive_sample_transformer(self):
   model = models.Transformer(10, d_model=32, d_ff=64, n_encoder_layers=1,
                              n_decoder_layers=1, n_heads=2, mode='predict')
   inputs = np.ones((1, 3), dtype=np.int32)
   model.init((shapes.signature(inputs),
               shapes.ShapeDtype((1, 1), dtype=np.int32)))
   s = decoding.autoregressive_sample(model, inputs=inputs,
                                      eos_id=-1, max_length=10)
   self.assertEqual(s.shape[0], 1)
   self.assertEqual(s.shape[1], 10)
예제 #30
0
    def test_run_reversible_same_as_default_terraformer(self):
        """Runs the reversible trainer, check results are the same as default."""
        inputs_batch = np.arange(8).reshape((2, 4)) + 1
        targets_batch = 2 * inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32)
        input_sig = (int_sig, int_sig, int_sig)
        # We want to test rng propagation too, so adding some dropout layers.
        model = terraformer.ConfigurableTerraformer(20,
                                                    d_model=8,
                                                    d_ff=32,
                                                    n_heads=1,
                                                    dropout=0.0,
                                                    n_encoder_layers=2,
                                                    n_decoder_layers=2,
                                                    ff_sparsity=(4, 8, 0.0,
                                                                 1.0),
                                                    pos_type=None,
                                                    reversible_encoder=True)
        loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss())
        optimizer_fn = optimizers.Adafactor
        blocks, loss_layer = optimizers.trainer.extract_reversible_blocks(
            [model, loss], loss_chunk_size=4)
        blocks_serial = [(tl.Serial(std), rev) for (std, rev) in blocks]
        model_with_loss = tl.Serial(model, loss)
        rng_init = fastmath.random.get_prng(12)
        model_with_loss.init(input_sig, rng=rng_init)

        # Make 3 steps with the original trainer.
        optimizer = optimizer_fn()
        optimizer.tree_init(model_with_loss.weights)
        trainer = optimizers.Trainer(model_with_loss, optimizer)
        rng_step1 = fastmath.random.get_prng(7)
        rng_step2 = fastmath.random.get_prng(8)
        rng_step3 = fastmath.random.get_prng(9)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)
        first_weights = blocks_serial[0][0].weights
        first_rev_weights = blocks[0][1][0].weights
        loss_weights = loss_layer.weights

        # Now make 3 steps with reversible trainer.
        model_with_loss.init(input_sig, rng=rng_init)
        trainer = optimizers.ReversibleSerialTrainer(blocks, loss_layer,
                                                     optimizer_fn)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)

        # Check that weights end up the same.
        self._assert_all_equal(loss_weights, loss_layer.weights)
        self._assert_all_equal(first_rev_weights, blocks[0][1][0].weights)
        self._assert_all_equal(first_weights, blocks_serial[0][0].weights)