def forward(self, x, qst):
        with x.context:
            self.coord_tensor = F.zeros((x.shape[0], 25, 2))

        # prepare coord tensor
        def cvt_coord(i):
            return [(i / 5 - 2) / 2., (i % 5 - 2) / 2.]

        for i in range(25):
            self.coord_tensor[:, i, :] = F.array(cvt_coord(i))

        #input size = (64 * 3 * 75 * 75)
        x = self.conv(x)  ## x = (64 * 24 * 5 * 5)

        ##g part
        mb = x.shape[0]
        n_channels = x.shape[1]
        d = x.shape[2]

        x_flat = x.reshape(shape=(mb, n_channels, d * d))
        x_flat = F.swapaxes(x_flat, 1, 2)  ## (64 * 25 * 24)

        ##add coordinates
        x_flat = F.concat(x_flat, self.coord_tensor, dim=2)

        ##add question
        qst = qst.expand_dims(1)
        qst = F.repeat(qst, repeats=25, axis=1)
        qst = qst.expand_dims(2)

        # cast all pairs against each other
        x_i = x_flat.expand_dims(1)
        x_i = F.repeat(x_i, repeats=25, axis=1)

        x_j = x_flat.expand_dims(2)
        x_j = F.concat(x_j, qst, dim=3)
        x_j = F.repeat(x_j, repeats=25, axis=2)

        #concatenate all
        x_full = F.concat(x_i, x_j, dim=3)

        #reshape and apply dnn network
        x_ = x_full.reshape((-1, 63))
        x_ = self.g_fc1(x_)
        x_ = self.g_fc2(x_)
        x_ = self.g_fc3(x_)
        x_ = self.g_fc4(x_)

        x_g = x_.reshape((mb, -1, 256))
        x_g = x_g.sum(1)

        ##### f part #######
        x_f = self.f_fc1(x_g)

        return self.fcout(x_f)
Exemple #2
0
def fftfilt_nd(x, params):
    (b, m, nx, nb, L, nfft) = params

    B = nd.contrib.fft(data=nd.concatenate(
        [b.T, nd.zeros(shape=(1, (nfft - b.size)), ctx=ctx)], axis=1))
    if b.size == 1:
        B = B.T  # make sure fft of B is a column (might be a row if b is scalar)
    if b.shape[1] == 1:
        B = nd.repeat(data=B, repeats=x.shape[1],
                      axis=0)  # replicate the column B
        B_re = nd.slice(data=B, begin=(0, 0), end=(0, None), step=(1, 2))
        B_im = nd.slice(data=B, begin=(0, 1), end=(0, None), step=(1, 2))
    if x.shape[1] == 1:
        x = nd.repeat(data=x, repeats=b.shape[1],
                      axis=1)  # replicate the column x
    y = nd.zeros_like(x.T)

    istart = 1
    while istart <= nx:
        iend = min(istart + L - 1, nx)
        if (iend - istart) == 0:
            X = x[istart] * np.ones((nfft, 1))  # need to fft a scalar
        else:
            temp = nd.slice(x, begin=istart - 1, end=iend).T
            X = nd.contrib.fft(data=nd.concatenate([
                temp,
                nd.zeros(shape=(temp.shape[0], (nfft - temp.shape[1])),
                         ctx=ctx)
            ],
                                                   axis=1))
            X_re = nd.slice(data=X, begin=(0, 0), end=(0, None), step=(1, 2))
            X_im = nd.slice(data=X, begin=(0, 1), end=(0, None), step=(1, 2))

        XprodB_re = (X_re * B_re - X_im * B_im)
        XprodB_im = (X_re * B_im + X_im * B_re)
        Ytemp = nd.zeros((X.shape[0], X.shape[1]), ctx=ctx)
        Ytemp[:, ::2] = XprodB_re
        Ytemp[:, 1::2] = XprodB_im
        Y = mx.contrib.ndarray.ifft(Ytemp / nfft)  # only the real part!!!!

        yend = min(nx, istart + nfft - 1)

        y[:, istart - 1:yend] = nd.slice(
            data=y, begin=(0, istart - 1), end=(0, yend),
            step=(1, 1)) + nd.slice(
                data=Y, begin=(0, 0), end=(0, yend - istart + 1), step=(1, 1))
        istart += L
#     y = real(y)

    return y
def verify_l2normalization_rewrite(shape, eps, mode):
    assert len(shape) == 4  # NCHW
    data_np = np.random.uniform(size=shape)
    x = nd.array(data_np)

    # org op
    y = nd.L2Normalization(x, eps=eps, mode=mode)

    # rewrite op
    z = nd.broadcast_mul(x, x)
    if mode == "channel":
        axis = [1]
    elif mode == "instance":
        axis = [1, 2, 3]
    elif mode == "spatial":
        axis = [2, 3]
    else:
        assert "not valid `mode` type: %s" % mode
    z = nd.sum(z, axis=axis)
    eps_tensor = nd.array([eps])
    z = nd.broadcast_add(z, eps_tensor)
    z = nd.sqrt(z)
    for i in axis:
        z = nd.expand_dims(z, axis=i)
        z = nd.repeat(z, repeats=shape[i], axis=i)
    z = nd.broadcast_div(x, z)
    print(z.shape)
    return

    # compare
    assert z.shape == y.shape
    zn, zp = get_norm(z)
    yn, yp = get_norm(y)
    rn = np.linalg.norm(zp - yp)
    print(zn, yn, rn)
Exemple #4
0
def repeat(input, repeats, dim):
    if isinstance(repeats, nd.NDArray):
        return nd.array(np.repeat(input.asnumpy(), repeats.asnumpy(),
                                  axis=dim),
                        ctx=input.context,
                        dtype=input.dtype)
    else:
        return nd.repeat(input, repeats, axis=dim)
Exemple #5
0
def repeat(input, repeats, dim):
    return nd.repeat(input, repeats, axis=dim)
def verify_broadcast_like_dynamic(xshp, wshp, lhs_axes, rhs_axes):
    x_np = np.random.uniform(size=xshp)
    w_np = np.random.uniform(size=wshp)
    x = nd.array(x_np)
    w = nd.array(w_np)

    # org op
    y = nd.broadcast_like(x, w,
        lhs_axes=lhs_axes, rhs_axes=rhs_axes)
    print(y.shape)

    # rewrite op
    xndims, wndims = len(xshp), len(wshp)
    if lhs_axes is None or rhs_axes is None:
        assert xndims == wndims and lhs_axes is None \
            and rhs_axes is None
        z = _broadcast_like(x, w)
    else:
        lhs_axes, lndims = list(lhs_axes), len(lhs_axes)
        rhs_axes, rndims = list(rhs_axes), len(rhs_axes)
        assert lndims == rndims > 0

        lhs_axes = tuple([v+xndims if v<0 else v for v in lhs_axes])
        assert all([0<=v<xndims for v in list(lhs_axes)])

        rhs_axes = tuple([v+wndims if v<0 else v for v in rhs_axes])
        assert all([0<=v<wndims for v in list(rhs_axes)])

        assert all([xshp[lhs_axes[i]] == 1 for i in range(lndims)])

        batch_axes = [0]
        flg = all([batch_axis not in rhs_axes \
            for batch_axis in batch_axes])
        if flg:
            cnts = {v: wshp[rhs_axes[i]] \
                for i, v in enumerate(lhs_axes)}
            reps = tuple([cnts[v] if v in lhs_axes else 1 \
                for v in range(xndims)])
            z = nd.tile(x, reps=reps)
        else:
            axis_map = {}
            for i, v in enumerate(lhs_axes):
                axis_map[v] = rhs_axes[i]
            for batch_axis in batch_axes:
                assert sum([1 if v == batch_axis else 0 \
                    for k, v in axis_map.items()]) <= 1, \
                    "multiple broadcast on batch_axis: %s, " + \
                    "which is not support by dynamic shape fusion." % \
                    batch_axis
            assert wndims < 6, \
                "slice can manipulate at most 5d"

            # reduce shape to 1 for non-broadcast dimensions
            begin = tuple([0]*wndims)
            end = tuple([wshp[v] if v in axis_map.values() else 1 \
                for v in range(wndims)])
            w = nd.slice(w, begin=begin, end=end)

            # decompose k1->v, k2->v into k1->v, k2->v2
            # which make axis
            while True:
                vs, flag, paxis_map = set(), True, axis_map
                for pk, pv in paxis_map.items():
                    if pv not in vs:
                        vs.add(pv)
                        continue
                    flag = False
                    axis_map = {k: (v+1 if v>pv or k==pk else v) \
                        for k, v in axis_map.items()}
                    w = nd.expand_dims(w, axis=pv)
                    w = nd.repeat(w, axis=pv, repeats=wshp[pv])
                    wshp = wshp[:pv] + (wshp[pv],) + wshp[pv:]
                    break
                if flag:
                    break
            wndims = len(wshp)

            # trim wndims if not equal to xndims
            v = 0
            while wndims > xndims:
                while v in axis_map.values():
                    v += 1
                w = nd.squeeze(w, axis=v)
                wndims -= 1
                axis_map = {k: (nv-1 if nv > v else nv) \
                    for k, nv in axis_map.items()}
            while wndims < xndims:
                w = nd.expand_dims(w, axis=wndims)
                wndims += 1
            axes = list(range(wndims))
            while True:
                dels = [k for k, v in axis_map.items() if k==v]
                for k in dels:
                    del axis_map[k]
                if not axis_map:
                    break
                keys = list(axis_map.keys())
                k, v = keys[0], axis_map[keys[0]]
                axes[k], axes[v] = axes[v], axes[k]
                for nk in keys:
                    nv = axis_map[nk]
                    if nv == k:
                        axis_map[nk] = v
                    elif nv == v:
                        axis_map[nk] = k
            axes = tuple(axes)
            if axes != tuple(range(wndims)):
                assert wndims < 7, \
                    "slice can manipulate at most 6d"
                w = nd.transpose(w, axes=axes)
            z = _broadcast_like(x, w)
    print(z.shape)

    # compare
    assert z.shape == y.shape
    zn, zp = get_norm(z)
    yn, yp = get_norm(y)
    rn = np.linalg.norm(zp-yp)
    print(zn, yn, rn)