예제 #1
0
파일: model.py 프로젝트: NTT123/vietTTS
 def __call__(self, x, lengths):
     x = self.embed(x)
     x = jax.nn.relu(self.bn1(self.conv1(x), is_training=self.is_training))
     x = hk.dropout(hk.next_rng_key(), self.dropout_rate,
                    x) if self.is_training else x
     x = jax.nn.relu(self.bn2(self.conv2(x), is_training=self.is_training))
     x = hk.dropout(hk.next_rng_key(), self.dropout_rate,
                    x) if self.is_training else x
     x = jax.nn.relu(self.bn3(self.conv3(x), is_training=self.is_training))
     x = hk.dropout(hk.next_rng_key(), self.dropout_rate,
                    x) if self.is_training else x
     B, L, D = x.shape
     mask = jnp.arange(0, L)[None, :] >= (lengths[:, None] - 1)
     h0c0_fwd = self.lstm_fwd.initial_state(B)
     new_hx_fwd, new_hxcx_fwd = hk.dynamic_unroll(self.lstm_fwd,
                                                  x,
                                                  h0c0_fwd,
                                                  time_major=False)
     x_bwd, mask_bwd = jax.tree_map(lambda x: jnp.flip(x, axis=1),
                                    (x, mask))
     h0c0_bwd = self.lstm_bwd.initial_state(B)
     new_hx_bwd, new_hxcx_bwd = hk.dynamic_unroll(self.lstm_bwd,
                                                  (x_bwd, mask_bwd),
                                                  h0c0_bwd,
                                                  time_major=False)
     x = jnp.concatenate((new_hx_fwd, jnp.flip(new_hx_bwd, axis=1)),
                         axis=-1)
     return x
예제 #2
0
파일: train.py 프로젝트: tirkarthi/dm-haiku
def sample(
    rng_key: jnp.ndarray,
    context: jnp.ndarray,
    sample_length: int,
) -> jnp.ndarray:
    """Draws samples from the model, given an initial context."""
    # Note: this function is impure; we hk.transform() it below.
    assert context.ndim == 1  # No batching for now.
    core = make_network()

    def body_fn(t: int, v: LoopValues) -> LoopValues:
        token = v.tokens[t]
        next_logits, next_state = core(token, v.state)
        key, subkey = jax.random.split(v.rng_key)
        next_token = jax.random.categorical(subkey, next_logits, axis=-1)
        new_tokens = ops.index_update(v.tokens, ops.index[t + 1], next_token)
        return LoopValues(tokens=new_tokens, state=next_state, rng_key=key)

    logits, state = hk.dynamic_unroll(core, context, core.initial_state(None))
    key, subkey = jax.random.split(rng_key)
    first_token = jax.random.categorical(subkey, logits[-1])
    tokens = np.zeros(sample_length, dtype=np.int32)
    tokens = ops.index_update(tokens, ops.index[0], first_token)
    initial_values = LoopValues(tokens=tokens, state=state, rng_key=key)
    values: LoopValues = lax.fori_loop(0, sample_length, body_fn,
                                       initial_values)

    return values.tokens
예제 #3
0
        def unroll_fn(observations, state):
            lstm = hk.LSTM(hidden_size)
            embedding, state = hk.dynamic_unroll(lstm, observations, state)
            logits = hk.Linear(num_actions)(embedding)
            values = jnp.squeeze(hk.Linear(1)(embedding), axis=-1)

            return (logits, values), state
예제 #4
0
        def loss(trajectory: buffer.Trajectory, rnn_unroll_state: RNNState):
            """"Computes a linear combination of the policy gradient loss and value loss
      and regularizes it with an entropy term."""
            inputs = pack(trajectory)

            # Dyanmically unroll the network. This Haiku utility function unpacks the
            # list of input tensors such that the i^{th} row from each input tensor
            # is presented to the i^{th} unrolled RNN module.
            (logits, values, _, _,
             state_embeddings), new_rnn_unroll_state = hk.dynamic_unroll(
                 network, inputs, rnn_unroll_state)
            trajectory_len = trajectory.actions.shape[0]

            # Compute the combined loss given the output of the model.
            td_errors = rlax.td_lambda(v_tm1=values[:-1, 0],
                                       r_t=jnp.squeeze(trajectory.rewards, -1),
                                       discount_t=trajectory.discounts *
                                       discount,
                                       v_t=values[1:, 0],
                                       lambda_=jnp.array(td_lambda))
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1, 0],
                a_t=jnp.squeeze(trajectory.actions, 1),
                adv_t=td_errors,
                w_t=jnp.ones(trajectory_len))
            entropy_loss = jnp.mean(
                rlax.entropy_loss(logits[:-1, 0], jnp.ones(trajectory_len)))

            combined_loss = (actor_loss + critic_cost * critic_loss +
                             entropy_cost * entropy_loss)

            return combined_loss, new_rnn_unroll_state
예제 #5
0
def sequence_loss(batch: dataset.Batch) -> jnp.ndarray:
  """Unrolls the network over a sequence of inputs & targets, gets loss."""
  # Note: this function is impure; we hk.transform() it below.
  core = make_network()
  sequence_length, batch_size = batch['input'].shape
  initial_state = core.initial_state(batch_size)
  logits, _ = hk.dynamic_unroll(core, batch['input'], initial_state)
  log_probs = jax.nn.log_softmax(logits)
  one_hot_labels = jax.nn.one_hot(batch['target'], num_classes=logits.shape[-1])
  return -jnp.sum(one_hot_labels * log_probs) / (sequence_length * batch_size)
예제 #6
0
 def recurrence_function(sequence, initial_state=None):
     core = nets.make_flexible_recurrent_net(
         core_type=latent_dynamics_type,
         net_type=latent_system_net_type,
         output_dims=self.latent_system_dim,
         **self.latent_system_kwargs["net_kwargs"])
     initial_state = initial_state or core.initial_state(
         sequence.shape[1])
     core(sequence[0], initial_state)
     return hk.dynamic_unroll(core, sequence, initial_state)
예제 #7
0
    def unroll(self, x, state):
        """Unrolls more efficiently than dynamic_unroll."""
        if self._use_resnet:
            torso = AtariDeepTorso()
        else:
            torso = AtariShallowTorso()

        torso_output = hk.BatchApply(torso)(x.observation)
        if self._use_lstm:
            should_reset = jnp.equal(x.step_type, int(dm_env.StepType.FIRST))
            core_input = (torso_output, should_reset)
            core_output, state = hk.dynamic_unroll(self._core, core_input,
                                                   state)
        else:
            core_output = torso_output
            # state passes through.

        return hk.BatchApply(self._head)(core_output), state
예제 #8
0
  def __call__(self, inputs: AcousticInput):
    x = self.encoder(inputs.phonemes, inputs.lengths)
    x = self.upsample(x, inputs.durations, inputs.mels.shape[1])
    mels = self.prenet(inputs.mels)
    x = jnp.concatenate((x, mels), axis=-1)
    B, L, D = x.shape
    hx = self.decoder.initial_state(B)

    def zoneout_decoder(inputs, prev_state):
      x, mask = inputs
      x, state = self.decoder(x, prev_state)
      state = jax.tree_multimap(lambda m, s1, s2: s1*m + s2*(1-m), mask, prev_state, state)
      return x, state

    mask = jax.tree_map(lambda x: jax.random.bernoulli(hk.next_rng_key(), 0.1, (B, L, x.shape[-1])), hx)
    x, _ = hk.dynamic_unroll(zoneout_decoder, (x, mask), hx, time_major=False)
    x = self.projection(x)
    residual = self.postnet(x)
    return x, x + residual
예제 #9
0
파일: model.py 프로젝트: NTT123/vietTTS
    def inference(self, tokens, durations, n_frames):
        B, L = tokens.shape
        lengths = jnp.array([L], dtype=jnp.int32)
        x = self.encoder(tokens, lengths)
        x = self.upsample(x, durations, n_frames)

        def loop_fn(inputs, state):
            cond = inputs
            prev_mel, hxcx = state
            prev_mel = self.prenet(prev_mel)
            x = jnp.concatenate((cond, prev_mel), axis=-1)
            x, new_hxcx = self.decoder(x, hxcx)
            x = self.projection(x)
            return x, (x, new_hxcx)

        state = (jnp.zeros((B, FLAGS.mel_dim),
                           dtype=jnp.float32), self.decoder.initial_state(B))
        x, _ = hk.dynamic_unroll(loop_fn, x, state, time_major=False)
        residual = self.postnet(x)
        return x + residual
예제 #10
0
    def __call__(self, inputs: Tuple[ndarray, ndarray]):
        input_seq, input_mask = inputs
        B, L, D = input_seq.shape
        del L, D

        input_seq = jnp.swapaxes(input_seq, 0, 1)
        input_mask = jnp.swapaxes(input_mask, 0, 1)
        reset_mask = jnp.logical_not(input_mask)
        reset_mask = reset_mask.at[1:].set(
            reset_mask[:-1])  # move the mask to the right
        h0c0: hk.LSTMState = self.lstm.initial_state(B)  # type: ignore
        hx, state = hk.dynamic_unroll(self.lstm, (input_seq, reset_mask), h0c0)
        del state

        # split encoder/decoder states
        encoder_hx = hx[:self.padded_input_len]
        decoder_hx = hx[self.padded_input_len:]

        # append the initial hidden state.
        # this will be the encoder state for the [end] token.
        encoder_hx = jnp.concatenate([encoder_hx, h0c0.hidden[None]], axis=0)

        # create query and value for attention mechanism
        encoder_value = self.enc_att_fc(encoder_hx)[None]
        decoder_query = self.dec_att_fc(decoder_hx)[:, None]

        # energy function
        energy = encoder_value * decoder_query
        energy = jnp.sum(energy, axis=-1) / math.sqrt(energy.shape[-1])

        # apply input sequence mask
        input_mask = input_mask[:self.padded_input_len + 1][None]
        energy = jnp.where(input_mask, energy, float('-inf'))

        # normalize
        energy = jax.nn.log_softmax(energy, axis=1)

        # batch first, logit last
        return jnp.transpose(energy, [2, 0, 1])
예제 #11
0
def generate(seq, args):
  net = _create_network(args)
  init_state = net.lstm.initial_state(1)
  seq_ = jnp.asarray([(0., *e) for e in seq])[None]
  seq_ = jnp.swapaxes(seq_, 0, 1)
  encoder_hx, state = hk.dynamic_unroll(net.lstm.core, seq_, init_state)
  encoder_hx = jnp.concatenate([encoder_hx, init_state.hidden[None]], axis=0)
  hull = []
  #  = out[-1]
  encoder_value = net.enc_att_fc(encoder_hx)

  with open('/tmp/encoded_value.pk', 'wb') as f:
    pickle.dump(jax.device_get(encoder_value), f)

  inp = jnp.asarray([[1., 0., 0.]])  # [start] token

  queries = []

  for _ in range(len(seq) + 1):
    hidden, state = net.lstm.core(inp, state)
    decoder_query = net.dec_att_fc(hidden)
    queries.append(decoder_query)
    # logits = net.energy_fc(jnp.tanh(encoder_value + decoder_query[None]))
    logits = encoder_value * decoder_query[None]
    logits = jnp.sum(logits, axis=-1) / math.sqrt(logits.shape[-1])
    idx = jnp.argmax(logits, axis=0).item()
    if idx == len(seq):
      break
    hull.append(idx)
    inp = seq_[idx]

  queries = jnp.concatenate(queries, axis=0)

  with open('/tmp/decoder_query.pk', 'wb') as f:
    pickle.dump(jax.device_get(queries), f)

  plot_points_and_hull(seq, hull, 'imgs/prediction.png')
예제 #12
0
    def loss(trajectory: sequence.Trajectory, rnn_unroll_state: LSTMState):
      """"Actor-critic loss."""
      (logits, values), new_rnn_unroll_state = hk.dynamic_unroll(
          network, trajectory.observations[:, None, ...], rnn_unroll_state)
      seq_len = trajectory.actions.shape[0]
      td_errors = rlax.td_lambda(
          v_tm1=values[:-1, 0],
          r_t=trajectory.rewards,
          discount_t=trajectory.discounts * discount,
          v_t=values[1:, 0],
          lambda_=jnp.array(td_lambda),
      )
      critic_loss = jnp.mean(td_errors**2)
      actor_loss = rlax.policy_gradient_loss(
          logits_t=logits[:-1, 0],
          a_t=trajectory.actions,
          adv_t=td_errors,
          w_t=jnp.ones(seq_len))
      entropy_loss = jnp.mean(
          rlax.entropy_loss(logits[:-1, 0], jnp.ones(seq_len)))

      combined_loss = actor_loss + critic_loss + entropy_cost * entropy_loss

      return combined_loss, new_rnn_unroll_state
예제 #13
0
        jax.nn.relu,
        hk.Linear(10),
    ])(features)


def lstm_model(x, vocab_size=10_000, seq_len=256, args=None, **_):
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size + 4, 100, w_init=embed_init)
    o2 = token_embedding_map(x)

    o2 = jnp.reshape(o2, (o2.shape[1], o2.shape[0], o2.shape[2]))

    # LSTM Part of Network
    core = hk.LSTM(100)
    if args and args.dynamic_unroll:
        outs, state = hk.dynamic_unroll(core, o2,
                                        core.initial_state(x.shape[0]))
    else:
        outs, state = hk.static_unroll(core, o2,
                                       core.initial_state(x.shape[0]))
    outs = outs.reshape(outs.shape[1], outs.shape[0], outs.shape[2])

    # Avg Pool -> Linear
    red_dim_outs = hk.avg_pool(outs, seq_len, seq_len, "SAME").squeeze()
    final_layer = hk.Linear(2)
    ret = final_layer(red_dim_outs)

    return ret


def embedding_model(arr, vocab_size=10_000, seq_len=256, **_):
    x = arr