Beispiel #1
0
def mu_law_encode(audio, quantization_channels):
    '''Quantizes waveform amplitudes.'''
    with tf.name_scope('encode'):
        mu = tf.to_float(quantization_channels - 1)
        # Perform mu-law companding transformation (ITU-T, 1988).
        # Minimum operation is here to deal with rare large amplitudes caused
        # by resampling.
        safe_audio_abs = tf.minimum(tf.abs(audio), 1.0)
        magnitude = tf.log1p(mu * safe_audio_abs) / tf.log1p(mu)
        signal = tf.sign(audio) * magnitude
        # Quantize signal to the specified number of levels.
        return tf.cast((signal + 1) / 2 * mu + 0.5, dtype=tf.int32)
Beispiel #2
0
def focal_loss(logits, targets, alpha, gamma, normalizer):
    """Compute the focal loss between `logits` and the golden `target` values.

  Focal loss = -(1-pt)^gamma * log(pt)
  where pt is the probability of being classified to the true class.

  Args:
    logits: A float32 tensor of size [batch, height_in, width_in,
      num_predictions].
    targets: A float32 tensor of size [batch, height_in, width_in,
      num_predictions].
    alpha: A float32 scalar multiplying alpha to the loss from positive examples
      and (1-alpha) to the loss from negative examples.
    gamma: A float32 scalar modulating loss from hard and easy examples.
    normalizer: A float32 scalar normalizes the total loss from all examples.

  Returns:
    loss: A float32 scalar representing normalized total loss.
  """
    with tf.name_scope('focal_loss'):
        positive_label_mask = tf.equal(targets, 1.0)
        cross_entropy = (tf.nn.sigmoid_cross_entropy_with_logits(
            labels=targets, logits=logits))
        # Below are comments/derivations for computing modulator.
        # For brevity, let x = logits,  z = targets, r = gamma, and p_t = sigmod(x)
        # for positive samples and 1 - sigmoid(x) for negative examples.
        #
        # The modulator, defined as (1 - P_t)^r, is a critical part in focal loss
        # computation. For r > 0, it puts more weights on hard examples, and less
        # weights on easier ones. However if it is directly computed as (1 - P_t)^r,
        # its back-propagation is not stable when r < 1. The implementation here
        # resolves the issue.
        #
        # For positive samples (labels being 1),
        #    (1 - p_t)^r
        #  = (1 - sigmoid(x))^r
        #  = (1 - (1 / (1 + exp(-x))))^r
        #  = (exp(-x) / (1 + exp(-x)))^r
        #  = exp(log((exp(-x) / (1 + exp(-x)))^r))
        #  = exp(r * log(exp(-x)) - r * log(1 + exp(-x)))
        #  = exp(- r * x - r * log(1 + exp(-x)))
        #
        # For negative samples (labels being 0),
        #    (1 - p_t)^r
        #  = (sigmoid(x))^r
        #  = (1 / (1 + exp(-x)))^r
        #  = exp(log((1 / (1 + exp(-x)))^r))
        #  = exp(-r * log(1 + exp(-x)))
        #
        # Therefore one unified form for positive (z = 1) and negative (z = 0)
        # samples is:
        #      (1 - p_t)^r = exp(-r * z * x - r * log(1 + exp(-x))).
        neg_logits = -1.0 * logits
        modulator = tf.exp(gamma * targets * neg_logits -
                           gamma * tf.log1p(tf.exp(neg_logits)))
        loss = modulator * cross_entropy
        weighted_loss = tf.where(positive_label_mask, alpha * loss,
                                 (1.0 - alpha) * loss)
        weighted_loss /= normalizer
    return weighted_loss
Beispiel #3
0
  def _input_fn():

    if partition == "train":
      dataset = tf.data.Dataset.from_tensor_slices(({
          FEATURES_KEY: tf.log1p(x_train)
      }, tf.constant(y_train)))
    else:
      dataset = tf.data.Dataset.from_tensor_slices(({
          FEATURES_KEY: tf.log1p(x_test)
      }, tf.constant(y_test)))

    # We call repeat after shuffling, rather than before, to prevent separate
    # epochs from blending together.
    if training:
      dataset = dataset.shuffle(10 * batch_size, seed=RANDOM_SEED).repeat()

    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels
Beispiel #4
0
def legacy_focal_loss(logits, targets, alpha, gamma, normalizer, _):
  """A legacy focal loss that does not support label smooth."""
  with tf.name_scope('focal_loss'):
    positive_label_mask = tf.equal(targets, 1.0)
    cross_entropy = (
        tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits))

    neg_logits = -1.0 * logits
    modulator = tf.exp(gamma * targets * neg_logits -
                       gamma * tf.log1p(tf.exp(neg_logits)))
    loss = modulator * cross_entropy
    weighted_loss = tf.where(positive_label_mask, alpha * loss,
                             (1.0 - alpha) * loss)
    weighted_loss /= normalizer
  return weighted_loss
def zero_truncated_log_poisson_loss(targets,
                                    log_input,
                                    compute_full_loss=False,
                                    name=None):
    """Calculate log-loss for a zero-truncated Poisson distribution.

  See tf.nn.log_poisson_loss for details and sanity checks.

  Args:
    targets: A `Tensor` of the same type and shape as `log_input`.
    log_input: A `Tensor` of type `float32` or `float64`.
    compute_full_loss: whether to compute the full loss. If false, a constant
      term is dropped in favor of more efficient optimization.
    name: optional name for this op.

  Returns:
    A `Tensor` of the same shape as `log_input` with the componentwise
    logistic losses.
  """
    with tf.name_scope(name, 'ZeroTruncatedLogPoissonLoss',
                       [targets, log_input]) as scope:
        targets = tf.convert_to_tensor(targets, name='targets')
        log_input = tf.convert_to_tensor(log_input, name='log_input')

        input_ = tf.exp(log_input)
        zeros = tf.zeros_like(targets)

        # We switch to a different formula to avoid numerical stability problems
        # when log_input is small (see the go link above for details).
        approximate = log_input < -5.0
        # Use the nested-tf.where trick (go/tf-where-nan) to avoid NaNs in the
        # gradient. The value when approximate is True is arbitrary:
        input2 = tf.where(approximate, 1.0 + zeros, input_)
        result = tf.where(
            approximate, 0.5 * input_ - (targets - 1.0) * log_input,
            input_ - targets * log_input + tf.log1p(-tf.exp(-input2)))

        # The zero-truncated Poisson distribution isn't meaningfully defined for
        # targets less than one (i.e., zero):
        result = tf.where(targets < 1.0, np.nan + zeros, result)

        if compute_full_loss:
            # this is the same approximation used in tf.nn.log_poisson_loss
            stirling_approx = (targets * (tf.log(targets) - 1.0) +
                               0.5 * tf.log((2 * math.pi) * targets))
            result += tf.where(targets <= 1.0, zeros, stirling_approx)

        return tf.identity(result, name=scope)
Beispiel #6
0
def negative_dkl(variational_params=None,
                 clip_alpha=None,
                 eps=common.EPSILON,
                 log_alpha=None):
    R"""Compute the negative kl-divergence loss term.

  Computes the negative kl-divergence between the log-uniform prior over the
  weights and the variational posterior over the weights for each element
  in the set of variational parameters. Each contribution is summed and the
  sum is returned as a scalar Tensor.

  The true kl-divergence is intractable, so we compute the tight approximation
  from https://arxiv.org/abs/1701.05369.

  Args:
    variational_params: 2-tuple of Tensors, where the first tensor is the \theta
      values and the second contains the log of the \sigma^2 values.
    clip_alpha: Int or None. If integer, we clip the log \alpha values to
      [-clip_alpha, clip_alpha]. If None, don't clip the values.
    eps: Small constant value to use in log and sqrt operations to avoid NaNs.
    log_alpha: float32 tensor of log alpha values.
  Returns:
    Output scalar Tensor containing the sum of all negative kl-divergence
    contributions for each element in the input variational_params.

  Raises:
    RuntimeError: If the variational_params argument is not a 2-tuple.
  """

    if variational_params is not None:
        theta, log_sigma2 = _verify_variational_params(variational_params)

    if log_alpha is None:
        log_alpha = common.compute_log_alpha(log_sigma2, theta, eps,
                                             clip_alpha)

    # Constant values for approximating the kl divergence
    k1, k2, k3 = 0.63576, 1.8732, 1.48695
    c = -k1

    # Compute each term of the KL and combine
    term_1 = k1 * tf.nn.sigmoid(k2 + k3 * log_alpha)
    term_2 = -0.5 * tf.log1p(tf.exp(tf.negative(log_alpha)))
    eltwise_dkl = term_1 + term_2 + c
    return -tf.reduce_sum(eltwise_dkl)
Beispiel #7
0
def lossfn(real_input, fake_input, compress, hparams, lsgan, name):
    """Loss function."""
    eps = 1e-12
    with tf.variable_scope(name):
        d1 = discriminator(real_input, compress, hparams, "discriminator")
        d2 = discriminator(fake_input,
                           compress,
                           hparams,
                           "discriminator",
                           reuse=True)
        if lsgan:
            dloss = tf.reduce_mean(tf.squared_difference(
                d1, 0.9)) + tf.reduce_mean(tf.square(d2))
            gloss = tf.reduce_mean(tf.squared_difference(d2, 0.9))
            loss = (dloss + gloss) / 2
        else:  # cross_entropy
            dloss = -tf.reduce_mean(tf.log(d1 + eps)) - tf.reduce_mean(
                tf.log1p(eps - d2))
            gloss = -tf.reduce_mean(tf.log(d2 + eps))
            loss = (dloss + gloss) / 2
        return loss
def focal_loss(logits, targets, alpha, gamma, normalizer):
  with tf.name_scope('focal_loss'):
    positive_label_mask = tf.equal(targets, 1.0)
    cross_entropy = (
        tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits))
    # Below are comments/derivations for computing modulator.
    # For brevity, let x = logits,  z = targets, r = gamma, and p_t = sigmod(x)
    # for positive samples and 1 - sigmoid(x) for negative examples.
    #
    # The modulator, defined as (1 - P_t)^r, is a critical part in focal loss
    # computation. For r > 0, it puts more weights on hard examples, and less
    # weights on easier ones. However if it is directly computed as (1 - P_t)^r,
    # its back-propagation is not stable when r < 1. The implementation here
    # resolves the issue.
    #
    # For positive samples (labels being 1),
    #    (1 - p_t)^r
    #  = (1 - sigmoid(x))^r
    #  = (1 - (1 / (1 + exp(-x))))^r
    #  = (exp(-x) / (1 + exp(-x)))^r
    #  = exp(log((exp(-x) / (1 + exp(-x)))^r))
    #  = exp(r * log(exp(-x)) - r * log(1 + exp(-x)))
    #  = exp(- r * x - r * log(1 + exp(-x)))
    #
    # For negative samples (labels being 0),
    #    (1 - p_t)^r
    #  = (sigmoid(x))^r
    #  = (1 / (1 + exp(-x)))^r
    #  = exp(log((1 / (1 + exp(-x)))^r))
    #  = exp(-r * log(1 + exp(-x)))
    #
    # Therefore one unified form for positive (z = 1) and negative (z = 0)
    # samples is:
    #      (1 - p_t)^r = exp(-r * z * x - r * log(1 + exp(-x))).
    neg_logits = -1.0 * logits
    modulator = tf.exp(gamma * targets * neg_logits - gamma * tf.log1p(tf.exp(neg_logits)))
    loss = modulator * cross_entropy
    weighted_loss = tf.where(positive_label_mask, alpha * loss,(1.0 - alpha) * loss)
    weighted_loss /= normalizer
  return weighted_loss
Beispiel #9
0
def weighted_ce(logits, labels, lbda=1.1):
    """Weighted cross entropy.
  Args:
    logits: Tensor of predictions.
    labels: Tensor of ground truth (same shape as logits).
  Returns:
    weighted_cross_entropy(logits, labels)
  L = -beta*y*log(p) - (1-beta)*(1-y)*log(1-p)
    = -beta*y*log(1/1+exp(-z)) - (1-beta)*(1-y)log(exp(-z)/(1+exp(-z)))
    = beta*y*log(1+exp(-z)) - (1-beta)*(1-y)*(log(exp(-z)) - log(1+exp(-z)))
    = beta*y*q + z*(1-beta)*(1-y) + (1-beta)*(1-y)*q   [q=log(1+exp(-z))]
    = beta*y*q + (1-beta)*(1-y)*(z+q)
  """
    beta_true = tf.reduce_mean(
        labels, axis=(1, 2, 3),
        keepdims=True)  # num_true / num_true + num_false
    beta_false = 1. - beta_true
    q = tf.log1p(tf.exp(-logits))
    xentropy = lbda * beta_true * labels * q + beta_false * (1 - labels) * (
        logits + q)
    xentropy = tf.reduce_mean(xentropy)
    return xentropy
Beispiel #10
0
def model_fn(features, labels, mode, params, config):
  """Builds the model function for use in an estimator.

  Arguments:
    features: The input features for the estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some parameters, unused here.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
  del labels, params, config

  if FLAGS.analytic_kl and FLAGS.mixture_components != 1:
    raise NotImplementedError(
        "Using `analytic_kl` is only supported when `mixture_components = 1` "
        "since there's no closed form otherwise.")
  if FLAGS.floating_prior and not (FLAGS.unit_posterior and
                                   FLAGS.mixture_components == 1):
    raise NotImplementedError(
        "Using `floating_prior` is only supported when `unit_posterior` = True "
        "since there's a scale ambiguity otherwise, and when "
        "`mixture_components = 1` since there's no closed form otherwise.")
  if FLAGS.fitted_samples and FLAGS.mixture_components != 1:
    raise NotImplementedError(
        "Using `fitted_samples` is only supported when "
        "`mixture_components = 1` since there's no closed form otherwise.")
  if FLAGS.bilbo and not FLAGS.floating_prior:
    raise NotImplementedError(
        "Using `bilbo` is only supported when `floating_prior = True`.")

  activation = tf.nn.leaky_relu
  encoder = make_encoder(activation, FLAGS.latent_size, FLAGS.base_depth)
  decoder = make_decoder(activation, FLAGS.latent_size, [IMAGE_SIZE] * 2 + [3],
                         FLAGS.base_depth)

  approx_posterior = encoder(features)
  approx_posterior_sample = approx_posterior.sample(FLAGS.n_samples)
  decoder_mu = decoder(approx_posterior_sample)

  if FLAGS.floating_prior or FLAGS.fitted_samples:
    posterior_batch_mean = tf.reduce_mean(approx_posterior.mean()**2, [0])
    posterior_batch_variance = tf.reduce_mean(approx_posterior.stddev()**2, [0])
    posterior_scale = posterior_batch_mean + posterior_batch_variance
    floating_prior = tfd.MultivariateNormalDiag(
        tf.zeros(FLAGS.latent_size), tf.sqrt(posterior_scale))
    tf.summary.scalar("posterior_scale", tf.reduce_sum(posterior_scale))

  if FLAGS.floating_prior:
    latent_prior = floating_prior
  else:
    latent_prior = make_mixture_prior(FLAGS.latent_size,
                                      FLAGS.mixture_components)

  # Decode samples from the prior for visualization.
  if FLAGS.fitted_samples:
    sample_distribution = floating_prior
  else:
    sample_distribution = latent_prior

  n_samples = VIZ_GRID_SIZE**2
  random_mu = decoder(sample_distribution.sample(n_samples))

  residual = tf.reshape(features - decoder_mu, [-1] + [IMAGE_SIZE] * 2 + [3])

  if FLAGS.use_students_t:
    nll = adaptive.image_lossfun(
        residual,
        color_space=FLAGS.color_space,
        representation=FLAGS.representation,
        wavelet_num_levels=FLAGS.wavelet_num_levels,
        wavelet_scale_base=FLAGS.wavelet_scale_base,
        use_students_t=FLAGS.use_students_t,
        scale_lo=FLAGS.scale_lo,
        scale_init=FLAGS.scale_init)[0]
  else:
    nll = adaptive.image_lossfun(
        residual,
        color_space=FLAGS.color_space,
        representation=FLAGS.representation,
        wavelet_num_levels=FLAGS.wavelet_num_levels,
        wavelet_scale_base=FLAGS.wavelet_scale_base,
        use_students_t=FLAGS.use_students_t,
        alpha_lo=FLAGS.alpha_lo,
        alpha_hi=FLAGS.alpha_hi,
        alpha_init=FLAGS.alpha_init,
        scale_lo=FLAGS.scale_lo,
        scale_init=FLAGS.scale_init)[0]

  nll = tf.reshape(nll, [tf.shape(decoder_mu)[0],
                         tf.shape(decoder_mu)[1]] + [IMAGE_SIZE] * 2 + [3])

  # Clipping to prevent the loss from nanning out.
  max_val = np.finfo(np.float32).max
  nll = tf.clip_by_value(nll, -max_val, max_val)

  viz_n_inputs = np.int32(np.minimum(VIZ_MAX_N_INPUTS, FLAGS.batch_size))
  viz_n_samples = np.int32(np.minimum(VIZ_MAX_N_SAMPLES, FLAGS.n_samples))

  image_tile_summary("input", tf.to_float(features), rows=1, cols=viz_n_inputs)

  image_tile_summary(
      "recon/mean",
      decoder_mu[:viz_n_samples, :viz_n_inputs],
      rows=viz_n_samples,
      cols=viz_n_inputs)

  img_summary_input = image_tile_summary(
      "input1", tf.to_float(features), rows=viz_n_inputs, cols=1)
  img_summary_recon = image_tile_summary(
      "recon1", decoder_mu[:1, :viz_n_inputs], rows=viz_n_inputs, cols=1)

  image_tile_summary(
      "random/mean", random_mu, rows=VIZ_GRID_SIZE, cols=VIZ_GRID_SIZE)

  distortion = tf.reduce_sum(nll, axis=[2, 3, 4])

  avg_distortion = tf.reduce_mean(distortion)
  tf.summary.scalar("distortion", avg_distortion)

  if FLAGS.analytic_kl:
    rate = tfd.kl_divergence(approx_posterior, latent_prior)
  else:
    rate = (
        approx_posterior.log_prob(approx_posterior_sample) -
        latent_prior.log_prob(approx_posterior_sample))
  avg_rate = tf.reduce_mean(rate)
  tf.summary.scalar("rate", avg_rate)

  elbo_local = -(rate + distortion)

  elbo = tf.reduce_mean(elbo_local)
  tf.summary.scalar("elbo", elbo)

  if FLAGS.bilbo:
    bilbo = -0.5 * tf.reduce_sum(
        tf.log1p(
            posterior_batch_mean / posterior_batch_variance)) - avg_distortion
    tf.summary.scalar("bilbo", bilbo)
    loss = -bilbo
  else:
    loss = -elbo

  importance_weighted_elbo = tf.reduce_mean(
      tf.reduce_logsumexp(elbo_local, axis=0) -
      tf.math.log(tf.to_float(FLAGS.n_samples)))
  tf.summary.scalar("elbo/importance_weighted", importance_weighted_elbo)

  # Perform variational inference by minimizing the -ELBO.
  global_step = tf.train.get_or_create_global_step()
  learning_rate = tf.train.cosine_decay(
      FLAGS.learning_rate,
      tf.maximum(
          tf.cast(0, tf.int64),
          global_step - int(FLAGS.decay_start * FLAGS.max_steps)),
      int((1. - FLAGS.decay_start) * FLAGS.max_steps))
  tf.summary.scalar("learning_rate", learning_rate)
  optimizer = tf.train.AdamOptimizer(learning_rate)

  if mode == tf.estimator.ModeKeys.TRAIN:
    train_op = optimizer.minimize(loss, global_step=global_step)
  else:
    train_op = None

  eval_metric_ops = {}
  eval_metric_ops["elbo"] = tf.metrics.mean(elbo)
  eval_metric_ops["elbo/importance_weighted"] = tf.metrics.mean(
      importance_weighted_elbo)
  eval_metric_ops["rate"] = tf.metrics.mean(avg_rate)
  eval_metric_ops["distortion"] = tf.metrics.mean(avg_distortion)
  # This ugly hackery is necessary to get TF to visualize when running the
  # eval set, apparently.
  eval_metric_ops["img_summary_input"] = (img_summary_input, tf.no_op())
  eval_metric_ops["img_summary_recon"] = (img_summary_recon, tf.no_op())
  eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops=eval_metric_ops,
  )
    def build(self):
        self.audios = tf.placeholder(tf.float32,
                                     [self.batch_size, self.n_speaker, None],
                                     name='input_signals')
        self.mix_input = tf.reduce_sum(self.audios, axis=1)

        with tf.variable_scope("encoder"):
            # [batch, encode_len, channels]
            encoded_input = tf.layers.Conv1D(
                filters=self.config["model"]["filters"]["ae"],
                kernel_size=self.fft_len,
                strides=self.fft_hop,
                activation=tf.nn.relu,
                name="conv1d_relu")(tf.expand_dims(self.mix_input, -1))

        stfts_mix = tf.signal.stft(self.mix_input,
                                   frame_length=self.fft_len,
                                   frame_step=self.fft_hop,
                                   fft_length=self.fft_len,
                                   window_fn=self.fft_wnd)
        magni_mix = tf.abs(stfts_mix)
        phase_mix = tf.atan2(tf.imag(stfts_mix), tf.real(stfts_mix))

        with tf.variable_scope("bottle_start"):
            norm_input = self.cLN(
                tf.concat([encoded_input, tf.log1p(magni_mix)], axis=-1),
                "layer_norm")
            block_input = tf.layers.Conv1D(
                filters=self.config["model"]["filters"]["1*1-conv"],
                kernel_size=1)(norm_input)

        for stack_i in range(self.num_stacks):
            for dilation in self.dilations:
                with tf.variable_scope("conv_block_{}_{}".format(
                        stack_i, dilation)):
                    block_output = tf.layers.Conv1D(
                        filters=self.config["model"]["filters"]["d-conv"],
                        kernel_size=1)(block_input)
                    block_output = self.prelu(block_output,
                                              name='1st-prelu',
                                              shared_axes=[1])
                    block_output = self.gLN(block_output, "first")
                    block_output = self._depthwise_conv1d(
                        block_output, dilation)
                    block_output = self.prelu(block_output,
                                              name='2nd-prelu',
                                              shared_axes=[1])
                    block_output = self.gLN(block_output, "second")
                    block_output = tf.layers.Conv1D(
                        filters=self.config["model"]["filters"]["1*1-conv"],
                        kernel_size=1)(block_output)
                    block_input += block_output

        if self.output_ratio == 1:
            embed_channel = self.config["model"]["filters"]["ae"]
            feature_map = encoded_input
        elif self.output_ratio == 0:
            embed_channel = self.stft_ch
            feature_map = magni_mix
        else:
            embed_channel = self.concat_channels
            feature_map = tf.concat([encoded_input, magni_mix], axis=-1)

        with tf.variable_scope('separator'):
            s_embed = tf.layers.Dense(
                embed_channel *
                self.config["model"]["embed_size"])(block_input)
            s_embed = tf.reshape(s_embed, [
                self.batch_size, -1, embed_channel,
                self.config["model"]["embed_size"]
            ])

            # Estimate attractor from best combination from anchors
            v_anchors = tf.get_variable(
                'anchors', [self.n_anchor, self.config["model"]["embed_size"]],
                dtype=tf.float32)
            c_combs = tf.constant(list(
                itertools.combinations(range(self.n_anchor), self.n_speaker)),
                                  name='combs')
            s_anchor_sets = tf.gather(v_anchors, c_combs)

            s_anchor_assignment = tf.einsum('btfe,pce->bptfc', s_embed,
                                            s_anchor_sets)
            s_anchor_assignment = tf.nn.softmax(s_anchor_assignment)

            s_attractor_sets = tf.einsum('bptfc,btfe->bpce',
                                         s_anchor_assignment, s_embed)
            s_attractor_sets /= tf.expand_dims(
                tf.reduce_sum(s_anchor_assignment, axis=(2, 3)), -1)

            sp = tf.matmul(s_attractor_sets,
                           tf.transpose(s_attractor_sets, [0, 1, 3, 2]))
            diag = tf.fill(sp.shape[:-1], float("-inf"))
            sp = tf.linalg.set_diag(sp, diag)

            s_in_set_similarities = tf.reduce_max(sp, axis=(-1, -2))

            s_subset_choice = tf.argmin(s_in_set_similarities, axis=1)
            s_subset_choice_nd = tf.transpose(
                tf.stack([
                    tf.range(self.batch_size, dtype=tf.int64), s_subset_choice
                ]))
            s_attractors = tf.gather_nd(s_attractor_sets, s_subset_choice_nd)

            s_logits = tf.einsum('btfe,bce->bctf', s_embed, s_attractors)
            output_code = s_logits * tf.expand_dims(feature_map, 1)

        with tf.variable_scope("decoder"):
            conv_out = pred_istfts = 0
            if self.output_ratio != 0:
                output_frame = tf.layers.Dense(
                    self.config["model"]["kernel_size"]["ae"])(output_code[
                        ..., :self.config["model"]["filters"]["ae"]])
                conv_out = tf.signal.overlap_and_add(signal=output_frame,
                                                     frame_step=self.fft_hop)

            if self.output_ratio != 1:
                phase_mix_expand = tf.expand_dims(phase_mix, 1)
                pred_stfts = tf.complex(
                    tf.cos(phase_mix_expand) *
                    output_code[..., -self.stft_ch:],
                    tf.sin(phase_mix_expand) *
                    output_code[..., -self.stft_ch:])
                pred_istfts = tf.signal.inverse_stft(
                    pred_stfts,
                    frame_length=self.fft_len,
                    frame_step=self.fft_hop,
                    fft_length=self.fft_len,
                    window_fn=tf.signal.inverse_stft_window_fn(
                        self.fft_hop, forward_window_fn=self.fft_wnd))

            self.data_out = conv_out * self.output_ratio + pred_istfts * (
                1 - self.output_ratio)

        self.loss, self.pred_output, self.sdr, self.perm_idxs = loss.pit_loss(
            self.audios, self.data_out, self.config, self.batch_size,
            self.n_speaker, self.n_output)

        ### fixed loss not implemented yet !!!!!! ###
        self.loss_fix, self.pred_output_fix, self.sdr_fix, self.perm_idxs_fix = loss.pit_loss(
            self.audios, self.data_out, self.config, self.batch_size,
            self.n_speaker, self.n_output)