Esempio n. 1
0
 def _build_clp_multiplication(self, clp_kernel):
   from returnn.tf.util.basic import safe_log
   input_placeholder = self.input_data.get_placeholder_as_batch_major()
   tf_compat.v1.assert_equal(tf.shape(clp_kernel)[1], tf.shape(input_placeholder)[2] // 2)
   tf_compat.v1.assert_equal(tf.shape(clp_kernel)[2], self._nr_of_filters)
   input_real = tf.strided_slice(input_placeholder, [0, 0, 0], tf.shape(input_placeholder), [1, 1, 2])
   input_imag = tf.strided_slice(input_placeholder, [0, 0, 1], tf.shape(input_placeholder), [1, 1, 2])
   kernel_real = self._clp_kernel[0, :, :]
   kernel_imag = self._clp_kernel[1, :, :]
   output_real = tf.einsum('btf,fp->btp', input_real, kernel_real) - tf.einsum('btf,fp->btp', input_imag, kernel_imag)
   output_imag = tf.einsum('btf,fp->btp', input_imag, kernel_real) + tf.einsum('btf,fp->btp', input_real, kernel_imag)
   output_uncompressed = tf.sqrt(tf.pow(output_real, 2) + tf.pow(output_imag, 2))
   output_compressed = safe_log(output_uncompressed)
   return output_compressed
def full_sum_loss_with_stop_grad_prior(logits, logits_seq_lens, logits_time_major, targets, targets_seq_lens):
  """
  Similar to :func:`tf.nn.ctc_loss`.
  We use our :func:`fast_baum_welch`.
  Also see :class:`FastBaumWelchLoss`.

  :param tf.Tensor logits: (time,batch,dim) or (batch,time,dim). unnormalized (before softmax)
  :param tf.Tensor logits_seq_lens: shape (batch,) of int32|int64
  :param bool logits_time_major:
  :param tf.Tensor targets: batch-major, [batch,time]
  :param tf.Tensor targets_seq_lens: (batch,)
  :return: loss, shape (batch,)
  :rtype: tf.Tensor
  """
  assert logits.get_shape().ndims == 3 and logits.get_shape().dims[-1].value
  dim = logits.get_shape().dims[-1].value
  if not logits_time_major:
    logits = tf.transpose(logits, [1, 0, 2])  # (time,batch,dim)

  # No need for stop_gradient here; we will control it via custom_gradient.
  log_sm = tf.nn.log_softmax(logits)  # (time,batch,dim)
  sm = tf.exp(log_sm)
  # Note: Not the correct masking here. Should be fixed, but does not matter for the demo here.
  avg_sm = tf.reduce_mean(sm, axis=0, keep_dims=True)  # (1,1,dim)
  am_scores = log_sm - safe_log(avg_sm)

  from returnn.TFUtil import sequence_mask_time_major
  seq_mask = sequence_mask_time_major(logits_seq_lens)  # (time,batch)

  from returnn.TFNativeOp import get_ctc_fsa_fast_bw, fast_baum_welch
  edges, weights, start_end_states = get_ctc_fsa_fast_bw(
    targets=targets, seq_lens=targets_seq_lens, blank_idx=dim - 1)
  fwdbwd, obs_scores = fast_baum_welch(
    am_scores=-am_scores,  # -log space
    float_idx=seq_mask,
    edges=edges, weights=weights, start_end_states=start_end_states)
  loss = obs_scores[0]  # (batch,)
  n_batch = tf.shape(loss)[0]
  bw = tf.exp(-fwdbwd)  # (time,batch,dim). fwdbwd in -log space
  grad_x = where_bc(tf.expand_dims(seq_mask, 2), sm - bw, 0.0)  # (time,batch,dim)
  loss = tf.reshape(loss, [1, n_batch, 1])  # (1,batch,1), such that we can broadcast to logits/grad_x
  loss = custom_gradient.generic_loss_and_error_signal(loss=loss, x=logits, grad_x=grad_x)
  loss = tf.reshape(loss, [n_batch])
  return loss
Esempio n. 3
0
 def __init__(self,
              beam_size,
              search=NotSpecified,
              input_type="prob",
              prob_scale=1.0,
              base_beam_score_scale=1.0,
              random_sample_scale=0.0,
              length_normalization=True,
              custom_score_combine=None,
              source_beam_sizes=None,
              scheduled_sampling=False,
              cheating=False,
              explicit_search_sources=None,
              **kwargs):
     super(ChoiceStateVarLayer, self).__init__(**kwargs)
     rec_layer = self.network.parent_layer
     assert isinstance(rec_layer, RecStepByStepLayer)
     assert len(self.sources) == 1
     source = self.sources[0]
     assert source.output.is_batch_major and len(source.output.shape) == 1
     scores_in = source.output.placeholder
     if input_type == "prob":
         if source.output_before_activation:
             scores_in = source.output_before_activation.get_log_output()
         else:
             from returnn.tf.util.basic import safe_log
             scores_in = safe_log(scores_in)
     elif input_type == "log_prob":
         pass
     else:
         raise ValueError("Not handled input_type %r" % (input_type, ))
     rec_layer.create_state_var(name="stochastic_var_scores_%s" % self.name,
                                data_shape=source.output)
     rec_layer.set_state_var_final_value(name="stochastic_var_scores_%s" %
                                         self.name,
                                         final_value=scores_in)
     self.output.placeholder = rec_layer.create_state_var(
         name="stochastic_var_choice_%s" % self.name,
         data_shape=self.output)
     rec_layer.add_stochastic_var(self.name)
def full_sum_loss_with_prior(logits, logits_seq_lens, logits_time_major, targets, targets_seq_lens):
  """
  Similar to :func:`tf.nn.ctc_loss`.
  We use our :func:`fast_baum_welch`.
  Also see :class:`FastBaumWelchLoss`.

  :param tf.Tensor logits: (time,batch,dim) or (batch,time,dim). unnormalized (before softmax)
  :param tf.Tensor logits_seq_lens: shape (batch,) of int32|int64
  :param bool logits_time_major:
  :param tf.Tensor|Fsa targets: batch-major, [batch,time]
  :param tf.Tensor|None targets_seq_lens: (batch,)
  :return: loss, shape (batch,)
  :rtype: tf.Tensor
  """
  assert isinstance(targets, Fsa)  # not implemented otherwise
  assert logits_time_major
  # Warning: logits_seq_lens ignored currently...

  log_sm = tf.nn.log_softmax(logits)
  sm = tf.exp(log_sm)
  avg_sm = tf.reduce_mean(tf.squeeze(sm, axis=1), axis=0)
  scores = log_sm - safe_log(avg_sm)
  return -targets.tf_get_full_sum(logits=scores)
def model(
      num_classes, target_seq,
      model_type,
      num_frames=None,
      input_seq=None,
      scale_sm_by_prior=False,
      loss_type="sum",
      init_type="zero", rnd_scale=1., rnd_seed=42, blank_bias_init=None,
      opt_class=tf1.train.GradientDescentOptimizer, learning_rate=0.1,
      logits_time_dropout=0, grad_noise=0, weight_noise=0,
      scale_update_inv_param_size=False,
      update_exact=False,
      scale_grads_by_1_m_prior=False):
  """
  :param int num_classes: except blank
  :param int|None num_frames:
  :param list[int]|Fsa|None target_seq:
  :param str model_type: "bias" or "mem" or "mem+bias", "ff", "blstm"
  :param list[int]|None input_seq: if given, length should be like num_frames. will use tf.one_hot
  :param bool scale_sm_by_prior:
  :param str loss_type: "sum" or "max"
  :param str init_type: "zero" or "rnd_normal"
  :param float rnd_scale:
  :param int rnd_seed:
  :param float|None blank_bias_init:
  :param type opt_class:
  :param float learning_rate:
  :param bool scale_update_inv_param_size:
  :param float logits_time_dropout:
  :param float grad_noise:
  :param float weight_noise:
  :param bool update_exact:
  :param bool scale_grads_by_1_m_prior:
  :return:
  """
  if num_frames is None:
    assert input_seq is not None
    num_frames = len(input_seq)
  dim = num_classes + 1
  rnd = numpy.random.RandomState(rnd_seed)
  tf1.set_random_seed(rnd.randint(0, 2 ** 16))
  if init_type == "zero":
    init_func = numpy.zeros
  elif init_type == "rnd_normal":
    def init_func(shape, dtype):
      return rnd.normal(size=shape, scale=rnd_scale).astype(dtype)
  elif init_type == "rnd_uniform":    
    def init_func(shape, dtype):
      return rnd.uniform(size=shape, low=-rnd_scale, high=rnd_scale).astype(dtype)
  elif init_type == "identity":
    def init_func(shape, dtype):
      if len(shape) == 2 and shape[0] == shape[1]:
        return numpy.eye(shape[0], dtype=dtype)
      return numpy.zeros(shape=shape, dtype=dtype)
  else:
    raise ValueError("invalid init_type %r" % (init_type,))
  if loss_type == "sum":
    loss_func = full_sum_loss
  elif loss_type == "gen_sum":
    loss_func = full_sum_loss_no_renorm
  elif loss_type == "sum_with_prior":
    loss_func = full_sum_loss_with_prior
  elif loss_type == "sum_with_prior_sg":
    # Very similar to `scale_sm_by_prior` option, but no renorm afterwards.
    loss_func = full_sum_loss_with_stop_grad_prior
  elif loss_type == "max":
    loss_func = ctc_loss_viterbi
  elif loss_type == "sum+max":
    def loss_func(**kwargs):
      return (ctc_loss(**kwargs) + ctc_loss_viterbi(**kwargs)) * 0.5
  else:
    raise ValueError("invalid loss_type %r" % (loss_type,))
  global_step = tf1.train.get_or_create_global_step()
  mem = None

  if model_type == "bias":
    bias_init = init_func((dim,), dtype="float32")
    if blank_bias_init is not None:
      bias_init[-1] = blank_bias_init
    bias = tf.get_variable("bias", shape=(dim,), initializer=tf.constant_initializer(value=bias_init))
    params = [bias]
    bias = apply_weight_noise(bias, weight_noise)
    logits = tf.expand_dims(expand_dims_unbroadcast(bias, axis=0, dim=num_frames), axis=1)  # (time,batch,dim)

  elif model_type == "mem":
    mem_init = init_func((num_frames, dim), dtype="float32")
    if blank_bias_init is not None:
      mem_init[:, -1] = blank_bias_init
    mem = tf.get_variable("mem", shape=(num_frames, dim), initializer=tf.constant_initializer(value=mem_init))
    params = [mem]
    mem = apply_weight_noise(mem, weight_noise)
    logits = tf.expand_dims(mem, axis=1)  # (time,batch,dim)

  elif model_type == "mem+bias":
    mem_init = init_func((num_frames, dim), dtype="float32")
    if blank_bias_init is not None:
      mem_init[:, -1] = blank_bias_init
    mem = tf.get_variable("mem", shape=(num_frames, dim), initializer=tf.constant_initializer(value=mem_init))
    bias_init = numpy.zeros((dim,), dtype="float32")
    if blank_bias_init is not None:
      bias_init[-1] = blank_bias_init
    bias = tf.get_variable("bias", shape=(dim,), initializer=tf.constant_initializer(value=bias_init))
    params = [mem, bias]
    mem = apply_weight_noise(mem, weight_noise)
    bias = apply_weight_noise(bias, weight_noise)
    logits_bias = tf.expand_dims(expand_dims_unbroadcast(bias, axis=0, dim=num_frames), axis=1)  # (time,batch,dim)
    logits = tf.expand_dims(mem, axis=1) + logits_bias  # (time,batch,dim)

  elif model_type == "model_free":
    assert dim == 2  # currently not implemented otherwise
    input_symbols = sorted(set(input_seq))
    param_init = init_func((len(input_symbols),), dtype="float32")
    param = tf.get_variable(  # single variable for all input symbols; plot_loss_grad_map can nicely plot this
      "param", shape=(len(input_symbols),), initializer=tf.constant_initializer(value=param_init))
    params = [param]
    input_seq_tensors = []
    for x in input_seq:
      p_idx = input_symbols.index(x)
      tensor = [param[x], -param[x]]
      if dim == len(input_symbols) == 2 and p_idx == 1:
        # Just for somewhat nicer / more consistent plotting to swap this around.
        # Really somewhat arbitrary, but does not matter anyway.
        tensor = [-param[x], param[x]]
      input_seq_tensors.append(tensor)
    logits = tf.convert_to_tensor(input_seq_tensors)  # (time,dim)
    logits.set_shape((len(input_seq), dim))
    logits = tf.expand_dims(logits, axis=1)  # (time,batch,dim)

  elif model_type == "gen_model_free":
    input_symbols = sorted(set(input_seq))
    assert len(input_symbols) == 2  # currently not implemented otherwise
    param_init = init_func((dim,), dtype="float32")
    param = tf.get_variable(  # single variable for all input symbols; plot_loss_grad_map can nicely plot this
      "param", shape=(dim,), initializer=tf.constant_initializer(value=param_init))
    params = [param]
    log_probs = []
    for i in range(dim):
      i_logits = [param[i], -param[i]]
      if i == 1 and dim == 2:
        # Just for somewhat nicer / more consistent plotting to swap this around.
        # Really somewhat arbitrary, but does not matter anyway.
        i_logits = [-param[i], param[i]]
      log_probs.append(tf.nn.log_softmax(i_logits))
    input_seq_tensors = []
    for x in input_seq:
      p_idx = input_symbols.index(x)
      tensor = [log_probs[i][p_idx] for i in range(dim)]
      input_seq_tensors.append(tensor)
    logits = tf.convert_to_tensor(input_seq_tensors)  # (time,dim)
    logits.set_shape((len(input_seq), dim))
    logits = tf.expand_dims(logits, axis=1)  # (time,batch,dim)

  elif isinstance(model_type, (dict, list)):  # generic NN
    mem = None
    if isinstance(model_type, list):
      model_type = {"layers": model_type}
    layers = model_type.get("layers", [])
    assert isinstance(layers, list)
    params = []
    x = generate_input(input_seq=input_seq, num_frames=num_frames, num_classes=num_classes)
    n_batch = 1
    x = tf.expand_dims(x, axis=1)  # (T,B,D)
    index = tf.ones([num_frames, n_batch])

    for i, layer in enumerate(layers):
      if isinstance(layer, str):
        layer = {"class": layer}
      assert isinstance(layer, dict)
      dim = layer.get("dim", max(num_classes * 2, 10))
      layer_class = layer["class"]

      if layer_class == "linear":
        shape = (x.shape[-1].value, dim)
        mat_init = init_func(shape, dtype="float32")
        mat = tf.get_variable("W%i" % i, shape=shape, initializer=tf.constant_initializer(value=mat_init))
        mat = apply_weight_noise(mat, weight_noise)
        x = dot(x, mat)
        if layer.get("bias", model_type.get("bias", True)):
          bias_init = init_func((dim,), dtype="float32")
          bias = tf.get_variable("b%i" % i, shape=(dim,), initializer=tf.constant_initializer(value=bias_init))
          bias = apply_weight_noise(bias, weight_noise)
          x = x + bias
        x.set_shape((num_frames, 1, dim))
        act = layer.get("act", "relu")
        if act:
          x = getattr(tf.nn, act)(x)

      elif layer_class == "blstm":
        shape = (x.shape[-1].value, dim * 4)
        xs = []
        for d in (-1, 1):
          with tf.variable_scope("blstm%i_%s" % (i, {-1: "bwd", 1: "fwd"}[d])):
            x_ = x
            mat_init = init_func(shape, dtype="float32")
            mat = tf.get_variable("W_ff", shape=shape, initializer=tf.constant_initializer(value=mat_init))
            mat = apply_weight_noise(mat, weight_noise)
            x_ = dot(x_, mat)
            if layer.get("bias", model_type.get("bias", True)):
              bias_init = init_func(shape[-1:], dtype="float32")
              bias = tf.get_variable("b", shape=shape[-1:], initializer=tf.constant_initializer(value=bias_init))
              bias = apply_weight_noise(bias, weight_noise)
              x_ = x_ + bias
            cell = NativeLstm2(n_hidden=dim, n_input_dim=shape[0], step=d)
            x_, _ = cell(
              x_, index,
              recurrent_weights_initializer=tf.constant_initializer(
                value=init_func((dim, dim * 4), dtype="float32")))
            xs.append(x_)
        x = tf.concat(axis=2, values=xs)  # [T,B,D*2]

      else:
        raise ValueError("invalid layer %i %r in model %r" % (i, layer, model_type))

    shape = (x.shape[-1].value, num_classes + 1)
    mat_init = init_func(shape, dtype="float32")
    mat = tf1.get_variable("W_final", shape=shape, initializer=tf.constant_initializer(value=mat_init))
    mat = apply_weight_noise(mat, weight_noise)
    x = dot(x, mat)
    if model_type.get("bias", True):
      bias_init = init_func(shape[-1:], dtype="float32")
      bias = tf1.get_variable("b_final", shape=shape[-1:], initializer=tf.constant_initializer(value=bias_init))
      bias = apply_weight_noise(bias, weight_noise)
      x = x + bias
    logits = x

    for p in tf1.get_collection(tf1.GraphKeys.TRAINABLE_VARIABLES):
      if p not in params:
        params.append(p)

  else:
    raise ValueError("invalid model_type %r" % (model_type,))

  logits.set_shape((num_frames, 1, num_classes + 1))
  if logits_time_dropout:
    logits = tf.nn.dropout(
      logits, noise_shape=[num_frames, 1, 1],
      rate=tf.where(tf.equal(global_step % 2, 0), logits_time_dropout, 0.))
  if scale_sm_by_prior:
    # such that we can rescale by prior, norm them now
    logits -= tf.stop_gradient(tf.reduce_logsumexp(logits, axis=-1, keep_dims=True))
    sm = tf.exp(logits)
    avg_sm = tf.reduce_mean(sm, axis=0, keep_dims=True)  # (1,1,dim)
    logits -= tf.stop_gradient(safe_log(avg_sm))
  if scale_grads_by_1_m_prior:
    logits = get_scaled_grads_by_1_m_prior(logits)
  logits_seq_len = tf.convert_to_tensor([num_frames])  # (batch,)
  if target_seq is None:
    target_seq = list(range(num_classes))
  if isinstance(model_type, str) and model_type.startswith("gen_"):  # e.g. "gen_model_free"
    am_scores = logits
  else:
    am_scores = tf.nn.log_softmax(logits)
  if isinstance(target_seq, Fsa):
    pass
  else:
    assert len(target_seq) <= num_frames  # and that even might not be enough, e.g. for repeating entries
  targets_seq_len = None
  if isinstance(target_seq, Fsa):
    targets = target_seq
    viterbi, _ = targets.tf_get_best_alignment(logits=am_scores)
  else:
    targets = tf.convert_to_tensor([target_seq])  # (batch,time)
    targets_seq_len = tf.convert_to_tensor([len(target_seq)])  # (batch,)
    edges, weights, start_end_states = get_ctc_fsa_fast_bw(
      targets=targets, seq_lens=targets_seq_len, blank_idx=num_classes)
    viterbi, _ = fast_viterbi(
      am_scores=am_scores, am_seq_len=logits_seq_len,
      edges=edges, weights=weights, start_end_states=start_end_states)
    if input_seq:
      fer = tf.cast(tf.reduce_sum(tf.cast(tf.not_equal(viterbi[:,0], input_seq), tf.int32)), tf.float32) / tf.cast(num_frames, tf.float32)
      tf.summary.scalar("fer_viterbi_to_ref", fer)
      fer = tf.cast(tf.reduce_sum(tf.cast(tf.not_equal(tf.argmax(logits[:,0],axis=-1,output_type=tf.int32), input_seq), tf.int32)), tf.float32) / tf.cast(num_frames, tf.float32)
      tf.summary.scalar("fer_softmax_to_ref", fer)
    fer = tf.cast(tf.reduce_sum(tf.cast(tf.not_equal(tf.argmax(logits[:,0],axis=-1,output_type=tf.int32), viterbi[:,0]), tf.int32)), tf.float32) / tf.cast(num_frames, tf.float32)
    tf.summary.scalar("fer_softmax_to_viterbi", fer)
  assert isinstance(targets, (tf.Tensor, Fsa))
  loss = loss_func(
    logits=logits, logits_seq_lens=logits_seq_len, logits_time_major=True,
    targets=targets, targets_seq_lens=targets_seq_len)
  loss = tf.reduce_mean(loss)
  tf.summary.scalar("loss", loss)
  grads = tf.gradients(ys=[loss], xs=[logits] + params)
  logits_grad, param_grads = grads[0], grads[1:]
  # logits_grad == softmax(logits) - baum_welch
  baum_welch = tf.nn.softmax(logits) - logits_grad
  loss_bw = full_sum_loss(
    logits=safe_log(baum_welch), logits_seq_lens=logits_seq_len, logits_time_major=True,
    targets=targets, targets_seq_lens=targets_seq_len)
  tf.summary.scalar("loss_bw", tf.reduce_mean(loss_bw))
  assert len(params) == len(param_grads)
  for param, grad in zip(params, param_grads):
    assert grad is not None, "no grad for param %r?" % param
  grads_l2 = tf.reduce_sum([tf.nn.l2_loss(v) for v in param_grads])
  tf.summary.scalar("grad_norm", grads_l2)
  opt = opt_class(learning_rate=learning_rate)
  assert isinstance(opt, tf1.train.Optimizer)
  grads_and_vars = [(g, v) for (g, v) in zip(param_grads, params)]
  if grad_noise:
    grads_and_vars = add_scaled_noise_to_gradients(grads_and_vars, grad_noise)
  if scale_update_inv_param_size:
    max_num_elements = max([param.get_shape().num_elements() for param in params])
    grads_and_vars = [(g * (float(v.get_shape().num_elements()) / max_num_elements), v) for (g, v) in grads_and_vars]
  if update_exact:
    assert model_type == "mem"
    update_op = tf1.assign(mem, safe_log(tf.squeeze(baum_welch, axis=1)))
  else:
    update_op = opt.apply_gradients(grads_and_vars)
  with tf.control_dependencies([update_op]):
    update_op = tf1.assign_add(global_step, 1)
  return loss, logits, baum_welch, viterbi, update_op, params, param_grads