コード例 #1
0
 def define_graph(self, data, out, axis=0):
     out.set_shape(data.shape)
     i = pm.index(0, data.shape[axis] - 1, name="i")
     indices = [
         pm.index(0, s - 1, name=f"{data.name}[{i}]")
         for i, s in enumerate(data.shape)
     ]
     indices[axis] = i
     indices = tuple(indices)
     maxes = pm.max([i], data[indices], name="maxes")
     lse_stable = pm.log(pm.sum([i],
                                pm.exp(
                                    (data[indices] - maxes[indices[0]]))),
                         name="lse_stable")
     out[indices] = data[indices] - maxes[indices[0]] - lse_stable[
         indices[0]]
コード例 #2
0
    def define_graph(self, z, y, loss, reduction="mean"):
        a = pm.temp(name=f"temp_{y.name}", shape=z.shape)

        i = pm.index(0, z.shape[1] - 1, name="i")
        indices = [
            pm.index(0, s - 1, name=f"{z.name}[{i}]")
            for i, s in enumerate(z.shape)
        ]
        indices[1] = i
        indices = tuple(indices)
        maxes = pm.max([i], z[indices], name="maxes")
        exp_val = pm.exp((z[indices] - maxes[indices[0]]))
        lse_stable = pm.log(pm.sum([i], exp_val[indices], name="testing_lse"),
                            name="lse_stable")
        a[indices] = z[indices] - maxes[indices[0]] - lse_stable[indices[0]]
        gathered = pm.gather_elements(a,
                                      pm.reshape(y, (a.shape[0], 1),
                                                 name="reshaped1"),
                                      axis=1,
                                      shape=(y.shape[0], ),
                                      name="gathered_elem")
        reshaped = pm.reshape(-1 * gathered, (y.shape[0], ),
                              name="other_reshape")
        idx = (pm.index(0, a.shape[0] - 1), )
        if reduction == "none":
            loss.set_shape(reshaped.shape)
            loss[idx] = reshaped[idx]
        elif reduction == "mean":
            loss.set_shape((1, ))
            denom = 1
            for s in reshaped.shape:
                denom = denom * s
            loss[0] = pm.sum([idx[0]], reshaped[idx],
                             name="test_sum_name") / denom
        elif reduction == "sum":
            loss.set_shape((1, ))
            loss[0] = pm.sum([idx[0]], reshaped[idx])
コード例 #3
0
 def define_graph(self, x, out):
     indices = _get_single_node_indices(out, shape=out.shape)
     out[indices] = pm.log(x[indices])