Exemplo n.º 1
0
def get_direction_grid(height, width, focal_length, return_ij_2d_grid=False):
    """Forms a mesh grid for a given height and width and assumes the camera position to be fixed at the center of the the grid 
    (with a sufficiently large enough offset in z direction). Based on the prefixed camera position, 
    computes ray direction for every point in the grid.

    Args:
        height (int): Height of the image/grid
        width (int): Width of the image/grid
        focal_length (float): Camera focal length (calibrated intrinsics)

    Returns:
        directions (nn.Variable or nn.NdArray): Shape is (height, width, 3) - direction of projected ray for every grid point.
    """
    x = F.arange(0, width)
    y = F.arange(0, height)

    xx, yy = F.meshgrid(x, y)

    if return_ij_2d_grid:
        return F.stack(*list(F.meshgrid(x, y, ij_indexing=True)), axis=2)

    directions = F.stack((xx - width * 0.5) / focal_length,
                         -(yy - height * 0.5) / focal_length,
                         F.constant(-1, xx.shape),
                         axis=2)
    return directions
Exemplo n.º 2
0
def create_fixed_length_rnn(xs0, h0, w0, w, b, num_layers, nonlinearity,
                            num_directions, with_bias):
    # xs : [T, B, I]
    # h0 : [L, D, B, H]
    # c0 : [L, D, B, H]
    # w0 : [D, H, I+H]
    # w : [L-1, D, H, D * H + H]
    # b : [L, D, H]

    batch_size = xs0.shape[1]
    hidden_size = h0.shape[3]

    if xs0.shape[0] == 1:
        xs = [xs0[0]]
    else:
        xs = F.split(xs0, axis=0)
    hn = []
    for i in range(num_layers):
        wi = w0
        if i > 0:
            wi = w[i - 1]
        # wi : [D, H, ?]
        # Forward direction
        hif = h0[i, 0]  # [B, H]
        wif = wi[0]
        bif = None
        if with_bias:
            bif = b[i, 0]
        hs = []
        for j, x in enumerate(xs):
            # x : [B, I]
            hif = rnn(x, hif, wif, bif, nonlinearity, with_bias)
            hs.append(hif)
        hn.append(hif)

        if num_directions == 1:
            xs = hs
            continue

        # Backward direction
        hib = h0[i, 1]  # [B, H]
        wib = wi[1]
        bib = None
        if with_bias:
            bib = b[i, 1]
        for k, x, in enumerate(reversed(xs)):
            j = len(xs) - 1 - k
            # x : [B, I]
            hib = rnn(x, hib, wib, bib, nonlinearity, with_bias)
            hs[j] = F.concatenate(hs[j], hib, axis=1)
        hn.append(hib)
        xs = hs

    ys = xs  # list of [B, HD]
    ys = F.stack(*ys, axis=0)  # [T, B, HD]
    hn = F.reshape(F.stack(*hn, axis=0),
                   (num_layers, num_directions, batch_size,
                    hidden_size))  # LD list of [B, H] --> [L, D, B, H]
    return ys, hn
Exemplo n.º 3
0
    def _build(self):
        # infer variable
        self.infer_obs_t = nn.Variable((1, 4, 84, 84))
        # inference output
        self.infer_qs_t = self.q_function(self.infer_obs_t, self.num_actions,
                                          self.num_heads, 'q_func')
        self.infer_all = F.sink(*self.infer_qs_t)

        # train variables
        self.obss_t = nn.Variable((self.batch_size, 4, 84, 84))
        self.acts_t = nn.Variable((self.batch_size, 1))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size, 4, 84, 84))
        self.ters_tp1 = nn.Variable((self.batch_size, 1))
        self.weights = nn.Variable((self.batch_size, self.num_heads))

        # training output
        qs_t = self.q_function(self.obss_t, self.num_actions, self.num_heads,
                               'q_func')
        qs_tp1 = q_function(self.obss_tp1, self.num_actions, self.num_heads,
                            'target')
        stacked_qs_t = F.transpose(F.stack(*qs_t), [1, 0, 2])
        stacked_qs_tp1 = F.transpose(F.stack(*qs_tp1), [1, 0, 2])

        # select one dimension
        a_one_hot = F.reshape(F.one_hot(self.acts_t, (self.num_actions, )),
                              (-1, 1, self.num_actions))
        # mask output
        q_t_selected = F.sum(stacked_qs_t * a_one_hot, axis=2)
        q_tp1_best = F.max(stacked_qs_tp1, axis=2)
        q_tp1_best.need_grad = False

        # reward clipping
        clipped_rews_tp1 = clip_by_value(self.rews_tp1, -1.0, 1.0)

        # loss calculation
        y = clipped_rews_tp1 + self.gamma * q_tp1_best * (1.0 - self.ters_tp1)
        td = F.huber_loss(q_t_selected, y)
        self.loss = F.mean(F.sum(td * self.weights, axis=1))

        # optimizer
        self.solver = S.RMSprop(self.lr, 0.95, 1e-2)

        # weights and biases
        with nn.parameter_scope('q_func'):
            self.params = nn.get_parameters()
            self.head_params = []
            for i in range(self.num_heads):
                with nn.parameter_scope('head%d' % i):
                    self.head_params.append(nn.get_parameters())
            with nn.parameter_scope('shared'):
                self.shared_params = nn.get_parameters()
        with nn.parameter_scope('target'):
            self.target_params = nn.get_parameters()

        # set q function parameters to solver
        self.solver.set_parameters(self.params)
Exemplo n.º 4
0
def bicubic_four(inputs, scope='bicubic_four'):
    """
    Equivalent to tf.image.resize_bicubic( inputs, (h*4, w*4) ) for a fix ratio of 4 FOR API <=1.13
    For API 2.0, tf.image.resize_bicubic will be different, old version is tf.compat.v1.image.resize_bicubic
    **Parallel Catmull-Rom Spline Interpolation Algorithm for Image Zooming Based on CUDA*[Wu et. al.]**
    """
    with nn.parameter_scope(scope):
        b, h, w, c = inputs.shape

        p_inputs = F.concatenate(inputs[:, :1, :, :], inputs,
                                 axis=1)  # pad top
        p_inputs = F.concatenate(p_inputs[:, :, :1, :], p_inputs,
                                 axis=2)  # pad left
        p_inputs = F.concatenate(p_inputs,
                                 p_inputs[:, -1:, :, :],
                                 p_inputs[:, -1:, :, :],
                                 axis=1)  # pad bottom
        p_inputs = F.concatenate(p_inputs,
                                 p_inputs[:, :, -1:, :],
                                 p_inputs[:, :, -1:, :],
                                 axis=2)  # pad right

        hi_res_bin = [p_inputs[:, bi:bi + h, :, :] for bi in range(4)]
        r = 0.75
        mat = np.float32([[0, 1, 0, 0], [-r, 0, r, 0],
                          [2 * r, r - 3, 3 - 2 * r, -r], [-r, 2 - r, r - 2,
                                                          r]])
        weights = [
            np.float32([1.0, t, t * t, t * t * t]).dot(mat)
            for t in [0.0, 0.25, 0.5, 0.75]
        ]

        hi_res_array = []  # [hi_res_bin[1]]
        for hi in range(4):
            cur_wei = weights[hi]
            cur_data = cur_wei[0] * hi_res_bin[0] + cur_wei[1] * hi_res_bin[1] + \
                cur_wei[2] * hi_res_bin[2] + cur_wei[3] * hi_res_bin[3]
            hi_res_array.append(cur_data)
        hi_res_y = F.stack(*hi_res_array, axis=2)  # shape (b,h,4,w,c)
        hi_res_y = F.reshape(hi_res_y, (b, h * 4, w + 3, c))
        hi_res_bin = [hi_res_y[:, :, bj:bj + w, :] for bj in range(4)]

        hi_res_array = []  # [hi_res_bin[1]]
        for hj in range(4):
            cur_wei = weights[hj]
            cur_data = cur_wei[0] * hi_res_bin[0] + cur_wei[1] * hi_res_bin[1] + \
                cur_wei[2] * hi_res_bin[2] + cur_wei[3] * hi_res_bin[3]
            hi_res_array.append(cur_data)
        hi_res = F.stack(*hi_res_array, axis=3)  # shape (b,h*4,w,4,c)
        hi_res = F.reshape(hi_res, (b, h * 4, w * 4, c))

    return hi_res
Exemplo n.º 5
0
def simple_rnn(inputs, units, return_sequences=False, fix_parameters=False):
    '''
    A vanilla recurrent neural network layer
    Args:
        inputs (nnabla.Variable): A shape of [B, SentenceLength, EmbeddingSize].
        units (int): Dimensionality of the output space.
        return_sequences (bool): Whether to return the last output. in the output sequence, or the full sequence.
        fix_parameters (bool): Fix parameters (Set need_grad=False).
    Returns:
        nn.Variable: A shape [B, SentenceLength, units].
        or
        nn.Variable: A shape [B, units]
    '''

    hs = []
    batch_size = inputs.shape[0]
    sentence_length = inputs.shape[1]
    h0 = nn.Variable.from_numpy_array(np.zeros((batch_size, units)))

    inputs = F.split(inputs, axis=1) # split in the direction of sequence

    h = h0
    for x in inputs:
        h = F.tanh(PF.affine(F.concatenate(x, h, axis=1), units, fix_parameters=fix_parameters))
        hs.append(h)

    if return_sequences:
        hs = F.stack(*hs, axis=1)
        return hs
    else:
        return hs[-1]
Exemplo n.º 6
0
def simple_rnn(inputs: nn.Variable, units: int, mask: Optional[nn.Variable] = None,
               return_sequences: bool = False, fix_parameters=False) -> nn.Variable:
    '''
    A vanilla recurrent neural network layer
    Args:
        inputs (nnabla.Variable): A shape of [batch_size, length, embedding_size].
        units (int): Dimensionality of the output space.
        mask (nnabla.Variable): A shape of [batch_size, length, 1].
        return_sequences (bool): Whether to return the last output. in the output sequence, or the full sequence.
        fix_parameters (bool): Fix parameters (Set need_grad=False).
    Returns:
        nn.Variable: A shape [batch_size, length, units]
        or
        nn.Variable: A shape [batch_size units].
    '''

    hs = []
    batch_size, length, embedding_size = inputs.shape
    h0 = F.constant(0, shape=(batch_size, units))

    h = h0

    if mask is None:
        mask = F.constant(1, shape=(batch_size, length, 1))

    for x, cond in zip(F.split(inputs, axis=1), F.split(mask, axis=1)):
        h_t = F.tanh(PF.affine(F.concatenate(x, h, axis=1), units, fix_parameters=fix_parameters))
        h = where(cond, h_t, h)
        hs.append(h)

    if return_sequences:
        hs = F.stack(*hs, axis=1)
        return hs
    else:
        return hs[-1]
Exemplo n.º 7
0
def lstm(inputs: nn.Variable, units: int, mask: Optional[nn.Variable] = None, initial_state: Tuple[nn.Variable, nn.Variable] = None,
         return_sequences: bool = False, return_state: bool = False, fix_parameters: bool = False) -> nn.Variable:
    '''
    A long short-term memory
    Args:
        inputs (nnabla.Variable): A shape of [batch_size, length, embedding_size].
        units (int): Dimensionality of the output space.
        mask (nnabla.Variable): A shape of [batch_size, length].
        initial_state ([nnabla.Variable, nnabla.Variable]): A tuple of an initial cell and an initial hidden state.
        return_sequences (bool): Whether to return the last output. in the output sequence, or the full sequence.
        return_state (bool): Whether to return the last state which is consist of the cell and the hidden state.
        fix_parameters (bool): Fix parameters (Set need_grad=False).
    Returns:
        nn.Variable: A shape [batch_size, length, units].
        or
        nn.Variable: A shape [batch_size units]
    '''
    
    batch_size, length, embedding_size = inputs.shape

    if initial_state is None:
        c0 = F.constant(0, shape=(batch_size, units))
        h0 = F.constant(0, shape=(batch_size, units))
    else:
        assert type(initial_state) is tuple or type(initial_state) is list, \
               'initial_state must be a typle or a list.'
        assert len(initial_state) == 2, \
               'initial_state must have only two states.'

        c0, h0 = initial_state

        assert c0.shape == h0.shape, 'shapes of initial_state must be same.'
        assert c0.shape[0] == batch_size, \
               'batch size of initial_state ({0}) is different from that of inputs ({1}).'.format(c0.shape[0], batch_size)
        assert c0.shape[1] == units, \
               'units size of initial_state ({0}) is different from that of units of args ({1}).'.format(c0.shape[1], units)

    cell = c0
    hidden = h0

    hs = []

    if mask is None:
        mask = F.constant(1, shape=(batch_size, length, 1))
    for x, cond in zip(F.split(inputs, axis=1), F.split(mask, axis=1)):
        cell_t, hidden_t = lstm_cell(x, cell, hidden)
        cell = where(cond, cell_t, cell)
        hidden = where(cond, hidden_t, hidden)
        hs.append(hidden)

    if return_sequences:
        ret = F.stack(*hs, axis=1)
    else:
        ret = hs[-1]

    if return_state:
        return ret, cell, hidden
    else:
        return ret
Exemplo n.º 8
0
def stack(xs, axis=0):
    if len(xs) == 1:
        s = list(xs[0].shape)
        s.insert(axis, 1)
        xs[0] = F.broadcast(xs[0], xs[0].shape)
        return F.reshape(xs[0], s)
    else:
        return F.stack(*xs, axis=axis)
Exemplo n.º 9
0
def guided_filter(img, r, eps):
    """
    Edge preserving filter
    """
    img2 = F.concatenate(img, img * img, axis=3)
    img2 = box_filter(img2, r)
    mean = F.split(img2, axis=3)
    mean_i = F.stack(mean[0], mean[1], mean[2], axis=3)
    mean_ii = F.stack(mean[3], mean[4], mean[5], axis=3)
    var_i = mean_ii - mean_i * mean_i
    a = var_i / (var_i + eps)
    b = mean_i - a * mean_i
    ab = F.concatenate(a, b, axis=3)
    ab = box_filter(ab, r)
    mean_ab = F.split(ab, axis=3)
    mean_a = F.stack(mean_ab[0], mean_ab[1], mean_ab[2], axis=3)
    mean_b = F.stack(mean_ab[3], mean_ab[4], mean_ab[5], axis=3)
    q = mean_a * img + mean_b
    return q
Exemplo n.º 10
0
    def call(self, input):
        if self._mode == 'full':
            out = F.stack(*[op(input) for op in self._ops], axis=0)
            out = F.mul2(out, F.softmax(self._alpha, axis=0))
            return F.sum(out, axis=0)

        # update active index
        self._update_active_index()

        return self._ops[self._active](input)
Exemplo n.º 11
0
def sample_pdf(bins, weights, N_samples, det=False):
    """Sample additional points for training fine network

    Args:
      bins: int. Height in pixels.
      weights: int. Width in pixels.
      N_samples: float. Focal length of pinhole camera.
      det

    Returns:
      samples: array of shape [batch_size, 3]. Depth samples for fine network
    """
    weights += 1e-5
    pdf = weights / F.sum(weights, axis=-1, keepdims=True)

    cdf = F.cumsum(pdf, axis=-1)
    # if isinstance(pdf, nn.Variable):
    #     cdf = nn.Variable.from_numpy_array(tf.math.cumsum(pdf.d, axis=-1))
    # else:
    #     cdf = nn.Variable.from_numpy_array(tf.math.cumsum(pdf.data, axis=-1)).data
    cdf = F.concatenate(F.constant(0, cdf[..., :1].shape), cdf, axis=-1)

    if det:
        u = F.arange(0., 1., 1 / N_samples)
        u = F.broadcast(u[None, :], cdf.shape[:-1] + (N_samples, ))
        u = u.data if isinstance(cdf, nn.NdArray) else u
    else:
        u = F.rand(shape=cdf.shape[:-1] + (N_samples, ))

    indices = F.searchsorted(cdf, u, right=True)
    # if isinstance(cdf, nn.Variable):
    #     indices = nn.Variable.from_numpy_array(
    #         tf.searchsorted(cdf.d, u.d, side='right').numpy())
    # else:
    #     indices = nn.Variable.from_numpy_array(
    #         tf.searchsorted(cdf.data, u.data, side='right').numpy())
    below = F.maximum_scalar(indices - 1, 0)
    above = F.minimum_scalar(indices, cdf.shape[-1] - 1)
    indices_g = F.stack(below, above, axis=below.ndim)
    cdf_g = F.gather(cdf,
                     indices_g,
                     axis=-1,
                     batch_dims=len(indices_g.shape) - 2)
    bins_g = F.gather(bins,
                      indices_g,
                      axis=-1,
                      batch_dims=len(indices_g.shape) - 2)

    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom = F.where(F.less_scalar(denom, 1e-5), F.constant(1, denom.shape),
                    denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples
Exemplo n.º 12
0
def lstm(inputs, units, initial_state=None, return_sequences=False, return_state=False, fix_parameters=False):
    '''
    A long short-term memory
    Args:
        inputs (nnabla.Variable): A shape of [B, SentenceLength, EmbeddingSize].
        units (int): Dimensionality of the output space.
        initial_state ([nnabla.Variable, nnabla.Variable]): A tuple of an initial cell and an initial hidden state.
        return_sequences (bool): Whether to return the last output. in the output sequence, or the full sequence.
        return_state (bool): Whether to return the last state which is consist of the cell and the hidden state.
        fix_parameters (bool): Fix parameters (Set need_grad=False).
    Returns:
        nn.Variable: A shape [B, SentenceLength, units].
        or
        nn.Variable: A shape [B, units]
    '''
    
    batch_size = inputs.shape[0]

    if initial_state is None:
        c0 = nn.Variable.from_numpy_array(np.zeros((batch_size, units)))
        h0 = nn.Variable.from_numpy_array(np.zeros((batch_size, units)))
    else:
        assert type(initial_state) is tuple or type(initial_state) is list, \
               'initial_state must be a typle or a list.'
        assert len(initial_state) == 2, \
               'initial_state must have only two states.'

        c0, h0 = initial_state

        assert c0.shape == h0.shape, 'shapes of initial_state must be same.'
        assert c0.shape[0] == batch_size, \
               'batch size of initial_state ({0}) is different from that of inputs ({1}).'.format(c0.shape[0], batch_size)
        assert c0.shape[1] == units, \
               'units size of initial_state ({0}) is different from that of units of args ({1}).'.format(c0.shape[1], units)

    cell = c0
    hidden = h0

    hs = []

    for x in F.split(inputs, axis=1):
        cell, hidden = lstm_cell(x, cell, hidden)
        hs.append(hidden)

    if return_sequences:
        ret = F.stack(*hs, axis=1)
    else:
        ret = hs[-1]

    if return_state:
        return ret, cell, hidden
    else:
        return ret
Exemplo n.º 13
0
 def network(self, x_in, name='LSTM', n_hidden=32):
     hlist = []
     for x_i in F.split(x_in, axis=1):
         self._h, self._c = self._lstm_cell(name, n_hidden, x_i, self._h, self._c)
         with nn.parameter_scope(name + '_Affine_2'):
             self._h = PF.affine(self._h, (self._cols_size,))
         hlist.append(self._h)
     h = F.stack(*hlist, axis=1)
     h = F.slice(h, start=[0, h.shape[1]-self._x_output_length, 0],
             stop=[self._batch_size, h.shape[1], self._cols_size],
             step=[1, 1, 1])
     return h
Exemplo n.º 14
0
def lab2rgb(input):
    input_trans = F.split(input, axis=1)
    L, a, b = F.split(input, axis=1)
    y = (L + 16.0) / 116.0
    x = (a / 500.0) + y
    z = y - (b / 200.0)
    neg_mask = F.less_scalar(z, 0).apply(need_grad=False)
    z = z * F.logical_not(neg_mask)
    mask_Y = F.greater_scalar(y, 0.2068966).apply(need_grad=False)
    mask_X = F.greater_scalar(x, 0.2068966).apply(need_grad=False)
    mask_Z = F.greater_scalar(z, 0.2068966).apply(need_grad=False)
    Y_1 = (y ** 3) * mask_Y
    Y_2 = L / (116. * 7.787) * F.logical_not(mask_Y)
    var_Y = Y_1 + Y_2

    X_1 = (x ** 3) * mask_X
    X_2 = (x - 16. / 116.) / 7.787 * F.logical_not(mask_X)
    var_X = X_1 + X_2

    Z_1 = (z ** 3) * mask_Z
    Z_2 = (z - 16. / 116.) / 7.787 * F.logical_not(mask_Z)
    var_Z = Z_1 + Z_2

    X = 0.95047 * var_X
    Y = 1.00000 * var_Y
    Z = 1.08883 * var_Z

    var_R = X * 3.2406 + Y * -1.5372 + Z * -0.4986
    var_G = X * -0.9689 + Y * 1.8758 + Z * 0.0415
    var_B = X * 0.0557 + Y * -0.2040 + Z * 1.0570

    mask_R = F.greater_scalar(var_R, 0.0031308).apply(need_grad=False)
    n_mask_R = F.logical_not(mask_R)
    R_1 = (1.055 * (F.maximum2(var_R, n_mask_R) ** (1 / 2.4)) - 0.055) * mask_R
    R_2 = (12.92 * var_R) * n_mask_R
    var_R = R_1 + R_2

    mask_G = F.greater_scalar(var_G, 0.0031308).apply(need_grad=False)
    n_mask_G = F.logical_not(mask_G)
    G_1 = (1.055 * (F.maximum2(var_G, n_mask_G) ** (1 / 2.4)) - 0.055) * mask_G
    G_2 = (12.92 * var_G) * n_mask_G
    var_G = G_1 + G_2

    mask_B = F.greater_scalar(var_B, 0.0031308).apply(need_grad=False)
    n_mask_B = F.logical_not(mask_B)
    B_1 = (1.055 * (F.maximum2(var_B, n_mask_B) ** (1 / 2.4)) - 0.055) * mask_B
    B_2 = (12.92 * var_B) * n_mask_B
    var_B = B_1 + B_2
    return F.stack(var_R, var_G, var_B, axis=1)
Exemplo n.º 15
0
def split_backward(inputs, axis=0):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    dy = inputs[0:-1]
    x0 = inputs[-1]
    D = len(x0.shape)
    axis = positive_axis(axis, D)
    dx = F.stack(*dy, axis=axis)
    return dx
Exemplo n.º 16
0
def ndc_rays(H, W, focal, near, rays_o, rays_d):
    """Normalized device coordinate rays.

    Space such that the canvas is a cube with sides [-1, 1] in each axis.

    Args:
      H (int): Height in pixels.
      W (int): Width in pixels.
      focal (float):  Focal length of pinhole camera.
      near (float): Near depth bound for the scene.
      rays_o (nn.Variable or nn.NdArray): shape [batch_size, 3]. Camera origin.
      rays_d (nn.Variable or nn.NdArray): shape [batch_size, 3]. Ray direction.

    Returns:
      rays_o: array of shape [batch_size, 3]. Camera origin in NDC.
      rays_d: array of shape [batch_size, 3]. Ray direction in NDC.
    """
    # Shift ray origins to near plane
    t = -(near + rays_o[..., 2]) / (rays_d[..., 2] + 1e-5)
    rays_o = rays_o + t[..., None] * rays_d

    # Projection
    o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
    o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
    o2 = 1. + 2. * near / rays_o[..., 2]

    d0 = -1./(W/(2.*focal)) * \
        (rays_d[..., 0]/rays_d[..., 2] - rays_o[..., 0]/rays_o[..., 2])
    d1 = -1./(H/(2.*focal)) * \
        (rays_d[..., 1]/rays_d[..., 2] - rays_o[..., 1]/rays_o[..., 2])
    d2 = -2. * near / rays_o[..., 2]

    rays_o = F.stack(o0, o1, o2, axis=-1)
    rays_d = F.stack(d0, d1, d2, axis=-1)

    return rays_o, rays_d
Exemplo n.º 17
0
def build_cost_volume(limg, rimg, maxdisp):
    left_stack = []
    right_stack = []
    for i in range(int(maxdisp / 4)):
        sliced_limg = limg[:, :, :, i:]
        sliced_rimg = rimg[:, :, :, :limg.shape[3] - i]
        if i == 0:
            padded_limg = sliced_limg
            padded_rimg = sliced_rimg
        else:
            # Padd i pixels on the left edge
            # The shape of padded_* becomes [B, C, H, W]
            padded_limg = F.pad(sliced_limg, (i, 0))
            padded_rimg = F.pad(sliced_rimg, (i, 0))

        left_stack.append(padded_limg)
        right_stack.append(padded_rimg)

    left_stacked = F.stack(*left_stack, axis=2)  # [B, C, D, H, W]
    right_stacked = F.stack(*right_stack, axis=2)  # [B, C, D, H, W]

    cost_volume = F.concatenate(left_stacked, right_stacked,
                                axis=1)  # [B, 2C, D, H, W]
    return cost_volume
Exemplo n.º 18
0
def LSTM(inputs,
         units,
         initial_state=None,
         return_sequences=False,
         return_state=False,
         name='lstm'):

    batch_size = inputs.shape[0]

    if initial_state is None:

        c0 = nn.Variable.from_numpy_array(np.zeros((batch_size, units)),
                                          need_grad=True)
        h0 = nn.Variable.from_numpy_array(np.zeros((batch_size, units)),
                                          need_grad=True)
    else:
        assert type(initial_state) is tuple or type(initial_state) is list, \
               'initial_state must be a typle or a list.'
        assert len(initial_state) == 2, \
               'initial_state must have only two states.'

        c0, h0 = initial_state

        assert c0.shape == h0.shape, 'shapes of initial_state must be same.'
        assert c0.shape[0] == batch_size, \
               'batch size of initial_state ({0}) is different from that of inputs ({1}).'.format(c0.shape[0], batch_size)
        assert c0.shape[1] == units, \
               'units size of initial_state ({0}) is different from that of units of args ({1}).'.format(c0.shape[1], units)

    cell = c0
    hidden = h0

    hs = []

    for x in F.split(inputs, axis=1):
        with nn.parameter_scope(name):
            cell, hidden = LSTMCell(x, cell, hidden)
        hs.append(hidden)

    if return_sequences:
        ret = F.stack(*hs, axis=1)
    else:
        ret = hs[-1]

    if return_state:
        return ret, cell, hidden
    else:
        return ret
Exemplo n.º 19
0
    def backward_impl(self, inputs, outputs, prop_down, accum):
        # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or
        # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph]

        # Args
        axis = self.forward_func.info.args["axis"]

        # Compute
        ## w.r.t. dy
        if prop_down[-1]:
            g_dy = inputs[-1].grad
            g_dy_ = F.stack(*[o.grad for o in outputs], axis=axis)
            if accum[-1]:
                g_dy += g_dy_
            else:
                g_dy.copy_from(g_dy_)
Exemplo n.º 20
0
def create_network(batchsize, imheight, imwidth, args):
    import gc
    gc.collect()
    nnabla_ext.cuda.clear_memory_cache()

    anchors = args.num_anchors
    classes = args.num_classes
    yolo_x = nn.Variable((batchsize, 3, imheight, imwidth))
    yolo_features = yolov2.yolov2(yolo_x, anchors, classes, test=False)

    nB = yolo_features.shape[0]
    nA = args.num_anchors
    nC = args.num_classes
    nH = yolo_features.shape[2]
    nW = yolo_features.shape[3]

    output = yolo_features.get_unlinked_variable(need_grad=True)
    # TODO: Workaround until v1.0.2.
    # Explicitly enable grad since need_grad option above didn't work.
    output.need_grad = True

    output = F.reshape(output, (nB, nA, (5 + nC), nH, nW))
    output_splitted = F.split(output, 2)
    x, y, w, h, conf = [v.reshape((nB, nA, nH, nW))
                        for v in output_splitted[0:5]]
    x, y, conf = map(F.sigmoid, [x, y, conf])

    cls = F.stack(*output_splitted[5:], axis=2)
    cls = cls.reshape((nB*nA, nC, nH*nW))
    cls = F.transpose(cls, [0, 2, 1]).reshape((nB*nA*nH*nW, nC))

    tx, ty, tw, th, tconf, coord_mask, conf_mask_sq = [
        nn.Variable(v.shape) for v in [x, y, w, h, conf, x, conf]]
    cls_ones, cls_mask = [nn.Variable(cls.shape) for _ in range(2)]
    tcls, cls_mask_bb = [nn.Variable((cls.shape[0], 1)) for _ in range(2)]

    coord_mask_sq = F.pow_scalar(coord_mask, 2)
    loss_x = args.coord_scale * F.sum(F.squared_error(x, tx) * coord_mask_sq)
    loss_y = args.coord_scale * F.sum(F.squared_error(y, ty) * coord_mask_sq)
    loss_w = args.coord_scale * F.sum(F.squared_error(w, tw) * coord_mask_sq)
    loss_h = args.coord_scale * F.sum(F.squared_error(h, th) * coord_mask_sq)
    loss_conf = F.sum(F.squared_error(conf, tconf) * conf_mask_sq)
    loss_cls = args.class_scale * \
        F.sum(cls_mask_bb * F.softmax_cross_entropy(cls + cls_ones - cls_mask, tcls))
    loss_nnabla = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls

    return yolo_x, yolo_features, (x, y, w, h, conf, cls), (tx, ty, tw, th, tconf, coord_mask, conf_mask_sq, cls_ones, cls_mask, tcls, cls_mask_bb), loss_nnabla
Exemplo n.º 21
0
def get_d_data(conf, flow_hr, gen_outputs, r_targets, rnn_length):
    """
    prepare data for temporal Discriminators
    """
    # 3 frames are used as one entry, the last input images%3 frames are abandoned
    t_size = int(3 * (rnn_length // 3))
    t_gen_output = F.reshape(
        gen_outputs[:, :t_size, :, :, :],
        (conf.train.batch_size * t_size, conf.train.crop_size * 4,
         conf.train.crop_size * 4, 3),
        inplace=False)
    t_targets = F.reshape(
        r_targets[:, :t_size, :, :, :],
        (conf.train.batch_size * t_size, conf.train.crop_size * 4,
         conf.train.crop_size * 4, 3),
        inplace=False)
    t_batch = conf.train.batch_size * t_size // 3
    t_inputs_v_pre_batch = F.identity(
        flow_hr[:, 0:t_size:3, :, :, :])  # forward motion reused,
    t_inputs_v_batch = nn.Variable(t_inputs_v_pre_batch.shape)
    # no motion for middle frames
    t_inputs_v_batch.data.zero()
    t_inputs_v_nxt_batch = F.identity(
        flow_hr[:, -2:-1 - t_size:-3, :, :, :])  # backward motion

    t_vel = F.stack(
        *[t_inputs_v_pre_batch, t_inputs_v_batch, t_inputs_v_nxt_batch],
        axis=2)
    # batch, t_size/3, 3, FLAGS.crop_size*4, FLAGS.crop_size*4, 2
    t_vel = F.reshape(t_vel,
                      (conf.train.batch_size * t_size,
                       conf.train.crop_size * 4, conf.train.crop_size * 4, 2),
                      inplace=False)
    # Stop gradient to fnet from discriminator, details in TecoGAN supplemental paper
    t_vel.need_grad = False

    disc_data = collections.namedtuple(
        'disc_data', 't_vel, t_gen_output, t_batch, t_targets, t_size')
    return disc_data(t_vel=t_vel,
                     t_gen_output=t_gen_output,
                     t_batch=t_batch,
                     t_targets=t_targets,
                     t_size=t_size)
Exemplo n.º 22
0
def easy_pcd(feature_p1, feature_p2, n_filt, name):
    """
    easy 3 level pyramid cascade aligning
    input: features (feature_p1, feature_p2)
    feature size: f1 = f2 = [B, N, C, H, W]
    """

    with nn.parameter_scope(name):
        # L1: level 1, original spatial size
        l1_fea = F.stack(*[feature_p1, feature_p2], axis=1)
        batch, num_frames, channels, height, width = l1_fea.shape
        l1_fea = l1_fea.reshape((-1, channels, height, width))

        # L2: level 2, 1/2 spatial size
        l2_fea = F.leaky_relu(conv2d(l1_fea, n_filt, 3, 2, 1, bias=True, name='fea_l2_conv1')
                              )
        l2_fea = F.leaky_relu(conv2d(l2_fea, n_filt, 3, 1, 1, bias=True, name='fea_l2_conv2')
                              )

        # L3: level 3, 1/4 spatial size
        l3_fea = F.leaky_relu(conv2d(l2_fea, n_filt, 3, 2, 1, bias=True, name='fea_l3_conv1')
                              )
        l3_fea = F.leaky_relu(conv2d(l3_fea, n_filt, 3, 1, 1, bias=True, name='fea_l3_conv2')
                              )

        l1_fea = F.reshape(l1_fea, (batch, num_frames, -1,
                                    height, width), inplace=False)
        l2_fea = F.reshape(l2_fea, (batch, num_frames, -1,
                                    height // 2, width // 2), inplace=False)
        l3_fea = F.reshape(l3_fea, (batch, num_frames, -1,
                                    height // 4, width // 4), inplace=False)

        fea1 = [l1_fea[:, 0, :, :, :],
                l2_fea[:, 0, :, :, :], l3_fea[:, 0, :, :, :]]
        fea2 = [l1_fea[:, 1, :, :, :],
                l2_fea[:, 1, :, :, :], l3_fea[:, 1, :, :, :]]

        aligned_fea = pcd_align(fea1, fea2)
        fusion_fea = conv2d(aligned_fea, n_filt, 1, 1,
                            0, bias=True, name='fusion')

    return fusion_fea
Exemplo n.º 23
0
    def call(self, x, y):
        hp = self.hp
        results = []
        with nn.parameter_scope('layer_0'):
            x = F.pad(x, (0, 0, 7, 7), 'reflect')
            x = wn_conv(x, hp.ndf, (15,))
            x = F.leaky_relu(x, 0.2, inplace=True)
            results.append(x)

        nf = hp.ndf
        stride = hp.downsamp_factor

        for i in range(1, hp.n_layers_D + 1):
            nf_prev = nf
            nf = min(nf * stride, 1024)
            with nn.parameter_scope(f'layer_{i}'):
                x = wn_conv(
                    x, nf, (stride * 10 + 1,),
                    stride=(stride,),
                    pad=(stride * 5,),
                    group=nf_prev // 4,
                )
                x = F.leaky_relu(x, 0.2, inplace=True)
                results.append(x)

        with nn.parameter_scope(f'layer_{hp.n_layers_D + 1}'):
            nf = min(nf * 2, 1024)
            x = wn_conv(x, nf, kernel=(5,), pad=(2,))
            x = F.leaky_relu(x, 0.2, inplace=True)
            results.append(x)

        with nn.parameter_scope(f'layer_{hp.n_layers_D + 2}'):
            x = wn_conv(x, hp.n_speakers, kernel=(3,), pad=(1,))
            if y is not None:
                idx = F.stack(
                    F.arange(0, hp.batch_size),
                    y.reshape((hp.batch_size,))
                )
                x = F.gather_nd(x, idx)
            results.append(x)

        return results
Exemplo n.º 24
0
    def __call__(self, img0, img1, normalize=False, mean_batch=False):
        """
            Args:
               img0, img1(Variable): Variable containing images. N batch images can be used. 
               normalize(bool): if True, assumes inputs are in [0., 1.] and scales the inputs between [-1., +1.].
                                if False, assumes inputs are in [-1., +1.]
        """
        assert img0.shape == img1.shape, "img0 and img1 have different shape."
        assert isinstance(img0, nn.Variable), "img0 is not Variable."
        assert isinstance(img1, nn.Variable), "img1 is not Variable."

        if normalize:
            # scales the input between [-1., +1.]
            img0 = 2 * img0 - 1
            img1 = 2 * img1 - 1

        if self.apply_scale:
            img0 = (img0 - self._shift) / self._scale
            img1 = (img1 - self._shift) / self._scale

        dists = compute_each_feat_dist(img0,
                                       img1,
                                       feat_extractor=self.feat_extractor)

        if self.spatial:
            # note that this upsampling method is different from the original LPIPS.
            # in the original implementation, it is torch.nn.upsample(mode="bilinear")
            dists = [
                F.interpolate(dist * (1. * img0.shape[2] / dist.shape[2]),
                              output_size=img0.shape[2:]) for dist in dists
            ]
        else:
            dists = [
                F.mean(dist, axis=[2, 3], keepdims=True) for dist in dists
            ]
        # returns N scores ((N, 1, 1, 1))
        lpips_val = F.sum(F.stack(*dists), axis=0)

        if mean_batch:
            lpips_val = F.mean(lpips_val, axis=0)

        return lpips_val
Exemplo n.º 25
0
def upscale_four(inputs, scope='upscale_four'):
    """
    Mimic the tensorflow bilinear-upscaling for a fix ratio of 4.
    """
    with nn.parameter_scope(scope):
        b, h, w, c = inputs.shape

        p_inputs = F.concatenate(
            inputs, inputs[:, -1:, :, :], axis=1)  # pad bottom
        p_inputs = F.concatenate(
            p_inputs, p_inputs[:, :, -1:, :], axis=2)  # pad right

        hi_res_bin = [
            [
                    inputs,  # top-left
                    p_inputs[:, :-1, 1:, :]  # top-right
            ],
            [
                    p_inputs[:, 1:, :-1, :],  # bottom-left
                    p_inputs[:, 1:, 1:, :]  # bottom-right
            ]
            ]

        hi_res_array = []
        for hi in range(4):
            for wj in range(4):
                hi_res_array.append(
                        hi_res_bin[0][0] *
                            (1.0 - 0.25 * hi) * (1.0 - 0.25 * wj)
                        + hi_res_bin[0][1] * (1.0 - 0.25 * hi) * (0.25 * wj)
                        + hi_res_bin[1][0] * (0.25 * hi) * (1.0 - 0.25 * wj)
                        + hi_res_bin[1][1] * (0.25 * hi) * (0.25 * wj)
                        )

        hi_res = F.stack(*hi_res_array, axis=3)  # shape (b,h,w,16,c)
        hi_res_reshape = F.reshape(hi_res, (b, h, w, 4, 4, c))
        hi_res_reshape = F.transpose(hi_res_reshape, (0, 1, 3, 2, 4, 5))
        hi_res_reshape = F.reshape(hi_res_reshape, (b, h*4, w*4, c))

    return hi_res_reshape
Exemplo n.º 26
0
def deformable_conv_lstm(input_tensor, n_filt, kernel_size):
    """
    defomable convolution lstm cell definition
    """

    hidden_state_h = nn.Variable(
        (input_tensor.shape[0], n_filt, input_tensor.shape[3], input_tensor.shape[4]))
    hidden_state_c = nn.Variable(
        (input_tensor.shape[0], n_filt, input_tensor.shape[3], input_tensor.shape[4]))
    hidden_state_h.data.zero()
    hidden_state_c.data.zero()
    seq_len = input_tensor.shape[1]
    output_inner = []

    for t_idx in range(seq_len):
        in_tensor = input_tensor[:, t_idx, :, :, :]
        h_temp = easy_pcd(in_tensor, hidden_state_h, n_filt, 'pcd_h')
        c_temp = easy_pcd(in_tensor, hidden_state_c, n_filt, 'pcd_c')
        hidden_state_h, hidden_state_c = conv_lstm_cell(
            in_tensor, [h_temp, c_temp], n_filt, kernel_size)
        output_inner.append(hidden_state_h)
    layer_output = F.stack(*output_inner, axis=1)
    return layer_output
Exemplo n.º 27
0
def get_generator_output(conf, rnn_length, r_inputs, flow_hr, scope_name):
    """
    Return the generated HR frames
    """
    # list for all outputs
    gen_outputs = []

    # for the first frame, concat with zeros
    input0 = F.concatenate(
        r_inputs[:, 0, :, :, :],
        F.constant(0, (conf.train.batch_size, conf.train.crop_size,
                       conf.train.crop_size, 3 * 4 * 4)))
    with nn.parameter_scope(scope_name + "generator"):
        gen_pre_output = generator(input0, 3, conf.train.num_resblock)
    gen_outputs.append(gen_pre_output)  # append generated HR frame-0

    for frame_i in range(rnn_length - 1):
        cur_flow = flow_hr[:, frame_i, :, :, :]
        # warp the previously generated frame
        gen_pre_output_warp = warp_by_flow(gen_pre_output, cur_flow)
        gen_pre_output_warp = F.identity(deprocess(gen_pre_output_warp))
        # apply space-to-depth transform
        gen_pre_output_warp = space_to_depth(gen_pre_output_warp)
        # pack it as the recurrent input
        inputs = F.concatenate(r_inputs[:, frame_i + 1, :, :, :],
                               gen_pre_output_warp)
        # super-resolution part
        with nn.parameter_scope(scope_name + "generator"):
            gen_output = generator(inputs, 3, conf.train.num_resblock)
        gen_outputs.append(gen_output)
        gen_pre_output = gen_output

    # gen_outputs, a list, len = frame, shape = (batch, FLAGS.crop_size*4, FLAGS.crop_size*4, 3)
    gen_outputs = F.stack(*gen_outputs, axis=1)
    # gen_outputs, nn.Variable with shape = (batch, frame, FLAGS.crop_size*4, FLAGS.crop_size*4, 3)

    return gen_outputs
Exemplo n.º 28
0
def constructing_cell(args,
                      ops,
                      which_cell,
                      cell_prev_prev,
                      cell_prev,
                      output_filter,
                      is_reduced_curr,
                      is_reduced_prev,
                      test=False):
    """
        Constructing one cell.
        input:
            args: arguments set by user.
            ops: operations used in the network.
            arch_dict: a dictionary containing architecture information.
            which_cell: int. An index of cell currently constructed.
            cell_prev_prev: Variable. Output of the cell behind the previous cell.
            cell_prev: Variable. Output of the previous cell.
            output_filter:t he number of the filter used for this cell.
            is_reduced_curr: bool. True if the current cell is the reduction cell.
            is_reduced_prev: bool. True if the previous cell is the reduction cell.
            test: bool. True if the network is for validation.
    """

    # If True, all the parameters in batch_normalizations won't be updated.
    is_search = True

    if is_reduced_curr:
        keyname_basis = "alpha_reduction"
        output_shape = (cell_prev.shape[0], output_filter,
                        cell_prev.shape[2] // 2, cell_prev.shape[3] // 2)
    else:
        keyname_basis = "alpha_normal"
        output_shape = (cell_prev.shape[0], output_filter, cell_prev.shape[2],
                        cell_prev.shape[3])

    if is_reduced_prev:
        scope = "fr{}".format(which_cell)
        cell_prev_prev = factorized_reduction(cell_prev_prev, output_filter,
                                              scope, test, is_search)
    else:
        scope = "preprocess_cell{}_node{}".format(which_cell, 0)
        cell_prev_prev = conv1x1(cell_prev_prev, output_filter, scope, test,
                                 is_search)

    scope = "preprocess_cell{}_node{}".format(which_cell, 1)
    cell_prev = conv1x1(cell_prev, output_filter, scope, test, is_search)

    num_of_nodes = args.num_nodes

    # latter_nodes are all the intermediate nodes,
    # except for 2 input nodes and 1 output node.
    latter_nodes = [
        nn.Variable(output_shape) for _ in range(num_of_nodes - 2 - 1)
    ]
    for v in latter_nodes:
        v.d = 0  # initialize.

    num_of_ops = len(ops)

    # prepare a list to store all nodes.
    nodes = [cell_prev_prev, cell_prev] + latter_nodes
    for i in range(num_of_nodes - 2):
        successors = [_ for _ in range(i + 1, num_of_nodes - 1)]
        for j in successors:
            if j == 1:
                continue
            from_node, to_node = i, j
            scope = "cell{}/node{}_{}".format(which_cell, from_node, to_node)

            stacked_x = num_of_ops * (nodes[i], )
            stacked_x = tuple([
                op(x, output_filter, scope + "/ops{}".format(op_id), i,
                   is_reduced_curr, test, is_search) for x, op, op_id in zip(
                       stacked_x, tuple(ops.values()), tuple(ops.keys()))
            ])
            y = F.stack(*stacked_x, axis=0)

            alpha_name = keyname_basis + "_{}_{}".format(i, j)
            current_alpha = nn.parameter.get_parameter_or_create(
                alpha_name, (num_of_ops, ) + (1, 1, 1, 1))
            alpha_prob = F.softmax(current_alpha, axis=0)
            y = F.mul2(y, alpha_prob)
            if i == 0:
                nodes[j] = F.sum(y, axis=0)
            else:
                nodes[j] = F.add2(nodes[j], F.sum(y, axis=0))

    intermediate_nodes = nodes[2:num_of_nodes - 1]
    output = F.concatenate(*intermediate_nodes, axis=1)

    is_reduced_prev = is_reduced_curr
    return output, is_reduced_curr, is_reduced_prev, output_filter
Exemplo n.º 29
0
    def __call__(self, x, test=False):

        fft_real, fft_imag = STFT(x, n_fft=self.n_fft, n_hop=self.n_hop)
        x_theta = F.atan2(fft_imag, fft_real)

        x = Spectrogram(fft_real,
                        fft_imag,
                        power=self.power,
                        mono=(self.nb_channels == 1))

        nb_frames, nb_samples, nb_channels, nb_bins = x.shape

        mix_spec = F.identity(x)
        x = x[..., :self.nb_bins]

        # clone
        x_bass = F.identity(x)
        x_drums = F.identity(x)
        x_vocals = F.identity(x)
        x_other = F.identity(x)

        # shift and scale input to mean=0 std=1 (across all bins)
        x_bass += F.reshape(self.input_mean_bass,
                            shape=(1, 1, 1, self.nb_bins),
                            inplace=False)
        x_drums += F.reshape(self.input_mean_drums,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)
        x_vocals += F.reshape(self.input_mean_vocals,
                              shape=(1, 1, 1, self.nb_bins),
                              inplace=False)
        x_other += F.reshape(self.input_mean_other,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)

        x_bass *= F.reshape(self.input_scale_bass,
                            shape=(1, 1, 1, self.nb_bins),
                            inplace=False)
        x_drums *= F.reshape(self.input_scale_drums,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)
        x_vocals *= F.reshape(self.input_scale_vocals,
                              shape=(1, 1, 1, self.nb_bins),
                              inplace=False)
        x_other *= F.reshape(self.input_scale_other,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)

        # encode and normalize every instance in a batch
        x_bass = self.fc_bn(x_bass,
                            self.hidden_size,
                            "fc1_bass",
                            test,
                            activation='tanh')
        x_drums = self.fc_bn(x_drums,
                             self.hidden_size,
                             "fc1_drums",
                             test,
                             activation='tanh')
        x_vocals = self.fc_bn(x_vocals,
                              self.hidden_size,
                              "fc1_vocals",
                              test,
                              activation='tanh')
        x_other = self.fc_bn(x_other,
                             self.hidden_size,
                             "fc1_other",
                             test,
                             activation='tanh')

        # Average the sources
        cross_1 = (x_bass + x_drums + x_vocals + x_other) / 4.0

        # apply 3-layers of stacked LSTM
        lstm_out_bass = self.lstm(cross_1, nb_samples, "lstm_bass", test)
        lstm_out_drums = self.lstm(cross_1, nb_samples, "lstm_drums", test)
        lstm_out_vocals = self.lstm(cross_1, nb_samples, "lstm_vocals", test)
        lstm_out_other = self.lstm(cross_1, nb_samples, "lstm_other", test)

        # lstm skip connection
        x_bass = F.concatenate(x_bass, lstm_out_bass)
        x_drums = F.concatenate(x_drums, lstm_out_drums)
        x_vocals = F.concatenate(x_vocals, lstm_out_vocals)
        x_other = F.concatenate(x_other, lstm_out_other)

        cross_2 = (x_bass + x_drums + x_vocals + x_other) / 4.0

        # first dense stage + batch norm
        x_bass = self.fc_bn(cross_2,
                            self.hidden_size,
                            "fc2_bass",
                            test,
                            activation='relu')
        x_drums = self.fc_bn(cross_2,
                             self.hidden_size,
                             "fc2_drums",
                             test,
                             activation='relu')
        x_vocals = self.fc_bn(cross_2,
                              self.hidden_size,
                              "fc2_vocals",
                              test,
                              activation='relu')
        x_other = self.fc_bn(cross_2,
                             self.hidden_size,
                             "fc2_other",
                             test,
                             activation='relu')

        # second dense stage + batch norm
        x_bass = self.fc_bn(x_bass, nb_channels * nb_bins, "fc3_bass", test)
        x_drums = self.fc_bn(x_drums, nb_channels * nb_bins, "fc3_drums", test)
        x_vocals = self.fc_bn(x_vocals, nb_channels * nb_bins, "fc3_vocals",
                              test)
        x_other = self.fc_bn(x_other, nb_channels * nb_bins, "fc3_other", test)

        # reshape back to original dim
        x_bass = F.reshape(
            x_bass, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_drums = F.reshape(
            x_drums, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_vocals = F.reshape(
            x_vocals,
            (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_other = F.reshape(
            x_other, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))

        # apply output scaling
        x_bass *= F.reshape(self.output_scale_bass,
                            shape=(1, 1, 1, self.nb_output_bins),
                            inplace=False)
        x_drums *= F.reshape(self.output_scale_drums,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)
        x_vocals *= F.reshape(self.output_scale_vocals,
                              shape=(1, 1, 1, self.nb_output_bins),
                              inplace=False)
        x_other *= F.reshape(self.output_scale_other,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)

        x_bass += F.reshape(self.output_mean_bass,
                            shape=(1, 1, 1, self.nb_output_bins),
                            inplace=False)
        x_drums += F.reshape(self.output_mean_drums,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)
        x_vocals += F.reshape(self.output_mean_vocals,
                              shape=(1, 1, 1, self.nb_output_bins),
                              inplace=False)
        x_other += F.reshape(self.output_mean_other,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)

        # since our output is non-negative, we can apply RELU
        mask_bass = F.relu(x_bass)
        mask_drums = F.relu(x_drums)
        mask_vocals = F.relu(x_vocals)
        mask_other = F.relu(x_other)

        # (Frames, Bsize, Channels, Fbins)
        x_bass = mask_bass * mix_spec
        x_drums = mask_drums * mix_spec
        x_vocals = mask_vocals * mix_spec
        x_other = mask_other * mix_spec

        if not self.is_predict:
            tmp = F.stack(*[x_bass, x_drums, x_vocals, x_other], axis=0)
            # (4(sources), Frames, Bsize(16), 2(channels), Fbins) ==> (4, Bsize, Channels, Fbins, Frames)
            tmp = F.transpose(tmp, (0, 2, 3, 4, 1))
            pred_r, pred_i = [], []
            for i in range(tmp.shape[0]):
                pred_r.append(tmp[i] * F.cos(x_theta))
                pred_i.append(tmp[i] * F.sin(x_theta))
            pred_r = F.stack(*pred_r, axis=0)
            pred_i = F.stack(*pred_i, axis=0)
            pred_r = F.reshape(pred_r,
                               (4 * nb_samples * nb_channels, 2049, nb_frames))
            pred_i = F.reshape(pred_i,
                               (4 * nb_samples * nb_channels, 2049, nb_frames))
            pred = istft(pred_r,
                         pred_i,
                         self.n_fft,
                         self.n_hop,
                         self.n_fft,
                         window_type='hanning',
                         center=True)
            pred = F.reshape(pred, (4, nb_samples, nb_channels, -1))

        else:
            pred = None

        return mix_spec, F.concatenate(mask_bass,
                                       mask_drums,
                                       mask_vocals,
                                       mask_other,
                                       axis=2), pred
Exemplo n.º 30
0
    def call(self, memory, decoder_inputs=None):
        r"""Return mel-spectrograms, gate outputs and an attention matrix.

        Args:
            memory (nn.Variable): A 3D tensor of shape (B, T, C).
            decoder_inputs (nn.Variable, optional): A 3D tensor with shape of (B, T/r, r*n_mels).
                Shifted log melspectrogram of sound files. Defaults to None.

        Returns:
            nn.Variable: The synthetic mel-spectrograms of shape (B, Ty/r, r*n_mels).
            nn.Variable: The gate outputs of shape (B, Ty).
            nn.Variable: The attention matrix of shape (B, Tx, Ty).
        """
        hp = self._hparams
        mel_shape = hp.n_mels * hp.r

        # initialize decoder states
        decoder_input = F.constant(shape=(hp.batch_size, 1, mel_shape))
        decoder_hidden = F.constant(shape=(1, 1, hp.batch_size,
                                           hp.decoder_rnn_dim))
        decoder_cell = F.constant(shape=(1, 1, hp.batch_size,
                                         hp.decoder_rnn_dim))

        # initialize attention states
        attention_weights = F.constant(shape=(hp.batch_size, 1, hp.text_len))
        attention_weights_cum = F.constant(shape=(hp.batch_size, 1,
                                                  hp.text_len))
        attention_context = F.constant(shape=(hp.batch_size, 1,
                                              hp.encoder_embedding_dim))
        attention_hidden = F.constant(shape=(1, 1, hp.batch_size,
                                             hp.attention_rnn_dim))
        attention_cell = F.constant(shape=(1, 1, hp.batch_size,
                                           hp.attention_rnn_dim))

        # store outputs
        mel_outputs, gate_outputs, alignments = [], [], []

        for i in range(hp.mel_len):
            if i > 0:
                decoder_input = (mel_outputs[-1] if decoder_inputs is None else
                                 decoder_inputs[:, i - 1:i, :])
                if decoder_inputs is None:
                    decoder_input = decoder_input[None, ...]
            # decoder of shape (B, 1, prenet_channels=256)
            decoder_input = prenet(decoder_input,
                                   hp.prenet_channels,
                                   is_training=self.training,
                                   scope='prenet')

            with nn.parameter_scope('attention_rnn'):
                # cell_input of shape (B, 1, prenet_channels[-1] + C=768)
                cell_input = F.concatenate(decoder_input,
                                           attention_context,
                                           axis=2)
                _, attention_hidden, attention_cell = PF.lstm(
                    F.transpose(cell_input, (1, 0, 2)),
                    attention_hidden,
                    attention_cell,
                    training=self.training,
                    name='lstm_attention'
                )  # (1, 1, B, attention_hidden), (1, 1, B, attention_hidden)
                if self.training:
                    attention_hidden = F.dropout(attention_hidden,
                                                 hp.p_attention_dropout)

            with nn.parameter_scope('location_attention'):
                attention_weights_cat = F.concatenate(attention_weights,
                                                      attention_weights_cum,
                                                      axis=1)
                attention_context, attention_weights = location_sensitive_attention(
                    F.transpose(attention_hidden[0], (1, 0, 2)),
                    memory,
                    attention_weights_cat,
                    attention_location_kernel_size=hp.
                    attention_location_kernel_size,
                    attention_n_filters=hp.attention_location_n_filters,
                    attention_dim=hp.attention_dim,
                    is_training=self.training,
                    scope='ls_attention')
                attention_weights_cum += attention_weights
                alignments.append(attention_weights)

            with nn.parameter_scope('decoder_rnn'):
                # (1, B, attention_rnn_dim + encoder_embedding_dim)
                inp_decoder = F.concatenate(attention_hidden[0],
                                            F.transpose(
                                                attention_context, (1, 0, 2)),
                                            axis=2)
                _, decoder_hidden, decoder_cell = PF.lstm(
                    inp_decoder,
                    decoder_hidden,
                    decoder_cell,
                    training=self.training,
                    name='lstm_decoder')
                if self.training:
                    decoder_hidden = F.dropout(decoder_hidden,
                                               hp.p_decoder_dropout)

            with nn.parameter_scope('projection'):
                proj_input = F.concatenate(
                    decoder_hidden[0, 0],
                    F.reshape(attention_context, (hp.batch_size, -1),
                              inplace=False),
                    axis=1)  # (B, decoder_rnn_dim + encoder_embedding_dim)
                decoder_output = affine_norm(proj_input,
                                             mel_shape,
                                             base_axis=1,
                                             with_bias=True,
                                             w_init_gain='affine',
                                             scope='affine')
                mel_outputs.append(decoder_output)

            with nn.parameter_scope('gate_prediction'):
                gate_prediction = affine_norm(proj_input,
                                              1,
                                              base_axis=1,
                                              with_bias=True,
                                              w_init_gain='sigmoid',
                                              scope='affine')
                gate_outputs.append(gate_prediction)

        # (B, T2, n_mels*r)
        mel_outputs = F.stack(*mel_outputs, axis=1)
        gate_outputs = F.concatenate(*gate_outputs, axis=1)  # (B, T2)
        alignments = F.concatenate(*alignments, axis=1)  # (B, T1, T2)

        return mel_outputs, gate_outputs, alignments