def define_graph(self, inp, weight, bias, grad, inp_grad, weight_grad, bias_grad, optimizer, optimizer_kwargs, stride=1, pad=0, dilation=1): min_sizes = [] k = len(grad.shape) - 2 for d in range(k): min_sizes.append((grad.shape[d + 2] - 1) * stride - 2 * pad + (weight.shape[-1] - 1) * dilation + 1) grad_input_padding = tuple(inp.shape[-k + d] - min_sizes[d] for d in range(k)) assert grad_input_padding[0] == grad_input_padding[1] pm.conv_transpose_bias(grad, weight, bias, inp_grad, stride=stride, pad=pad, out_pad=grad_input_padding[0]) inp_indices = tuple(pm.index(0, s - 1) for s in inp.shape) grad_indices = tuple(pm.index(0, s - 1) for s in grad.shape) weight_indices = tuple(pm.index(0, s - 1) for s in weight.shape) inp_transposed = pm.temp(name=f"transposed_{inp.name}", shape=(inp.shape[1], inp.shape[0], inp.shape[2], inp.shape[3])) grad_transposed = pm.state(name=f"transposed_{grad.name}", shape=(grad.shape[1], grad.shape[0], grad.shape[2], grad.shape[3])) wgt_grad_transposed = pm.temp(name=f"transposed_{weight.name}", shape=(weight.shape[1], weight.shape[0], weight.shape[2], weight.shape[3])) pm.tensor_transpose(inp, inp_transposed, perm=(1, 0, 2, 3)) pm.tensor_transpose(grad, grad_transposed, perm=(1, 0, 2, 3)) pm.conv(inp_transposed, grad_transposed, wgt_grad_transposed, stride=dilation, pad=pad, dilation=stride) pm.tensor_transpose(wgt_grad_transposed, weight_grad, perm=(1, 0, 2, 3)) # Weight update OPTIMIZERS[optimizer](weight, weight_grad, **optimizer_kwargs) pm.reduce_sum(grad, bias_grad) OPTIMIZERS[optimizer](bias, bias_grad, **optimizer_kwargs)
def define_graph(self, x, w, y, y_pred, mu, m): i = pm.index(0, (m - 1).set_name("m-1"), name="i") h = pm.temp(name="h", shape=(m)) h = pm.sigmoid(pm.sum([i], (x[i] * w[i]), name="h")) d = (h - y).set_name("h-y") g = (d * x[i]).set_name("d*x") w[i] = w[i] - mu * g[i]
def populate_temp(self, node): if node.shape != pm.DEFAULT_SHAPES[0]: indices = list(product(*tuple([np.arange(i) for i in node.shape]))) for i in indices: x = pm.temp(graph=node, name=f"{node.name}{i}", root_name=node.name, shape=(1, )) self.stored_objects[id(x)] = x
def dilate(var: pm.placeholder, strides, name=None): n = len(var.shape) assert len(strides) == n out_shape = () nz_indices = () shape_idx = () for i in range(n): out_shape += ((var.shape[i] - 1) * strides[i] + 1, ) nz_indices += (pm.index(0, out_shape[i] - 1, stride=strides[i]), ) shape_idx += (pm.index(0, out_shape[i] - 1), ) padded = pm.temp(name=name, shape=out_shape) padded[shape_idx] = 0 padded[(shape_idx[0])] = 0
def get_gemm(a, b, c=None, shape=None, name=None, alpha=None, beta=None, transA=None, transB=None, out=None): if not out: out = pm.output(shape=shape, name=name) if transB: assert len(b.shape) == 2 b.shape = (b.shape[1], b.shape[0]) transB = False if c: pm.gemm(a, b, c, out, alpha=alpha, beta=beta, transA=transA, transB=transB, strict_shapes=True) else: t_c = pm.temp(shape=shape) i = pm.index(0, shape[0] - 1) j = pm.index(0, shape[1] - 1) t_c[i, j] = 0 pm.gemm(a, b, t_c, out, alpha=alpha, beta=beta, transA=transA, transB=transB, strict_shapes=True) return out
def define_graph(self, x, y, alpha, beta, bias, nsize): n = pm.index(0, x.shape[0] - 1) c = pm.index(0, x.shape[1] - 1) h = pm.index(0, x.shape[2] - 1) w = pm.index(0, x.shape[3] - 1) c_ = pm.index(0, x.shape[1] - 1) ext = pm.temp(name="extended", shape=tuple([*x.shape, x.shape[-3]])) bounds = pm.output(name="bounds", shape=(x.shape[1], x.shape[1])) radius = nsize // 2 hbool = ((((x.shape[1] > (c + radius + 1)) * (c + radius)) + (x.shape[1] <= (c + radius + 1)) * (x.shape[1] - 1)) >= c_) lbool = ((((c - radius) > 0) * (c - radius)) + (((c - radius) <= 0) * 0) <= c_) bounds[c, c_] = hbool * lbool ext[n, c, h, w, c_] = x[n, c_, h, w] * bounds[c, c_] # y[n, c, h, w] = x[n,c,h,w] / ((bias + (alpha/nsize) * pm.sum([c_], ext[n, c, h, w, c_]**2))**beta) y[n, c, h, w] = x[n, c, h, w] / ( (bias + (alpha / nsize) * pm.sum([c_], ext[n, c, h, w, c_]**2))**beta)
def define_graph(self, data, wgt, out, stride=1, pad=0, out_pad=0): n, c, h, w = data.shape dim_in, dim_out, kh, kw = wgt.shape sh, sw = stride - 1, stride - 1 y = pm.temp(name=f"{data.name}_reshaped", shape=(n * c, h * w, 1, 1)) pm.tensor_reshape(data, y, y.shape) y1 = pm.temp(name=f"{data.name}_reshaped1") pm.tensor_pad(y, y1, ((0, 0), (0, 0), (0, sh), (0, sw))) y2 = pm.temp(name=f"{data.name}_reshaped2", shape=(n * c, h, w, 1 + sh, 1 + sw)) pm.tensor_reshape(y1, y2, y2.shape) y3 = pm.temp(name=f"{data.name}_permuted", shape=(n * c, h, 1 + sh, w, 1 + sw)) pm.tensor_transpose(y2, y3, (0, 1, 3, 2, 4)) y4 = pm.temp(name=f"{data.name}_reshaped3", shape=(n, c, h * (1 + sh), w * (1 + sw))) pm.tensor_reshape(y3, y4, y4.shape) ph, pw = kh - pad - 1, kw - pad - 1 w_perm = pm.temp(name=f"{wgt.name}_perm", shape=(wgt.shape[1], wgt.shape[0], wgt.shape[2], wgt.shape[3])) w_perm_flip = pm.temp(name=f"{wgt.name}_flip", shape=(wgt.shape[1], wgt.shape[0], wgt.shape[2], wgt.shape[3])) pm.tensor_transpose(wgt, w_perm, (1, 0, 2, 3)) pm.tensor_flip(w_perm, w_perm_flip, (2, 3)) y5 = pm.temp(name=f"{data.name}_pad2") pm.tensor_pad(y4, y5, ((0, 0), (0, 0), (pw, pw - sw + out_pad), (ph, ph - sh + out_pad))) pm.conv(y5, w_perm_flip, out, pad=0, stride=1)
def define_graph(self, inp, out, kh, kw, stride=1, pad=0): oh = ((inp.shape[2] + 2 * pad - kh) // stride + 1) ow = ((inp.shape[3] + 2 * pad - kw) // stride + 1) out.set_shape((inp.shape[0], inp.shape[1], oh, ow)) b = pm.index(0, inp.shape[0] - 1, name="b") c = pm.index(0, inp.shape[1] - 1, name="c") y = pm.index(0, oh - 1, name="y") x = pm.index(0, ow - 1, name="x") m = pm.index(0, kh - 1, name="m") n = pm.index(0, kw - 1, name="n_") ihp = (inp.shape[2] + pad * 2) iwp = inp.shape[3] + pad * 2 ihp_ = pm.index(0, ihp - 1, name="ihp") iwp_ = pm.index(0, iwp - 1, name="iwp") iy = pm.index(0, inp.shape[2] - 1, name="iy") ix = pm.index(0, inp.shape[3] - 1, name="ix") padded = pm.temp(shape=(inp.shape[0], inp.shape[1], ihp, iwp)) padded[b, c, ihp_, iwp_] = 0 padded[b, c, iy + pad, ix + pad] = inp[b, c, iy, ix] out[b, c, y, x] = ((1 / (kh * kw)) * pm.sum([m, n], padded[b, c, stride * y + m, stride * x + n]))
def define_graph(self, data, out, kh, kw, stride=(1, 1), pad=(0, 0)): sx, sy = stride oh = ((data.shape[-2] + 2 * pad[0] - kh) // stride[0] + 1) ow = ((data.shape[-1] + 2 * pad[1] - kw) // stride[1] + 1) y = pm.index(0, oh - 1, name="y") x = pm.index(0, ow - 1, name="x") m = pm.index(0, kh - 1, name="m") n = pm.index(0, kw - 1, name="n_") ihp = (data.shape[-2] + pad[0] * 2) iwp = data.shape[-1] + pad[1] * 2 ihp_ = pm.index(0, ihp - 1, name="ihp") iwp_ = pm.index(0, iwp - 1, name="iwp") iy = pm.index(0, data.shape[-2] - 1, name="iy") ix = pm.index(0, data.shape[-1] - 1, name="ix") if len(data.shape) > 3: b = pm.index(0, data.shape[0] - 1, name="b") c = pm.index(0, data.shape[1] - 1, name="c") o_indices = [b, c] p_shape = (data.shape[0], data.shape[1], ihp, iwp) out.set_shape((data.shape[0], data.shape[1], oh, ow)) else: c = pm.index(0, data.shape[0] - 1, name="c") o_indices = [c] p_shape = (data.shape[0], ihp, iwp) out.set_shape((data.shape[0], oh, ow)) o_indices = tuple(o_indices) padded = pm.temp(shape=p_shape) padded[o_indices + (ihp_, iwp_)] = 0 padded[o_indices + (iy + pad[0], ix + pad[1])] = data[o_indices + (iy, ix)] out[o_indices + (y, x)] = pm.sum( [m, n], padded[o_indices + (sx * y + m, sy * x + n)]) * (1 / (kh * kw))
def define_graph(self, z, y, loss, reduction="mean"): a = pm.temp(name=f"temp_{y.name}", shape=z.shape) i = pm.index(0, z.shape[1] - 1, name="i") indices = [ pm.index(0, s - 1, name=f"{z.name}[{i}]") for i, s in enumerate(z.shape) ] indices[1] = i indices = tuple(indices) maxes = pm.max([i], z[indices], name="maxes") exp_val = pm.exp((z[indices] - maxes[indices[0]])) lse_stable = pm.log(pm.sum([i], exp_val[indices], name="testing_lse"), name="lse_stable") a[indices] = z[indices] - maxes[indices[0]] - lse_stable[indices[0]] gathered = pm.gather_elements(a, pm.reshape(y, (a.shape[0], 1), name="reshaped1"), axis=1, shape=(y.shape[0], ), name="gathered_elem") reshaped = pm.reshape(-1 * gathered, (y.shape[0], ), name="other_reshape") idx = (pm.index(0, a.shape[0] - 1), ) if reduction == "none": loss.set_shape(reshaped.shape) loss[idx] = reshaped[idx] elif reduction == "mean": loss.set_shape((1, )) denom = 1 for s in reshaped.shape: denom = denom * s loss[0] = pm.sum([idx[0]], reshaped[idx], name="test_sum_name") / denom elif reduction == "sum": loss.set_shape((1, )) loss[0] = pm.sum([idx[0]], reshaped[idx])
def define_graph(self, data, out, kh, kw, stride=(1, 1), pad=(0, 0)): oh = ((data.shape[-2] + 2 * pad[0] - kh) // stride[0] + 1) ow = ((data.shape[-1] + 2 * pad[1] - kw) // stride[1] + 1) y = pm.index(0, oh - 1) x = pm.index(0, ow - 1) m = pm.index(0, kh - 1) n = pm.index(0, kw - 1) ihp = (data.shape[-2] + pad[0] * 2) iwp = data.shape[-1] + pad[1] * 2 ihp_ = pm.index(0, ihp - 1, name="ihp") iwp_ = pm.index(0, iwp - 1, name="iwp") iy = pm.index(0, data.shape[-2] - 1, name="iy") ix = pm.index(0, data.shape[-1] - 1, name="ix") if len(data.shape) > 3: b = pm.index(0, data.shape[0] - 1, name="b") c = pm.index(0, data.shape[1] - 1, name="c") o_indices = (b, c) p_shape = (data.shape[0], data.shape[1], ihp, iwp) out.set_shape((data.shape[0], data.shape[1], oh, ow)) else: c = pm.index(0, data.shape[0] - 1, name="c") o_indices = (c, ) p_shape = (data.shape[0], ihp, iwp) out.set_shape((data.shape[0], oh, ow)) padded = pm.temp(shape=p_shape) padded[o_indices, ihp_, iwp_] = 0 padded[o_indices, iy + pad[0], ix + pad[1]] = data[o_indices, iy, ix] out[o_indices, y, x] = pm.max([m, n], padded[o_indices, stride[0] * y + m, stride[1] * x + n])
def define_graph(self, data, w, out, stride=1, pad=0, dilation=1): if not isinstance(stride, (tuple, list)): stride_h = stride_w = stride else: stride_h, stride_w = stride if not isinstance(stride, (tuple, list)): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation if not isinstance(stride, (tuple, list)): pad = (pad, pad) batch, in_channel, in_height, in_width = data.shape num_filter, channel, kernel_h, kernel_w = w.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( pad, (dilated_kernel_h, dilated_kernel_w)) out_channel = num_filter oh = (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1 ow = (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1 pad_before = [0, 0, pad_top, pad_left] pad_after = [0, 0, pad_down, pad_right] c = pm.index(0, w.shape[0] - 1, name="c") y = pm.index(0, oh - 1, name="y_") x = pm.index(0, ow - 1, name="x_") dy = pm.index(0, w.shape[2] - 1, name="dy") dx = pm.index(0, w.shape[3] - 1, name="dx") iy = pm.index(0, data.shape[-2] - 1, name="iy") ix = pm.index(0, data.shape[-1] - 1, name="ix") k = pm.index(0, data.shape[-3] - 1, name="k") ihp = data.shape[-2] + pad_top + pad_down iwp = data.shape[-1] + pad_left + pad_right ihp_ = pm.index(0, ihp - 1, name="ihp") iwp_ = pm.index(0, iwp - 1, name="iwp") if len(data.shape) > 3: b = pm.index(0, data.shape[0] - 1, name="b") o_indices = (b, c) p_indices = ( b, k, ) p_shape = (data.shape[0], data.shape[1], ihp, iwp) out.set_shape((data.shape[0], w.shape[0], oh, ow)) else: o_indices = (c, ) p_indices = (k, ) p_shape = (data.shape[0], ihp, iwp) out.set_shape((w.shape[0], oh, ow)) padded = pm.temp(shape=p_shape) padded[p_indices + (ihp_, iwp_)] = 0 padded[p_indices + (iy + pad_top, ix + pad_left)] = data[p_indices + (iy, ix)] # out[o_indices + (y, x)] = pm.sum([dy, dx, k], (padded[p_indices + (dy + stride*y, dx + stride*x)] * w[c, k, dy, dx])) + bias[c] out[o_indices + (y, x)] = pm.sum( [dy, dx, k], (padded[p_indices + (dy * dilation_h + stride * y, dx * dilation_w + stride * x)] * w[c, k, dy, dx]))