Example #1
0
    def define_graph(self,
                     inp,
                     weight,
                     bias,
                     grad,
                     inp_grad,
                     weight_grad,
                     bias_grad,
                     optimizer,
                     optimizer_kwargs,
                     stride=1,
                     pad=0,
                     dilation=1):
        min_sizes = []
        k = len(grad.shape) - 2

        for d in range(k):
            min_sizes.append((grad.shape[d + 2] - 1) * stride - 2 * pad +
                             (weight.shape[-1] - 1) * dilation + 1)

        grad_input_padding = tuple(inp.shape[-k + d] - min_sizes[d]
                                   for d in range(k))
        assert grad_input_padding[0] == grad_input_padding[1]
        pm.conv_transpose_bias(grad,
                               weight,
                               bias,
                               inp_grad,
                               stride=stride,
                               pad=pad,
                               out_pad=grad_input_padding[0])
        inp_indices = tuple(pm.index(0, s - 1) for s in inp.shape)
        grad_indices = tuple(pm.index(0, s - 1) for s in grad.shape)
        weight_indices = tuple(pm.index(0, s - 1) for s in weight.shape)
        inp_transposed = pm.temp(name=f"transposed_{inp.name}",
                                 shape=(inp.shape[1], inp.shape[0],
                                        inp.shape[2], inp.shape[3]))
        grad_transposed = pm.state(name=f"transposed_{grad.name}",
                                   shape=(grad.shape[1], grad.shape[0],
                                          grad.shape[2], grad.shape[3]))
        wgt_grad_transposed = pm.temp(name=f"transposed_{weight.name}",
                                      shape=(weight.shape[1], weight.shape[0],
                                             weight.shape[2], weight.shape[3]))
        pm.tensor_transpose(inp, inp_transposed, perm=(1, 0, 2, 3))
        pm.tensor_transpose(grad, grad_transposed, perm=(1, 0, 2, 3))
        pm.conv(inp_transposed,
                grad_transposed,
                wgt_grad_transposed,
                stride=dilation,
                pad=pad,
                dilation=stride)
        pm.tensor_transpose(wgt_grad_transposed,
                            weight_grad,
                            perm=(1, 0, 2, 3))
        # Weight update
        OPTIMIZERS[optimizer](weight, weight_grad, **optimizer_kwargs)
        pm.reduce_sum(grad, bias_grad)
        OPTIMIZERS[optimizer](bias, bias_grad, **optimizer_kwargs)
Example #2
0
 def define_graph(self, x, w, y, y_pred, mu, m):
     i = pm.index(0, (m - 1).set_name("m-1"), name="i")
     h = pm.temp(name="h", shape=(m))
     h = pm.sigmoid(pm.sum([i], (x[i] * w[i]), name="h"))
     d = (h - y).set_name("h-y")
     g = (d * x[i]).set_name("d*x")
     w[i] = w[i] - mu * g[i]
Example #3
0
    def populate_temp(self, node):
        if node.shape != pm.DEFAULT_SHAPES[0]:

            indices = list(product(*tuple([np.arange(i) for i in node.shape])))
            for i in indices:
                x = pm.temp(graph=node,
                            name=f"{node.name}{i}",
                            root_name=node.name,
                            shape=(1, ))
                self.stored_objects[id(x)] = x
Example #4
0
def dilate(var: pm.placeholder, strides, name=None):
    n = len(var.shape)
    assert len(strides) == n
    out_shape = ()
    nz_indices = ()
    shape_idx = ()

    for i in range(n):
        out_shape += ((var.shape[i] - 1) * strides[i] + 1, )
        nz_indices += (pm.index(0, out_shape[i] - 1, stride=strides[i]), )
        shape_idx += (pm.index(0, out_shape[i] - 1), )

    padded = pm.temp(name=name, shape=out_shape)
    padded[shape_idx] = 0
    padded[(shape_idx[0])] = 0
Example #5
0
def get_gemm(a,
             b,
             c=None,
             shape=None,
             name=None,
             alpha=None,
             beta=None,
             transA=None,
             transB=None,
             out=None):
    if not out:
        out = pm.output(shape=shape, name=name)
    if transB:
        assert len(b.shape) == 2
        b.shape = (b.shape[1], b.shape[0])
        transB = False

    if c:
        pm.gemm(a,
                b,
                c,
                out,
                alpha=alpha,
                beta=beta,
                transA=transA,
                transB=transB,
                strict_shapes=True)
    else:
        t_c = pm.temp(shape=shape)
        i = pm.index(0, shape[0] - 1)
        j = pm.index(0, shape[1] - 1)
        t_c[i, j] = 0
        pm.gemm(a,
                b,
                t_c,
                out,
                alpha=alpha,
                beta=beta,
                transA=transA,
                transB=transB,
                strict_shapes=True)
    return out
Example #6
0
    def define_graph(self, x, y, alpha, beta, bias, nsize):
        n = pm.index(0, x.shape[0] - 1)
        c = pm.index(0, x.shape[1] - 1)
        h = pm.index(0, x.shape[2] - 1)
        w = pm.index(0, x.shape[3] - 1)
        c_ = pm.index(0, x.shape[1] - 1)
        ext = pm.temp(name="extended", shape=tuple([*x.shape, x.shape[-3]]))

        bounds = pm.output(name="bounds", shape=(x.shape[1], x.shape[1]))
        radius = nsize // 2
        hbool = ((((x.shape[1] > (c + radius + 1)) * (c + radius)) +
                  (x.shape[1] <= (c + radius + 1)) * (x.shape[1] - 1)) >= c_)
        lbool = ((((c - radius) > 0) * (c - radius)) +
                 (((c - radius) <= 0) * 0) <= c_)
        bounds[c, c_] = hbool * lbool
        ext[n, c, h, w, c_] = x[n, c_, h, w] * bounds[c, c_]
        # y[n, c, h, w] = x[n,c,h,w] / ((bias + (alpha/nsize) * pm.sum([c_], ext[n, c, h, w, c_]**2))**beta)
        y[n, c, h, w] = x[n, c, h, w] / (
            (bias +
             (alpha / nsize) * pm.sum([c_], ext[n, c, h, w, c_]**2))**beta)
Example #7
0
    def define_graph(self, data, wgt, out, stride=1, pad=0, out_pad=0):

        n, c, h, w = data.shape
        dim_in, dim_out, kh, kw = wgt.shape
        sh, sw = stride - 1, stride - 1

        y = pm.temp(name=f"{data.name}_reshaped", shape=(n * c, h * w, 1, 1))
        pm.tensor_reshape(data, y, y.shape)

        y1 = pm.temp(name=f"{data.name}_reshaped1")
        pm.tensor_pad(y, y1, ((0, 0), (0, 0), (0, sh), (0, sw)))

        y2 = pm.temp(name=f"{data.name}_reshaped2",
                     shape=(n * c, h, w, 1 + sh, 1 + sw))
        pm.tensor_reshape(y1, y2, y2.shape)

        y3 = pm.temp(name=f"{data.name}_permuted",
                     shape=(n * c, h, 1 + sh, w, 1 + sw))
        pm.tensor_transpose(y2, y3, (0, 1, 3, 2, 4))

        y4 = pm.temp(name=f"{data.name}_reshaped3",
                     shape=(n, c, h * (1 + sh), w * (1 + sw)))
        pm.tensor_reshape(y3, y4, y4.shape)
        ph, pw = kh - pad - 1, kw - pad - 1

        w_perm = pm.temp(name=f"{wgt.name}_perm",
                         shape=(wgt.shape[1], wgt.shape[0], wgt.shape[2],
                                wgt.shape[3]))
        w_perm_flip = pm.temp(name=f"{wgt.name}_flip",
                              shape=(wgt.shape[1], wgt.shape[0], wgt.shape[2],
                                     wgt.shape[3]))

        pm.tensor_transpose(wgt, w_perm, (1, 0, 2, 3))
        pm.tensor_flip(w_perm, w_perm_flip, (2, 3))

        y5 = pm.temp(name=f"{data.name}_pad2")
        pm.tensor_pad(y4, y5, ((0, 0), (0, 0), (pw, pw - sw + out_pad),
                               (ph, ph - sh + out_pad)))
        pm.conv(y5, w_perm_flip, out, pad=0, stride=1)
Example #8
0
    def define_graph(self, inp, out, kh, kw, stride=1, pad=0):
        oh = ((inp.shape[2] + 2 * pad - kh) // stride + 1)
        ow = ((inp.shape[3] + 2 * pad - kw) // stride + 1)
        out.set_shape((inp.shape[0], inp.shape[1], oh, ow))

        b = pm.index(0, inp.shape[0] - 1, name="b")
        c = pm.index(0, inp.shape[1] - 1, name="c")
        y = pm.index(0, oh - 1, name="y")
        x = pm.index(0, ow - 1, name="x")
        m = pm.index(0, kh - 1, name="m")
        n = pm.index(0, kw - 1, name="n_")
        ihp = (inp.shape[2] + pad * 2)
        iwp = inp.shape[3] + pad * 2
        ihp_ = pm.index(0, ihp - 1, name="ihp")
        iwp_ = pm.index(0, iwp - 1, name="iwp")
        iy = pm.index(0, inp.shape[2] - 1, name="iy")
        ix = pm.index(0, inp.shape[3] - 1, name="ix")
        padded = pm.temp(shape=(inp.shape[0], inp.shape[1], ihp, iwp))
        padded[b, c, ihp_, iwp_] = 0
        padded[b, c, iy + pad, ix + pad] = inp[b, c, iy, ix]
        out[b, c, y,
            x] = ((1 / (kh * kw)) *
                  pm.sum([m, n], padded[b, c, stride * y + m, stride * x + n]))
Example #9
0
    def define_graph(self, data, out, kh, kw, stride=(1, 1), pad=(0, 0)):
        sx, sy = stride
        oh = ((data.shape[-2] + 2 * pad[0] - kh) // stride[0] + 1)
        ow = ((data.shape[-1] + 2 * pad[1] - kw) // stride[1] + 1)

        y = pm.index(0, oh - 1, name="y")
        x = pm.index(0, ow - 1, name="x")
        m = pm.index(0, kh - 1, name="m")
        n = pm.index(0, kw - 1, name="n_")
        ihp = (data.shape[-2] + pad[0] * 2)
        iwp = data.shape[-1] + pad[1] * 2
        ihp_ = pm.index(0, ihp - 1, name="ihp")
        iwp_ = pm.index(0, iwp - 1, name="iwp")
        iy = pm.index(0, data.shape[-2] - 1, name="iy")
        ix = pm.index(0, data.shape[-1] - 1, name="ix")

        if len(data.shape) > 3:
            b = pm.index(0, data.shape[0] - 1, name="b")
            c = pm.index(0, data.shape[1] - 1, name="c")

            o_indices = [b, c]
            p_shape = (data.shape[0], data.shape[1], ihp, iwp)
            out.set_shape((data.shape[0], data.shape[1], oh, ow))

        else:
            c = pm.index(0, data.shape[0] - 1, name="c")
            o_indices = [c]
            p_shape = (data.shape[0], ihp, iwp)
            out.set_shape((data.shape[0], oh, ow))
        o_indices = tuple(o_indices)
        padded = pm.temp(shape=p_shape)
        padded[o_indices + (ihp_, iwp_)] = 0
        padded[o_indices + (iy + pad[0], ix + pad[1])] = data[o_indices +
                                                              (iy, ix)]
        out[o_indices + (y, x)] = pm.sum(
            [m, n], padded[o_indices + (sx * y + m, sy * x + n)]) * (1 /
                                                                     (kh * kw))
Example #10
0
    def define_graph(self, z, y, loss, reduction="mean"):
        a = pm.temp(name=f"temp_{y.name}", shape=z.shape)

        i = pm.index(0, z.shape[1] - 1, name="i")
        indices = [
            pm.index(0, s - 1, name=f"{z.name}[{i}]")
            for i, s in enumerate(z.shape)
        ]
        indices[1] = i
        indices = tuple(indices)
        maxes = pm.max([i], z[indices], name="maxes")
        exp_val = pm.exp((z[indices] - maxes[indices[0]]))
        lse_stable = pm.log(pm.sum([i], exp_val[indices], name="testing_lse"),
                            name="lse_stable")
        a[indices] = z[indices] - maxes[indices[0]] - lse_stable[indices[0]]
        gathered = pm.gather_elements(a,
                                      pm.reshape(y, (a.shape[0], 1),
                                                 name="reshaped1"),
                                      axis=1,
                                      shape=(y.shape[0], ),
                                      name="gathered_elem")
        reshaped = pm.reshape(-1 * gathered, (y.shape[0], ),
                              name="other_reshape")
        idx = (pm.index(0, a.shape[0] - 1), )
        if reduction == "none":
            loss.set_shape(reshaped.shape)
            loss[idx] = reshaped[idx]
        elif reduction == "mean":
            loss.set_shape((1, ))
            denom = 1
            for s in reshaped.shape:
                denom = denom * s
            loss[0] = pm.sum([idx[0]], reshaped[idx],
                             name="test_sum_name") / denom
        elif reduction == "sum":
            loss.set_shape((1, ))
            loss[0] = pm.sum([idx[0]], reshaped[idx])
Example #11
0
    def define_graph(self, data, out, kh, kw, stride=(1, 1), pad=(0, 0)):

        oh = ((data.shape[-2] + 2 * pad[0] - kh) // stride[0] + 1)
        ow = ((data.shape[-1] + 2 * pad[1] - kw) // stride[1] + 1)

        y = pm.index(0, oh - 1)
        x = pm.index(0, ow - 1)
        m = pm.index(0, kh - 1)
        n = pm.index(0, kw - 1)
        ihp = (data.shape[-2] + pad[0] * 2)
        iwp = data.shape[-1] + pad[1] * 2
        ihp_ = pm.index(0, ihp - 1, name="ihp")
        iwp_ = pm.index(0, iwp - 1, name="iwp")
        iy = pm.index(0, data.shape[-2] - 1, name="iy")
        ix = pm.index(0, data.shape[-1] - 1, name="ix")

        if len(data.shape) > 3:
            b = pm.index(0, data.shape[0] - 1, name="b")
            c = pm.index(0, data.shape[1] - 1, name="c")

            o_indices = (b, c)
            p_shape = (data.shape[0], data.shape[1], ihp, iwp)
            out.set_shape((data.shape[0], data.shape[1], oh, ow))

        else:
            c = pm.index(0, data.shape[0] - 1, name="c")
            o_indices = (c, )
            p_shape = (data.shape[0], ihp, iwp)
            out.set_shape((data.shape[0], oh, ow))

        padded = pm.temp(shape=p_shape)
        padded[o_indices, ihp_, iwp_] = 0
        padded[o_indices, iy + pad[0], ix + pad[1]] = data[o_indices, iy, ix]
        out[o_indices, y, x] = pm.max([m, n],
                                      padded[o_indices, stride[0] * y + m,
                                             stride[1] * x + n])
Example #12
0
    def define_graph(self, data, w, out, stride=1, pad=0, dilation=1):

        if not isinstance(stride, (tuple, list)):
            stride_h = stride_w = stride
        else:
            stride_h, stride_w = stride

        if not isinstance(stride, (tuple, list)):
            dilation_h = dilation_w = dilation
        else:
            dilation_h, dilation_w = dilation

        if not isinstance(stride, (tuple, list)):
            pad = (pad, pad)

        batch, in_channel, in_height, in_width = data.shape
        num_filter, channel, kernel_h, kernel_w = w.shape
        # compute the output shape
        dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
        dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
        pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
            pad, (dilated_kernel_h, dilated_kernel_w))
        out_channel = num_filter
        oh = (in_height - dilated_kernel_h + pad_top +
              pad_down) // stride_h + 1
        ow = (in_width - dilated_kernel_w + pad_left +
              pad_right) // stride_w + 1
        pad_before = [0, 0, pad_top, pad_left]
        pad_after = [0, 0, pad_down, pad_right]
        c = pm.index(0, w.shape[0] - 1, name="c")
        y = pm.index(0, oh - 1, name="y_")
        x = pm.index(0, ow - 1, name="x_")
        dy = pm.index(0, w.shape[2] - 1, name="dy")
        dx = pm.index(0, w.shape[3] - 1, name="dx")
        iy = pm.index(0, data.shape[-2] - 1, name="iy")
        ix = pm.index(0, data.shape[-1] - 1, name="ix")
        k = pm.index(0, data.shape[-3] - 1, name="k")
        ihp = data.shape[-2] + pad_top + pad_down
        iwp = data.shape[-1] + pad_left + pad_right
        ihp_ = pm.index(0, ihp - 1, name="ihp")
        iwp_ = pm.index(0, iwp - 1, name="iwp")
        if len(data.shape) > 3:
            b = pm.index(0, data.shape[0] - 1, name="b")
            o_indices = (b, c)
            p_indices = (
                b,
                k,
            )
            p_shape = (data.shape[0], data.shape[1], ihp, iwp)
            out.set_shape((data.shape[0], w.shape[0], oh, ow))
        else:
            o_indices = (c, )
            p_indices = (k, )
            p_shape = (data.shape[0], ihp, iwp)
            out.set_shape((w.shape[0], oh, ow))

        padded = pm.temp(shape=p_shape)
        padded[p_indices + (ihp_, iwp_)] = 0

        padded[p_indices + (iy + pad_top, ix + pad_left)] = data[p_indices +
                                                                 (iy, ix)]

        # out[o_indices + (y, x)] = pm.sum([dy, dx, k], (padded[p_indices + (dy + stride*y, dx + stride*x)] * w[c, k, dy, dx])) + bias[c]

        out[o_indices + (y, x)] = pm.sum(
            [dy, dx, k],
            (padded[p_indices +
                    (dy * dilation_h + stride * y,
                     dx * dilation_w + stride * x)] * w[c, k, dy, dx]))