Example #1
0
 def define_graph(self, data, out, axes=(0, ), keepdims=True):
     # indices = _get_single_node_indices(data)
     indices = tuple([pm.index(0, s - 1) for s in data.shape])
     sum_idx = tuple([indices[i] for i in axes])
     out_idx = tuple(
         [indices[i] for i in range(len(indices)) if i not in axes])
     out[out_idx] = pm.max([sum_idx], data[indices])
Example #2
0
 def define_graph(self, data, out, axis=0):
     out.set_shape(data.shape)
     i = pm.index(0, data.shape[axis] - 1, name="i")
     indices = [
         pm.index(0, s - 1, name=f"{data.name}[{i}]")
         for i, s in enumerate(data.shape)
     ]
     indices[axis] = i
     indices = tuple(indices)
     maxes = pm.max([i], data[indices], name="maxes")
     lse_stable = pm.log(pm.sum([i],
                                pm.exp(
                                    (data[indices] - maxes[indices[0]]))),
                         name="lse_stable")
     out[indices] = data[indices] - maxes[indices[0]] - lse_stable[
         indices[0]]
Example #3
0
 def define_graph(self, data, out, axis=0):
     out.set_shape(data.shape)
     i = pm.index(0, data.shape[axis] - 1, name="i")
     j = pm.index(0, data.shape[axis] - 1, name="j")
     indices = [
         pm.index(0, s - 1, name=f"{data.name}[{i}]")
         for i, s in enumerate(data.shape)
     ]
     indices_denom = indices
     indices_denom[axis] = j
     indices[axis] = i
     indices = tuple(indices)
     indices_denom = tuple(indices_denom)
     mval = pm.max([i], data[indices], name="max_test")
     e_x = pm.exp((data[indices] - mval), name="e_x")
     out[indices] = e_x[indices] / pm.sum(
         [indices_denom[axis]], e_x[indices_denom], name="denom")
Example #4
0
    def define_graph(self, z, y, loss, reduction="mean"):
        a = pm.temp(name=f"temp_{y.name}", shape=z.shape)

        i = pm.index(0, z.shape[1] - 1, name="i")
        indices = [
            pm.index(0, s - 1, name=f"{z.name}[{i}]")
            for i, s in enumerate(z.shape)
        ]
        indices[1] = i
        indices = tuple(indices)
        maxes = pm.max([i], z[indices], name="maxes")
        exp_val = pm.exp((z[indices] - maxes[indices[0]]))
        lse_stable = pm.log(pm.sum([i], exp_val[indices], name="testing_lse"),
                            name="lse_stable")
        a[indices] = z[indices] - maxes[indices[0]] - lse_stable[indices[0]]
        gathered = pm.gather_elements(a,
                                      pm.reshape(y, (a.shape[0], 1),
                                                 name="reshaped1"),
                                      axis=1,
                                      shape=(y.shape[0], ),
                                      name="gathered_elem")
        reshaped = pm.reshape(-1 * gathered, (y.shape[0], ),
                              name="other_reshape")
        idx = (pm.index(0, a.shape[0] - 1), )
        if reduction == "none":
            loss.set_shape(reshaped.shape)
            loss[idx] = reshaped[idx]
        elif reduction == "mean":
            loss.set_shape((1, ))
            denom = 1
            for s in reshaped.shape:
                denom = denom * s
            loss[0] = pm.sum([idx[0]], reshaped[idx],
                             name="test_sum_name") / denom
        elif reduction == "sum":
            loss.set_shape((1, ))
            loss[0] = pm.sum([idx[0]], reshaped[idx])
Example #5
0
    def define_graph(self, data, out, kh, kw, stride=(1, 1), pad=(0, 0)):

        oh = ((data.shape[-2] + 2 * pad[0] - kh) // stride[0] + 1)
        ow = ((data.shape[-1] + 2 * pad[1] - kw) // stride[1] + 1)

        y = pm.index(0, oh - 1)
        x = pm.index(0, ow - 1)
        m = pm.index(0, kh - 1)
        n = pm.index(0, kw - 1)
        ihp = (data.shape[-2] + pad[0] * 2)
        iwp = data.shape[-1] + pad[1] * 2
        ihp_ = pm.index(0, ihp - 1, name="ihp")
        iwp_ = pm.index(0, iwp - 1, name="iwp")
        iy = pm.index(0, data.shape[-2] - 1, name="iy")
        ix = pm.index(0, data.shape[-1] - 1, name="ix")

        if len(data.shape) > 3:
            b = pm.index(0, data.shape[0] - 1, name="b")
            c = pm.index(0, data.shape[1] - 1, name="c")

            o_indices = (b, c)
            p_shape = (data.shape[0], data.shape[1], ihp, iwp)
            out.set_shape((data.shape[0], data.shape[1], oh, ow))

        else:
            c = pm.index(0, data.shape[0] - 1, name="c")
            o_indices = (c, )
            p_shape = (data.shape[0], ihp, iwp)
            out.set_shape((data.shape[0], oh, ow))

        padded = pm.temp(shape=p_shape)
        padded[o_indices, ihp_, iwp_] = 0
        padded[o_indices, iy + pad[0], ix + pad[1]] = data[o_indices, iy, ix]
        out[o_indices, y, x] = pm.max([m, n],
                                      padded[o_indices, stride[0] * y + m,
                                             stride[1] * x + n])