def define_graph(self, data, out, axis=0): out.set_shape(data.shape) i = pm.index(0, data.shape[axis] - 1, name="i") indices = [ pm.index(0, s - 1, name=f"{data.name}[{i}]") for i, s in enumerate(data.shape) ] indices[axis] = i indices = tuple(indices) maxes = pm.max([i], data[indices], name="maxes") lse_stable = pm.log(pm.sum([i], pm.exp( (data[indices] - maxes[indices[0]]))), name="lse_stable") out[indices] = data[indices] - maxes[indices[0]] - lse_stable[ indices[0]]
def define_graph(self, z, y, loss, reduction="mean"): a = pm.temp(name=f"temp_{y.name}", shape=z.shape) i = pm.index(0, z.shape[1] - 1, name="i") indices = [ pm.index(0, s - 1, name=f"{z.name}[{i}]") for i, s in enumerate(z.shape) ] indices[1] = i indices = tuple(indices) maxes = pm.max([i], z[indices], name="maxes") exp_val = pm.exp((z[indices] - maxes[indices[0]])) lse_stable = pm.log(pm.sum([i], exp_val[indices], name="testing_lse"), name="lse_stable") a[indices] = z[indices] - maxes[indices[0]] - lse_stable[indices[0]] gathered = pm.gather_elements(a, pm.reshape(y, (a.shape[0], 1), name="reshaped1"), axis=1, shape=(y.shape[0], ), name="gathered_elem") reshaped = pm.reshape(-1 * gathered, (y.shape[0], ), name="other_reshape") idx = (pm.index(0, a.shape[0] - 1), ) if reduction == "none": loss.set_shape(reshaped.shape) loss[idx] = reshaped[idx] elif reduction == "mean": loss.set_shape((1, )) denom = 1 for s in reshaped.shape: denom = denom * s loss[0] = pm.sum([idx[0]], reshaped[idx], name="test_sum_name") / denom elif reduction == "sum": loss.set_shape((1, )) loss[0] = pm.sum([idx[0]], reshaped[idx])
def define_graph(self, x, out): indices = _get_single_node_indices(out, shape=out.shape) out[indices] = pm.log(x[indices])