def define_graph(self, data, out, axes=(0, ), keepdims=True): # indices = _get_single_node_indices(data) indices = tuple([pm.index(0, s - 1) for s in data.shape]) sum_idx = tuple([indices[i] for i in axes]) out_idx = tuple( [indices[i] for i in range(len(indices)) if i not in axes]) out[out_idx] = pm.max([sum_idx], data[indices])
def define_graph(self, data, out, axis=0): out.set_shape(data.shape) i = pm.index(0, data.shape[axis] - 1, name="i") indices = [ pm.index(0, s - 1, name=f"{data.name}[{i}]") for i, s in enumerate(data.shape) ] indices[axis] = i indices = tuple(indices) maxes = pm.max([i], data[indices], name="maxes") lse_stable = pm.log(pm.sum([i], pm.exp( (data[indices] - maxes[indices[0]]))), name="lse_stable") out[indices] = data[indices] - maxes[indices[0]] - lse_stable[ indices[0]]
def define_graph(self, data, out, axis=0): out.set_shape(data.shape) i = pm.index(0, data.shape[axis] - 1, name="i") j = pm.index(0, data.shape[axis] - 1, name="j") indices = [ pm.index(0, s - 1, name=f"{data.name}[{i}]") for i, s in enumerate(data.shape) ] indices_denom = indices indices_denom[axis] = j indices[axis] = i indices = tuple(indices) indices_denom = tuple(indices_denom) mval = pm.max([i], data[indices], name="max_test") e_x = pm.exp((data[indices] - mval), name="e_x") out[indices] = e_x[indices] / pm.sum( [indices_denom[axis]], e_x[indices_denom], name="denom")
def define_graph(self, z, y, loss, reduction="mean"): a = pm.temp(name=f"temp_{y.name}", shape=z.shape) i = pm.index(0, z.shape[1] - 1, name="i") indices = [ pm.index(0, s - 1, name=f"{z.name}[{i}]") for i, s in enumerate(z.shape) ] indices[1] = i indices = tuple(indices) maxes = pm.max([i], z[indices], name="maxes") exp_val = pm.exp((z[indices] - maxes[indices[0]])) lse_stable = pm.log(pm.sum([i], exp_val[indices], name="testing_lse"), name="lse_stable") a[indices] = z[indices] - maxes[indices[0]] - lse_stable[indices[0]] gathered = pm.gather_elements(a, pm.reshape(y, (a.shape[0], 1), name="reshaped1"), axis=1, shape=(y.shape[0], ), name="gathered_elem") reshaped = pm.reshape(-1 * gathered, (y.shape[0], ), name="other_reshape") idx = (pm.index(0, a.shape[0] - 1), ) if reduction == "none": loss.set_shape(reshaped.shape) loss[idx] = reshaped[idx] elif reduction == "mean": loss.set_shape((1, )) denom = 1 for s in reshaped.shape: denom = denom * s loss[0] = pm.sum([idx[0]], reshaped[idx], name="test_sum_name") / denom elif reduction == "sum": loss.set_shape((1, )) loss[0] = pm.sum([idx[0]], reshaped[idx])
def define_graph(self, data, out, kh, kw, stride=(1, 1), pad=(0, 0)): oh = ((data.shape[-2] + 2 * pad[0] - kh) // stride[0] + 1) ow = ((data.shape[-1] + 2 * pad[1] - kw) // stride[1] + 1) y = pm.index(0, oh - 1) x = pm.index(0, ow - 1) m = pm.index(0, kh - 1) n = pm.index(0, kw - 1) ihp = (data.shape[-2] + pad[0] * 2) iwp = data.shape[-1] + pad[1] * 2 ihp_ = pm.index(0, ihp - 1, name="ihp") iwp_ = pm.index(0, iwp - 1, name="iwp") iy = pm.index(0, data.shape[-2] - 1, name="iy") ix = pm.index(0, data.shape[-1] - 1, name="ix") if len(data.shape) > 3: b = pm.index(0, data.shape[0] - 1, name="b") c = pm.index(0, data.shape[1] - 1, name="c") o_indices = (b, c) p_shape = (data.shape[0], data.shape[1], ihp, iwp) out.set_shape((data.shape[0], data.shape[1], oh, ow)) else: c = pm.index(0, data.shape[0] - 1, name="c") o_indices = (c, ) p_shape = (data.shape[0], ihp, iwp) out.set_shape((data.shape[0], oh, ow)) padded = pm.temp(shape=p_shape) padded[o_indices, ihp_, iwp_] = 0 padded[o_indices, iy + pad[0], ix + pad[1]] = data[o_indices, iy, ix] out[o_indices, y, x] = pm.max([m, n], padded[o_indices, stride[0] * y + m, stride[1] * x + n])