Exemplo n.º 1
0
    def forward(self, x):
        x = x.permute(0, 2, 4, 1, 3).contiguous()
        sr, sb, sc, ss, si = tuple(x.size())
        x = x.view(sr * sb, sc, ss, si)

        result = Variable(cast(np.zeros([sr * sb, 8, SIZE, INPUT + OUTPUT])))

        init = x[:, :, 0:WINDOW, :]
        guess = self.guess(init.contiguous())
        state = th.cat((init, guess), dim=3)
        result[:, :, 0::SIZE, :] = state[:, :, 0::SIZE, :]

        for i in range(1, SIZE, 1):
            print('-----------------------------')
            print('idx:', i)
            sys.stdout.flush()

            statenx = self.evolve(state, w=1)
            input = statenx[:, :, :, :INPUT]
            if i < SIZE - WINDOW:
                init = x[:, :, i:WINDOW + i, :]
                guess = self.guess(init.contiguous())
                stategs = th.cat((init, guess), dim=3)
            else:
                guess = self.guess(input.contiguous())
                stategs = th.cat((input, guess), dim=3)

            ratio = self.ratio(th.cat([statenx, stategs], dim=1))
            state = ratio * statenx + (1 - ratio) * stategs

            result[:, :, i::SIZE, :] = state[:, :, 0::SIZE, :]

        return result
Exemplo n.º 2
0
    def __init__(self):
        super(Evolve, self).__init__()
        w = SIZE
        c = 8
        d = c * w

        off_diag = np.ones([BODYCOUNT, BODYCOUNT]) - np.eye(BODYCOUNT)
        self.rel_rec = Variable(
            cast(
                np.array(encode_onehot(np.where(off_diag)[1]),
                         dtype=np.float32)))
        self.rel_send = Variable(
            cast(
                np.array(encode_onehot(np.where(off_diag)[0]),
                         dtype=np.float32)))

        self.encoder = MLPEncoder(d, 2048, 1)
        self.decoder = MLPDecoder(c, 1, 2048, 2048, 2048)
Exemplo n.º 3
0
    def __init__(self):
        super(Evolve, self).__init__()
        n = INPUT + OUTPUT
        w = WINDOW
        c = 8
        d = c * w

        off_diag = np.ones([n, n]) - np.eye(n)
        self.rel_rec = Variable(
            cast(
                np.array(encode_onehot(np.where(off_diag)[1]),
                         dtype=np.float32)))
        self.rel_send = Variable(
            cast(
                np.array(encode_onehot(np.where(off_diag)[0]),
                         dtype=np.float32)))

        self.encoder = MLPEncoder(d, 2048, 1)
        self.decoder = MLPDecoder(c, 1, 2048, 2048, 2048)
Exemplo n.º 4
0
    def forward(self, x):
        x = x.permute(0, 2, 4, 1, 3).contiguous()
        sr, sb, sc, ss, si = tuple(x.size())
        state = x.view(sr * sb, sc, ss, si)
        result = Variable(cast(np.zeros([sr * sb, 8, 3 * SIZE, BODYCOUNT])))
        for i in range(4 * SIZE):
            state = self.evolve(state, w=1)
            if i >= SIZE:
                result[:, :, i - SIZE, :] = state[:, :, 0, :]

        return result
Exemplo n.º 5
0
def divergence_th(xs, ys):
    sz = xs.size()
    b = sz[0]
    v = sz[2] * sz[3]
    s = Variable(cast(sun), requires_grad=False)
    s = th.cat([s for _ in range(b // s.size()[0])], dim=0)

    xs = xs.permute(0, 2, 3, 1).contiguous().view(b, v, 3)
    ys = ys.permute(0, 2, 3, 1).contiguous().view(b, v, 3)
    xs = xs - s
    ys = ys - s
    rx = th.norm(xs, p=2, dim=-1, keepdim=True)
    ry = th.norm(ys, p=2, dim=-1, keepdim=True)
    ux = xs / rx
    uy = ys / ry
    da = 1 - th.bmm(ux.view(b * v, 1, 3), uy.view(b * v, 3, 1)).view(b, v, 1)
    dr = ((rx - ry) * (rx - ry)).view(b, v, 1)
    return th.sum(da + dr, dim=2)