コード例 #1
0
ファイル: cnn_model_025.py プロジェクト: kzky/works
def attention(k, q, v, div_dim=True, softmax=True):
    v_shape = v.shape
    k = F.identity(k)
    q = F.identity(q)
    k = F.reshape(k, (k.shape[0], np.prod(k.shape[1:])))
    q = F.reshape(q, (q.shape[0], np.prod(q.shape[1:])))
    v = q  # F.reshape is inplace
    cf = F.affine(q, F.transpose(k, (1, 0)))
    if div_dim:
        dim = np.prod(v_shape[1:])
        cf /= np.sqrt(dim)
    h = cf
    if softmax: 
        h = F.softmax(h)
    h = F.affine(h, v)x
    h = F.reshape(h, v_shape)
    return h
コード例 #2
0
def upscale_four(inputs, scope='upscale_four'):
    """
    Mimic the tensorflow bilinear-upscaling for a fix ratio of 4.
    """
    with nn.parameter_scope(scope):
        b, h, w, c = inputs.shape

        p_inputs = F.concatenate(
            inputs, inputs[:, -1:, :, :], axis=1)  # pad bottom
        p_inputs = F.concatenate(
            p_inputs, p_inputs[:, :, -1:, :], axis=2)  # pad right

        hi_res_bin = [
            [
                    inputs,  # top-left
                    p_inputs[:, :-1, 1:, :]  # top-right
            ],
            [
                    p_inputs[:, 1:, :-1, :],  # bottom-left
                    p_inputs[:, 1:, 1:, :]  # bottom-right
            ]
            ]

        hi_res_array = []
        for hi in range(4):
            for wj in range(4):
                hi_res_array.append(
                        hi_res_bin[0][0] *
                            (1.0 - 0.25 * hi) * (1.0 - 0.25 * wj)
                        + hi_res_bin[0][1] * (1.0 - 0.25 * hi) * (0.25 * wj)
                        + hi_res_bin[1][0] * (0.25 * hi) * (1.0 - 0.25 * wj)
                        + hi_res_bin[1][1] * (0.25 * hi) * (0.25 * wj)
                        )

        hi_res = F.stack(*hi_res_array, axis=3)  # shape (b,h,w,16,c)
        hi_res_reshape = F.reshape(hi_res, (b, h, w, 4, 4, c))
        hi_res_reshape = F.transpose(hi_res_reshape, (0, 1, 3, 2, 4, 5))
        hi_res_reshape = F.reshape(hi_res_reshape, (b, h*4, w*4, c))

    return hi_res_reshape
コード例 #3
0
ファイル: rnn.py プロジェクト: Pandinosaurus/nnabla
def _rnn(x, h, w, b, nonlinearity, with_bias):
    """RNN cell.
    Args:
        x (:obj:`~nnabla.Variable`): Input data.
        h (:obj:`~nnabla.Variable`): Hidden state.
        w (:obj:`~nnabla.Variable`): Weight.
        b (:obj:`~nnabla.Variable`): Bias.
        nonlinearity (str): "tanh" or "relu".
        with_bias (bool): Include the bias or not.
    """
    hidden_size = h.shape[1]
    xh = F.concatenate(*(x, h), axis=1)
    b_ = None
    if with_bias:
        b_ = b
    h_t = F.affine(xh, F.transpose(w, (1, 0)), b_)
    if nonlinearity == 'tanh':
        h_t = F.tanh(h_t)
    elif nonlinearity == 'relu':
        h_t = F.relu(h_t)

    return h_t
コード例 #4
0
    def __call__(self, conv_in, h=None):

        v_stack_in = conv_in
        h_stack_in = conv_in

        features = []
        with nn.parameter_scope('ConditionalPixelCNN'):
            for i in range(self.num_layers):
                if i == 0:
                    kernel_shape = (7, 7)
                    mask_type = self.mask_type_A
                    residual = False
                else:
                    kernel_shape = (3, 3)
                    mask_type = self.mask_type_B
                    residual = True

                v_stack_gated, v_stack_conv = self.gated_conv(v_stack_in, kernel_shape, h, mask_type=mask_type, return_payload=True,
                                                              scope_name='vertical_stack_gated_'+str(i))
                h_stack_gated = self.gated_conv(h_stack_in, (1, kernel_shape[0]), h, mask_type=mask_type,
                                                payload=v_stack_conv, scope_name='horizontal_stack_gated_'+str(i))
                h_stack_conv = self.gated_conv(h_stack_gated, (1, 1), h, mask_type=mask_type, gated=False,
                                               scope_name='horizontal_stack_conv_'+str(i))
                if residual:
                    h_stack_conv += h_stack_in

                v_stack_in = v_stack_gated
                h_stack_in = h_stack_conv

            fc_1 = self.gated_conv(
                    h_stack_in, (1, 1), gated=False, scope_name='fc_1')
            fc_2 = PF.convolution(fc_1, self.out_channels,
                                  (1, 1), apply_w=self.mask_type_B, name='fc_2')

        fc_2 = F.transpose(fc_2, (0, 2, 3, 1))
        fc_2 = F.reshape(fc_2, (-1, fc_2.shape[-1]), inplace=True)

        return fc_2
コード例 #5
0
    def define_network(self):

        if self.use_inst:
            obj_onehot, bm = encode_inputs(self.ist_mask,
                                           self.obj_mask,
                                           n_ids=self.conf.n_class)

            mask = F.concatenate(obj_onehot, bm, axis=1)
        else:
            om = self.obj_mask
            if len(om.shape) == 3:
                om = F.reshape(om, om.shape + (1, ))
            obj_onehot = F.one_hot(om, shape=(self.conf.n_class, ))
            mask = F.transpose(obj_onehot, (0, 3, 1, 2))

        generator = SpadeGenerator(self.conf.g_ndf,
                                   image_shape=self.conf.image_shape)
        z = F.randn(shape=(self.conf.batch_size, self.conf.z_dim))
        fake = generator(z, mask)

        # Pixel intensities of fake are [-1, 1]. Rescale it to [0, 1]
        fake = (fake + 1) / 2

        return fake
コード例 #6
0
ファイル: model.py プロジェクト: TengHu/nnabla-examples
def primary_capsule(h, factor_capsules=32, out_channels=8, kernel=9, stride=2, fix_parameters=False):
    '''
    Takes Conv1 output and produces PrimaryCapsules.

    PrimaryCapsules are computed by using a single Convolution layer.

    Args:
        h (nnabla.Variable): A shape of [B, C, H, W].
        factor_capsules (int): Multiplication factor of output capsules. The output capsules will be ``factor_capsules x out_H x out_H`` where ``out_H`` and ``out_W`` are height and width of the output of the ``(kernel, kernel)`` Convolution with ``stride``. E.g. ``out_H = (H - (kernel - 1)) / stride``.
        out_channels (int): Number of units in each capsule of the output.
        kernel (int): Kernel size of the Convolution.
        stride (int): Stride of the Convolution.
        fix_parameters (bool): Fix parameters (Set need_grad=False).

    Returns:
        nn.Variable: A shape [B, factor_capsules x H' x W', out_channels].

    '''
    h = PF.convolution(h, out_channels * factor_capsules, (kernel, kernel),
                       stride=(stride, stride), fix_parameters=fix_parameters)
    num_capsules = factor_capsules * h.shape[2] * h.shape[3]
    h = F.reshape(h, [h.shape[0], out_channels, num_capsules])
    h = F.transpose(h, (0, 2, 1))
    return squash(h)
コード例 #7
0
    def backward_impl(self, inputs, outputs, prop_down, accum):
        # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or
        # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph]

        # Args
        axes = self.forward_func.info.args["axes"]
        # Inputs
        x0 = inputs[0].data
        dy = inputs[1].data
        # Outputs
        dx0 = outputs[0].data
        # Grads of inputs
        g_x0 = inputs[0].grad
        g_dy = inputs[1].grad
        # Grads of outputs
        g_dx0 = outputs[0].grad

        # Computation
        if prop_down[1]:
            g_dy_ = F.transpose(g_dx0, axes)
            if accum[1]:
                g_dy += g_dy_
            else:
                g_dy.copy_from(g_dy_)
コード例 #8
0
ファイル: model.py プロジェクト: sony/nnabla-examples
    def call(self, char_inputs, mel_inputs=None):
        r"""Return mel-spectrograms, spectrograms, and attention weights.

        Args:
            char_inputs (nn.Variable): Inputs containing indices of characters.
                This has a shape of(B, Tx).

        Returns:
            nn.Variable: The synthetic mel-spectrograms of shape (B, T_y/r, n_mels*r).
            nn.Variable: The synthetic mel-spectrograms of shape (B, T_y/r, n_mels*r) after postnet.
            nn.Variable: The attention matrix of shape (B, Tx, Ty).
        """
        with nn.parameter_scope('encoder'):
            encoder_outputs = self.encoder(char_inputs)  # (Tx, B, C=512)

        with nn.parameter_scope('decoder'):
            encoder_outputs = F.transpose(encoder_outputs, (1, 0, 2))
            mel_outputs, gate_outputs, alignments = self.decoder(
                encoder_outputs, mel_inputs)

        with nn.parameter_scope('post_net'):
            mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs

        return mel_outputs, mel_outputs_postnet, gate_outputs, alignments
コード例 #9
0
def dot(a, b, out=None):
    '''
    A compatible operation with ``numpy.dot``.

    Note:
        Any operation between nnabla's Variable/NdArray and numpy array is not supported.

        If both arguments are 1-D, it is inner product of vectors.
        If both arguments are 2-D, it is matrix multiplication.
        If either a or b is 0-D(scalar), it is equivalent to multiply.
        If b is a 1-D array, it is a sum product over the last axis of a and b.
        If b is an M-D array (M>=2), it is a sum product over the last axis of a and the second-to-last axis of b.

    Args:
        a (Variable, NdArray or scalar): Left input array.
        b (Variable, NdArray or scalar): Right input array.
        out: Output argument. This must have the same shape, dtype, and type as the result that would be returned for F.dot(a,b).

    Returns:
        ~nnabla.Variable or ~nnabla.NdArray

    Examples:

    .. code-block:: python

        import numpy as np
        import nnabla as nn
        import nnabla.functions as F

        # 2-D matrix * 2-D matrix
        arr1 = np.arange(5*6).reshape(5, 6)
        arr2 = np.arange(6*8).reshape(6, 8)
        nd1 = nn.NdArray.from_numpy_array(arr1)
        nd2 = nn.NdArray.from_numpy_array(arr2)
        ans1 = F.dot(nd1, nd2)
        print(ans1.shape)
        #(5, 8)

        var1 = nn.Variable.from_numpy_array(arr1)
        var2 = nn.Variable.from_numpy_array(arr2)
        ans2 = F.dot(var1, var2)
        ans2.forward()
        print(ans2.shape)
        #(5, 8)

        out1 = nn.NdArray((5, 8))
        out1.cast(np.float32)
        F.dot(nd1, nd2, out1)
        print(out1.shape)
        #(5, 8)

        out2 = nn.Variable((5, 8))
        out2.data.cast(np.float32)
        F.dot(var1, var2, out2)
        out2.forward()
        print(out2.shape)
        #(5, 8)

        # N-D matrix * M-D matrix (M>=2)
        arr1 = np.arange(5*6*7*8).reshape(5, 6, 7, 8)
        arr2 = np.arange(2*3*8*6).reshape(2, 3, 8, 6)
        nd1 = nn.NdArray.from_numpy_array(arr1)
        nd2 = nn.NdArray.from_numpy_array(arr2)
        ans1 = F.dot(nd1, nd2)
        print(ans1.shape)
        #(5, 6, 7, 2, 3, 6)

        var1 = nn.Variable.from_numpy_array(arr1)
        var2 = nn.Variable.from_numpy_array(arr2)
        ans2 = F.dot(var1, var2)
        ans2.forward()
        print(ans2.shape)
        #(5, 6, 7, 2, 3, 6)

        out1 = nn.NdArray((5, 6, 7, 2, 3, 6))
        out1.cast(np.float32)
        F.dot(nd1, nd2, out1)
        print(out1.shape)
        #(5, 6, 7, 2, 3, 6)

        out2 = nn.Variable((5, 6, 7, 2, 3, 6))
        out2.data.cast(np.float32)
        F.dot(var1, var2, out2)
        out2.forward()
        print(out2.shape)
        #(5, 6, 7, 2, 3, 6)
    '''
    import nnabla as nn
    import nnabla.functions as F

    def _chk(x, mark=0):
        if isinstance(x, nn.NdArray):
            return x.data, 1
        elif isinstance(x, nn.Variable):
            return x.d, 1
        else:
            return x, mark

    m, mark1 = _chk(a)
    n, mark2 = _chk(b)

    if mark1 and mark2:
        if a.ndim == 1 and b.ndim == 1:
            result = F.sum(a * b)
        elif a.ndim == 2 and b.ndim == 2:
            result = F.affine(a, b)
        elif a.ndim == 0 or b.ndim == 0:
            if a.ndim == 0:
                result = F.mul_scalar(b, m)
                if isinstance(a, nn.NdArray) and isinstance(b, nn.Variable):
                    result.forward()
                    result = result.data
            else:
                result = F.mul_scalar(a, n)
                if isinstance(a, nn.Variable) and isinstance(b, nn.NdArray):
                    result.forward()
                    result = result.data
        elif b.ndim == 1:
            h = F.affine(a, F.reshape(b, (-1, 1)), base_axis=a.ndim - 1)
            result = F.reshape(h, h.shape[:-1])
        elif b.ndim >= 2:
            index = [*range(0, b.ndim)]
            index.insert(0, index.pop(b.ndim - 2))
            b = F.transpose(b, index)
            h = F.affine(a, b, base_axis=a.ndim - 1)
            result = h
    else:
        result = np.dot(a, b)

    if out is not None:
        out_, _ = _chk(out)
        result_, _ = _chk(result)
        if type(out) == type(
                result
        ) and out_.shape == result_.shape and out_.dtype == result_.dtype:
            if isinstance(out, nn.NdArray):
                out.cast(result.data.dtype)[...] = result.data
            elif isinstance(out, nn.Variable):
                out.rewire_on(result)
            else:
                out = result
        else:
            raise ValueError(
                f"Output argument must have the same shape, type and dtype as the result that would be "
                f"returned for F.dot(a,b).")
    else:
        return result
コード例 #10
0
def train():
    rng = np.random.RandomState(803)

    conf = get_config()

    comm = init_nnabla(conf)

    # create data iterator
    if conf.dataset == "cityscapes":
        data_list = get_cityscape_datalist(conf.cityscapes,
                                           save_file=comm.rank == 0)
        n_class = conf.cityscapes.n_label_ids
        use_inst = True

        data_iter = create_cityscapes_iterator(conf.batch_size,
                                               data_list,
                                               comm=comm,
                                               image_shape=conf.image_shape,
                                               rng=rng,
                                               flip=conf.use_flip)

    elif conf.dataset == "ade20k":
        data_list = get_ade20k_datalist(conf.ade20k, save_file=comm.rank == 0)
        n_class = conf.ade20k.n_label_ids + 1  # class id + unknown
        use_inst = False

        load_shape = tuple(
            x + 30
            for x in conf.image_shape) if conf.use_crop else conf.image_shape
        data_iter = create_ade20k_iterator(conf.batch_size,
                                           data_list,
                                           comm=comm,
                                           load_shape=load_shape,
                                           crop_shape=conf.image_shape,
                                           rng=rng,
                                           flip=conf.use_flip)

    else:
        raise NotImplementedError(
            "Currently dataset {} is not supported.".format(conf.dataset))

    real = nn.Variable(shape=(conf.batch_size, 3) + conf.image_shape)
    obj_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape)

    if use_inst:
        ist_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape)
        obj_onehot, bm = encode_inputs(ist_mask, obj_mask, n_ids=n_class)
        mask = F.concatenate(obj_onehot, bm, axis=1)
    else:
        om = obj_mask
        if len(om.shape) == 3:
            om = F.reshape(om, om.shape + (1, ))
        obj_onehot = F.one_hot(om, shape=(n_class, ))
        mask = F.transpose(obj_onehot, (0, 3, 1, 2))

    # generator
    generator = SpadeGenerator(conf.g_ndf, image_shape=conf.image_shape)
    z = F.randn(shape=(conf.batch_size, conf.z_dim))
    fake = generator(z, mask)

    # unlinking
    ul_mask, ul_fake = get_unlinked_all(mask, fake)

    # discriminator
    discriminator = PatchGAN(n_scales=conf.d_n_scales)
    d_input_real = F.concatenate(real, ul_mask, axis=1)
    d_input_fake = F.concatenate(ul_fake, ul_mask, axis=1)
    d_real_out, d_real_feats = discriminator(d_input_real)
    d_fake_out, d_fake_feats = discriminator(d_input_fake)

    g_gan, g_feat, d_real, d_fake = discriminator.get_loss(
        d_real_out,
        d_real_feats,
        d_fake_out,
        d_fake_feats,
        use_fm=conf.use_fm,
        fm_lambda=conf.lambda_fm,
        gan_loss_type=conf.gan_loss_type)

    def _rescale(x):
        return rescale_values(x,
                              input_min=-1,
                              input_max=1,
                              output_min=0,
                              output_max=255)

    g_vgg = vgg16_perceptual_loss(_rescale(ul_fake),
                                  _rescale(real)) * conf.lambda_vgg

    set_persistent_all(fake, mask, g_gan, g_feat, d_real, d_fake, g_vgg)

    # loss
    g_loss = g_gan + g_feat + g_vgg
    d_loss = (d_real + d_fake) / 2

    # load params
    if conf.load_params is not None:
        print("load parameters from {}".format(conf.load_params))
        nn.load_parameters(conf.load_params)

    # Setup Solvers
    g_solver = S.Adam(beta1=0.)
    g_solver.set_parameters(get_params_startswith("spade_generator"))

    d_solver = S.Adam(beta1=0.)
    d_solver.set_parameters(get_params_startswith("discriminator"))

    # lr scheduler
    g_lrs = LinearDecayScheduler(start_lr=conf.g_lr,
                                 end_lr=0.,
                                 start_iter=100,
                                 end_iter=200)
    d_lrs = LinearDecayScheduler(start_lr=conf.d_lr,
                                 end_lr=0.,
                                 start_iter=100,
                                 end_iter=200)

    ipe = get_iteration_per_epoch(data_iter._size,
                                  conf.batch_size,
                                  round="ceil")

    if not conf.show_interval:
        conf.show_interval = ipe
    if not conf.save_interval:
        conf.save_interval = ipe
    if not conf.niter:
        conf.niter = 200 * ipe

    # Setup Reporter
    losses = {
        "g_gan": g_gan,
        "g_feat": g_feat,
        "g_vgg": g_vgg,
        "d_real": d_real,
        "d_fake": d_fake
    }
    reporter = Reporter(comm,
                        losses,
                        conf.save_path,
                        nimage_per_epoch=min(conf.batch_size, 5),
                        show_interval=10)
    progress_iterator = trange(conf.niter, disable=comm.rank > 0)
    reporter.start(progress_iterator)

    colorizer = Colorize(n_class)

    # output all config and dump to file
    if comm.rank == 0:
        conf.dump_to_stdout()
        write_yaml(os.path.join(conf.save_path, "config.yaml"), conf)

    epoch = 0
    for itr in progress_iterator:
        if itr % ipe == 0:
            g_lr = g_lrs(epoch)
            d_lr = d_lrs(epoch)
            g_solver.set_learning_rate(g_lr)
            d_solver.set_learning_rate(d_lr)
            if comm.rank == 0:
                print(
                    "\n[epoch {}] update lr to ... g_lr: {}, d_lr: {}".format(
                        epoch, g_lr, d_lr))

            epoch += 1

        if conf.dataset == "cityscapes":
            im, ist, obj = data_iter.next()
            ist_mask.d = ist
        elif conf.dataset == "ade20k":
            im, obj = data_iter.next()
        else:
            raise NotImplemented()

        real.d = im
        obj_mask.d = obj

        # text embedding and create fake
        fake.forward()

        # update discriminator
        d_solver.zero_grad()
        d_loss.forward()
        d_loss.backward(clear_buffer=True)
        comm.all_reduced_solver_update(d_solver)

        # update generator
        ul_fake.grad.zero()
        g_solver.zero_grad()
        g_loss.forward()
        g_loss.backward(clear_buffer=True)

        # backward generator
        fake.backward(grad=None, clear_buffer=True)
        comm.all_reduced_solver_update(g_solver)

        # report iteration progress
        reporter()

        # report epoch progress
        show_epoch = itr // conf.show_interval
        if (itr % conf.show_interval) == 0:
            show_images = {
                "RealImages": real.data.get_data("r").transpose((0, 2, 3, 1)),
                "ObjectMask": colorizer(obj).astype(np.uint8),
                "GeneratedImage": fake.data.get_data("r").transpose(
                    (0, 2, 3, 1))
            }

            reporter.step(show_epoch, show_images)

        if (itr % conf.save_interval) == 0 and comm.rank == 0:
            nn.save_parameters(
                os.path.join(conf.save_path,
                             'param_{:03d}.h5'.format(show_epoch)))

    if comm.rank == 0:
        nn.save_parameters(os.path.join(conf.save_path, 'param_final.h5'))
コード例 #11
0
ファイル: spectral_norm.py プロジェクト: Pandinosaurus/nnabla
def _spectral_norm_backward(dw_sn, w, u, dim=0, itr=1, eps=1e-12):
    # Forward recomputation

    w_shape = w.shape
    # Transpose if the output dimension is not the most-left dimension.
    if dim != 0:
        dims_transpose = [dim] + [i for i in range(len(w_shape)) if i != dim]
        w = F.transpose(w, dims_transpose)
        w_shape = w.shape
    d0 = w.shape[0]            # Out
    d1 = np.prod(w.shape[1:])  # In
    w = F.reshape(w, [d0, d1])
    u = F.reshape(u, [1, d0])
    # Power method
    for _ in range(itr):
        # v
        v = F.affine(u, w)
        v = v / ((F.sum(v ** 2.0, keepdims=True) + eps) ** 0.5)
        v = F.reshape(v, [d1, 1])
        # u
        u = F.affine(w, v)
        u = u / ((F.sum(u ** 2.0, keepdims=True) + eps) ** 0.5)
        u = F.reshape(u, [1, d0])
    # No grad
    u = no_grad(u)
    v = no_grad(v)
    # Spectral normalization
    wv = F.affine(w, v)
    sigma = F.affine(u, wv)
    w_sn = w / sigma
    # The fowllowing process is not necessary for gradient calculation
    # w_sn = F.reshape(w_sn, w_shape)
    # # Transpose again if the output dimension is not the most-left dimension.
    # if dim != 0:
    #     dims_transpose = [i for i in range(1, dim + 1)] \
    #                      + [0] + [i for i in range(dim + 1, len(w_shape))]
    #     w_sn = F.transpose(w_sn, dims_transpose)

    # Backward

    # Backward for post-transpose
    if dim != 0:
        dims_transpose = [dim] + [i for i in range(len(w_shape)) if i != dim]
        dw_sn = F.transpose(dw_sn, dims_transpose)
    dw_sn = dw_sn.reshape(w.shape)

    # Backward for spectral norm
    # Sum for broadcast backward
    S = sum_for_arithmetics(dw_sn * w_sn, sigma)
    # Add batch axis
    S = S.reshape((1,) + S.shape)
    u = u.reshape((1,) + u.shape)
    v = v.reshape((1,) + v.shape)
    m = F.batch_matmul(u, S, transpose_a=True)
    m = F.batch_matmul(m, v, transpose_b=True)
    # Remove batch axis
    m = m.reshape((m.shape[1], m.shape[2]))
    dw = (dw_sn - m) / sigma

    # Backward for pre-transpose
    dw = dw.reshape(w_shape)
    if dim != 0:
        dims_transpose = [i for i in range(1, dim + 1)] \
                         + [0] + [i for i in range(dim + 1, len(w_shape))]
        dw = F.transpose(dw, dims_transpose)

    return dw, None
コード例 #12
0
ファイル: ops.py プロジェクト: sony/nnabla-examples
def cbhg(inputs, K, projections, depth, is_training, scope):
    r"""Returns the 1D Convolution Bank Highwaynet bindirectional
    GRU (CBHG) module.

    Args:
        inputs (nn.Variable): NNabla Variable of shape (B, C, T).
        K (int): Maximum kernel size.
        projections (list of int): A list of channels.
        depth (int): A depth. This should be an even number.
        is_training (bool): Whether training mode is activated.
        scope (str): The parameter scope name.

    Returns:
        nn.Variable: Output variable.
    """

    with nn.parameter_scope(scope):
        # Convolution bank: concatenate channels from all 1D convolutions
        with nn.parameter_scope('conv_bank'):
            conv = partial(conv1d,
                           inputs,
                           channels=128,
                           activation=F.relu,
                           is_training=is_training)
            conv_outputs = [
                conv(kernel_size=k, scope=f'conv1d_{k}')
                for k in range(1, K + 1)
            ]
            conv_outputs = F.concatenate(*conv_outputs, axis=1)

        # make sure a valid input to max_pooling
        x = F.pad(conv_outputs, (0, ) * 5 + (1, ), mode='constant')

        # Maxpooling: reshape is needed because nnabla does support 1D pooling
        maxpool_output = F.max_pooling(x.reshape(x.shape + (1, )),
                                       kernel=(2, 1),
                                       stride=(1,
                                               1)).reshape(conv_outputs.shape)

        # Two projection layers:
        proj1_output = conv1d(maxpool_output,
                              kernel_size=3,
                              channels=projections[0],
                              activation=F.relu,
                              is_training=is_training,
                              scope='proj_1')
        proj2_output = conv1d(proj1_output,
                              kernel_size=3,
                              channels=projections[1],
                              activation=None,
                              is_training=is_training,
                              scope='proj_2')

        # Residual connection:
        highway_input = proj2_output + inputs

        assert depth % 2 == 0
        half_depth = depth // 2

        with nn.parameter_scope('highwaynet'):
            # transposing to shape (B, T, C)
            highway_input = F.transpose(highway_input, (0, 2, 1))

            # Handle dimensionality mismatch:
            if highway_input.shape[2] != half_depth:
                highway_input = PF.affine(highway_input,
                                          half_depth,
                                          base_axis=2,
                                          name='adjust_dim')

            # 4-layer HighwayNet:
            for i in range(4):
                highway_input = highwaynet(highway_input,
                                           half_depth,
                                           scope=f'highway_{i+1}')

        with nn.parameter_scope('rnn_net'):
            # transpose to shape (T, B, C)
            rnn_input = F.transpose(highway_input, (1, 0, 2))
            outputs, _ = PF.gru(rnn_input,
                                F.constant(shape=(2, 2, rnn_input.shape[1],
                                                  half_depth)),
                                training=is_training,
                                bidirectional=True)  # (T, B, C)

    return outputs
コード例 #13
0
    def __call__(self, x, test=False):

        fft_real, fft_imag = STFT(x, n_fft=self.n_fft, n_hop=self.n_hop)
        x_theta = F.atan2(fft_imag, fft_real)

        x = Spectrogram(fft_real,
                        fft_imag,
                        power=self.power,
                        mono=(self.nb_channels == 1))

        nb_frames, nb_samples, nb_channels, nb_bins = x.shape

        mix_spec = F.identity(x)
        x = x[..., :self.nb_bins]

        # clone
        x_bass = F.identity(x)
        x_drums = F.identity(x)
        x_vocals = F.identity(x)
        x_other = F.identity(x)

        # shift and scale input to mean=0 std=1 (across all bins)
        x_bass += F.reshape(self.input_mean_bass,
                            shape=(1, 1, 1, self.nb_bins),
                            inplace=False)
        x_drums += F.reshape(self.input_mean_drums,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)
        x_vocals += F.reshape(self.input_mean_vocals,
                              shape=(1, 1, 1, self.nb_bins),
                              inplace=False)
        x_other += F.reshape(self.input_mean_other,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)

        x_bass *= F.reshape(self.input_scale_bass,
                            shape=(1, 1, 1, self.nb_bins),
                            inplace=False)
        x_drums *= F.reshape(self.input_scale_drums,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)
        x_vocals *= F.reshape(self.input_scale_vocals,
                              shape=(1, 1, 1, self.nb_bins),
                              inplace=False)
        x_other *= F.reshape(self.input_scale_other,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)

        # encode and normalize every instance in a batch
        x_bass = self.fc_bn(x_bass,
                            self.hidden_size,
                            "fc1_bass",
                            test,
                            activation='tanh')
        x_drums = self.fc_bn(x_drums,
                             self.hidden_size,
                             "fc1_drums",
                             test,
                             activation='tanh')
        x_vocals = self.fc_bn(x_vocals,
                              self.hidden_size,
                              "fc1_vocals",
                              test,
                              activation='tanh')
        x_other = self.fc_bn(x_other,
                             self.hidden_size,
                             "fc1_other",
                             test,
                             activation='tanh')

        # Average the sources
        cross_1 = (x_bass + x_drums + x_vocals + x_other) / 4.0

        # apply 3-layers of stacked LSTM
        lstm_out_bass = self.lstm(cross_1, nb_samples, "lstm_bass", test)
        lstm_out_drums = self.lstm(cross_1, nb_samples, "lstm_drums", test)
        lstm_out_vocals = self.lstm(cross_1, nb_samples, "lstm_vocals", test)
        lstm_out_other = self.lstm(cross_1, nb_samples, "lstm_other", test)

        # lstm skip connection
        x_bass = F.concatenate(x_bass, lstm_out_bass)
        x_drums = F.concatenate(x_drums, lstm_out_drums)
        x_vocals = F.concatenate(x_vocals, lstm_out_vocals)
        x_other = F.concatenate(x_other, lstm_out_other)

        cross_2 = (x_bass + x_drums + x_vocals + x_other) / 4.0

        # first dense stage + batch norm
        x_bass = self.fc_bn(cross_2,
                            self.hidden_size,
                            "fc2_bass",
                            test,
                            activation='relu')
        x_drums = self.fc_bn(cross_2,
                             self.hidden_size,
                             "fc2_drums",
                             test,
                             activation='relu')
        x_vocals = self.fc_bn(cross_2,
                              self.hidden_size,
                              "fc2_vocals",
                              test,
                              activation='relu')
        x_other = self.fc_bn(cross_2,
                             self.hidden_size,
                             "fc2_other",
                             test,
                             activation='relu')

        # second dense stage + batch norm
        x_bass = self.fc_bn(x_bass, nb_channels * nb_bins, "fc3_bass", test)
        x_drums = self.fc_bn(x_drums, nb_channels * nb_bins, "fc3_drums", test)
        x_vocals = self.fc_bn(x_vocals, nb_channels * nb_bins, "fc3_vocals",
                              test)
        x_other = self.fc_bn(x_other, nb_channels * nb_bins, "fc3_other", test)

        # reshape back to original dim
        x_bass = F.reshape(
            x_bass, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_drums = F.reshape(
            x_drums, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_vocals = F.reshape(
            x_vocals,
            (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_other = F.reshape(
            x_other, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))

        # apply output scaling
        x_bass *= F.reshape(self.output_scale_bass,
                            shape=(1, 1, 1, self.nb_output_bins),
                            inplace=False)
        x_drums *= F.reshape(self.output_scale_drums,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)
        x_vocals *= F.reshape(self.output_scale_vocals,
                              shape=(1, 1, 1, self.nb_output_bins),
                              inplace=False)
        x_other *= F.reshape(self.output_scale_other,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)

        x_bass += F.reshape(self.output_mean_bass,
                            shape=(1, 1, 1, self.nb_output_bins),
                            inplace=False)
        x_drums += F.reshape(self.output_mean_drums,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)
        x_vocals += F.reshape(self.output_mean_vocals,
                              shape=(1, 1, 1, self.nb_output_bins),
                              inplace=False)
        x_other += F.reshape(self.output_mean_other,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)

        # since our output is non-negative, we can apply RELU
        mask_bass = F.relu(x_bass)
        mask_drums = F.relu(x_drums)
        mask_vocals = F.relu(x_vocals)
        mask_other = F.relu(x_other)

        # (Frames, Bsize, Channels, Fbins)
        x_bass = mask_bass * mix_spec
        x_drums = mask_drums * mix_spec
        x_vocals = mask_vocals * mix_spec
        x_other = mask_other * mix_spec

        if not self.is_predict:
            tmp = F.stack(*[x_bass, x_drums, x_vocals, x_other], axis=0)
            # (4(sources), Frames, Bsize(16), 2(channels), Fbins) ==> (4, Bsize, Channels, Fbins, Frames)
            tmp = F.transpose(tmp, (0, 2, 3, 4, 1))
            pred_r, pred_i = [], []
            for i in range(tmp.shape[0]):
                pred_r.append(tmp[i] * F.cos(x_theta))
                pred_i.append(tmp[i] * F.sin(x_theta))
            pred_r = F.stack(*pred_r, axis=0)
            pred_i = F.stack(*pred_i, axis=0)
            pred_r = F.reshape(pred_r,
                               (4 * nb_samples * nb_channels, 2049, nb_frames))
            pred_i = F.reshape(pred_i,
                               (4 * nb_samples * nb_channels, 2049, nb_frames))
            pred = istft(pred_r,
                         pred_i,
                         self.n_fft,
                         self.n_hop,
                         self.n_fft,
                         window_type='hanning',
                         center=True)
            pred = F.reshape(pred, (4, nb_samples, nb_channels, -1))

        else:
            pred = None

        return mix_spec, F.concatenate(mask_bass,
                                       mask_drums,
                                       mask_vocals,
                                       mask_other,
                                       axis=2), pred
コード例 #14
0
    def infer(self, mels, sigma=0.9):
        r"""Returns the generated audio.

        Args:
            mels (nn.Variable): Inputs containing mel-spectrograms of shape(B, n_mels, Ty).
                Defaults to None. If None, the mel spectrograms are infferred from data.
            sigma (float, optional): Sigma used to infer audio. Defaults to 0.9.

        Returns:
            nn.Variable: A synthetic audio.
        """

        hp = self.hparams
        with nn.parameter_scope('', self.parameter_scope):

            #  Upsample spectrogram to size of audio
            with nn.parameter_scope('upsample'):
                with nn.parameter_scope('deconv'):
                    mels = PF.deconvolution(mels,
                                            hp.n_mels,
                                            kernel=(1024, ),
                                            stride=(256, ))
                # cutout conv artifacts
                mels = mels[..., :-(1024 - 256)]  # kernel - stride

                # transforming to correct shape
                mels = F.reshape(mels,
                                 mels.shape[:2] + (-1, hp.n_samples_per_group))
                mels = F.transpose(mels, (0, 2, 1, 3))
                mels = F.reshape(mels, mels.shape[:2] + (-1, ))
                # (B, n_mels * n_groups, L/n_groups)
                mels = F.transpose(mels, (0, 2, 1))

            wave = F.randn(shape=(mels.shape[0], self.n_remaining_channels,
                                  mels.shape[2])) * sigma

            for k in reversed(range(hp.n_flows)):
                n_half = wave.shape[1] // 2
                audio_0 = wave[:, :n_half, :]
                audio_1 = wave[:, n_half:, :]

                with nn.parameter_scope(f'wn_{k}'):
                    output = getattr(self, f'WN_{k}')(audio_0, mels)
                    s = output[:, n_half:, :]
                    b = output[:, :n_half, :]
                    audio_1 = (audio_1 - b) / F.exp(s)
                    wave = F.concatenate(audio_0, audio_1, axis=1)

                wave = invertible_conv(wave,
                                       reverse=True,
                                       rng=self.rng,
                                       scope=f'inv_{k}')

                if k % hp.n_early_every == 0 and k > 0:
                    z = F.randn(shape=(mels.shape[0], hp.n_early_size,
                                       mels.shape[2]))
                    wave = F.concatenate(sigma * z, wave, axis=1)

            wave = F.transpose(wave, (0, 2, 1))
            wave = F.reshape(wave, (wave.shape[0], -1))

        return wave
コード例 #15
0
def embed_inverse(embed, n_inputs, n_features, base_axis=1):
    W = nn.parameter.get_parameter_or_create("embed/W", [n_inputs, n_features])
    W = F.transpose(W, axes=[1, 0])
    return F.affine(embed, W, base_axis=base_axis)
コード例 #16
0
def global_attention(query: nn.Variable,
                     memory: nn.Variable,
                     mask: Optional[nn.Variable] = None,
                     score: str = 'general',
                     fix_parameters: bool = False) -> nn.Variable:
    '''
    A global attention layer
    Args:
        query (nnabla.Variable): A shape of [batch_size, length_query, embedding_size]
        memory (nnabla.Variable): A shape of [batch_size, length_memory, embedding_size]
        mask (nnabla.Variable): A shape of [batch_size, length_query, length_memory]
        score (str): A kind of score functions for calculating attention weights.
                     'general', 'dot' or 'concat'.
                     see [Effective Approaches to Attention-based Neural Machine Translation]
                         (http://aclweb.org/anthology/D15-1166)
        fix_parameters (bool): Fix parameters (Set need_grad=False).
    Returns:
        nn.Variable: A shape [batch_size, length_query, embedding_size].
    '''
    batch_size, length_query, embedding_size = query.shape
    _, length_memory, _ = memory.shape
    q = query
    # -> (batch_size, length_query, embedding_size)
    k = memory
    # -> (batch_size, length_memory, embedding_size)
    v = memory
    # -> (batch_size, length_memory, embedding_size)
    if score == 'dot':
        logit = F.batch_matmul(q, k, transpose_b=True)
        # -> (batch_size, length_query, length_memory)
    elif score == 'general':
        with nn.parameter_scope('Wa'):
            wa = time_distributed(PF.affine)(q,
                                             embedding_size,
                                             with_bias=False)
            # -> (batch_size, length_query, embeding_size)
        logit = F.batch_matmul(wa, k, transpose_b=True)
        # -> (batch_size, length_query, length_memory)
    elif score == 'concat':
        a_list = []
        for _q in F.split(q, axis=1):
            _q = F.reshape(_q, shape=(batch_size, 1, embedding_size))
            _q = F.broadcast(_q,
                             shape=(batch_size, length_memory, embedding_size))
            concat = F.concatenate(_q, k, axis=2)
            # -> (batch_size, length_memory, embedding_size * 2)
            with nn.parameter_scope('Wa'):
                a = time_distributed(PF.affine)(concat, 1, with_bias=False)
                # -> (batch_size, length_memory, 1)
                a_list.append(a)

        logit = F.concatenate(*a_list, axis=2)
        # -> (batch_size, length_memory, length_query)
        logit = F.transpose(logit, axes=(0, 2, 1))
        # -> (batch_size, length_query, length_memory)

    # get_attention_logit_mask -> (batch_size, length_query, length_memory)である
    if mask is not None:
        logit += get_attention_logit_mask(mask)

    attention_weights = F.softmax(logit, axis=2)
    # -> (batch_size, length_query, length_memory)

    attention_output = F.batch_matmul(attention_weights, v)
    # -> (batch_size, length_query, embedding_size)

    return attention_output
コード例 #17
0
ファイル: model.py プロジェクト: sony/nnabla-examples
    def call(self, memory, inputs=None):
        r"""Return mel-spectrogram and attention matrix.

        Args:
            memory(nn.Variable): A 3D tensor of shape (T, B, C).
            inputs(nn.Variable, optional): A 3D tensor with shape of
                [B, T/r, n_mels(*r)]. Shifted log melspectrogram of sound files.
                Defaults to None.

        Returns:
            nn.Variable: The synthetic mel-spectrograms of shape
                (B, Ty/r, r*n_mels).
            nn.Variable: The attention matrix of shape
                (B, Tx, Ty).

        References:
            - https://github.com/Kyubyong/tacotron/
        """
        hp = self._hparams
        bz, mel_shape = hp.batch_size, hp.n_mels * hp.r
        encoder_dim = hp.encoder_embedding_dim

        # initialize input tensor
        input = F.constant(shape=(bz, 1, mel_shape))

        # initialize hidden states
        context = F.constant(shape=(bz, 1, hp.attention_dim))
        hidden = F.constant(shape=(1, 1, bz, encoder_dim))
        h_gru = [
            F.constant(shape=(1, 1, bz, encoder_dim)),
            F.constant(shape=(1, 1, bz, encoder_dim))
        ]

        outputs, attends = [], []

        for i in range(hp.n_frames):
            if i > 0:
                input = (outputs[-1] if inputs is None else inputs[:,
                                                                   i - 1:i, :])

            # feed a prenet to the input
            input = prenet(input,
                           layer_sizes=hp.prenet_channels,
                           is_training=self.training,
                           scope='prenet_decoder')  # (bz, 1, C)

            # concat the input and context vector
            input = F.concatenate(input, context)  # (bz, 1, 384)

            with nn.parameter_scope('rnn_attention'):
                # calculate the output
                output, hidden = PF.gru(
                    input.reshape((1, bz, -1)),
                    hidden,
                    training=self.training,
                    bidirectional=False)  # (1, bz, 256), (1, 1, bz, 256)

            # compute the context and attention vectors
            context, attend = Bahdanau_attention(
                F.transpose(hidden[0], (1, 0, 2)),
                memory,
                out_features=hp.attention_dim,
                scope='Bahdanau_attention')  # (bz, 1, 256), (bz, 1, T)

            with nn.parameter_scope('rnn_decoder'):
                # concat RNN output and attention context vector
                with nn.parameter_scope('project_to_decoder'):
                    output = F.concatenate(output,
                                           F.transpose(context, (1, 0, 2)),
                                           axis=2)
                    output = PF.affine(output, encoder_dim,
                                       base_axis=2)  # (1, bz, 256)

                # decoder RNN with residual connection
                for j in range(2):
                    with nn.parameter_scope(f'gru_resisidual_{j}'):
                        out, h_gru[j] = PF.gru(output,
                                               h_gru[j],
                                               training=self.training,
                                               bidirectional=False)
                        output += out  # (1, bz, 256)

                # projector to mels
                with nn.parameter_scope('project_to_mel'):
                    output = F.transpose(output, (1, 0, 2))
                    # (bz, 1, n_mels*r)
                    output = PF.affine(output, mel_shape, base_axis=2)

            outputs.append(output)
            attends.append(attend)

        outputs = F.concatenate(*outputs, axis=1)  # (B, T2, C2)
        attends = F.concatenate(*attends, axis=1)  # (B, T2, T1)

        return outputs, attends
コード例 #18
0
ファイル: train.py プロジェクト: wpfhtl/nnabla-examples
def train():
    args = get_args()

    # Set context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in {}:{}".format(args.context, args.type_config))
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    data_iterator = data_iterator_librispeech(args.batch_size, args.data_dir)
    _data_source = data_iterator._data_source  # dirty hack...

    # model
    x = nn.Variable(
        shape=(args.batch_size, data_config.duration, 1))  # (B, T, 1)
    onehot = F.one_hot(x, shape=(data_config.q_bit_len, ))  # (B, T, C)
    wavenet_input = F.transpose(onehot, (0, 2, 1))  # (B, C, T)

    # speaker embedding
    if args.use_speaker_id:
        s_id = nn.Variable(shape=(args.batch_size, 1))
        with nn.parameter_scope("speaker_embedding"):
            s_emb = PF.embed(s_id, n_inputs=_data_source.n_speaker,
                             n_features=WavenetConfig.speaker_dims)
            s_emb = F.transpose(s_emb, (0, 2, 1))
    else:
        s_emb = None

    net = WaveNet()
    wavenet_output = net(wavenet_input, s_emb)

    pred = F.transpose(wavenet_output, (0, 2, 1))

    # (B, T, 1)
    t = nn.Variable(shape=(args.batch_size, data_config.duration, 1))

    loss = F.mean(F.softmax_cross_entropy(pred, t))

    # for generation
    prob = F.softmax(pred)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)

    # setup save env.
    audio_save_path = os.path.join(os.path.abspath(
        args.model_save_path), "audio_results")
    if audio_save_path and not os.path.exists(audio_save_path):
        os.makedirs(audio_save_path)

    # Training loop.
    for i in range(args.max_iter):
        # todo: validation

        x.d, _speaker, t.d = data_iterator.next()
        if args.use_speaker_id:
            s_id.d = _speaker.reshape(-1, 1)

        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.update()

        loss.data.cast(np.float32, ctx)
        monitor_loss.add(i, loss.d.copy())

        if i % args.model_save_interval == 0:
            prob.forward()
            audios = mu_law_decode(
                np.argmax(prob.d, axis=-1), quantize=data_config.q_bit_len)  # (B, T)
            save_audio(audios, i, audio_save_path)
コード例 #19
0
ファイル: nonlocal_net.py プロジェクト: sony/nnabla-examples
def nonlocal_net(B_lab_map,
                 relu_layers,
                 temperature=0.001 * 5,
                 detach_flag=False,
                 WTA_scale_weight=1,
                 feature_noise=0):

    batch_size = B_lab_map.shape[0]
    channel = B_lab_map.shape[1]
    image_height = B_lab_map.shape[2]
    image_width = B_lab_map.shape[3]
    feature_height = int(image_height / 4)
    feature_width = int(image_width / 4)

    feature_channel = 64
    in_channels = feature_channel * 4
    inter_channels = 256

    # layer2_1
    A_feature2_1 = layer2_1(relu_layers[0])
    B_feature2_1 = layer2_1(relu_layers[4])
    # layer3_1
    A_feature3_1 = layer3_1(relu_layers[1])
    B_feature3_1 = layer3_1(relu_layers[5])
    # layer4_1
    A_feature4_1 = layer4_1(relu_layers[2])
    B_feature4_1 = layer4_1(relu_layers[6])
    # layer5_1
    A_feature5_1 = layer5_1(relu_layers[3])
    B_feature5_1 = layer5_1(relu_layers[7])

    if A_feature5_1.shape[2] != A_feature2_1.shape[2] or A_feature5_1.shape[3] != A_feature2_1.shape[3]:
        A_feature5_1 = pad_replicate(A_feature5_1)
        B_feature5_1 = pad_replicate(B_feature5_1)
    A_features = layer(
        F.concatenate(
            A_feature2_1,
            A_feature3_1,
            A_feature4_1,
            A_feature5_1,
            axis=1),
        feature_channel * 4)
    B_features = layer(
        F.concatenate(
            B_feature2_1,
            B_feature3_1,
            B_feature4_1,
            B_feature5_1,
            axis=1),
        feature_channel * 4)
    # pairwise cosine similarity
    theta = PF.convolution(
        A_features, inter_channels, kernel=(
            1, 1), stride=(
            1, 1), name='theta')
    theta_re = F.reshape(theta, (batch_size, inter_channels, -1))
    theta_re = theta_re - F.mean(theta_re, axis=2,
                                 keepdims=True)  # center the feature
    theta_norm = F.norm(
        theta_re,
        p=2,
        axis=1,
        keepdims=True) + sys.float_info.epsilon
    theta_re = F.div2(theta_re, theta_norm)
    # 2*(feature_height*feature_width)*256
    theta_permute = F.transpose(theta_re, (0, 2, 1))
    phi = PF.convolution(
        B_features, inter_channels, kernel=(
            1, 1), stride=(
            1, 1), name='phi')
    phi_re = F.reshape(phi, (batch_size, inter_channels, -1))
    # center the feature
    phi_re = phi_re - F.mean(phi_re, axis=2, keepdims=True)
    phi_norm = F.norm(phi_re, p=2, axis=1, keepdims=True) + \
        sys.float_info.epsilon
    phi_re = F.div2(phi_re, phi_norm)
    # 2*(feature_height*feature_width)*(feature_height*feature_width)
    f = F.batch_matmul(theta_permute, phi_re)

    f_shape = f.shape
    f = F.reshape(f, (1,) + f_shape)
    f_similarity = F.reshape(f, (1,) + f_shape)
    similarity_map = F.max(f_similarity, axis=3, keepdims=True)
    similarity_map = F.reshape(
        similarity_map, (batch_size, 1, feature_height, feature_width))

    # f can be negative
    # if WTA_scale_weight == 1:
    f_WTA = f

    f_WTA = f_WTA / temperature

    f_WTA_sp = f_WTA.shape
    f_WTA = F.reshape(f_WTA, (f_WTA_sp[1], f_WTA_sp[2], f_WTA_sp[3]))
    # 2*1936*1936; softmax along the horizontal line (dim=-1)
    f_div_C = F.softmax(f_WTA, axis=2)

    # downsample the reference color
    B_lab = F.average_pooling(B_lab_map, (4, 4))
    B_lab = F.reshape(B_lab, (batch_size, channel, -1))
    B_lab = F.transpose(B_lab, (0, 2, 1))  # 2*1936*channel

    # multiply the corr map with color
    y = F.batch_matmul(f_div_C, B_lab)  # 2*1936*channel
    y = F.transpose(y, (0, 2, 1))
    y = F.reshape(
        y,
        (batch_size,
         channel,
         feature_height,
         feature_width))  # 2*3*44*44
    y = F.interpolate(y, scale=(4, 4), mode='nearest', align_corners=False)
    similarity_map = F.interpolate(
        similarity_map, scale=(
            4, 4), mode='nearest', align_corners=False)

    return y, similarity_map
コード例 #20
0
ファイル: nerf.py プロジェクト: sony/nnabla-examples
def train_nerf(config, comm, model, dataset='blender'):

    use_transient = False
    use_embedding = False

    if model == 'wild':
        use_transient = True
        use_embedding = True
    elif model == 'uncertainty':
        use_transient = True
    elif model == 'appearance':
        use_embedding = True

    save_results_dir = config.log.save_results_dir
    os.makedirs(save_results_dir, exist_ok=True)

    train_loss_dict = {
        'train_coarse_loss': 0.0,
        'train_fine_loss': 0.0,
        'train_total_loss': 0.0,
    }

    test_metric_dict = {'test_loss': 0.0, 'test_psnr': 0.0}

    monitor_manager = MonitorManager(train_loss_dict, test_metric_dict,
                                     save_results_dir)

    if dataset != 'phototourism':
        images, poses, _, hwf, i_test, i_train, near_plane, far_plane = get_data(
            config)
        height, width, focal_length = hwf
    else:
        di = get_photo_tourism_dataiterator(config, 'train', comm)
        val_di = get_photo_tourism_dataiterator(config, 'val', comm)

    if model != 'vanilla':
        if dataset != 'phototourism':
            config.train.n_vocab = max(np.max(i_train), np.max(i_test)) + 1
        print(
            f'Setting Vocabulary size of embedding as {config.train.n_vocab}')

    if dataset != 'phototourism':
        if model in ['vanilla']:
            if comm is not None:
                # uncomment the following line to test on fewer images
                i_test = i_test[3 * comm.rank:3 * (comm.rank + 1)]
                pass
            else:
                # uncomment the following line to test on fewer images
                i_test = i_test[:3]
                pass
        else:
            # i_test = i_train[0:5]
            i_test = [i * (comm.rank + 1) for i in range(5)]
    else:
        i_test = [1]

    encode_position_function = get_encoding_function(
        config.train.num_encodings_position, True, True)
    if config.train.use_view_directions:
        encode_direction_function = get_encoding_function(
            config.train.num_encodings_direction, True, True)
    else:
        encode_direction_function = None

    lr = config.solver.lr
    num_decay_steps = config.solver.lr_decay_step * 1000
    lr_decay_factor = config.solver.lr_decay_factor
    solver = S.Adam(alpha=lr)

    load_solver_state = False
    if config.checkpoint.param_path is not None:
        nn.load_parameters(config.checkpoint.param_path)
        load_solver_state = True

    if comm is not None:
        num_decay_steps /= comm.n_procs
        comm_size = comm.n_procs
    else:
        comm_size = 1
    pbar = trange(config.train.num_iterations // comm_size,
                  disable=(comm is not None and comm.rank > 0))

    for i in pbar:

        if dataset != 'phototourism':

            idx = np.random.choice(i_train)
            image = nn.Variable.from_numpy_array(images[idx][None, :, :, :3])
            pose = nn.Variable.from_numpy_array(poses[idx])

            ray_directions, ray_origins = get_ray_bundle(
                height, width, focal_length, pose)

            grid = get_direction_grid(width,
                                      height,
                                      focal_length,
                                      return_ij_2d_grid=True)
            grid = F.reshape(grid, (-1, 2))

            select_inds = np.random.choice(grid.shape[0],
                                           size=[config.train.num_rand_points],
                                           replace=False)
            select_inds = F.gather_nd(grid, select_inds[None, :])
            select_inds = F.transpose(select_inds, (1, 0))

            embed_inp = nn.Variable.from_numpy_array(
                np.full((config.train.chunksize_fine, ), idx, dtype=int))

            ray_origins = F.gather_nd(ray_origins, select_inds)
            ray_directions = F.gather_nd(ray_directions, select_inds)

            image = F.gather_nd(image[0], select_inds)

        else:
            rays, embed_inp, image = di.next()
            ray_origins = nn.Variable.from_numpy_array(rays[:, :3])
            ray_directions = nn.Variable.from_numpy_array(rays[:, 3:6])
            near_plane = nn.Variable.from_numpy_array(rays[:, 6])
            far_plane = nn.Variable.from_numpy_array(rays[:, 7])

            embed_inp = nn.Variable.from_numpy_array(embed_inp)
            image = nn.Variable.from_numpy_array(image)

            hwf = None

        app_emb, trans_emb = None, None
        if use_embedding:
            with nn.parameter_scope('embedding_a'):
                app_emb = PF.embed(embed_inp, config.train.n_vocab,
                                   config.train.n_app)

        if use_transient:
            with nn.parameter_scope('embedding_t'):
                trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                     config.train.n_trans)

        if use_transient:
            rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass(
                ray_directions,
                ray_origins,
                near_plane,
                far_plane,
                app_emb,
                trans_emb,
                encode_position_function,
                encode_direction_function,
                config,
                use_transient,
                hwf=hwf,
                image=image)
            course_loss = 0.5 * F.mean(F.squared_error(rgb_map_course, image))
            fine_loss = 0.5 * F.mean(
                F.squared_error(rgb_map_fine, image) /
                F.reshape(F.pow_scalar(beta, 2), beta.shape + (1, )))
            beta_reg_loss = 3 + F.mean(F.log(beta))
            sigma_trans_reg_loss = 0.01 * F.mean(transient_sigma)
            loss = course_loss + fine_loss + beta_reg_loss + sigma_trans_reg_loss
        else:
            rgb_map_course, _, _, _, rgb_map_fine, _, _, _ = forward_pass(
                ray_directions,
                ray_origins,
                near_plane,
                far_plane,
                app_emb,
                trans_emb,
                encode_position_function,
                encode_direction_function,
                config,
                use_transient,
                hwf=hwf)
            course_loss = F.mean(F.squared_error(rgb_map_course, image))
            fine_loss = F.mean(F.squared_error(rgb_map_fine, image))
            loss = course_loss + fine_loss

        pbar.set_description(
            f'Total: {np.around(loss.d, 4)}, Course: {np.around(course_loss.d, 4)}, Fine: {np.around(fine_loss.d, 4)}'
        )

        solver.set_parameters(nn.get_parameters(),
                              reset=False,
                              retain_state=True)
        if load_solver_state:
            solver.load_states(config['checkpoint']['solver_path'])
            load_solver_state = False

        solver.zero_grad()

        loss.backward(clear_buffer=True)

        # Exponential LR decay
        if dataset != 'phototourism':
            lr_factor = (lr_decay_factor**((i) / num_decay_steps))
            solver.set_learning_rate(lr * lr_factor)
        else:
            if i % num_decay_steps == 0 and i != 0:
                solver.set_learning_rate(lr * lr_decay_factor)

        if comm is not None:
            params = [x.grad for x in nn.get_parameters().values()]
            comm.all_reduce(params, division=False, inplace=True)
        solver.update()

        if ((i % config.train.save_interval == 0
             or i == config.train.num_iterations - 1)
                and i != 0) and (comm is not None and comm.rank == 0):
            nn.save_parameters(os.path.join(save_results_dir, f'iter_{i}.h5'))
            solver.save_states(
                os.path.join(save_results_dir, f'solver_iter_{i}.h5'))

        if (i % config.train.test_interval == 0
                or i == config.train.num_iterations - 1) and i != 0:
            avg_psnr, avg_mse = 0.0, 0.0
            for i_t in trange(len(i_test)):

                if dataset != 'phototourism':
                    idx_test = i_test[i_t]
                    image = nn.NdArray.from_numpy_array(
                        images[idx_test][None, :, :, :3])
                    pose = nn.NdArray.from_numpy_array(poses[idx_test])

                    ray_directions, ray_origins = get_ray_bundle(
                        height, width, focal_length, pose)

                    ray_directions = F.reshape(ray_directions, (-1, 3))
                    ray_origins = F.reshape(ray_origins, (-1, 3))

                    embed_inp = nn.NdArray.from_numpy_array(
                        np.full((config.train.chunksize_fine, ),
                                idx_test,
                                dtype=int))

                else:
                    rays, embed_inp, image = val_di.next()
                    ray_origins = nn.NdArray.from_numpy_array(rays[0, :, :3])
                    ray_directions = nn.NdArray.from_numpy_array(rays[0, :,
                                                                      3:6])
                    near_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 6])
                    far_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 7])

                    embed_inp = nn.NdArray.from_numpy_array(
                        embed_inp[0, :config.train.chunksize_fine])
                    image = nn.NdArray.from_numpy_array(image[0].transpose(
                        1, 2, 0))
                    image = F.reshape(image, (1, ) + image.shape)
                    idx_test = 1

                app_emb, trans_emb = None, None
                if use_embedding:
                    with nn.parameter_scope('embedding_a'):
                        app_emb = PF.embed(embed_inp, config.train.n_vocab,
                                           config.train.n_app)

                if use_transient:
                    with nn.parameter_scope('embedding_t'):
                        trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                             config.train.n_trans)

                num_ray_batches = ray_directions.shape[
                    0] // config.train.ray_batch_size + 1

                if use_transient:
                    rgb_map_fine_list, static_rgb_map_fine_list, transient_rgb_map_fine_list = [], [], []
                else:
                    rgb_map_fine_list, depth_map_fine_list = [], []

                for r_idx in trange(num_ray_batches):
                    if r_idx != num_ray_batches - 1:
                        ray_d, ray_o = ray_directions[
                            r_idx * config.train.ray_batch_size:(r_idx + 1) *
                            config.train.ray_batch_size], ray_origins[
                                r_idx *
                                config.train.ray_batch_size:(r_idx + 1) *
                                config.train.ray_batch_size]

                        if dataset == 'phototourism':
                            near_plane = near_plane_[
                                r_idx *
                                config.train.ray_batch_size:(r_idx + 1) *
                                config.train.ray_batch_size]
                            far_plane = far_plane_[r_idx *
                                                   config.train.ray_batch_size:
                                                   (r_idx + 1) *
                                                   config.train.ray_batch_size]

                    else:
                        if ray_directions.shape[0] - (
                                num_ray_batches -
                                1) * config.train.ray_batch_size == 0:
                            break
                        ray_d, ray_o = ray_directions[
                            r_idx *
                            config.train.ray_batch_size:, :], ray_origins[
                                r_idx * config.train.ray_batch_size:, :]
                        if dataset == 'phototourism':
                            near_plane = near_plane_[r_idx * config.train.
                                                     ray_batch_size:]
                            far_plane = far_plane_[r_idx * config.train.
                                                   ray_batch_size:]

                    if use_transient:
                        rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass(
                            ray_d,
                            ray_o,
                            near_plane,
                            far_plane,
                            app_emb,
                            trans_emb,
                            encode_position_function,
                            encode_direction_function,
                            config,
                            use_transient,
                            hwf=hwf)

                        rgb_map_fine_list.append(rgb_map_fine)
                        static_rgb_map_fine_list.append(static_rgb_map_fine)
                        transient_rgb_map_fine_list.append(
                            transient_rgb_map_fine)

                    else:
                        _, _, _, _, rgb_map_fine, depth_map_fine, _, _ = \
                            forward_pass(ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb,
                                         encode_position_function, encode_direction_function, config, use_transient, hwf=hwf)

                        rgb_map_fine_list.append(rgb_map_fine)
                        depth_map_fine_list.append(depth_map_fine)

                if use_transient:
                    rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
                    static_rgb_map_fine = F.concatenate(
                        *static_rgb_map_fine_list, axis=0)
                    transient_rgb_map_fine = F.concatenate(
                        *transient_rgb_map_fine_list, axis=0)

                    rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape)
                    static_rgb_map_fine = F.reshape(static_rgb_map_fine,
                                                    image[0].shape)
                    transient_rgb_map_fine = F.reshape(transient_rgb_map_fine,
                                                       image[0].shape)
                    static_trans_img_to_save = np.concatenate(
                        (static_rgb_map_fine.data,
                         np.ones((image[0].shape[0], 5, 3)),
                         transient_rgb_map_fine.data),
                        axis=1)
                    img_to_save = np.concatenate(
                        (image[0].data, np.ones(
                            (image[0].shape[0], 5, 3)), rgb_map_fine.data),
                        axis=1)
                else:

                    rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
                    depth_map_fine = F.concatenate(*depth_map_fine_list,
                                                   axis=0)

                    rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape)
                    depth_map_fine = F.reshape(depth_map_fine,
                                               image[0].shape[:-1])
                    img_to_save = np.concatenate(
                        (image[0].data, np.ones(
                            (image[0].shape[0], 5, 3)), rgb_map_fine.data),
                        axis=1)

                filename = os.path.join(save_results_dir,
                                        f'{i}_{idx_test}.png')
                try:
                    imsave(filename,
                           np.clip(img_to_save, 0, 1),
                           channel_first=False)
                    print(f'Saved generation at {filename}')
                    if use_transient:
                        filename_static_trans = os.path.join(
                            save_results_dir, f's_t_{i}_{idx_test}.png')
                        imsave(filename_static_trans,
                               np.clip(static_trans_img_to_save, 0, 1),
                               channel_first=False)

                    else:
                        filename_dm = os.path.join(save_results_dir,
                                                   f'dm_{i}_{idx_test}.png')
                        depth_map_fine = (depth_map_fine.data -
                                          depth_map_fine.data.min()) / (
                                              depth_map_fine.data.max() -
                                              depth_map_fine.data.min())
                        imsave(filename_dm,
                               depth_map_fine[:, :, None],
                               channel_first=False)
                        plt.imshow(depth_map_fine.data)
                        plt.savefig(filename_dm)
                        plt.close()
                except:
                    pass

                avg_mse += F.mean(F.squared_error(rgb_map_fine, image[0])).data
                avg_psnr += (-10. * np.log10(
                    F.mean(F.squared_error(rgb_map_fine, image[0])).data))

            test_metric_dict['test_loss'] = avg_mse / len(i_test)
            test_metric_dict['test_psnr'] = avg_psnr / len(i_test)
            monitor_manager.add(i, test_metric_dict)
            print(
                f'Saved generations after {i} training iterations! Average PSNR: {avg_psnr/len(i_test)}. Average MSE: {avg_mse/len(i_test)}'
            )
コード例 #21
0
ファイル: train.py プロジェクト: satopirka/nlp-nnabla
def loss_function(u, v, negative_samples):
    return F.sum(-F.log(
        F.exp(-distance(u, v)) / sum([
            F.exp(-distance(u, x)) for x in F.split(negative_samples, axis=2)
        ])))


u = nn.Variable((batch_size, ))
v = nn.Variable((batch_size, ))
negative_samples = nn.Variable((batch_size, negative_sample_size))

_u = PF.embed(u, vocab_size, embedding_size)
_v = PF.embed(v, vocab_size, embedding_size)
_neg = PF.embed(negative_samples, vocab_size, embedding_size)
_neg = F.transpose(_neg, axes=(0, 2, 1))

loss = loss_function(_u, _v, _neg)

nn.get_parameters()["embed/W"].d = I.UniformInitializer(
    [-0.01, 0.01])(shape=(vocab_size, embedding_size))

solver = RiemannianSgd(lr=0.1)
solver.set_parameters(nn.get_parameters())

trainer = Trainer(inputs=[u, v, negative_samples], loss=loss, solver=solver)
trainer.run(train_data_iter, None, epochs=max_epoch)

line_points = [['mustang.n.01', 'odd-toed_ungulate.n.01'],
               ['elk.n.01', 'even-toed_ungulate.n.01'],
               ['even-toed_ungulate.n.01', 'ungulate.n.01'],
コード例 #22
0
ファイル: model.py プロジェクト: sony/nnabla-examples
    def call(self, memory, decoder_inputs=None):
        r"""Return mel-spectrograms, gate outputs and an attention matrix.

        Args:
            memory (nn.Variable): A 3D tensor of shape (B, T, C).
            decoder_inputs (nn.Variable, optional): A 3D tensor with shape of (B, T/r, r*n_mels).
                Shifted log melspectrogram of sound files. Defaults to None.

        Returns:
            nn.Variable: The synthetic mel-spectrograms of shape (B, Ty/r, r*n_mels).
            nn.Variable: The gate outputs of shape (B, Ty).
            nn.Variable: The attention matrix of shape (B, Tx, Ty).
        """
        hp = self._hparams
        mel_shape = hp.n_mels * hp.r

        # initialize decoder states
        decoder_input = F.constant(shape=(hp.batch_size, 1, mel_shape))
        decoder_hidden = F.constant(shape=(1, 1, hp.batch_size,
                                           hp.decoder_rnn_dim))
        decoder_cell = F.constant(shape=(1, 1, hp.batch_size,
                                         hp.decoder_rnn_dim))

        # initialize attention states
        attention_weights = F.constant(shape=(hp.batch_size, 1, hp.text_len))
        attention_weights_cum = F.constant(shape=(hp.batch_size, 1,
                                                  hp.text_len))
        attention_context = F.constant(shape=(hp.batch_size, 1,
                                              hp.encoder_embedding_dim))
        attention_hidden = F.constant(shape=(1, 1, hp.batch_size,
                                             hp.attention_rnn_dim))
        attention_cell = F.constant(shape=(1, 1, hp.batch_size,
                                           hp.attention_rnn_dim))

        # store outputs
        mel_outputs, gate_outputs, alignments = [], [], []

        for i in range(hp.mel_len):
            if i > 0:
                decoder_input = (mel_outputs[-1] if decoder_inputs is None else
                                 decoder_inputs[:, i - 1:i, :])
                if decoder_inputs is None:
                    decoder_input = decoder_input[None, ...]
            # decoder of shape (B, 1, prenet_channels=256)
            decoder_input = prenet(decoder_input,
                                   hp.prenet_channels,
                                   is_training=self.training,
                                   scope='prenet')

            with nn.parameter_scope('attention_rnn'):
                # cell_input of shape (B, 1, prenet_channels[-1] + C=768)
                cell_input = F.concatenate(decoder_input,
                                           attention_context,
                                           axis=2)
                _, attention_hidden, attention_cell = PF.lstm(
                    F.transpose(cell_input, (1, 0, 2)),
                    attention_hidden,
                    attention_cell,
                    training=self.training,
                    name='lstm_attention'
                )  # (1, 1, B, attention_hidden), (1, 1, B, attention_hidden)
                if self.training:
                    attention_hidden = F.dropout(attention_hidden,
                                                 hp.p_attention_dropout)

            with nn.parameter_scope('location_attention'):
                attention_weights_cat = F.concatenate(attention_weights,
                                                      attention_weights_cum,
                                                      axis=1)
                attention_context, attention_weights = location_sensitive_attention(
                    F.transpose(attention_hidden[0], (1, 0, 2)),
                    memory,
                    attention_weights_cat,
                    attention_location_kernel_size=hp.
                    attention_location_kernel_size,
                    attention_n_filters=hp.attention_location_n_filters,
                    attention_dim=hp.attention_dim,
                    is_training=self.training,
                    scope='ls_attention')
                attention_weights_cum += attention_weights
                alignments.append(attention_weights)

            with nn.parameter_scope('decoder_rnn'):
                # (1, B, attention_rnn_dim + encoder_embedding_dim)
                inp_decoder = F.concatenate(attention_hidden[0],
                                            F.transpose(
                                                attention_context, (1, 0, 2)),
                                            axis=2)
                _, decoder_hidden, decoder_cell = PF.lstm(
                    inp_decoder,
                    decoder_hidden,
                    decoder_cell,
                    training=self.training,
                    name='lstm_decoder')
                if self.training:
                    decoder_hidden = F.dropout(decoder_hidden,
                                               hp.p_decoder_dropout)

            with nn.parameter_scope('projection'):
                proj_input = F.concatenate(
                    decoder_hidden[0, 0],
                    F.reshape(attention_context, (hp.batch_size, -1),
                              inplace=False),
                    axis=1)  # (B, decoder_rnn_dim + encoder_embedding_dim)
                decoder_output = affine_norm(proj_input,
                                             mel_shape,
                                             base_axis=1,
                                             with_bias=True,
                                             w_init_gain='affine',
                                             scope='affine')
                mel_outputs.append(decoder_output)

            with nn.parameter_scope('gate_prediction'):
                gate_prediction = affine_norm(proj_input,
                                              1,
                                              base_axis=1,
                                              with_bias=True,
                                              w_init_gain='sigmoid',
                                              scope='affine')
                gate_outputs.append(gate_prediction)

        # (B, T2, n_mels*r)
        mel_outputs = F.stack(*mel_outputs, axis=1)
        gate_outputs = F.concatenate(*gate_outputs, axis=1)  # (B, T2)
        alignments = F.concatenate(*alignments, axis=1)  # (B, T1, T2)

        return mel_outputs, gate_outputs, alignments
コード例 #23
0
valid_data_iter = data_iterator_simple(load_valid_func,
                                       len(x_valid),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)

char_embedding_dim = 16
lstm_size = 650
filters = [50, 150, 200, 200]
filster_sizes = [1, 3, 5, 7]
# filters = [50, 100, 150, 200, 200, 200, 200]
# filster_sizes = [1, 2, 3, 4, 5, 6, 7]

x = nn.Variable((batch_size, sentence_length, word_length))
h = PF.embed(x, char_vocab_size, char_embedding_dim)
h = F.transpose(h, (0, 3, 1, 2))
output = []
for f, f_size in zip(filters, filster_sizes):
    _h = PF.convolution(h,
                        f,
                        kernel=(1, f_size),
                        pad=(0, f_size // 2),
                        name='conv_{}'.format(f_size))
    _h = F.max_pooling(_h, kernel=(1, word_length))
    output.append(_h)
h = F.concatenate(*output, axis=1)
h = F.transpose(h, (0, 2, 1, 3))
h = F.reshape(h, (batch_size, sentence_length, sum(filters)))
# h = PF.batch_normalization(h, axes=[2])
h = TimeDistributed(Highway)(h, name='highway1')
h = TimeDistributed(Highway)(h, name='highway2')
コード例 #24
0
    def call(self, wave):
        hp = self.hparams
        batch_size = hp.batch_size

        # compute mel-spectrogram from waveform
        with nn.parameter_scope('stft'):
            mels = self.compute_mel(wave)

        #  Upsample spectrogram to the size of audio
        with nn.parameter_scope('upsample'):
            with nn.parameter_scope('deconv'):
                mels = PF.deconvolution(mels,
                                        hp.n_mels,
                                        kernel=(1024, ),
                                        stride=(256, ))

            # make sure mels having the same length as wave
            if mels.shape[2] > wave.shape[1]:
                mels = mels[..., :wave.shape[1]]  # (B, L, n_mels)

            # transforming to correct shape
            mels = F.reshape(mels,
                             mels.shape[:2] + (-1, hp.n_samples_per_group))
            mels = F.transpose(mels, (0, 2, 1, 3))
            mels = F.reshape(mels, mels.shape[:2] + (-1, ))
            # (B, n_mels * n_groups, L/n_groups)
            mels = F.transpose(mels, (0, 2, 1))

        # reshape audio
        wave = F.reshape(wave, (batch_size, -1, hp.n_samples_per_group))
        wave = F.transpose(wave, (0, 2, 1))  # (B, n_groups, L/n_groups)

        output_audio, log_s_list, log_det_W_list = [], [], []

        for k in range(hp.n_flows):
            if k % hp.n_early_every == 0 and k > 0:
                output_audio.append(wave[:, :hp.n_early_size, :])
                wave = wave[:, hp.n_early_size:, :]

            # apply invertible convolution
            wave, log_det_W = invertible_conv(wave,
                                              reverse=False,
                                              rng=self.rng,
                                              scope=f'inv_{k}')
            log_det_W_list.append(log_det_W)

            n_half = wave.shape[1] // 2
            audio_0 = wave[:, :n_half, :]
            audio_1 = wave[:, n_half:, :]

            with nn.parameter_scope(f'wn_{k}'):
                output = getattr(self, f'WN_{k}')(audio_0, mels)
                log_s = output[:, n_half:, :]  # (B, n_half, L/n_groups)
                b = output[:, :n_half, :]  # (B, n_half, L/n_groups)
                audio_1 = F.add2(F.exp(log_s) * audio_1, b, inplace=True)
                log_s_list.append(log_s)

            # (B, n_half*2, L/n_groups)
            wave = F.concatenate(audio_0, audio_1, axis=1)

        output_audio.append(wave)

        return F.concatenate(*output_audio, axis=1), log_s_list, log_det_W_list
コード例 #25
0
def sample_from_controller(args):
    """
        2-layer RNN(LSTM) based controller which outputs an architecture of CNN, 
        represented as a sequence of integers and its list.
        Given the number of layers, for each layer, 
        it executes 2 types of computation, one for sampling the operation at that layer,
        another for sampling the skip connection patterns.
    """

    entropys = nn.Variable([1, 1], need_grad=True)
    log_probs = nn.Variable([1, 1], need_grad=True)

    entropys.d = log_probs.d = 0.0  # initialize them all

    num_cells = args.num_cells
    num_nodes = args.num_nodes
    lstm_size = args.lstm_size
    state_size = args.state_size
    lstm_num_layers = args.lstm_layers
    temperature = args.temperature
    tanh_constant = args.tanh_constant
    op_tanh_reduce = args.op_tanh_reduce
    num_branch = args.num_ops

    both_archs = [list(), list()]
    initializer = I.UniformInitializer((-0.1, 0.1))

    prev_h = [
        nn.Variable([1, lstm_size], need_grad=True)
        for _ in range(lstm_num_layers)
    ]
    prev_c = [
        nn.Variable([1, lstm_size], need_grad=True)
        for _ in range(lstm_num_layers)
    ]

    for i in range(len(prev_h)):
        prev_h[i].d = 0  # initialize.
        prev_c[i].d = 0

    inputs = nn.Variable([1, lstm_size])
    inputs.d = np.random.normal(0, 0.5, [1, lstm_size])

    g_emb = nn.Variable([1, lstm_size])
    g_emb.d = np.random.normal(0, 0.5, [1, lstm_size])

    for ind in range(2):
        # first create conv cell and then reduc cell.
        idx_seq = list()
        ops_seq = list()
        for node_id in range(num_nodes):
            if node_id == 0:
                anchors = nn.parameter.get_parameter_or_create("anchors",
                                                               [2, lstm_size],
                                                               initializer,
                                                               need_grad=False)
                anchors_w_1 = nn.parameter.get_parameter_or_create(
                    "anchors_w_1", [2, lstm_size],
                    initializer,
                    need_grad=False)
            else:
                assert anchors.shape[0] == node_id + \
                    2, "Something wrong with anchors."
                assert anchors_w_1.shape[0] == node_id + \
                    2, "Something wrong with anchors_w_1."

            # for each node, get the index used as inputs
            for i in range(2):
                # One-step stacked LSTM.
                with nn.parameter_scope("controller_lstm"):
                    next_h, next_c = stack_lstm(inputs, prev_h, prev_c,
                                                state_size)
                prev_h, prev_c = next_h, next_c  # shape:(1, lstm_size)
                query = anchors_w_1

                with nn.parameter_scope("skip_affine_1"):
                    query = F.tanh(
                        F.add2(
                            query,
                            PF.affine(next_h[-1],
                                      lstm_size,
                                      w_init=initializer,
                                      with_bias=False)))
                    #            (node_id + 2, lstm_size)   +   (1, lstm_size)
                    # broadcast occurs here. resulting shape is; (node_id + 2, lstm_size)

                with nn.parameter_scope("skip_affine_2"):
                    # (node_id + 2, 1)
                    logit = PF.affine(query,
                                      1,
                                      w_init=initializer,
                                      with_bias=False)

                if temperature is not None:
                    logit = F.mul_scalar(logit, (1 / temperature))
                if tanh_constant is not None:
                    logit = F.mul_scalar(F.tanh(logit), tanh_constant)

                index = F.exp(logit)
                index = F.mul_scalar(index, (1 / index.d.sum()))

                # Sampling input indices from multinomial distribution.
                index = np.random.multinomial(
                    1,
                    np.reshape(index.d, (1, index.d.size))[0], 1)
                idx_seq.append(index.nonzero()[1])

                label = nn.Variable.from_numpy_array(
                    index.transpose())  # (node_id + 2, 1)
                log_prob = F.softmax_cross_entropy(logit, label)
                log_probs = F.add2(log_probs, F.sum(log_prob, keepdims=True))

                curr_ent = F.softmax_cross_entropy(logit, F.softmax(logit))
                entropy = F.sum(curr_ent, keepdims=True)
                entropys = F.add2(entropys, entropy)
                taking_ind = int(index.nonzero()[1][0])

                # (1, lstm_size)
                inputs = F.reshape(anchors[taking_ind], (1, anchors.shape[1]))

            # ops
            for j in range(2):
                with nn.parameter_scope("controller_lstm"):
                    next_h, next_c = stack_lstm(inputs, prev_h, prev_c,
                                                state_size)
                prev_h, prev_c = next_h, next_c  # shape:(1, lstm_size)

                # Compute for operation.
                with nn.parameter_scope("ops"):
                    logit = PF.affine(next_h[-1],
                                      num_branch,
                                      w_init=initializer,
                                      with_bias=False)

                # shape of logit : (1, num_branch)
                if temperature is not None:
                    logit = F.mul_scalar(logit, (1 / temperature))

                if tanh_constant is not None:
                    op_tanh = tanh_constant / op_tanh_reduce
                    logit = F.mul_scalar(F.tanh(logit), op_tanh)

                # normalizing logits.
                normed_logit = np.e**logit.d
                normed_logit = normed_logit / np.sum(normed_logit)

                # Sampling operation id from multinomial distribution.
                branch_id = np.random.multinomial(1, normed_logit[0],
                                                  1).nonzero()[1]
                branch_id = nn.Variable.from_numpy_array(branch_id)
                ops_seq.append(branch_id.d)

                # log policy for operation.
                log_prob = F.softmax_cross_entropy(
                    logit, F.reshape(branch_id, shape=(1, 1)))
                # accumulate log policy as log probs
                log_probs = F.add2(log_probs, log_prob)

                logit = F.transpose(logit, axes=(1, 0))
                curr_ent = F.softmax_cross_entropy(logit, F.softmax(logit))
                entropy = F.sum(curr_ent, keepdims=True)
                entropys = F.add2(entropys, entropy)

                w_emb = nn.parameter.get_parameter_or_create(
                    "w_emb", [num_branch, lstm_size],
                    initializer,
                    need_grad=False)
                # (1, lstm_size)
                inputs = F.reshape(w_emb[int(branch_id.d)],
                                   (1, w_emb.shape[1]))

                with nn.parameter_scope("controller_lstm"):
                    next_h, next_c = stack_lstm(inputs, prev_h, prev_c,
                                                lstm_size)
                prev_h, prev_c = next_h, next_c

                with nn.parameter_scope("skip_affine_3"):
                    adding_w_1 = PF.affine(next_h[-1],
                                           lstm_size,
                                           w_init=initializer,
                                           with_bias=False)

            # (node_id + 2 + 1, lstm_size)
            anchors = F.concatenate(anchors, next_h[-1], axis=0)
            # (node_id + 2 + 1, lstm_size)
            anchors_w_1 = F.concatenate(anchors_w_1, adding_w_1, axis=0)

        for idx, ops in zip(idx_seq, ops_seq):
            both_archs[ind].extend([int(idx), int(ops)])

    return both_archs, log_probs, entropys
コード例 #26
0
ファイル: model.py プロジェクト: sony/nnabla-examples
def multi_head_attention(query,
                         key,
                         value,
                         d_model,
                         num_heads,
                         need_weights=False,
                         attn_mask=None):
    in_proj_weight = nn.parameter.get_parameter_or_create(
        name="attn/in_proj_W", shape=(d_model * 3, d_model))
    in_proj_bias = nn.parameter.get_parameter_or_create(name="attn/in_proj_b",
                                                        shape=(d_model * 3, ))
    out_proj_weight = nn.parameter.get_parameter_or_create(
        name="attn/out_proj/W", shape=(d_model, d_model))
    out_proj_bias = nn.parameter.get_parameter_or_create(
        name="attn/out_proj/b", shape=(d_model, ))

    tgt_len, batch_size, embed_dim = query.shape
    src_len, _, _ = key.shape

    head_dim = d_model // num_heads
    assert head_dim * num_heads == embed_dim, 'embed_dim must be divisible by num_heads'

    if attn_mask is not None:
        if attn_mask.ndim == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(
                    f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
                )
            attn_mask = attn_mask.reshape((1, tgt_len, src_len))
        elif attn_mask.dim() == 3:
            correct_3d_size = (batch_size * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(
                    f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
                )
        else:
            raise RuntimeError(
                f"attn_mask's dimension {attn_mask.ndim} is not supported")

    q, k, v = _in_projection_packed(query, key, value, in_proj_weight,
                                    in_proj_bias)

    q = F.transpose(F.reshape(q, (tgt_len, batch_size * num_heads, head_dim)),
                    (1, 0, 2))  # q:(B*H, L_T, head_dim)
    k = F.transpose(F.reshape(k, (-1, batch_size * num_heads, head_dim)),
                    (1, 0, 2))  # k:(B*H, L_S, head_dim)
    v = F.transpose(F.reshape(v, (-1, batch_size * num_heads, head_dim)),
                    (1, 0, 2))  # v:(B*H, L_S, head_vdim)

    dropout_p = 0.0

    attn_output, attn_output_weights = _scaled_dot_product_attention(
        q, k, v, attn_mask, dropout_p)
    attn_output = F.reshape(
        F.transpose(attn_output, (1, 0, 2)),
        (tgt_len, batch_size, embed_dim))  # attn_output: (L_T, B, E_v)

    out_proj_weight = F.transpose(out_proj_weight, (1, 0))
    attn_output = F.affine(attn_output,
                           out_proj_weight,
                           out_proj_bias,
                           base_axis=2)

    return attn_output
コード例 #27
0
def predict_dense_motion(source_image,
                         kp_driving,
                         kp_source,
                         block_expansion,
                         num_blocks,
                         max_features,
                         num_kp,
                         num_channels,
                         estimate_occlusion_map=False,
                         scale_factor=1,
                         kp_variance=0.01,
                         test=False,
                         comm=None):
    if scale_factor != 1:
        source_image = anti_alias_interpolate(source_image, num_channels,
                                              scale_factor)

    bs, _, h, w = source_image.shape

    out_dict = dict()
    heatmap_representation = create_heatmap_representations(
        source_image, kp_driving, kp_source, kp_variance)
    sparse_motion = create_sparse_motions(source_image, kp_driving, kp_source,
                                          num_kp)
    deformed_source = create_deformed_source_image(source_image, sparse_motion,
                                                   num_kp)
    out_dict['sparse_deformed'] = deformed_source

    input = F.concatenate(heatmap_representation, deformed_source, axis=2)
    input = F.reshape(input, (bs, -1, h, w))

    with nn.parameter_scope("hourglass"):
        prediction = hourglass(input,
                               block_expansion=block_expansion,
                               num_blocks=num_blocks,
                               max_features=max_features,
                               test=test,
                               comm=comm)

    with nn.parameter_scope("mask"):
        inmaps, outmaps = prediction.shape[1], num_kp + 1
        k_w = I.calc_normal_std_he_forward(inmaps, outmaps,
                                           kernel=(7, 7)) / np.sqrt(2.)
        k_b = I.calc_normal_std_he_forward(inmaps, outmaps) / np.sqrt(2.)
        w_init = I.UniformInitializer((-k_w, k_w))
        b_init = I.UniformInitializer((-k_b, k_b))
        mask = PF.convolution(prediction,
                              outmaps=num_kp + 1,
                              kernel=(7, 7),
                              pad=(3, 3),
                              w_init=w_init,
                              b_init=b_init)

    mask = F.softmax(mask, axis=1)
    out_dict['mask'] = mask
    reshaped_mask = F.reshape(mask,
                              mask.shape[:2] + (1, ) + mask.shape[2:],
                              inplace=False)
    sparse_motion = F.transpose(sparse_motion, (0, 1, 4, 2, 3))
    deformation = F.sum(sparse_motion * reshaped_mask, axis=1)
    deformation = F.transpose(deformation, (0, 2, 3, 1))

    out_dict['deformation'] = deformation

    if estimate_occlusion_map:
        with nn.parameter_scope("occlusion_map"):
            occlusion_map = F.sigmoid(
                PF.convolution(prediction,
                               outmaps=1,
                               kernel=(7, 7),
                               pad=(3, 3),
                               w_init=w_init,
                               b_init=b_init))
        out_dict['occlusion_map'] = occlusion_map
    else:
        occlusion_map = None

    return out_dict
コード例 #28
0
ファイル: ops.py プロジェクト: shikisawamura/nnabla-examples
def location_sensitive_attention(query, values, attention_weights_cat,
                                 attention_location_kernel_size,
                                 attention_n_filters, attention_dim,
                                 is_training, scope):
    r"""Returns the location-sensitive attention mechanism.

    Args:
        query (nn.Variable): A query of size (B, 1, C1).
        values (nn.Variable): Values of size (B, T, C2).
        attention_weights_cat (nn.Variable): A variable of shape (B, 2, T).
        attention_dim (int): The projected dimensionality.
        scope (str): Parameter scope.

    Returns:
        nn.Variable: The context vector.
        nn.Variable: The attention weight vector.

    References:
        J. K. Chorowski, et al., "Attention-based models for speech recognition"
        in Advances in Neural Information Processing Systems, 2015, pp. 577-585.
    """

    with nn.parameter_scope(scope):
        x = affine_norm(query,
                        attention_dim,
                        base_axis=2,
                        with_bias=False,
                        w_init_gain='tanh',
                        scope='query')
        y = affine_norm(values,
                        attention_dim,
                        base_axis=2,
                        with_bias=False,
                        w_init_gain='tanh',
                        scope='memory')

        # apply a 1D-convolutional filter
        z = conv_norm(attention_weights_cat,
                      attention_n_filters,
                      kernel_size=attention_location_kernel_size,
                      stride=1,
                      padding=(attention_location_kernel_size - 1) // 2,
                      dilation=1,
                      bias=False,
                      w_init_gain='affine',
                      scope='conv_norm_lsa')
        z = F.transpose(z, (0, 2, 1))

        # location of shape (B, T, attention_dim)
        location = affine_norm(z,
                               attention_dim,
                               base_axis=2,
                               with_bias=False,
                               w_init_gain='tanh',
                               scope='location')

        # scores of shape (B, T, 1)
        scores = affine_norm(F.tanh(x + y + location),
                             1,
                             base_axis=2,
                             with_bias=False,
                             w_init_gain='affine',
                             scope='scores')

        # attention_weights of shape (B, 1, T)
        attention_weights = F.softmax(scores, axis=1).reshape(
            (query.shape[0], 1, -1))

        # context_vector shape after sum == (B, 1, C)
        context_vector = F.batch_matmul(attention_weights, values)

    return context_vector, attention_weights
コード例 #29
0
def feature_transform_net(
        feature: nn.Variable,
        train: bool,
        K: int = 64) -> Tuple[nn.Variable, Dict[str, nn.Variable]]:
    """T net, create transformation matrix

    Args:
        feature (nn.Variable): feature, shape(batch, number of points, 1, K)
        train (bool): training flag
        K (int): transformation matrix size, default is 64.

    Returns:
        Tuple[nn.Variable, Dict[str, nn.Variable]]: transformation matrix and internal variables
    """
    batch_size, num_points, *_ = feature.shape
    # B*H(=num_points)*W(=dim)*C(=K) to B*C(=K)*H(=num_points)*W(=dim)
    feature = F.transpose(feature, (0, 3, 1, 2))
    with nn.parameter_scope("conv1"):
        conv_h1 = PF.convolution(feature,
                                 64, (1, 1),
                                 stride=(1, 1),
                                 with_bias=False)
        conv_h1 = PF.batch_normalization(conv_h1, batch_stat=train)
        conv_h1 = F.relu(conv_h1)

    with nn.parameter_scope("conv2"):
        conv_h2 = PF.convolution(conv_h1,
                                 128, (1, 1),
                                 stride=(1, 1),
                                 with_bias=False)
        conv_h2 = PF.batch_normalization(conv_h2, batch_stat=train)
        conv_h2 = F.relu(conv_h2)

    with nn.parameter_scope("conv3"):
        conv_h3 = PF.convolution(conv_h2,
                                 1024, (1, 1),
                                 stride=(1, 1),
                                 with_bias=False)
        conv_h3 = PF.batch_normalization(conv_h3, batch_stat=train)
        conv_h3 = F.relu(conv_h3)

    pool_h = F.max_pooling(conv_h3, (num_points, 1))
    pool_h = F.reshape(pool_h, (batch_size, -1))

    with nn.parameter_scope("affine1"):
        affine_h1 = PF.affine(pool_h, 512, with_bias=False)
        affine_h1 = PF.batch_normalization(affine_h1, batch_stat=train)
        affine_h1 = F.relu(affine_h1)

    with nn.parameter_scope("affine2"):
        affine_h2 = PF.affine(affine_h1, 256, with_bias=False)
        affine_h2 = PF.batch_normalization(affine_h2, batch_stat=train)
        affine_h2 = F.relu(affine_h2)

    with nn.parameter_scope("affine3"):
        transform_h = PF.affine(affine_h2, K * K)
        eye_mat = nn.Variable.from_numpy_array(
            np.eye(K, dtype=np.float32).flatten())
        eye_mat = F.reshape(eye_mat, (1, K * K))
        transform_h = transform_h + eye_mat

    transform_h = F.reshape(transform_h, (batch_size, K, K))
    return transform_h, {
        "conv_h1": conv_h1,
        "conv_h2": conv_h2,
        "conv_h3": conv_h3,
        "pool_h": pool_h,
        "affine_h1": affine_h1,
        "affine_h2": affine_h2,
        "transform_h": transform_h,
    }
コード例 #30
0
ファイル: model.py プロジェクト: sony/ai-research-code
    def __call__(self, inp):
        '''
        Define D3Net
        '''

        valid_signal_idx = self.hparams['valid_signal_idx']
        band_split_idxs = self.hparams['band_split_idxs'] + \
            [self.hparams['valid_signal_idx']]

        inp = F.transpose(inp, (0, 2, 1, 3))

        scaled_inp = (inp - self.in_offset) / self.in_scale

        max_final_k = 0
        for k in self.hparams['dens_k']:
            if max_final_k < k[-1]:
                max_final_k = k[-1]
        i = 0
        band_idx_start = 0
        band_out = []
        band_dense_out = []

        # Low ~ middle bands
        for num_init_features, dens_k, num_layer_block, b_n_block, comp_rates in zip(
                self.hparams['num_init_features'], self.hparams['dens_k'],
                self.hparams['num_layer_blocks'], self.hparams['b_n_blocks'],
                self.hparams['comp_rates']):
            x_band = scaled_inp[:, :, :, band_idx_start:band_split_idxs[i]]
            x_band = self.conv2d(x_band,
                                 num_init_features,
                                 kernel_size=3,
                                 stride=1,
                                 name='features_init/%s' % i,
                                 pad=1)
            dense_band = self.md3_block_ds(x_band,
                                           num_init_features,
                                           dens_k,
                                           num_layer_block,
                                           b_n_block,
                                           comp_rates,
                                           name='dense_band/%s' % i)
            band_dense_out.append(dense_band[::-1])
            if max_final_k > self.hparams['dens_k'][i][-1]:
                h = self.batch_norm(band_dense_out[-1][0],
                                    name='match_fm_conv/%s/norm' % i)
                out = self.conv2d(h,
                                  max_final_k,
                                  kernel_size=1,
                                  stride=1,
                                  name='match_fm_conv/%s/conv' % i)
                band_out.append(out)
            else:
                band_out.append(band_dense_out[-1][0])
            band_idx_start = band_split_idxs[i]
            i += 1

        # full bands
        full = self.conv2d(scaled_inp[:, :, :, :valid_signal_idx],
                           self.hparams['f_num_init_features'],
                           kernel_size=3,
                           stride=1,
                           name='features_init_full',
                           pad=1)
        full = self.md3_block_ds(full,
                                 self.hparams['f_num_init_features'],
                                 self.hparams['f_dens_k'],
                                 self.hparams['f_num_layer_block'],
                                 self.hparams['f_n_blocks'],
                                 self.hparams['f_comp_rates'],
                                 name='dense_full')

        # concat low~middle bands and then with full bands
        concat_bands = F.concatenate(*band_out, axis=3)
        concat_full = F.concatenate(*[concat_bands, full[-1]], axis=1)

        # Final dense block
        final = self.dilated_dense_block_2(concat_full,
                                           self.hparams['ttl_dens_k'],
                                           self.hparams['ttl_num_layer_block'],
                                           scope_name='final_dense')

        # Define BNC_Gate : Batch-Normalization, Convolution and Sigmoid Gate
        with nn.parameter_scope('out_gate'):
            bn_out = self.batch_norm(final, name='bn')
            gate = F.sigmoid(
                self.conv2d(bn_out,
                            self.hparams['n_channels'],
                            kernel_size=1,
                            stride=1,
                            name='conv_gate/conv'))
            filt = self.conv2d(bn_out,
                               self.hparams['n_channels'],
                               kernel_size=1,
                               stride=1,
                               name='conv_filt/conv')

        out = gate * filt
        out = out * self.decode_scale + self.decode_bias
        out = F.relu(out)
        out = F.concatenate(*[out, inp[:, :, :, valid_signal_idx:]], axis=3)
        out = F.transpose(out, (0, 2, 1, 3))
        return out
コード例 #31
0
def pointnet_feature_extraction(
        point_cloud: nn.Variable,
        train: bool) -> Tuple[nn.Variable, Dict[str, nn.Variable]]:
    """pointnet feature extraction proposed by Charles R. Qi et. al.
        See: https://arxiv.org/pdf/1612.00593.pdf

    Args:
        point_cloud (nn.Variable): point cloud, shape(batch, number of points, 3)
        train (bool): training flag

    Returns:
        Tuple[nn.Variable, Dict[str, nn.Variable]]: pointnet feature and internal variables
    """
    batch_size, num_points, _ = point_cloud.shape

    with nn.parameter_scope("tnet1"):
        point_cloud_transformation_mat, _ = point_cloud_transform_net(
            point_cloud, train)

    transformed_point_cloud = F.batch_matmul(point_cloud,
                                             point_cloud_transformation_mat)
    # expand dim to B*C(=K)*H(=num_points)*W(=dim)
    input_point_cloud = F.reshape(transformed_point_cloud,
                                  (batch_size, 1, num_points, 3))

    with nn.parameter_scope("conv1"):
        conv_h1 = PF.convolution(input_point_cloud,
                                 64, (1, 3),
                                 stride=(1, 1),
                                 with_bias=False)
        conv_h1 = PF.batch_normalization(conv_h1, batch_stat=train)
        conv_h1 = F.relu(conv_h1)
        conv_h1 = F.transpose(conv_h1, (0, 2, 3, 1))

    with nn.parameter_scope("tnet2"):
        feature_transformation_mat, _ = feature_transform_net(conv_h1,
                                                              train,
                                                              K=64)

    transformed_feature = F.batch_matmul(conv_h1[:, :, 0, :],
                                         feature_transformation_mat)
    # expand dim to B*H(=num_points)*W(=dim)*C(=K)
    input_feature = F.reshape(transformed_feature,
                              (batch_size, num_points, 1, 64))
    # B*H(=num_points)*W(=dim)*C(=K) to B*C(=K)*H(=num_points)*W(=dim)
    input_feature = F.transpose(input_feature, (0, 3, 1, 2))

    with nn.parameter_scope("conv2"):
        conv_h2 = PF.convolution(input_feature,
                                 128, (1, 1),
                                 stride=(1, 1),
                                 with_bias=False)
        conv_h2 = PF.batch_normalization(conv_h2, batch_stat=train)
        conv_h2 = F.relu(conv_h2)

    with nn.parameter_scope("conv3"):
        conv_h3 = PF.convolution(conv_h2,
                                 1024, (1, 1),
                                 stride=(1, 1),
                                 with_bias=False)
        conv_h3 = PF.batch_normalization(conv_h3, batch_stat=train)
        conv_h3 = F.relu(conv_h3)

    pool_h = F.max_pooling(conv_h3, (num_points, 1))
    pool_h = F.reshape(pool_h, (batch_size, -1))

    return pool_h, {
        "transformed_point_cloud": transformed_point_cloud,
        "point_cloud_transformation_mat": point_cloud_transformation_mat,
        "conv_h1": conv_h1,
        "conv_h2": conv_h2,
        "feature_transformation_mat": feature_transformation_mat,
        "transformed_feature": transformed_feature,
        "conv_h3": conv_h3,
        "pool_h": pool_h,
    }