Ejemplo n.º 1
0
    def backward_impl(self, inputs, outputs, prop_down, accum):
        # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or
        # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph]

        #raise NotImplementedError("The backward method of BinaryCrossEntropyBackward class is not implemented.")

        # Inputs
        x0 = inputs[0].data  # probabilities
        t0 = inputs[1].data  # labels
        dz = inputs[2].data  # grad_input
        # Outputs
        dx0 = outputs[0].data
        dt0 = outputs[1].data
        # Grads of inputs
        g_x0 = inputs[0].grad
        g_t0 = inputs[1].grad
        g_dz = inputs[2].grad
        # Grads of outputs
        g_dx0 = outputs[0].grad
        g_dt0 = outputs[1].grad

        # Computation
        ## w.r.t. x0
        if prop_down[0]:
            u0 = g_dx0 * (t0 / x0 ** 2.0 + (1.0 - t0) / (1 - x0) ** 2.0)
            u1 = g_dt0 / (x0 * (1.0 - x0))
            g_x0_ = dz * (u0 - u1)
            if accum[0]:
                g_x0 += g_x0_
            else:
                g_x0.copy_from(g_x0_)

        ## w.r.t. t0
        if prop_down[1]:
            #g_t0_ = g_dx0 * dz * (1.0 / x0 + 1.0 / (1.0 - x0))
            g_t0_ = g_dx0 * dz / (x0 * (1.0 - x0))
            if accum[1]:
                g_t0 -= g_t0_
            else:
                g_t0.copy_from(-g_t0_)

        ## w.r.t. dz
        if prop_down[2]:
            #u0 = g_dx0 * ((1.0 - t0) / (1.0 - x0) - t0 / x0)
            u0 = g_dx0 * (x0 - t0) / (x0 * (1.0 - x0))
            u1 = g_dt0 * (F.log(1.0 - x0) - F.log(x0))
            g_dz_ = u0 + u1
            if accum[2]:
                g_dz += g_dz_
            else:
                g_dz.copy_from(g_dz_)
Ejemplo n.º 2
0
def kl_divergence(ctx, pred, label, log_var):
    with nn.context_scope(ctx):
        s = F.pow_scalar(F.exp(log_var), 0.5)
        elms = softmax_with_temperature(ctx, label, s) \
               * F.log(F.softmax(pred, axis=1))
        loss = -F.mean(F.sum(elms, axis=1))
    return loss
Ejemplo n.º 3
0
    def backward_impl(self, inputs, outputs, prop_down, accum):
        # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or
        # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph]

        # Inputs
        x0 = inputs[0].data
        dy = inputs[1].data
        # Outputs
        dx0 = outputs[0].data
        # Grads of inputs
        g_x0 = inputs[0].grad
        g_dy = inputs[1].grad
        # Grads of outputs
        g_dx0 = outputs[0].grad

        # Compute
        val = self.forward_func.info.args["val"]
        if prop_down[0] or prop_down[1]:
            cv = F.constant(val, x0.shape)
            if not nn.get_auto_forward():
                cv.forward()
            log_v = F.log(cv.data)
        if prop_down[0]:
            if accum[0]:
                g_x0 += g_dx0 * dy * F.r_pow_scalar(x0, val) * log_v**2.0
            else:
                g_x0.copy_from(g_dx0 * dy * F.r_pow_scalar(x0, val) *
                               log_v**2.0)
        if prop_down[1]:
            if accum[1]:
                g_dy += g_dx0 * F.r_pow_scalar(x0, val) * log_v
            else:
                g_dy.copy_from(g_dx0 * F.r_pow_scalar(x0, val) * log_v)
Ejemplo n.º 4
0
    def warp_coordinates(self, coordinates):
        theta = self.theta
        theta = F.reshape(
            theta, theta.shape[:1] + (1,) + theta.shape[1:], inplace=False)
        if coordinates.shape[0] == self.bs:
            transformed = F.batch_matmul(
                            F.tile(theta[:, :, :, :2],
                                   (1, coordinates.shape[1], 1, 1)),
                            F.reshape(coordinates, coordinates.shape + (1,), inplace=False)) + theta[:, :, :, 2:]
        else:
            transformed = F.batch_matmul(
                            F.tile(theta[:, :, :, :2],
                                   (1, coordinates.shape[1], 1, 1)),
                            F.tile(F.reshape(coordinates, coordinates.shape + (1,), inplace=False),
                                   (self.bs / coordinates.shape[0], 1, 1, 1))) + theta[:, :, :, 2:]
        transformed = F.reshape(
            transformed, transformed.shape[:-1], inplace=False)

        if self.tps:
            control_points = self.control_points
            control_params = self.control_params
            distances = F.reshape(
                coordinates, (coordinates.shape[0], -1, 1, 2), inplace=False) - F.reshape(control_points, (1, 1, -1, 2))
            distances = F.sum(F.abs(distances), axis=distances.ndim - 1)

            result = distances ** 2
            result = result * F.log(distances + 1e-6)
            result = result * control_params
            result = F.sum(result, axis=2)
            result = F.reshape(
                result, (self.bs, coordinates.shape[1], 1), inplace=False)
            transformed = transformed + result

        return transformed
Ejemplo n.º 5
0
def kl_divergence(ctx, pred, label, log_var):
    with nn.context_scope(ctx):
        s = F.pow_scalar(F.exp(log_var), 0.5)
        elms = softmax_with_temperature(ctx, label, s) \
               * F.log(F.softmax(pred, axis=1))
        loss = -F.mean(F.sum(elms, axis=1))
    return loss
Ejemplo n.º 6
0
def gaussian_log_likelihood(x, mean, logstd, orig_max_val=255):
    """
    Compute the log-likelihood of a Gaussian distribution for given data `x`.

    Args:
        x (nn.Variable): Target data. It is assumed that the values are ranged [-1, 1],
                         which are originally [0, orig_max_val].
        means (nn.Variable): Gaussian mean. Must be the same shape as x.
        logstd (nn.Variable): Gaussian log standard deviation. Must be the same shape as x.
        orig_max_val (int): The maximum value that x originally has before being rescaled.

    Return:
        A log probabilies of x in nats.
    """
    assert x.shape == mean.shape == logstd.shape
    centered_x = x - mean
    inv_std = F.exp(-logstd)
    half_bin = 1.0 / orig_max_val

    def clamp(val):
        # Here we don't need to clip max
        return F.clip_by_value(val, min=1e-12, max=1e8)

    # x + 0.5 (in original scale)
    plus_in = inv_std * (centered_x + half_bin)
    cdf_plus = approx_standard_normal_cdf(plus_in)
    log_cdf_plus = F.log(clamp(cdf_plus))

    # x - 0.5 (in original scale)
    minus_in = inv_std * (centered_x - half_bin)
    cdf_minus = approx_standard_normal_cdf(minus_in)
    log_one_minus_cdf_minus = F.log(clamp(1.0 - cdf_minus))

    log_cdf_delta = F.log(clamp(cdf_plus - cdf_minus))

    log_probs = F.where(
        F.less_scalar(x, -0.999),
        log_cdf_plus,  # Edge case for 0. It uses cdf for -inf as cdf_minus.
        F.where(F.greater_scalar(x, 0.999),
                # Edge case for orig_max_val. It uses cdf for +inf as cdf_plus.
                log_one_minus_cdf_minus,
                log_cdf_delta  # otherwise
                )
    )

    assert log_probs.shape == x.shape
    return log_probs
Ejemplo n.º 7
0
def er_loss(ctx, pred):
    with nn.context_scope(ctx):
        bs = pred.shape[0]
        d = np.prod(pred.shape[1:])
        denominator = bs * d
        pred_normalized = F.softmax(pred)
        pred_log_normalized = F.log(F.softmax(pred))
        loss_er = - F.sum(pred_normalized * pred_log_normalized) / denominator
    return loss_er
Ejemplo n.º 8
0
def er_loss(ctx, pred):
    with nn.context_scope(ctx):
        bs = pred.shape[0]
        d = np.prod(pred.shape[1:])
        denominator = bs * d
        pred_normalized = F.softmax(pred)
        pred_log_normalized = F.log(F.softmax(pred))
        loss_er = -F.sum(pred_normalized * pred_log_normalized) / denominator
    return loss_er
def net(n_class,
        xs,
        xq,
        init_type='nnabla',
        embedding='conv4',
        net_type='prototypical',
        distance='euclid',
        test=False):
    '''
    Similarity net function
        This function implements the network with settings as specified.

        Args:
            n_class (int): number of classes. Typical setting is 5 or 20.
            xs (~nnabla.Variable): support images.
            xq (~nnabla.Variable): query images.
            init_type (str, optional): initialization type for weights and bias parameters. See conv_initializer function.
            embedding(str, optional): embedding network.
            distance (str, optional): similarity metric to use. See similarity function.
            test (bool, optional): switch flag for training dataset and test dataset
        Returns:
            h (~nnabla.Variable): output variable indicating similarity between support and query.
    '''

    # feature embedding for supports and queries
    n_shot = xs.shape[0] / n_class
    n_query = xq.shape[0] / n_class
    if embedding == 'conv4':
        fs = conv4(xs, test, init_type)  # tensor of (n_support, fdim)
        fq = conv4(xq, test, init_type)  # tensor of (n_query, fdim)

    if net_type == 'matching':
        # This example does not include the full-context-embedding of matching networks.
        fs = F.reshape(fs, (1, ) + fs.shape)  # (1, n_way, fdim)
        # (n_way*n_query, 1, fdim)
        fq = F.reshape(fq, (fq.shape[0], 1) + fq.shape[1:])
        h = similarity(fq, fs, distance)
        h = h - F.mean(h, axis=1, keepdims=True)
        if 1 < n_shot:
            h = F.minimum_scalar(F.maximum_scalar(h, -35), 35)
            h = F.softmax(h)
            h = F.reshape(h, (h.shape[0], n_class, n_shot))
            h = F.mean(h, axis=2)
            # Reverse to logit to use same softmax cross entropy
            h = F.log(h)
    elif net_type == 'prototypical':
        if 1 < n_shot:
            fs = F.reshape(fs, (n_class, n_shot) + fs.shape[1:])
            fs = F.mean(fs, axis=1)
        fs = F.reshape(fs, (1, ) + fs.shape)  # (1, n_way, fdim)
        # (n_way*n_query, 1, fdim)
        fq = F.reshape(fq, (fq.shape[0], 1) + fq.shape[1:])
        h = similarity(fq, fs, distance)
        h = h - F.mean(h, axis=1, keepdims=True)

    return h
Ejemplo n.º 10
0
def sr_loss_with_uncertainty(ctx, pred0, pred1, log_var0, log_var1):
    var0 = F.exp(log_var0)
    var1 = F.exp(log_var1)
    s0 = F.pow_scalar(var0, 0.5)
    s1 = F.pow_scalar(var0, 0.5)
    squared_error = F.squared_error(pred0, pred1)
    with nn.context_scope(ctx):
        loss = F.log(s1/s0) + (var0/var1 + squared_error/var1) * 0.5
        loss_sr = F.mean(loss)
    return loss_sr
Ejemplo n.º 11
0
	def forward(self, x):
		N, C, H, W = x.shape

		log_abs = F.log(F.abs(self.scale))
		logdet = H*W*F.sum(log_abs)

		if self.logdet:
			return self.scale * (x + self.loc), logdet
		else:
			return self.scale * (x + self.loc)
Ejemplo n.º 12
0
def log_spectrogram(wave, window_size):
    r"""Return log spectrogram.

    Args:
        wave (nn.Variable): Input waveform of shape (B, 1, L).
        window_size (int): Window size.

    Returns:
        nn.Variable: Log spectrogram.
    """
    linear = spectrogram(wave, window_size)
    return F.log(linear * 1e4 + 1.0)
Ejemplo n.º 13
0
 def compute_mel(self, wave):
     hp = self.hparams
     reals, imags = F.stft(wave,
                           window_size=hp.win_length,
                           stride=hp.hop_length,
                           fft_size=hp.n_fft)
     linear = F.pow_scalar(
         F.add2(F.pow_scalar(reals, 2), F.pow_scalar(imags, 2)), 0.5)
     mels = F.batch_matmul(self.basis, linear)
     mels = F.log(F.clip_by_value(mels, 1e-5,
                                  np.inf)).apply(need_grad=False)
     return mels
Ejemplo n.º 14
0
Archivo: pow2.py Proyecto: sony/nnabla
def pow2_backward(inputs, inplace=False):  # Inplacing is obsoleted.
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    dy = inputs[0]
    x0 = inputs[1]
    x1 = inputs[2]
    dx0 = dy * x1 * x0**(x1 - 1)
    dx1 = dy * (x0**x1) * F.log(x0)
    return dx0, dx1
Ejemplo n.º 15
0
def loss(target_action, target_action_type, target_action_mask, rule_prob,
         terminal_gen_action_prob, token_prob, copy_prob):
    batch_size, max_action_length, _ = target_action.shape
    _, _, rule_num = rule_prob.shape
    _, _, token_num = token_prob.shape
    _, _, max_query_length = copy_prob.shape

    # (batch_size, max_action_length)
    target_rule, target_token, target_copy = F.split(target_action, axis=2)

    target_rule = F.reshape(target_rule, (batch_size, max_action_length, 1))
    target_rule = F.one_hot(
        target_rule, (rule_num, ))  # (batch_size, max_action_length, rule_num)
    rule_tgt_prob = rule_prob * target_rule  # (batch_size, max_action_length, rule_num)
    rule_tgt_prob = F.sum(rule_tgt_prob,
                          axis=2)  # (batch_size, max_action_length)

    target_token = F.reshape(target_token, (batch_size, max_action_length, 1))
    target_token = F.one_hot(
        target_token,
        (token_num, ))  # (batch_size, max_action_length, token_num)
    token_tgt_prob = token_prob * target_token  # (batch_size, max_action_length, token_num)
    token_tgt_prob = F.sum(token_tgt_prob,
                           axis=2)  # (batch_size, max_action_length)

    target_copy = F.reshape(target_copy, (batch_size, max_action_length, 1))
    target_copy = F.one_hot(
        target_copy, (max_query_length,
                      ))  # (batch_size, max_action_length, max_query_lenght)
    copy_tgt_prob = copy_prob * target_copy  # (batch_size, max_action_length, max_query_length)
    copy_tgt_prob = F.sum(copy_tgt_prob,
                          axis=2)  # (batch_size, max_action_length)

    # (batch_size, max_action_length)
    gen_token_prob, copy_token_prob = F.split(terminal_gen_action_prob, axis=2)
    # (batch_size, max_action_length)
    rule_mask, token_mask, copy_mask = F.split(target_action_type, axis=2)

    # (batch_size, max_action_length)
    target_prob = rule_mask * rule_tgt_prob + \
                  token_mask * gen_token_prob * token_tgt_prob + \
                  copy_mask * copy_token_prob * copy_tgt_prob
    # (batch_size, max_action_length)
    likelihood = F.log(target_prob + 1e-7)
    loss = -likelihood * target_action_mask
    # (batch_size)
    loss = F.sum(loss, axis=1)
    return F.mean(loss)
Ejemplo n.º 16
0
    def __call__(self, _out_var=None):
        # input
        # _out_var : type=nn.Variable(), The discriminator output
        # --- self ---
        # self.coef_dict : type=OrderedDict(), The coefficient dict of the synthesis network (This needs to be on the graph.)
        # self.data_iterator : type=nnabla data iterator

        # output
        # loss : type=nn.Variable()

        # --- Calculation of the Fisher Information ---
        if _out_var is not None:
            temp_need_grad = self.y.need_grad
            self.y.need_grad = True
            if len(self.FisherInformation_val_dict) == 0:
                log_likelihood_var = F.log(F.sigmoid(_out_var))
                for i in range(self.iter_num):
                    log_likelihood_var.forward(clear_no_need_grad=True)
                    self._zero_grad_all()
                    log_likelihood_var.backward(clear_buffer=True)
                    self._accumulate_grads()
                    sys.stdout.write(
                        '\rFisher Information Accumulating ... {}/{}'.format(
                            i + 1, self.iter_num))
                    sys.stdout.flush()
                print('')
                for key in self.FisherInformation_val_dict:
                    self.FisherInformation_val_dict[key] /= self.iter_num
            self.y.need_grad = temp_need_grad
        # --- make loss graph ---
        loss = 0
        for key in self.FisherInformation_val_dict:
            key_source = key.replace(self.FI_scope + '/', '')
            FI_var = nn.Variable.from_numpy_array(
                self.FisherInformation_val_dict[key].copy())
            FI_var.name = key
            coef_source_var = nn.Variable.from_numpy_array(
                self.coef_dict_for_FI[key_source].d.copy())
            coef_source_var.name = key.replace(self.FI_scope + '/',
                                               'weight_source/')
            loss += F.mean(
                FI_var *
                (self.coef_dict_for_FI[key_source] - coef_source_var)**2)
        # --- save Fisher Information ---
        if self.FI_save_switch:
            self._save_FisherInformation()
        print('[ElasticWeightConsolidation] Success!')
        return loss
Ejemplo n.º 17
0
def kl_multinomial_backward(inputs, base_axis=1):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    dy = inputs[0]
    p = inputs[1]
    q = inputs[2]
    reshape = list(dy.shape[:base_axis]) + \
        [1 for _ in range(p.ndim - base_axis)]
    dy = F.reshape(dy, reshape, inplace=False)
    dp = dy * (F.log(p / q) + 1)
    dq = -dy * p / q
    return dp, dq
Ejemplo n.º 18
0
    def __call__(self, x, return_encoding_indices=False):

        x = F.transpose(x, (0, 2, 3, 1))
        x_flat = x.reshape((-1, self.embedding_dim))

        x_flat_squared = F.broadcast(F.sum(x_flat**2, axis=1, keepdims=True),
                                     (x_flat.shape[0], self.num_embedding))
        emb_wt_squared = F.transpose(
            F.sum(self.embedding_weight**2, axis=1, keepdims=True), (1, 0))

        distances = x_flat_squared + emb_wt_squared - 2 * \
            F.affine(x_flat, F.transpose(self.embedding_weight, (1, 0)))

        encoding_indices = F.min(distances,
                                 only_index=True,
                                 axis=1,
                                 keepdims=True)
        encoding_indices.need_grad = False

        quantized = F.embed(
            encoding_indices.reshape(encoding_indices.shape[:-1]),
            self.embedding_weight).reshape(x.shape)

        if return_encoding_indices:
            return encoding_indices, F.transpose(quantized, (0, 3, 1, 2))

        encodings = F.one_hot(encoding_indices, (self.num_embedding, ))

        e_latent_loss = F.mean(
            F.squared_error(quantized.get_unlinked_variable(need_grad=False),
                            x))
        q_latent_loss = F.mean(
            F.squared_error(quantized,
                            x.get_unlinked_variable(need_grad=False)))
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = x + (quantized - x).get_unlinked_variable(need_grad=False)

        avg_probs = F.mean(encodings, axis=0)
        perplexity = F.exp(-F.sum(avg_probs * F.log(avg_probs + 1.0e-10)))

        return loss, F.transpose(quantized,
                                 (0, 3, 1, 2)), perplexity, encodings
Ejemplo n.º 19
0
def compute_mel(wave, basis, hp):
    r"""Compute the mel-spectrogram from the waveform.

    Args:
        wave (nn.Variable): Wavefrom variable of shape (B, 1, L).
        basis (nn.Variable): Basis for mel-spectrogram computation.
        hp (HParams): Hyper-parameters.

    Returns:
        nn.Variable: Output variable.
    """
    reals, imags = stft(wave,
                        window_size=hp.win_length,
                        stride=hp.hop_length,
                        fft_size=hp.n_fft)
    linear = (reals**2 + imags**2)**0.5
    mels = F.batch_matmul(basis, linear)
    mels = F.log(F.clip_by_value(mels, 1e-5, np.inf))

    return mels
Ejemplo n.º 20
0
    def __init__(self, num_actions, num_envs, batch_size, v_coeff, ent_coeff,
                 lr_scheduler):
        # inference graph
        self.infer_obs_t = nn.Variable((num_envs, 4, 84, 84))
        self.infer_pi_t,\
        self.infer_value_t = cnn_network(self.infer_obs_t, num_actions,
                                         'network')
        self.infer_t = F.sink(self.infer_pi_t, self.infer_value_t)

        # evaluation graph
        self.eval_obs_t = nn.Variable((1, 4, 84, 84))
        self.eval_pi_t, _ = cnn_network(self.eval_obs_t, num_actions,
                                        'network')

        # training graph
        self.obss_t = nn.Variable((batch_size, 4, 84, 84))
        self.acts_t = nn.Variable((batch_size, 1))
        self.rets_t = nn.Variable((batch_size, 1))
        self.advs_t = nn.Variable((batch_size, 1))

        pi_t, value_t = cnn_network(self.obss_t, num_actions, 'network')

        # value loss
        l2loss = F.squared_error(value_t, self.rets_t)
        self.value_loss = v_coeff * F.mean(l2loss)

        # policy loss
        log_pi_t = F.log(pi_t + 1e-20)
        a_one_hot = F.one_hot(self.acts_t, (num_actions, ))
        log_probs_t = F.sum(log_pi_t * a_one_hot, axis=1, keepdims=True)
        self.pi_loss = F.mean(log_probs_t * self.advs_t)

        # KL loss
        entropy = -ent_coeff * F.mean(F.sum(pi_t * log_pi_t, axis=1))

        self.loss = self.value_loss - self.pi_loss - entropy

        self.params = nn.get_parameters()
        self.solver = S.RMSprop(lr_scheduler(0.0), 0.99, 1e-5)
        self.solver.set_parameters(self.params)
        self.lr_scheduler = lr_scheduler
Ejemplo n.º 21
0
def invertible_conv(x, reverse, rng, scope):
    r"""Invertible 1x1 Convolution Layer.

    Args:
        x (nn.Variable): Input variable.
        reverse (bool): Whether it's a reverse direction.
        rng (numpy.random.RandomState): A random generator.
        scope (str): The scope.

    Returns:
        nn.Variable: The output variable.
    """
    batch_size, c, n_groups = x.shape
    with nn.parameter_scope(scope):
        # initialize w by an orthonormal matrix
        w_init = np.linalg.qr(rng.randn(c, c))[0][None, ...]
        W_var = get_parameter_or_create("W", (1, c, c), w_init, True, True)
        W = F.batch_inv(W_var) if reverse else W_var
        x = F.convolution(x, F.reshape(W, (c, c, 1)), None, stride=(1, ))
    if reverse:
        return x
    log_det = batch_size * n_groups * F.log(F.abs(F.batch_det(W)))
    return x, log_det
Ejemplo n.º 22
0
def log_mel_spectrogram(wave, sr, window_size, n_mels=80):
    """Return log mel-spectrogram.

    Args:
        wave (nn.Variable): Input waveform of shape (B, 1, L).
        sr (int): Sampling rate.
        window_size (int): Window size.
        n_mels (int): Number of mel banks.
        jitter (bool): Whether to apply random crop. Defaults to False.
        max_jitter_steps (int): Maximum number of jitter steps if jitter is
            set to `True`.

    Returns:
        nn.Variable: Log mel-spectrogram.
    """
    linear = spectrogram(wave, window_size)
    mel_basis = librosa_mel_fn(sr,
                               window_size,
                               n_mels=n_mels,
                               fmin=80.0,
                               fmax=7600.0)
    basis = nn.Variable.from_numpy_array(mel_basis[None, ...])
    mels = F.batch_matmul(basis, linear)
    return F.log(mels * 1e4 + 1.0)
Ejemplo n.º 23
0
def train_nerf(config, comm, model, dataset='blender'):

    use_transient = False
    use_embedding = False

    if model == 'wild':
        use_transient = True
        use_embedding = True
    elif model == 'uncertainty':
        use_transient = True
    elif model == 'appearance':
        use_embedding = True

    save_results_dir = config.log.save_results_dir
    os.makedirs(save_results_dir, exist_ok=True)

    train_loss_dict = {
        'train_coarse_loss': 0.0,
        'train_fine_loss': 0.0,
        'train_total_loss': 0.0,
    }

    test_metric_dict = {'test_loss': 0.0, 'test_psnr': 0.0}

    monitor_manager = MonitorManager(train_loss_dict, test_metric_dict,
                                     save_results_dir)

    if dataset != 'phototourism':
        images, poses, _, hwf, i_test, i_train, near_plane, far_plane = get_data(
            config)
        height, width, focal_length = hwf
    else:
        di = get_photo_tourism_dataiterator(config, 'train', comm)
        val_di = get_photo_tourism_dataiterator(config, 'val', comm)

    if model != 'vanilla':
        if dataset != 'phototourism':
            config.train.n_vocab = max(np.max(i_train), np.max(i_test)) + 1
        print(
            f'Setting Vocabulary size of embedding as {config.train.n_vocab}')

    if dataset != 'phototourism':
        if model in ['vanilla']:
            if comm is not None:
                # uncomment the following line to test on fewer images
                i_test = i_test[3 * comm.rank:3 * (comm.rank + 1)]
                pass
            else:
                # uncomment the following line to test on fewer images
                i_test = i_test[:3]
                pass
        else:
            # i_test = i_train[0:5]
            i_test = [i * (comm.rank + 1) for i in range(5)]
    else:
        i_test = [1]

    encode_position_function = get_encoding_function(
        config.train.num_encodings_position, True, True)
    if config.train.use_view_directions:
        encode_direction_function = get_encoding_function(
            config.train.num_encodings_direction, True, True)
    else:
        encode_direction_function = None

    lr = config.solver.lr
    num_decay_steps = config.solver.lr_decay_step * 1000
    lr_decay_factor = config.solver.lr_decay_factor
    solver = S.Adam(alpha=lr)

    load_solver_state = False
    if config.checkpoint.param_path is not None:
        nn.load_parameters(config.checkpoint.param_path)
        load_solver_state = True

    if comm is not None:
        num_decay_steps /= comm.n_procs
        comm_size = comm.n_procs
    else:
        comm_size = 1
    pbar = trange(config.train.num_iterations // comm_size,
                  disable=(comm is not None and comm.rank > 0))

    for i in pbar:

        if dataset != 'phototourism':

            idx = np.random.choice(i_train)
            image = nn.Variable.from_numpy_array(images[idx][None, :, :, :3])
            pose = nn.Variable.from_numpy_array(poses[idx])

            ray_directions, ray_origins = get_ray_bundle(
                height, width, focal_length, pose)

            grid = get_direction_grid(width,
                                      height,
                                      focal_length,
                                      return_ij_2d_grid=True)
            grid = F.reshape(grid, (-1, 2))

            select_inds = np.random.choice(grid.shape[0],
                                           size=[config.train.num_rand_points],
                                           replace=False)
            select_inds = F.gather_nd(grid, select_inds[None, :])
            select_inds = F.transpose(select_inds, (1, 0))

            embed_inp = nn.Variable.from_numpy_array(
                np.full((config.train.chunksize_fine, ), idx, dtype=int))

            ray_origins = F.gather_nd(ray_origins, select_inds)
            ray_directions = F.gather_nd(ray_directions, select_inds)

            image = F.gather_nd(image[0], select_inds)

        else:
            rays, embed_inp, image = di.next()
            ray_origins = nn.Variable.from_numpy_array(rays[:, :3])
            ray_directions = nn.Variable.from_numpy_array(rays[:, 3:6])
            near_plane = nn.Variable.from_numpy_array(rays[:, 6])
            far_plane = nn.Variable.from_numpy_array(rays[:, 7])

            embed_inp = nn.Variable.from_numpy_array(embed_inp)
            image = nn.Variable.from_numpy_array(image)

            hwf = None

        app_emb, trans_emb = None, None
        if use_embedding:
            with nn.parameter_scope('embedding_a'):
                app_emb = PF.embed(embed_inp, config.train.n_vocab,
                                   config.train.n_app)

        if use_transient:
            with nn.parameter_scope('embedding_t'):
                trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                     config.train.n_trans)

        if use_transient:
            rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass(
                ray_directions,
                ray_origins,
                near_plane,
                far_plane,
                app_emb,
                trans_emb,
                encode_position_function,
                encode_direction_function,
                config,
                use_transient,
                hwf=hwf,
                image=image)
            course_loss = 0.5 * F.mean(F.squared_error(rgb_map_course, image))
            fine_loss = 0.5 * F.mean(
                F.squared_error(rgb_map_fine, image) /
                F.reshape(F.pow_scalar(beta, 2), beta.shape + (1, )))
            beta_reg_loss = 3 + F.mean(F.log(beta))
            sigma_trans_reg_loss = 0.01 * F.mean(transient_sigma)
            loss = course_loss + fine_loss + beta_reg_loss + sigma_trans_reg_loss
        else:
            rgb_map_course, _, _, _, rgb_map_fine, _, _, _ = forward_pass(
                ray_directions,
                ray_origins,
                near_plane,
                far_plane,
                app_emb,
                trans_emb,
                encode_position_function,
                encode_direction_function,
                config,
                use_transient,
                hwf=hwf)
            course_loss = F.mean(F.squared_error(rgb_map_course, image))
            fine_loss = F.mean(F.squared_error(rgb_map_fine, image))
            loss = course_loss + fine_loss

        pbar.set_description(
            f'Total: {np.around(loss.d, 4)}, Course: {np.around(course_loss.d, 4)}, Fine: {np.around(fine_loss.d, 4)}'
        )

        solver.set_parameters(nn.get_parameters(),
                              reset=False,
                              retain_state=True)
        if load_solver_state:
            solver.load_states(config['checkpoint']['solver_path'])
            load_solver_state = False

        solver.zero_grad()

        loss.backward(clear_buffer=True)

        # Exponential LR decay
        if dataset != 'phototourism':
            lr_factor = (lr_decay_factor**((i) / num_decay_steps))
            solver.set_learning_rate(lr * lr_factor)
        else:
            if i % num_decay_steps == 0 and i != 0:
                solver.set_learning_rate(lr * lr_decay_factor)

        if comm is not None:
            params = [x.grad for x in nn.get_parameters().values()]
            comm.all_reduce(params, division=False, inplace=True)
        solver.update()

        if ((i % config.train.save_interval == 0
             or i == config.train.num_iterations - 1)
                and i != 0) and (comm is not None and comm.rank == 0):
            nn.save_parameters(os.path.join(save_results_dir, f'iter_{i}.h5'))
            solver.save_states(
                os.path.join(save_results_dir, f'solver_iter_{i}.h5'))

        if (i % config.train.test_interval == 0
                or i == config.train.num_iterations - 1) and i != 0:
            avg_psnr, avg_mse = 0.0, 0.0
            for i_t in trange(len(i_test)):

                if dataset != 'phototourism':
                    idx_test = i_test[i_t]
                    image = nn.NdArray.from_numpy_array(
                        images[idx_test][None, :, :, :3])
                    pose = nn.NdArray.from_numpy_array(poses[idx_test])

                    ray_directions, ray_origins = get_ray_bundle(
                        height, width, focal_length, pose)

                    ray_directions = F.reshape(ray_directions, (-1, 3))
                    ray_origins = F.reshape(ray_origins, (-1, 3))

                    embed_inp = nn.NdArray.from_numpy_array(
                        np.full((config.train.chunksize_fine, ),
                                idx_test,
                                dtype=int))

                else:
                    rays, embed_inp, image = val_di.next()
                    ray_origins = nn.NdArray.from_numpy_array(rays[0, :, :3])
                    ray_directions = nn.NdArray.from_numpy_array(rays[0, :,
                                                                      3:6])
                    near_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 6])
                    far_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 7])

                    embed_inp = nn.NdArray.from_numpy_array(
                        embed_inp[0, :config.train.chunksize_fine])
                    image = nn.NdArray.from_numpy_array(image[0].transpose(
                        1, 2, 0))
                    image = F.reshape(image, (1, ) + image.shape)
                    idx_test = 1

                app_emb, trans_emb = None, None
                if use_embedding:
                    with nn.parameter_scope('embedding_a'):
                        app_emb = PF.embed(embed_inp, config.train.n_vocab,
                                           config.train.n_app)

                if use_transient:
                    with nn.parameter_scope('embedding_t'):
                        trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                             config.train.n_trans)

                num_ray_batches = ray_directions.shape[
                    0] // config.train.ray_batch_size + 1

                if use_transient:
                    rgb_map_fine_list, static_rgb_map_fine_list, transient_rgb_map_fine_list = [], [], []
                else:
                    rgb_map_fine_list, depth_map_fine_list = [], []

                for r_idx in trange(num_ray_batches):
                    if r_idx != num_ray_batches - 1:
                        ray_d, ray_o = ray_directions[
                            r_idx * config.train.ray_batch_size:(r_idx + 1) *
                            config.train.ray_batch_size], ray_origins[
                                r_idx *
                                config.train.ray_batch_size:(r_idx + 1) *
                                config.train.ray_batch_size]

                        if dataset == 'phototourism':
                            near_plane = near_plane_[
                                r_idx *
                                config.train.ray_batch_size:(r_idx + 1) *
                                config.train.ray_batch_size]
                            far_plane = far_plane_[r_idx *
                                                   config.train.ray_batch_size:
                                                   (r_idx + 1) *
                                                   config.train.ray_batch_size]

                    else:
                        if ray_directions.shape[0] - (
                                num_ray_batches -
                                1) * config.train.ray_batch_size == 0:
                            break
                        ray_d, ray_o = ray_directions[
                            r_idx *
                            config.train.ray_batch_size:, :], ray_origins[
                                r_idx * config.train.ray_batch_size:, :]
                        if dataset == 'phototourism':
                            near_plane = near_plane_[r_idx * config.train.
                                                     ray_batch_size:]
                            far_plane = far_plane_[r_idx * config.train.
                                                   ray_batch_size:]

                    if use_transient:
                        rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass(
                            ray_d,
                            ray_o,
                            near_plane,
                            far_plane,
                            app_emb,
                            trans_emb,
                            encode_position_function,
                            encode_direction_function,
                            config,
                            use_transient,
                            hwf=hwf)

                        rgb_map_fine_list.append(rgb_map_fine)
                        static_rgb_map_fine_list.append(static_rgb_map_fine)
                        transient_rgb_map_fine_list.append(
                            transient_rgb_map_fine)

                    else:
                        _, _, _, _, rgb_map_fine, depth_map_fine, _, _ = \
                            forward_pass(ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb,
                                         encode_position_function, encode_direction_function, config, use_transient, hwf=hwf)

                        rgb_map_fine_list.append(rgb_map_fine)
                        depth_map_fine_list.append(depth_map_fine)

                if use_transient:
                    rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
                    static_rgb_map_fine = F.concatenate(
                        *static_rgb_map_fine_list, axis=0)
                    transient_rgb_map_fine = F.concatenate(
                        *transient_rgb_map_fine_list, axis=0)

                    rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape)
                    static_rgb_map_fine = F.reshape(static_rgb_map_fine,
                                                    image[0].shape)
                    transient_rgb_map_fine = F.reshape(transient_rgb_map_fine,
                                                       image[0].shape)
                    static_trans_img_to_save = np.concatenate(
                        (static_rgb_map_fine.data,
                         np.ones((image[0].shape[0], 5, 3)),
                         transient_rgb_map_fine.data),
                        axis=1)
                    img_to_save = np.concatenate(
                        (image[0].data, np.ones(
                            (image[0].shape[0], 5, 3)), rgb_map_fine.data),
                        axis=1)
                else:

                    rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
                    depth_map_fine = F.concatenate(*depth_map_fine_list,
                                                   axis=0)

                    rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape)
                    depth_map_fine = F.reshape(depth_map_fine,
                                               image[0].shape[:-1])
                    img_to_save = np.concatenate(
                        (image[0].data, np.ones(
                            (image[0].shape[0], 5, 3)), rgb_map_fine.data),
                        axis=1)

                filename = os.path.join(save_results_dir,
                                        f'{i}_{idx_test}.png')
                try:
                    imsave(filename,
                           np.clip(img_to_save, 0, 1),
                           channel_first=False)
                    print(f'Saved generation at {filename}')
                    if use_transient:
                        filename_static_trans = os.path.join(
                            save_results_dir, f's_t_{i}_{idx_test}.png')
                        imsave(filename_static_trans,
                               np.clip(static_trans_img_to_save, 0, 1),
                               channel_first=False)

                    else:
                        filename_dm = os.path.join(save_results_dir,
                                                   f'dm_{i}_{idx_test}.png')
                        depth_map_fine = (depth_map_fine.data -
                                          depth_map_fine.data.min()) / (
                                              depth_map_fine.data.max() -
                                              depth_map_fine.data.min())
                        imsave(filename_dm,
                               depth_map_fine[:, :, None],
                               channel_first=False)
                        plt.imshow(depth_map_fine.data)
                        plt.savefig(filename_dm)
                        plt.close()
                except:
                    pass

                avg_mse += F.mean(F.squared_error(rgb_map_fine, image[0])).data
                avg_psnr += (-10. * np.log10(
                    F.mean(F.squared_error(rgb_map_fine, image[0])).data))

            test_metric_dict['test_loss'] = avg_mse / len(i_test)
            test_metric_dict['test_psnr'] = avg_psnr / len(i_test)
            monitor_manager.add(i, test_metric_dict)
            print(
                f'Saved generations after {i} training iterations! Average PSNR: {avg_psnr/len(i_test)}. Average MSE: {avg_mse/len(i_test)}'
            )
Ejemplo n.º 24
0
def kl_divergence(ctx, pred, label):
    with nn.context_scope(ctx):
        elms = F.softmax(label, axis=1) * F.log(F.softmax(pred, axis=1))
        loss = -F.mean(F.sum(elms, axis=1))
    return loss
Ejemplo n.º 25
0
def log2(x):
    return F.log(x) / np.log(2.)
Ejemplo n.º 26
0
def quantize_pow2(v):
    return 2**F.round(F.log(v) / np.log(2.))
Ejemplo n.º 27
0
    central_bias = PF.embed(x_central, vocab_size, 1)
with nn.parameter_scope('context_bias'):
    context_bias = PF.embed(x_context, vocab_size, 1)

dot_product = F.reshape(F.batch_matmul(
    F.reshape(central_embedding, shape=(batch_size, 1, embedding_size)),
    F.reshape(context_embedding, shape=(batch_size, embedding_size, 1))),
                        shape=(batch_size, 1))

prediction = dot_product + central_bias + context_bias

t = nn.Variable((batch_size, 1))
zero = F.constant(0, shape=(batch_size, 1))
one = F.constant(1, shape=(batch_size, 1))
weight = F.clip_by_value(t / 100, zero, one)**0.75
loss = F.sum(weight * ((prediction - F.log(t))**2))

# Create solver.
solver = S.Adam()
solver.set_parameters(nn.get_parameters())

# Create monitor
monitor = M.Monitor('./log')
monitor_loss = M.MonitorSeries("Training loss", monitor, interval=1000)
monitor_valid_loss = M.MonitorSeries("Validation loss", monitor, interval=1)
monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=1000)


# Create updater
def train_data_feeder():
    x_central.d, x_context.d, t.d = train_data_iter.next()
Ejemplo n.º 28
0
def ce_loss_soft(ctx, pred, target):
    with nn.context_scope(ctx):
        #todo: devide or not
        loss = - F.mean(F.sum(F.softmax(target) * F.log(F.softmax(pred)), axis=1))
    return loss
Ejemplo n.º 29
0
    def p_mean_var(self, model, x_t, t, clip_denoised=True):
        """
        Compute mean and var of p(x_{t-1}|x_t) from model.

        Args:
            model (Callable): A callbale that takes x_t and t and predict noise (and more).
            x_t (nn.Variable): The (B, C, ...) tensor at timestep t (x_t).
            t (nn.Variable): A 1-D tensor of timesteps. The first axis represents batchsize.
            clip_denoised (bool): If True, clip the denoised signal into [-1, 1].

        Returns:
            An AttrDict containing the following items:
                "mean": the mean predicted by model.
                "var": the variance predicted by model (or pre-defined variance).
                "log_var": the log of "var".
                "xstart": the x_0 predicted from x_t and t by model.
        """
        B, C, H, W = x_t.shape
        assert t.shape == (B, )
        pred = model(x_t, t)

        if self.model_var_type == ModelVarType.LEARNED_RANGE:
            assert pred.shape == (B, 2 * C, H, W)
            pred_noise, pred_var_coeff = chunk(pred, num_chunk=2, axis=1)

            min_log = self._extract(
                self.posterior_log_var_clipped, t, x_t.shape)
            max_log = F.log(self._extract(self.betas, t, x_t.shape))

            # pred_var_coeff should be [0, 1]
            v = F.sigmoid(pred_var_coeff)
            model_log_var = v * max_log + (1 - v) * min_log
            model_var = F.exp(model_log_var)
        else:
            # Model only predicts noise
            pred_noise = pred

            model_log_var, model_var = {
                ModelVarType.FIXED_LARGE: lambda: (
                    self._extract(self.log_betas_clipped, t, x_t.shape),
                    self._extract(self.betas_clipped, t, x_t.shape)
                ),
                ModelVarType.FIXED_SMALL: lambda: (
                    self._extract(
                        self.posterior_log_var_clipped, t, x_t.shape),
                    self._extract(self.posterior_var, t, x_t.shape)
                )
            }[self.model_var_type]()

        x_recon = self.predict_xstart_from_noise(
            x_t=x_t, t=t, noise=pred_noise)

        if clip_denoised:
            x_recon = F.clip_by_value(x_recon, -1, 1)

        model_mean, _, _ = self.q_posterior(x_start=x_recon, x_t=x_t, t=t)

        assert model_mean.shape == x_recon.shape == x_t.shape

        assert model_mean.shape == model_var.shape == model_log_var.shape or \
            (model_mean.shape[0] == model_var.shape[0] == model_log_var.shape[0] and model_var.shape[1:] == (
                1, 1, 1) and model_log_var.shape[1:] == (1, 1, 1))

        # returns
        ret = AttrDict()
        ret.mean = model_mean
        ret.var = model_var
        ret.log_var = model_log_var
        ret.xstart = x_recon

        return ret
Ejemplo n.º 30
0
def ce_soft(pred, label):
    elms = - F.softmax(label, axis=1) * F.log(F.softmax(pred, axis=1))
    loss = F.mean(F.sum(elms, axis=1))
    return loss
Ejemplo n.º 31
0
def loss_function(u, v, negative_samples):
    return F.sum(-F.log(
        F.exp(-distance(u, v)) / sum([
            F.exp(-distance(u, x)) for x in F.split(negative_samples, axis=2)
        ])))
Ejemplo n.º 32
0
    def _build(self):
        # inference graph
        self.infer_obs_t = nn.Variable((1, ) + self.obs_shape)
        with nn.parameter_scope('trainable'):
            infer_dist = policy_network(self.infer_obs_t, self.action_size,
                                        'actor')
        self.infer_act_t, _ = _squash_action(infer_dist)
        self.deterministic_act_t = infer_dist.mean()

        # training graph
        self.obss_t = nn.Variable((self.batch_size, ) + self.obs_shape)
        self.acts_t = nn.Variable((self.batch_size, self.action_size))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size, ) + self.obs_shape)
        self.ters_tp1 = nn.Variable((self.batch_size, 1))

        with nn.parameter_scope('trainable'):
            dist = policy_network(self.obss_t, self.action_size, 'actor')
            squashed_act_t, log_prob_t = _squash_action(dist)
            v_t = v_network(self.obss_t, 'value')
            q_t1 = q_network(self.obss_t, self.acts_t, 'critic/1')
            q_t2 = q_network(self.obss_t, self.acts_t, 'critic/2')
            q_t1_with_actor = q_network(self.obss_t, squashed_act_t,
                                        'critic/1')
            q_t2_with_actor = q_network(self.obss_t, squashed_act_t,
                                        'critic/2')

        with nn.parameter_scope('target'):
            v_tp1 = v_network(self.obss_tp1, 'value')

        # value loss
        q_t = F.minimum2(q_t1_with_actor, q_t2_with_actor)
        v_target = q_t - log_prob_t
        v_target.need_grad = False
        self.value_loss = 0.5 * F.mean(F.squared_error(v_t, v_target))

        # q function loss
        scaled_rews_tp1 = self.rews_tp1 * self.reward_scale
        q_target = scaled_rews_tp1 + self.gamma * v_tp1 * (1.0 - self.ters_tp1)
        q_target.need_grad = False
        q1_loss = 0.5 * F.mean(F.squared_error(q_t1, q_target))
        q2_loss = 0.5 * F.mean(F.squared_error(q_t2, q_target))
        self.critic_loss = q1_loss + q2_loss

        # policy function loss
        mean_loss = 0.5 * F.mean(dist.mean()**2)
        logstd_loss = 0.5 * F.mean(F.log(dist.stddev())**2)
        policy_reg_loss = self.policy_reg * (mean_loss + logstd_loss)
        self.objective_loss = F.mean(log_prob_t - q_t)
        self.actor_loss = self.objective_loss + policy_reg_loss

        # trainable parameters
        with nn.parameter_scope('trainable'):
            with nn.parameter_scope('value'):
                value_params = nn.get_parameters()
            with nn.parameter_scope('critic'):
                critic_params = nn.get_parameters()
            with nn.parameter_scope('actor'):
                actor_params = nn.get_parameters()
        # target parameters
        with nn.parameter_scope('target/value'):
            target_params = nn.get_parameters()

        # target update
        update_targets = []
        sync_targets = []
        for key, src in value_params.items():
            dst = target_params[key]
            updated_dst = (1.0 - self.tau) * dst + self.tau * src
            update_targets.append(F.assign(dst, updated_dst))
            sync_targets.append(F.assign(dst, src))
        self.update_target_expr = F.sink(*update_targets)
        self.sync_target_expr = F.sink(*sync_targets)

        # setup solvers
        self.value_solver = S.Adam(self.value_lr)
        self.value_solver.set_parameters(value_params)
        self.critic_solver = S.Adam(self.critic_lr)
        self.critic_solver.set_parameters(critic_params)
        self.actor_solver = S.Adam(self.actor_lr)
        self.actor_solver.set_parameters(actor_params)
Ejemplo n.º 33
0
def get_tecogan_model(conf, r_inputs, r_targets, scope_name, tecogan=True):
    """
    Create computation graph and variables for TecoGAN.
    """
    # r_inputs, r_targets : shape (batch, conf.train.rnn_n, h, w, c)
    rnn_length = conf.train.rnn_n
    if tecogan:
        r_inputs, r_targets = get_tecogan_inputs(r_inputs, r_targets)
        rnn_length = rnn_length * 2 - 1

    # get the consecutive frame sequences from the input sequence
    frame_t_pre, frame_t = r_inputs[:, 0:-1, :, :, :], r_inputs[:, 1:, :, :, :]

    # Get flow estimations
    fnet_output = get_fnet_output(conf, rnn_length, frame_t_pre, frame_t,
                                  scope_name)

    # Get the generated HR output frames
    gen_outputs = get_generator_output(conf, rnn_length, r_inputs,
                                       fnet_output.flow_hr, scope_name)

    s_gen_output = F.reshape(
        gen_outputs, (conf.train.batch_size * rnn_length,
                      conf.train.crop_size * 4, conf.train.crop_size * 4, 3),
        inplace=False)
    s_targets = F.reshape(
        r_targets, (conf.train.batch_size * rnn_length,
                    conf.train.crop_size * 4, conf.train.crop_size * 4, 3),
        inplace=False)

    # Content loss (l2 loss)
    content_loss = F.mean(
        F.sum(F.squared_error(s_gen_output, s_targets), axis=[3]))
    # Warp loss (l2 loss)
    warp_loss = get_warp_loss(conf, rnn_length, frame_t, frame_t_pre,
                              fnet_output.flow_lr)

    if tecogan:
        d_data = get_d_data(conf, fnet_output.flow_hr, gen_outputs, r_targets,
                            rnn_length)
        # Build the tempo discriminator for the real part and fake part
        t_d = get_t_d(conf, r_inputs, d_data)

        # Discriminator layer loss:
        d_layer_loss = get_d_layer(t_d.real_layers, t_d.fake_layers)
        # vgg loss (cosine similarity)
        loss_vgg = get_vgg_loss(s_gen_output, s_targets)
        # ping pong loss (an l1 loss)
        gen_out_first = gen_outputs[:, 0:conf.train.rnn_n - 1, :, :, :]
        gen_out_last_rev = gen_outputs[:, -1:-conf.train.rnn_n:-1, :, :, :]
        pp_loss = F.mean(F.abs(gen_out_first - gen_out_last_rev))
        # adversarial loss
        t_adversarial_loss = F.mean(-F.log(t_d.tdiscrim_fake_output +
                                           conf.train.eps))

        # Overall generator loss
        gen_loss = content_loss + pp_loss * conf.gan.pp_scaling + conf.gan.ratio * \
            t_adversarial_loss + conf.gan.vgg_scaling * loss_vgg + \
            conf.gan.dt_ratio_0 * d_layer_loss

        # Discriminator loss
        t_discrim_fake_loss = F.log(1 - t_d.tdiscrim_fake_output +
                                    conf.train.eps)
        t_discrim_real_loss = F.log(t_d.tdiscrim_real_output + conf.train.eps)
        t_discrim_loss = F.mean(-(t_discrim_fake_loss + t_discrim_real_loss))

        fnet_loss = gen_loss + warp_loss

        set_persistent_all(r_targets, r_inputs, loss_vgg, gen_out_first,
                           gen_out_last_rev, pp_loss, d_layer_loss,
                           content_loss, warp_loss, gen_loss,
                           t_adversarial_loss, t_discrim_loss,
                           t_discrim_real_loss, d_data.t_vel,
                           d_data.t_gen_output, s_gen_output, s_targets)

        Network = collections.namedtuple(
            'Network', 'content_loss,  warp_loss, fnet_loss, vgg_loss,'
            'gen_loss, pp_loss, sum_layer_loss,t_adversarial_loss,'
            't_discrim_loss,t_gen_output,t_discrim_real_loss')
        return Network(content_loss=content_loss,
                       warp_loss=warp_loss,
                       fnet_loss=fnet_loss,
                       vgg_loss=loss_vgg,
                       gen_loss=gen_loss,
                       pp_loss=pp_loss,
                       sum_layer_loss=d_layer_loss,
                       t_adversarial_loss=t_adversarial_loss,
                       t_discrim_loss=t_discrim_loss,
                       t_gen_output=d_data.t_gen_output,
                       t_discrim_real_loss=t_discrim_real_loss)

    gen_loss = content_loss
    fnet_loss = gen_loss + warp_loss
    set_persistent_all(content_loss, s_gen_output, warp_loss, gen_loss,
                       fnet_loss)

    Network = collections.namedtuple(
        'Network', 'content_loss, warp_loss, fnet_loss, gen_loss')
    return Network(
        content_loss=content_loss,
        warp_loss=warp_loss,
        fnet_loss=fnet_loss,
        gen_loss=gen_loss,
    )
Ejemplo n.º 34
0
def sample_from_controller(args):
    """
        2-layer RNN(LSTM) based controller which outputs an architecture of CNN, 
        represented as a sequence of integers and its list.
        Given the number of layers, for each layer, 
        it executes 2 types of computation, one for sampling the operation at that layer,
        another for sampling the skip connection patterns.
    """

    entropys = nn.Variable([1, 1], need_grad=True)
    log_probs = nn.Variable([1, 1], need_grad=True)
    skip_penaltys = nn.Variable([1, 1], need_grad=True)

    entropys.d = log_probs.d = skip_penaltys.d = 0.0  # initialize them all

    num_layers = args.num_layers
    lstm_size = args.lstm_size
    state_size = args.state_size
    lstm_num_layers = args.lstm_layers
    skip_target = args.skip_prob
    temperature = args.temperature
    tanh_constant = args.tanh_constant
    num_branch = args.num_ops

    arc_seq = []
    initializer = I.UniformInitializer((-0.1, 0.1))

    prev_h = [
        nn.Variable([1, lstm_size], need_grad=True)
        for _ in range(lstm_num_layers)
    ]
    prev_c = [
        nn.Variable([1, lstm_size], need_grad=True)
        for _ in range(lstm_num_layers)
    ]

    for i in range(len(prev_h)):
        prev_h[i].d = 0  # initialize variables in lstm layers.
        prev_c[i].d = 0

    inputs = nn.Variable([1, lstm_size])
    inputs.d = np.random.normal(0, 0.5, [1, lstm_size])

    g_emb = nn.Variable([1, lstm_size])
    g_emb.d = np.random.normal(0, 0.5, [1, lstm_size])

    skip_targets = nn.Variable([1, 2])
    skip_targets.d = np.array([[1.0 - skip_target, skip_target]])

    for layer_id in range(num_layers):
        # One-step stacked LSTM.
        with nn.parameter_scope("controller_lstm"):
            next_h, next_c = stack_lstm(inputs, prev_h, prev_c, state_size)
        prev_h, prev_c = next_h, next_c  # shape:(1, lstm_size)

        # Compute for operation.
        with nn.parameter_scope("ops"):
            logit = PF.affine(next_h[-1],
                              num_branch,
                              w_init=initializer,
                              with_bias=False)

        if temperature is not None:
            logit = F.mul_scalar(logit, (1 / temperature))

        if tanh_constant is not None:
            logit = F.mul_scalar(F.tanh(logit),
                                 tanh_constant)  # (1, num_branch)

        # normalizing logits.
        normed_logit = np.e**logit.d
        normed_logit = normed_logit / np.sum(normed_logit)

        # Sampling operation id from multinomial distribution.
        ops_id = np.random.multinomial(1, normed_logit[0], 1).nonzero()[1]
        ops_id = nn.Variable.from_numpy_array(ops_id)  # (1, )
        arc_seq.append(ops_id.d)

        # log policy for operation.
        log_prob = F.softmax_cross_entropy(logit,
                                           F.reshape(ops_id,
                                                     shape=(1, 1)))  # (1, )
        # accumulate log policy as log probs
        log_probs = F.add2(log_probs, log_prob)

        entropy = log_prob * F.exp(-log_prob)
        entropys = F.add2(entropys, entropy)  # accumulate entropy as entropys.

        w_emb = nn.parameter.get_parameter_or_create("w_emb",
                                                     [num_branch, lstm_size],
                                                     initializer,
                                                     need_grad=False)

        inputs = F.reshape(w_emb[int(ops_id.d)],
                           (1, w_emb.shape[1]))  # (1, lstm_size)

        with nn.parameter_scope("controller_lstm"):
            next_h, next_c = stack_lstm(inputs, prev_h, prev_c, lstm_size)
        prev_h, prev_c = next_h, next_c  # (1, lstm_size)

        with nn.parameter_scope("skip_affine_3"):
            adding_w_1 = PF.affine(next_h[-1],
                                   lstm_size,
                                   w_init=initializer,
                                   with_bias=False)  # (1, lstm_size)

        if layer_id == 0:
            inputs = g_emb  # (1, lstm_size)
            anchors = next_h[-1]  # (1, lstm_size)
            anchors_w_1 = adding_w_1  # then goes back to the entry point of the loop

        else:
            # (layer_id, lstm_size) this shape during the process
            query = anchors_w_1

            with nn.parameter_scope("skip_affine_1"):
                query = F.tanh(
                    F.add2(
                        query,
                        PF.affine(next_h[-1],
                                  lstm_size,
                                  w_init=initializer,
                                  with_bias=False)))
                #              (layer_id, lstm_size)   +   (1, lstm_size)
                # broadcast occurs here. resulting shape is; (layer_id, lstm_size)

            with nn.parameter_scope("skip_affine_2"):
                query = PF.affine(query,
                                  1,
                                  w_init=initializer,
                                  with_bias=False)  # (layer_id, 1)
            # note that each weight for skip_affine_X is shared across all steps of LSTM.

            # re-define logits, now its shape is;(layer_id, 2)
            logit = F.concatenate(-query, query, axis=1)

            if temperature is not None:
                logit = F.mul_scalar(logit, (1 / temperature))

            if tanh_constant is not None:
                logit = F.mul_scalar(F.tanh(logit), tanh_constant)

            skip_prob_unnormalized = F.exp(logit)  # (layer_id, 2)

            # normalizing skip_prob_unnormalized.
            summed = F.sum(skip_prob_unnormalized, axis=1,
                           keepdims=True).apply(need_grad=False)
            summed = F.concatenate(summed, summed, axis=1)

            skip_prob_normalized = F.div2(skip_prob_unnormalized,
                                          summed)  # (layer_id, 2)

            # Sampling skip_pattern from multinomial distribution.
            skip_pattern = np.random.multinomial(
                1, skip_prob_normalized.d[0],
                layer_id).nonzero()[1]  # (layer_id, 1)
            arc_seq.append(skip_pattern)
            skip = nn.Variable.from_numpy_array(skip_pattern)

            # compute skip penalty.
            # (layer_id, 2) broadcast occurs here too
            kl = F.mul2(skip_prob_normalized,
                        F.log(F.div2(skip_prob_normalized, skip_targets)))
            kl = F.sum(kl, keepdims=True)
            # get the mean value here in advance.
            kl = kl * (1.0 / (num_layers - 1))

            # accumulate kl divergence as skip penalty.
            skip_penaltys = F.add2(skip_penaltys, kl)

            # log policy for connection.
            log_prob = F.softmax_cross_entropy(
                logit, F.reshape(skip, shape=(skip.shape[0], 1)))
            log_probs = F.add2(log_probs, F.sum(log_prob, keepdims=True))

            entropy = F.sum(log_prob * F.exp(-log_prob), keepdims=True)
            # accumulate entropy as entropys.
            entropys = F.add2(entropys, entropy)

            skip = F.reshape(skip, (1, layer_id))

            inputs = F.affine(skip,
                              anchors).apply(need_grad=False)  # (1, lstm_size)
            inputs = F.mul_scalar(inputs, (1.0 / (1.0 + (np.sum(skip.d)))))

            # add new row for the next computation
            # (layer_id + 1, lstm_size)
            anchors = F.concatenate(anchors, next_h[-1], axis=0)
            # (layer_id + 1, lstm_size)
            anchors_w_1 = F.concatenate(anchors_w_1, adding_w_1, axis=0)

    return arc_seq, log_probs, entropys, skip_penaltys