예제 #1
0
    def __init__(self,
                 n_in,
                 n_out,
                 kernel_size,
                 period,
                 key_pick='hash',
                 learn_key=False):
        super(HashConvSpCh, self).__init__()
        self.key_pick = key_pick
        self.period = period
        n = n_in

        if isinstance(kernel_size, int):
            kernel_size = [kernel_size, kernel_size]
        for k in kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        w_r = torch.empty(n_out, n_in, kernel_size[0],
                          kernel_size[1]).uniform_(-stdv, stdv)
        w_phi = torch.Tensor(*w_r.shape).uniform_(-np.pi, np.pi)
        o_r = torch.ones(*((period, 1) + w_r.shape[1:]))
        o_phi = torch.Tensor(*o_r.shape).uniform_(-np.pi, np.pi)

        self.w = nn.Parameter(torch.stack(from_polar(w_r, w_phi)))
        self.bias = nn.Parameter(
            torch.torch.empty(n_out, ).uniform_(-stdv, stdv))

        o = torch.stack(from_polar(o_r, o_phi))
        self.o = nn.Parameter(o, requires_grad=learn_key)
예제 #2
0
    def __init__(self, n_in, n_out, period, key_pick='hash', learn_key=True):
        super(HashTransform, self).__init__()
        self.key_pick = key_pick
        w_r = nn.init.xavier_normal_(torch.empty(n_in, n_out))
        w_phi = torch.Tensor(n_in, n_out).uniform_(-np.pi, np.pi)
        o_r = torch.ones(period, n_in)
        o_phi = torch.Tensor(period, n_in).uniform_(-np.pi, np.pi)

        self.w = nn.Parameter(torch.stack(from_polar(w_r, w_phi)))
        self.o = nn.Parameter(torch.stack(from_polar(o_r, o_phi)))
        if not learn_key:
            self.o.requires_grad = False
예제 #3
0
    def __init__(self, n_in, n_out, period, key_pick='hash', learn_key=True):
        super(FourierLinear, self).__init__()
        self.key_pick = key_pick
        w_r = nn.init.xavier_normal_(torch.empty(n_in, n_out))
        w_phi = torch.Tensor(n_in, n_out).uniform_(-np.pi, np.pi)
        o_r = torch.ones(period, n_in)
        o_phi = torch.Tensor(period, n_in)
        #o_phi = torch.Tensor(period, n_in).uniform_(-np.pi, np.pi)
        for i in range(n_in):
            o_phi[:, i] = (2 * np.pi * (i + 1)) / period

        self.w = nn.Parameter(torch.stack(from_polar(w_r, w_phi)))
        self.bias = nn.Parameter(torch.zeros(n_out))
        self.o = nn.Parameter(torch.stack(from_polar(o_r, o_phi)))
        if not learn_key:
            self.o.requires_grad = False
예제 #4
0
def pick_key(pick_method, keys, time):
    if pick_method == 'hash':
        net_time = int(time) % keys.shape[1]
        o = keys[:, net_time]
    elif pick_method == 'local_mix':
        center_time = int(time)
        b_time = (center_time - 1) % keys.shape[1]
        m_time = center_time % keys.shape[1]
        e_time = (center_time + 1) % keys.shape[1]
        o_r, o_phi = to_polar(keys)
        o = torch.stack(
            from_polar(o_r.mean(0),
                       (o_phi[b_time] + o_phi[m_time] + o_phi[e_time]) / 3))
    elif pick_method == 'local_mult':
        center_time = int(time)
        b_time = (center_time - 1) % keys.shape[1]
        m_time = center_time % keys.shape[1]
        e_time = (center_time + 1) % keys.shape[1]
        o = cmul(cmul(keys[:, b_time], keys[:, m_time]), keys[:, e_time])
    elif pick_method == 'temp_mix':
        net_time = torch.tensor([int(time) % keys.shape[1]])
        key_logit = torch.zeros(1, keys.shape[1]).scatter_(
            1, net_time.unsqueeze(1), 1. / (time / 1000. + 1e-5))
        key_prob = F.softmax(key_logit, 1).cuda()
        o_r, o_phi = to_polar(keys)
        o_r_pick = torch.matmul(key_prob.squeeze(), o_r)
        o_phi_pick = torch.matmul(key_prob.squeeze(), o_phi)
        o = torch.stack(from_polar(o_r_pick, o_phi_pick))
    elif pick_method == 'random':
        net_time = np.random.randint(keys.shape[1])
        o = keys[:, net_time]
    elif pick_method == 'cosine':
        omega = (int(time) % keys.shape[1]) * np.pi / keys.shape[1]
        mix = (torch.cos(torch.tensor(omega).cuda()) + 1) / 2.
        o = (mix * keys[:, 0]) + ((1. - mix) * keys[:, 1])
    elif pick_method == 'triangle_multiply':
        net_time = int(time) % keys.shape[1]
        o_r, o_phi = to_polar(keys)
        o = torch.stack(from_polar(o_r.mean(0), o_phi[:net_time + 1].sum(0)))
    elif pick_method == 'one_power':
        net_time = time % keys.shape[1]
        o_r, o_phi = to_polar(keys)
        o = torch.stack(from_polar(o_r[0], net_time * o_phi[0]))
    else:
        raise NotImplementedError

    return o
예제 #5
0
    def __init__(self,
                 n_chin,
                 n_chout,
                 kernel_size,
                 period,
                 key_pick='hash',
                 learn_key=True):
        super(HashConv, self).__init__()
        self.key_pick = key_pick
        n = n_chin
        for k in kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        w_r = torch.Tensor(n_chout, n_chin, *kernel_size).uniform_(-stdv, stdv)
        w_phi = torch.Tensor(n_chout, n_chin,
                             *kernel_size).uniform_(-np.pi, np.pi)
        o_r = torch.ones(period, n_chout)
        o_phi = torch.Tensor(period, n_chout).uniform_(-np.pi, np.pi)

        self.w = nn.Parameter(torch.stack(from_polar(w_r, w_phi)))
        self.bias = nn.Parameter(torch.Tensor(n_chout).uniform_(-stdv, stdv))
        self.o = nn.Parameter(torch.stack(from_polar(o_r, o_phi)))
예제 #6
0
    def __init__(self, n_chin, n_chout, kernel_size, padding=0, init_mult=1.):
        super(ComplexConv, self).__init__()
        n = n_chin
        for k in kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        w_r = torch.Tensor(n_chout, n_chin, *kernel_size).uniform_(-stdv, stdv)
        w_phi = torch.Tensor(n_chout, n_chin,
                             *kernel_size).uniform_(-np.pi, np.pi)

        self.w = nn.Parameter(torch.stack(from_polar(init_mult * w_r, w_phi)))
        self.bias = nn.Parameter(torch.Tensor(n_chout).uniform_(-stdv, stdv))
        self.padding = padding
예제 #7
0
    def forward(self, x_a, x_b, time):
        net_time = time % self.o.shape[1]
        o_r, o_phi = to_polar(self.o)
        o = torch.stack(from_polar(o_r[0], net_time * o_phi[0]))

        o_a = o[0].unsqueeze(0)
        o_b = o[1].unsqueeze(0)
        m_a = x_a * o_a - x_b * o_b
        m_b = x_b * o_a + x_a * o_b

        w_a = self.w[0]
        w_b = self.w[1]
        r_a = torch.mm(m_a, w_a) - torch.mm(m_b, w_b)
        r_b = torch.mm(m_b, w_a) + torch.mm(m_a, w_b)
        return r_a + self.bias, r_b
예제 #8
0
 def __init__(self, n_in, n_out, period):
     super(SwapComplexLinear, self).__init__()
     w_r = nn.init.xavier_normal_(torch.empty(period, n_in, n_out))
     w_phi = torch.Tensor(period, n_in, n_out).uniform_(-np.pi, np.pi)
     self.w = nn.Parameter(torch.stack(from_polar(w_r, w_phi)))
     self.bias = nn.Parameter(torch.zeros(n_out))