Esempio n. 1
0
class InstanceLayerNormalization(nn.Module):
    def __init__(self, in_ch):
        super(InstanceLayerNormalization, self).__init__()

        self.ro = Parameter(torch.Tensor(1, in_ch, 1, 1))
        self.gamma = Parameter(torch.Tensor(1, in_ch, 1, 1))
        self.beta = Parameter(torch.Tensor(1, in_ch, 1, 1))
        self.ro.data.fill_(0.0)
        self.gamma.data.fill_(1.0)
        self.beta.data.fill_(0.0)

    def forward(self, x):
        i_mean = torch.mean(torch.mean(x, dim=2, keepdim=True), dim=3, keepdim=True)
        i_var = torch.var(torch.var(x, dim=2, keepdim=True), dim=3, keepdim=True)
        i_std = torch.sqrt(i_var + 1e-5)
        i_h = (x - i_mean) / i_std

        l_mean = torch.mean(torch.mean(torch.mean(x, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True)
        l_var = torch.var(torch.var(torch.var(x, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True)
        l_std = torch.sqrt(l_var + 1e-5)
        l_h = (x - l_mean) / l_std

        h = self.ro.expand(x.size(0), -1, -1, -1) * i_h + (1 - self.ro.expand(x.size(0), -1, -1, -1)) * l_h
        h = h * self.gamma.expand(x.size(0), -1, -1, -1) + self.beta.expand(x.size(0), -1, -1, -1)

        return h
Esempio n. 2
0
class adaILN(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(adaILN, self).__init__()
        ##num_features = depth of feature maps
        self.eps = eps
        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.rho.data.fill_(0.9)

    def forward(self, input, gamma, beta):
        in_mean, in_var = torch.mean(input, dim=[2, 3],
                                     keepdim=True), torch.var(input,
                                                              dim=[2, 3],
                                                              keepdim=True)
        out_in = (input - in_mean) / torch.sqrt(
            in_var + self.eps)  #Instance normalization
        ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3],
                                     keepdim=True), torch.var(input,
                                                              dim=[1, 2, 3],
                                                              keepdim=True)
        out_ln = (input - ln_mean) / torch.sqrt(
            ln_var + self.eps)  #Layer normalization
        out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (
            1 - self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
        out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(
            2).unsqueeze(3)

        return out
Esempio n. 3
0
class adaILN(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(adaILN, self).__init__()
        self.eps = eps
        # adaILN的参数p,通过这个参数来动态调整LN和IN的占比
        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.rho.data.fill_(0.9)

    def forward(self, input, gamma, beta):
        # 先求两种规范化的值
        in_mean, in_var = torch.mean(input, dim=[2, 3],
                                     keepdim=True), torch.var(input,
                                                              dim=[2, 3],
                                                              keepdim=True)
        out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
        ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3],
                                     keepdim=True), torch.var(input,
                                                              dim=[1, 2, 3],
                                                              keepdim=True)
        out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
        # 合并两种规范化(IN, LN)
        out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (
            1 - self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
        # 扩张得到结果
        out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(
            2).unsqueeze(3)

        return out
Esempio n. 4
0
class SoftAdaLIN(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(SoftAdaLIN, self).__init__()
        self.norm = adaLIN(num_features, eps)

        self.w_gamma = Parameter(torch.zeros(1, num_features))
        self.w_beta = Parameter(torch.zeros(1, num_features))

        self.c_gamma = nn.Sequential(nn.Linear(num_features, num_features),
                                     nn.ReLU(True),
                                     nn.Linear(num_features, num_features))
        self.c_beta = nn.Sequential(nn.Linear(num_features, num_features),
                                    nn.ReLU(True),
                                    nn.Linear(num_features, num_features))
        self.s_gamma = nn.Linear(num_features, num_features)
        self.s_beta = nn.Linear(num_features, num_features)

    def forward(self, x, content_features, style_features):
        content_gamma, content_beta = self.c_gamma(content_features), self.c_beta(content_features)
        style_gamma, style_beta = self.s_gamma(style_features), self.s_beta(style_features)

        w_gamma, w_beta = self.w_gamma.expand(x.shape[0], -1), self.w_beta.expand(x.shape[0], -1)
        soft_gamma = (1. - w_gamma) * style_gamma + w_gamma * content_gamma
        soft_beta = (1. - w_beta) * style_beta + w_beta * content_beta

        out = self.norm(x, soft_gamma, soft_beta)
        return out
class ILN(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(ILN, self).__init__()
        self.eps = eps
        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.beta = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.rho.data.fill_(0.0)
        self.gamma.data.fill_(1.0)
        self.beta.data.fill_(0.0)

    def forward(self, input):
        in_mean, in_var = torch.mean(input, dim=[2, 3],
                                     keepdim=True), torch.var(input,
                                                              dim=[2, 3],
                                                              keepdim=True)
        out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
        ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3],
                                     keepdim=True), torch.var(input,
                                                              dim=[1, 2, 3],
                                                              keepdim=True)
        out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
        out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (
            1 - self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
        out = out * self.gamma.expand(input.shape[0], -1, -1,
                                      -1) + self.beta.expand(
                                          input.shape[0], -1, -1, -1)

        return out
Esempio n. 6
0
class ILN(nn.Module):
    """ILN (Instance Layer Normalization)"""
    def __init__(self, num_features, eps=1e-5):
        super(ILN, self).__init__()

        self.eps = eps
        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.beta = Parameter(torch.Tensor(1, num_features, 1, 1))

        self.rho.data.fill_(0.0)
        self.gamma.data.fill_(1.0)
        self.beta.data.fill_(0.0)

    def forward(self, x):
        in_mean = torch.mean(x, dim=[2, 3], keepdim=True)
        in_var = torch.var(x, dim=[2, 3], keepdim=True)
        out_in = (x - in_mean) / torch.sqrt(in_var + self.eps)

        ln_mean = torch.mean(x, dim=[1, 2, 3], keepdim=True)
        ln_var = torch.var(x, dim=[1, 2, 3], keepdim=True)
        out_ln = (x - ln_mean) / torch.sqrt(ln_var + self.eps)

        out = out_in * self.rho.expand(x.shape[0], -1, -1, -1) + out_ln * (
            1 - self.rho.expand(x.shape[0], -1, -1, -1))
        out = out * self.gamma.expand(
            x.shape[0], -1, -1, -1) + self.beta.expand(x.shape[0], -1, -1, -1)
        return out
Esempio n. 7
0
class adaILN(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(adaILN, self).__init__()
        self.eps = eps
        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.rho.data.fill_(0.9)

    def forward(self, input, gamma, beta):
        in_mean = torch.mean(torch.mean(input, dim=2, keepdim=True),
                             dim=3,
                             keepdim=True)
        in_var = torch.var(torch.var(input, dim=2, keepdim=True),
                           dim=3,
                           keepdim=True)
        out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
        ln_mean = torch.mean(torch.mean(torch.mean(input, dim=1, keepdim=True),
                                        dim=2,
                                        keepdim=True),
                             dim=3,
                             keepdim=True),
        ln_var = torch.var(torch.var(torch.var(input, dim=1, keepdim=True),
                                     dim=2,
                                     keepdim=True),
                           dim=3,
                           keepdim=True)
        out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
        out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in \
            + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
        out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(
            2).unsqueeze(3)

        return out
Esempio n. 8
0
class LSTMNet(nn.Module):
    """
    lstm模型
    """
    def __init__(self, input_size, hidden_size, output_size, num_layers,
                 future_len):
        super(LSTMNet, self).__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size, output_size)

        # lstm的初始状态也进行学习
        self.h0 = Parameter(torch.randn(num_layers, 1, hidden_size),
                            requires_grad=True)
        self.c0 = Parameter(torch.randn(num_layers, 1, hidden_size),
                            requires_grad=True)

        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        # 要连续预测多少步
        self.future_len = future_len

    def forward(self, x, future_time):
        """

        :param x:  (seq_len, batch, input_size)
        :param future_time: (future_len-1, batch, 24) 未来要预测的时间长度
        :return: y (future_len, batch, output_size)
        """
        batch_size = x.size(1)

        hidden = (self.h0.expand(self.num_layers, batch_size,
                                 self.hidden_size).contiguous(),
                  self.c0.expand(self.num_layers, batch_size,
                                 self.hidden_size).contiguous())
        output, hidden = self.lstm(x, hidden)

        ys = []
        output = self.fc(output[-1, :, :].view(-1, self.hidden_size)).view(
            1, -1, self.output_size)
        ys.append(output)

        y = output
        for i in range(self.future_len - 1):
            now_time = future_time[i:i + 1, :, :]
            x = torch.cat([y, now_time], dim=2)

            y, hidden = self.lstm(x, hidden)
            y = self.fc(y.view(-1, self.hidden_size)).view(
                1, -1,
                self.output_size)  # (seq_len x batch_size x output_size)
            ys.append(y)

        ys = torch.cat(ys)
        return ys
Esempio n. 9
0
class LikelihoodLoss(nn.Module):
    def __init__(self,
                 classweight=1,
                 cls=range(10),
                 ndim=10,
                 sigmaform='identical'):
        super(LikelihoodLoss, self).__init__()
        nclass = len(cls)
        self.cls = cls
        if np.isscalar(classweight):
            self.classweight = np.ones(nclass) * classweight
        else:
            self.classweight = classweight
        self.nclass = nclass
        self.sigmaform = sigmaform
        self.ndim = ndim
        self.mu = Parameter(torch.randn(nclass, ndim))
        if sigmaform == 'identical':
            self.sigma = Parameter(
                torch.rand(nclass))  # assume identical matrix
            self.sig = self.sigma
        elif sigmaform == 'diagnal':
            self.sigma = Parameter(torch.rand(
                nclass, ndim))  # assume diagonal cov matrix
            self.sig = self.sigma
        elif sigmaform == 'share':
            self.sigma = Parameter(torch.rand(1))
            self.sig = self.sigma.expand(nclass)
        elif sigmaform == 'sharediag':
            self.sigma = Parameter(torch.rand(1, ndim))
            self.sig = self.sigma.expand(nclass, ndim)

    def forward(self, input, target):

        if self.sigmaform == 'share':
            self.sig = self.sigma.expand(self.nclass)
        elif self.sigmaform == 'sharediag':
            self.sig = self.sigma.expand(self.nclass, self.ndim)
        else:
            self.sig = self.sigma
        loss = 0

        for idx, cls in enumerate(self.cls):
            if (target == cls).any():
                loss = loss-torch.tensor(self.classweight[idx]).cuda() * \
                    gauss_logpdf(input[target==cls], self.mu[idx], self.sig[idx]).mean()

            #loss = loss - torch.tensor(self.classweight[i]).cuda() *  \
            #(gauss_logpdf(input[target==i], self.mu[i], sigma[i]).sum()-
            #0.1*gauss_logpdf(input[target!=i], self.mu[i], sigma[i]).sum())

        #print('loss', loss)
        return loss
Esempio n. 10
0
class BatchGraphSUM(Module):

    def __init__(self, in_features, out_features, bias=True):
        super(BatchGraphSUM, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
            init.constant_(self.bias, 0)
        else:
            self.register_parameter('bias', None)
        init.xavier_uniform_(self.weight)

    def forward(self, x, lap, layer, model_type):
        expand_weight = self.weight.expand(x.shape[0], -1, -1)
        support = torch.bmm(x, expand_weight)
        if layer <2:
            output = torch.bmm(lap, support)
        else:
            output = support
        if self.bias is not None:
            return output + self.bias
        else:
            return output
Esempio n. 11
0
class Net(nn.Module):
    def __init__(self, X, Y, hidden_layer_sizes):
        super(Net, self).__init__()

        # Initialize linear layer with least squares solution
        X_ = np.hstack([X, np.ones((X.shape[0], 1))])
        Theta = np.linalg.solve(X_.T.dot(X_), X_.T.dot(Y))

        self.lin = nn.Linear(X.shape[1], Y.shape[1])
        W, b = self.lin.parameters()
        W.data = torch.Tensor(Theta[:-1, :].T)
        b.data = torch.Tensor(Theta[-1, :])

        # Set up non-linear network of
        # Linear -> BatchNorm -> ReLU -> Dropout layers
        layer_sizes = [X.shape[1]] + hidden_layer_sizes
        layers = reduce(operator.add, [[
            nn.Linear(a, b),
            nn.BatchNorm1d(b),
            nn.ReLU(),
            nn.Dropout(p=0.2)
        ] for a, b in zip(layer_sizes[0:-1], layer_sizes[1:])])
        layers += [nn.Linear(layer_sizes[-1], Y.shape[1])]
        self.net = nn.Sequential(*layers)
        self.sig = Parameter(torch.ones(1, Y.shape[1], device=DEVICE))

    def forward(self, x):
        return self.lin(x) + self.net(x), \
            self.sig.expand(x.size(0), self.sig.size(1))

    def set_sig(self, X, Y):
        Y_pred = self.lin(X) + self.net(X)
        var = torch.mean((Y_pred - Y)**2, 0)
        self.sig.data = torch.sqrt(var).data.unsqueeze(0)
Esempio n. 12
0
class Embedding(nn.Module):
    def __init__(self,
                 inputs,
                 patches,
                 hidden_size,
                 transformer,
                 num_classes=10,
                 classifier='gap'):
        super(Embedding, self).__init__()
        c, h, w = inputs
        fh, fw = patches
        gh, gw = h // fh, w // fw
        self.classifier = classifier
        seq_len = gh * gw if classifier != 'token' else gh * gw + 1
        self.conv = nn.Conv2d(c,
                              hidden_size,
                              kernel_size=(fh, fw),
                              stride=(fh, fw))
        if self.classifier == 'token':
            self.cls = Parameter(torch.zeros(1, 1, hidden_size))
        self.add_pos_embed = AddPositionEmbs(seq_len, hidden_size)

    def forward(self, x):
        x = self.conv(x)
        n, c, h, w = x.shape
        x = torch.reshape(x, [n, c, h * w])
        x = x.permute(0, 2, 1)
        if self.classifier == 'token':
            cls = self.cls.expand(n, -1, -1)
            x = torch.cat([cls, x], dim=1)
        x = self.add_pos_embed(x)
        return x
Esempio n. 13
0
class BatchGraphConvolution(Module):
    def __init__(self, in_features, out_features, bias=True):
        super(BatchGraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
            init.constant_(self.bias, 0)
        else:
            self.register_parameter('bias', None)
        init.xavier_uniform_(self.weight)

    def forward(self, x, adj):
        expand_weight = self.weight.expand(x.shape[0], -1, -1)
        output = torch.bmm(adj, torch.bmm(x, expand_weight))
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Esempio n. 14
0
class _DBN(nn.Module):
    _version = 2

    def __init__(self, num_features):
        super(_DBN, self).__init__()
        self.num_features = num_features
        #self.weight = Parameter(torch.Tensor(num_features))
        self.weight = Parameter(torch.Tensor(1))
        self.bias = Parameter(torch.Tensor(self.num_features))
        self.running_mean = torch.zeros(self.num_features).cuda()
        self.running_var = torch.ones(self.num_features).cuda()
        self.num_batches_tracked = torch.tensor(0, dtype=torch.long).cuda()
        self.momentum = 0.9
        self.eps = 1e-5
        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.uniform_()

    def _check_input_dim(self, input):
        raise NotImplementedError

    def forward(self, input):
        self._check_input_dim(input)
        #self.weight_expand = self.weight
        self.weight_expand = self.weight.expand(self.num_features)
        #self.weight_expand = torch.ones(self.num_features).cuda()
        if self.training:
            #sample_mean = input.transpose(0,1).contiguous().view(self.num_features,-1).mean(1)
            sample_var = input.transpose(0, 1).contiguous().view(
                self.num_features, -1).var(1)
            sample_var = sample_var.mean().unsqueeze(0).expand(
                self.num_features)
            sample_mean = torch.zeros(self.num_features).cuda()
            #sample_var = torch.ones(self.num_features).cuda()

            out_ = (input - sample_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)
                    ) / torch.sqrt(
                        sample_var.unsqueeze(0).unsqueeze(2).unsqueeze(3) +
                        self.eps)

            self.running_mean = self.momentum * self.running_mean + (
                1 - self.momentum) * sample_mean
            self.running_var = self.momentum * self.running_var + (
                1 - self.momentum) * sample_var

            out = self.weight_expand.unsqueeze(0).unsqueeze(2).unsqueeze(
                3) * out_ + self.bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
        else:
            scale = self.weight_expand.unsqueeze(0).unsqueeze(2).unsqueeze(
                3) / torch.sqrt(
                    self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3) +
                    self.eps)
            out = input * scale + self.bias.unsqueeze(0).unsqueeze(
                2).unsqueeze(3) - self.running_mean.unsqueeze(0).unsqueeze(
                    2).unsqueeze(3) * scale

        return out
Esempio n. 15
0
class SimlarityLayer(Module):
    """
    Render similarity matrix of GCN as a learnable layer
    """
    def __init__(self, node_num):
        super(SimlarityLayer, self).__init__()
        self.node_num = node_num
        self.cnt = 0
        self.delta = 0.75
        self.weight = Parameter(torch.FloatTensor(node_num, node_num))
        self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        # stdv = 1. / math.sqrt(self.weight.size(1))
        # self.weight.data.uniform_(-stdv, stdv)
        # if self.bias is not None:
        #     self.bias.data.uniform_(-stdv, stdv)
        # nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu')
        nn.init.constant_(self.weight, 1)

    def forward(self, input, valid_mask):
        if self.cnt < 3:
            print(self.weight)
            self.cnt += 1
        adj_mat = F.softmax(self.weight.expand(valid_mask.size(0),
                                               self.weight.size(0),
                                               self.weight.size(1)),
                            dim=2)
        adj_mat = adj_mat * valid_mask
        adj_mat = \
            adj_mat + torch.eye(self.weight.size(0)).expand(valid_mask.size(0), self.weight.size(0), self.weight.size(1)).cuda()
        adj_mat_nor = self.diag_normalization(adj_mat)
        output = torch.matmul(adj_mat_nor, input)
        # output = self.delta*output + (1-self.delta)*input     # 对经过GCN之后的feature和未经GCN的feature进行加权
        # output = torch.matmul(self.weight, input)
        return output

    def diag_normalization(self, adj_mat):
        # 根据论文里面的公式先提前计算出A^(~)
        batch_size, node_num = adj_mat.size(0), adj_mat.size(1)
        D_mat = torch.zeros(batch_size, node_num, node_num)
        # adjMat = adjMat * valid_mask

        row_sum = adj_mat.sum(dim=2)
        D_mat[:, range(node_num), range(node_num)] = row_sum
        # for batch_id in range(batch_size):
        #     D[batch_id] = torch.diag(row_sum[batch_id])
        D_mat = torch.pow(D_mat, -0.5)

        inf_mask = torch.isinf(D_mat)
        D_mat[inf_mask] = 0
        D_mat = D_mat.float().cuda()

        adj_mat_nor = torch.matmul(torch.matmul(D_mat, adj_mat), D_mat)

        return adj_mat_nor
Esempio n. 16
0
class MFDeep1(nn.Module):
    is_train = True

    def __init__(self,
                 n_users,
                 n_items,
                 n_dim,
                 n_obs,
                 lub=1.,
                 lib=1.,
                 luv=1.,
                 liv=1.,
                 lmat=1.0,
                 loss=nn.MSELoss(size_average=False)):
        super(MFDeep1, self).__init__()
        self.embed_user = VariationalBiasedEmbedding(n_users,
                                                     n_dim,
                                                     lb=lub,
                                                     lv=luv,
                                                     n_obs=n_obs)
        self.embed_item = VariationalBiasedEmbedding(n_items,
                                                     n_dim,
                                                     lb=lib,
                                                     lv=liv,
                                                     n_obs=n_obs)
        self.lin1 = nn.Linear(n_dim, n_dim, bias=True)
        self.lin2 = nn.Linear(n_dim, n_dim, bias=True)
        self.glob_bias = Parameter(torch.Tensor(1, 1))
        self.n_obs = n_obs
        self.lossf = loss
        self.lin1.weight.data
        self.lmat = lmat
        self.glob_bias.data[:] = 1e-6

    def forward(self, u, i):
        bias = self.glob_bias.expand(len(u), 1).squeeze()
        bu, vu = self.embed_user(u)
        bi, vi = self.embed_item(i)
        x1 = vu * self.lin1(vi)
        x2 = selu(self.lin2(x1))
        logodds = bias + bi + bu + x2.sum(dim=1)  # + x0.sum(dim=1)
        return logodds

    def loss(self, prediction, target):
        n_batches = self.n_obs * 1.0 / target.size()[0]
        llh = self.lossf(prediction, target)
        mat1 = self.lin1.weight @ torch.t(self.lin1.weight)
        mat2 = self.lin2.weight @ torch.t(self.lin2.weight)
        eye1 = Variable(torch.eye(*mat1.size()))
        eye2 = Variable(torch.eye(*mat2.size()))
        diff1 = ((mat1 - eye1)**2.0).sum() * self.lmat
        diff2 = ((mat2 - eye2)**2.0).sum() * self.lmat
        reg = (diff1 + diff2 + self.embed_user.prior() +
               self.embed_item.prior())
        return llh + reg / n_batches
Esempio n. 17
0
class SpOptNetEq(nn.Module):
    def __init__(self, n, Qpenalty, trueInit=False):
        super().__init__()
        nx = (n**2)**3
        self.nx = nx

        spTensor = torch.cuda.sparse.DoubleTensor
        iTensor = torch.cuda.LongTensor
        dTensor = torch.cuda.DoubleTensor

        self.Qi = iTensor([range(nx), range(nx)])
        self.Qv = Variable(dTensor(nx).fill_(Qpenalty))
        self.Qsz = torch.Size([nx, nx])

        self.Gi = iTensor([range(nx), range(nx)])
        self.Gv = Variable(dTensor(nx).fill_(-1.0))
        self.Gsz = torch.Size([nx, nx])
        self.h = Variable(torch.zeros(nx).double().cuda())

        t = get_sudoku_matrix(n)
        neq = t.shape[0]
        if trueInit:
            I = t != 0
            self.Av = Parameter(dTensor(t[I]))
            Ai_np = np.nonzero(t)
            self.Ai = torch.stack((torch.LongTensor(Ai_np[0]),
                                   torch.LongTensor(Ai_np[1]))).cuda()
            self.Asz = torch.Size([neq, nx])
        else:
            # TODO: This is very dense:
            self.Ai = torch.stack(
                (iTensor(list(range(neq))).unsqueeze(1).repeat(1, nx).view(-1),
                 iTensor(list(range(nx))).repeat(neq)))
            self.Av = Parameter(dTensor(neq * nx).uniform_())
            self.Asz = torch.Size([neq, nx])
        self.b = Variable(torch.ones(neq).double().cuda())

    def forward(self, puzzles):
        nBatch = puzzles.size(0)

        p = -puzzles.view(nBatch, -1).double()

        return SpQPFunction(
            self.Qi,
            self.Qsz,
            self.Gi,
            self.Gsz,
            self.Ai,
            self.Asz,
            verbose=-1)(self.Qv.expand(nBatch, self.Qv.size(0)), p,
                        self.Gv.expand(nBatch, self.Gv.size(0)),
                        self.h.expand(nBatch, self.h.size(0)),
                        self.Av.expand(nBatch, self.Av.size(0)),
                        self.b.expand(nBatch,
                                      self.b.size(0))).float().view_as(puzzles)
Esempio n. 18
0
class MFClassic(nn.Module):
    def __init__(self,
                 n_users,
                 n_items,
                 dim,
                 n_obs,
                 reg_user_bas=1.,
                 reg_item_bas=1.,
                 reg_user_vec=1.,
                 reg_item_vec=1.):
        super(MFClassic, self).__init__()
        # User & item biases
        self.glob_bas = Parameter(torch.Tensor(1, 1))
        self.user_bas = nn.Embedding(n_users, 1)
        self.item_bas = nn.Embedding(n_items, 1)
        # User & item vectors
        self.user_vec = nn.Embedding(n_users, dim)
        self.item_vec = nn.Embedding(n_items, dim)
        self.n_obs = n_obs
        self.reg_user_bas = reg_user_bas
        self.reg_user_vec = reg_user_vec
        self.reg_item_bas = reg_item_bas
        self.reg_item_vec = reg_item_vec

    def forward(self, user_idx, item_idx):
        batchsize = len(user_idx)
        glob_bas = self.glob_bas.expand(batchsize, 1).squeeze()
        user_bas = self.user_bas(user_idx).squeeze()
        item_bas = self.item_bas(item_idx).squeeze()
        user_vec = self.user_vec(user_idx)
        item_vec = self.item_vec(item_idx)
        intx = (user_vec * item_vec).sum(dim=1)
        score = glob_bas + user_bas + item_bas + intx
        return score

    def loss(self, prediction, target):
        # Measure likelihood of target rating given prediction
        # Same as a logistic / sigmoid loss -- it compares a score
        # that ranges from -inf to +inf with a binary outcome of 0 or 1
        llh = F.binary_cross_entropy_with_logits(prediction, target)
        # L2 regularization of weights with custom coefficients
        # for each piece
        prior = (self.user_bas.weight.sum()**2. * self.reg_user_bas +
                 self.item_bas.weight.sum()**2. * self.reg_item_bas +
                 self.user_vec.weight.sum()**2. * self.reg_user_vec +
                 self.item_vec.weight.sum()**2. * self.reg_item_vec)
        # Since we're computing in minibatches but the prior is computed
        # once over a single pass thru dataset, adjust the prior loss s.t.
        # it's in proportion to the number of minibatches. This is degenerate
        # with the regularization coefficients, so not necessary, but it
        # means our initial gueeses for regularization coefs are around 1.0
        n_minibatches = self.n_obs * 1.0 / target.size()[0]
        prior_weighted = prior / n_minibatches
        return llh + prior_weighted
class SmoothUnit(nn.Module):
    def __init__(self, in_feat: int, out_feat: int):
        super(SmoothUnit, self).__init__()
        self.linear = nn.Linear(in_feat, out_feat)
        self.log_var = Parameter(torch.FloatTensor(out_feat))
        self.log_var.data.uniform_(-1, 1)

    def forward(self, x):
        mu = self.linear(x)
        logvar = self.log_var.expand(mu.shape[0], -1)
        var = torch.exp(logvar)
        return torch.normal(mu, var).detach(), (mu, logvar)
Esempio n. 20
0
class AdaILN(nn.Module):
    def __init__(self, in_ch):
        super(AdaILN, self).__init__()

        self.ro = Parameter(torch.Tensor(1, in_ch, 1, 1))
        self.ro.data.fill_(0.9)

    def forward(self, x, gamma, beta):
        i_mean = torch.mean(torch.mean(x, dim=2, keepdim=True), dim=3, keepdim=True)
        i_var = torch.var(torch.var(x, dim=2, keepdim=True), dim=3, keepdim=True)
        i_std = torch.sqrt(i_var + 1e-5)
        i_h = (x - i_mean) / i_std

        l_mean = torch.mean(torch.mean(torch.mean(x, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True)
        l_var = torch.var(torch.var(torch.var(x, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True)
        l_std = torch.sqrt(l_var + 1e-5)
        l_h = (x - l_mean) / l_std

        h = self.ro.expand(x.size(0), -1, -1, -1) * i_h + (1 - self.ro.expand(x.size(0), -1, -1, -1)) * l_h
        h = h * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)

        return h
Esempio n. 21
0
class AdaLIN(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super(AdaLIN, self).__init__()

        self.eps = eps
        self.rho = Parameter(torch.Tensor(1, dim, 1, 1))
        self.rho.data.fill_(0.9)

    def forward(self, x, gamma, beta):
        in_mean, in_var = torch.mean(x, dim=[2, 3],
                                     keepdim=True), torch.var(x,
                                                              dim=[2, 3],
                                                              keepdim=True)
        out_in = (x - in_mean) / torch.sqrt(in_var + self.eps)
        ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3],
                                     keepdim=True), torch.var(x,
                                                              dim=[1, 2, 3],
                                                              keepdim=True)
        out_ln = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
        out = self.rho.expand(x.shape[0], -1, -1, -1) * out_in + (
            1 - self.rho.expand(x.shape[0], -1, -1, -1)) * out_ln
        out = out * gamma + beta
        return out
Esempio n. 22
0
class AdaILN(nn.Module):
    """AdaILN (Adaptive Instance Layer Normalization)"""
    def __init__(self, num_features, eps=1e-5):
        super(AdaILN, self).__init__()

        self.eps = eps
        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.rho.data.fill_(0.9)

    def forward(self, x, gamma, beta):
        in_mean = torch.mean(x, dim=[2, 3], keepdim=True)
        in_var = torch.var(x, dim=[2, 3], keepdim=True)
        out_in = (x - in_mean) / torch.sqrt(in_var + self.eps)

        ln_mean = torch.mean(x, dim=[1, 2, 3], keepdim=True)
        ln_var = torch.var(x, dim=[1, 2, 3], keepdim=True)
        out_ln = (x - ln_mean) / torch.sqrt(ln_var + self.eps)

        out = out_in * self.rho.expand(x.shape[0], -1, -1, -1) + out_ln * (
            1 - self.rho.expand(x.shape[0], -1, -1, -1))
        out = out * gamma.unsqueeze(dim=2).unsqueeze(dim=3) + beta.unsqueeze(
            dim=2).unsqueeze(dim=3)
        return out
Esempio n. 23
0
class GraphConvolution(Module):

    def __init__(self, in_features, out_features):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.bias = Parameter(torch.Tensor(out_features))
        init.constant_(self.bias, 0)
        init.xavier_uniform_(self.weight)

    def forward(self, x, lap):
        expand_weight = self.weight.expand(x.shape[0], -1, -1)
        support = torch.bmm(x, expand_weight)
        output = torch.bmm(lap, support)
        return output + self.bias
Esempio n. 24
0
class FeatureRegression(nn.Module):
    def __init__(self, output_dim=6, use_cuda=True):
        super(FeatureRegression, self).__init__()
        self.align = CorrelationAlign()
        self.preconv = nn.Sequential(
            nn.Conv2d(841, 128, kernel_size=1, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=7, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
        )
        self.conv = nn.Sequential(
            nn.Conv2d(128 + 5, 128, kernel_size=1, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
        )
        self.proj = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=1, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
        )
        self.linear = nn.Linear(128, output_dim)
        self.att = Attention(128, 64)
        self.weight = Parameter(torch.ones(1, 5, 9, 9))
        self.weight.data.uniform_(-1, 1)
        if use_cuda:
            self.preconv.cuda()
            self.conv.cuda()
            self.cuda()
            self.linear.cuda()

    def forward(self, x):
        x = self.align(x)
        x = self.preconv(x)
        x_ = x
        x = torch.cat((self.weight.expand(x.size(0), 5, 9, 9), x), 1)
        x = self.conv(x)
        x = self.att(x_, self.proj(x))
        x = self.linear(x)
        return x
Esempio n. 25
0
class GraphConvLayer(Module):
    """ This class is taken from https://github.com/tkipf/pygcn.
    It implements one graph convolution layer.
    """
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        b, _, _ = np.shape(input)
        hidden = torch.bmm(
            input, self.weight.expand(b, self.in_features, self.out_features))
        output = torch.bmm(adj, hidden)
        return output
Esempio n. 26
0
class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self, val=None):
        if val is None:
            fan = self.in_features + self.out_features
            spread = math.sqrt(2.0) * math.sqrt(2.0 / fan)
        else:
            spread = val
        self.weight.data.uniform_(-spread, spread)
        if self.bias is not None:
            self.bias.data.uniform_(-spread, spread)

    def forward(self, input, adj):
        support = torch.bmm(
            input,
            self.weight.expand(input.size(0), self.in_features,
                               self.out_features))
        output = torch.bmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Esempio n. 27
0
class VMF(nn.Module):
    def __init__(self, n_users, n_items, n_dim, n_obs, lub=1.,
                 lib=1., luv=1., liv=1., loss=nn.MSELoss):
        super(VMF, self).__init__()
        self.embed_user = VBE(n_users, n_dim, lb=lub, lv=luv)
        self.embed_item = VBE(n_items, n_dim, lb=lib, lv=liv)
        self.glob_bias = Parameter(torch.Tensor(1, 1))
        self.n_obs = n_obs
        self.lossf = loss()

    def forward(self, input):
        u, i = input
        bias = self.glob_bias.expand(len(u), 1).squeeze()
        bu, vu = self.embed_user(u)
        bi, vi = self.embed_item(i)
        intx = (vu * vi).sum(dim=1)
        logodds = bias + bi + bu + intx
        return logodds

    def loss(self, prediction, target):
        n_batches = self.n_obs * 1.0 / target.size()[0]
        llh = self.lossf(prediction, target)
        reg = (self.embed_user.prior() + self.embed_item.prior()) / n_batches
        return llh + reg
Esempio n. 28
0
class DPConv(nn.Module):
    def __init__(self, in_planes, out_planes, k=3, stride=1):
        super(DPConv, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.k = k
        self.stride = stride
        self.perturbation = nn.Conv2d(in_planes, out_planes, kernel_size=k, stride=stride,
                              padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.bn2 = nn.BatchNorm2d(out_planes)

        self.distribution_zoom = Parameter(torch.ones(1)) # for mask size [-1,0,1]*zoom
        self.distribution_var = Parameter(torch.ones(out_planes,in_planes,1)) # for normal distribution variance.
        # TODO: add learnable parameters. # Done
        self.distribution_scale = Parameter(torch.ones(out_planes,in_planes,1,1))
        self.distribution_bias = Parameter(torch.zeros(out_planes, in_planes, 1, 1))

        self.register_buffer('normal_loc', torch.zeros(2))
        self.register_buffer('mask', self._get_mask())


    def forward(self, input):
        # print(self.mask)
        st = time.perf_counter()
        for i in range(50):
            param = self._init_distribution()
        print("param time: {}".format(time.perf_counter() - st))



        st = time.perf_counter()
        for i in range(50):
            distribution_out = self._distribution_conv(input,param)
        print("distribution_out time: {}".format(time.perf_counter() - st))


        return self.bn1(distribution_out)+self.bn2(self.perturbation(input))

    def _get_mask(self):
        mask = (self.perturbation.weight[0, 0, :, :] != -999).nonzero()
        mask = mask.reshape(self.k, self.k, 2)
        mask = mask.unsqueeze(dim=2).unsqueeze(dim=2) - self.k // 2  # assume square, lazy.
        return mask

    def _init_distribution(self):
        normal_scal = self.distribution_var.expand((self.out_planes , self.in_planes, 2))
        # scale_tril = torch.ones(self.out_planes * self.in_planes, 2)
        # normal_scal = 1 * scale_tril
        m = MultivariateNormal(loc=self.normal_loc, scale_tril=(normal_scal).diag_embed())
        # print(self.mask.device)
        y = m.log_prob(self.mask*self.distribution_zoom).exp()


        y = y.permute(2, 3, 0, 1)


        if self.distribution_zoom == 1:
            y = -(y - F.adaptive_avg_pool2d(y, 1))
            std = math.sqrt(2) / math.sqrt(self.out_planes * self.k * self.k)
            y = math.sqrt(std / y.var())*y
        else:
            y = -(y - self.distribution_bias)
            y = self.distribution_scale * y
        return  y

    def _distribution_conv(self,input, weight,bias=None,
                 padding=1, dilation=1, groups=1):
        stride = _pair(self.stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        return F.conv2d(input, weight, bias, stride,
                        padding, dilation, groups)
Esempio n. 29
0
class Seq2Seq(nn.Module):
    def __init__(self, encode_ntoken, decode_ntoken,
            input_size, hidden_size,
            input_max_len, output_max_len,
            batch_size,
            nlayers=1, bias=False, attention=True, dropout_p=0.5,
            batch_first=False):
        super(Seq2Seq, self).__init__()
        self.dropout_p = dropout_p
        if dropout_p > 0:
            self.dropout = nn.Dropout(p=dropout_p)

        # encoder stack
        self.enc_embedding = nn.Embedding(encode_ntoken, input_size)
        self.encoder = nn.LSTM(input_size, hidden_size, nlayers, bias=bias, batch_first=batch_first)

        # decoder stack
        self.dec_embedding = nn.Embedding(decode_ntoken, input_size)
        self.decoder = nn.LSTM(hidden_size, hidden_size, nlayers, bias=bias, batch_first=batch_first)
        if attention:
            self.attn_enc_linear = nn.Linear(hidden_size, hidden_size)
            self.attn_dec_linear = nn.Linear(hidden_size, hidden_size)
            self.attn_linear = nn.Linear(hidden_size*2, hidden_size)
            self.att_weight = Parameter(torch.Tensor(hidden_size, 1))
            self.attn_tanh = nn.Tanh()
            self.attn_conv = nn.Conv2d(1, hidden_size, (1,hidden_size))
        self.linear = nn.Linear(hidden_size, decode_ntoken, bias=True)

        self.softmax = nn.LogSoftmax()

        self.decode_ntoken = decode_ntoken
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.nlayers = nlayers
        self.attention = attention
        self.input_max_len = input_max_len
        self.output_max_len = output_max_len
        self.batch_first = batch_first

    def init_weights(self, initrange):
        for param in self.parameters():
            param.data.uniform_(-initrange, initrange)

    def attention_func(self, encoder_outputs, decoder_hidden):
        #calculate U*h_j
        att_dec_tmp = self.attn_dec_linear(decoder_hidden)
        sz = encoder_outputs.size()
        # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
        attn_value = self.attn_conv(encoder_outputs.view(sz[0], 1, sz[1], sz[2]))
        attn_value = attn_value.squeeze().transpose(1,2)

        ss = torch.stack([att_dec_tmp]*sz[0])
        attn_value = self.attn_tanh(attn_value + ss)

        # input_len * batch * 1 -->  batch * input_len * 1
        attn_value2 = torch.bmm(attn_value, self.att_weight.expand(self.hidden_size, self.input_max_len).t().unsqueeze(2))
        attention = nn.Softmax()(torch.transpose(attn_value2.squeeze(),0,1))

        hidden = torch.bmm(torch.transpose(torch.transpose(encoder_outputs, 0, 1), 1, 2), attention.unsqueeze(2))       
        return hidden.squeeze()

    def attention_func1(self, encoder_outputs, decoder_hidden):
        #att_enc_tmp = Variable(torch.Tensor(self.input_size, self.batch_size, self.hidden_size))
        att_dec_tmp = self.attn_dec_linear(decoder_hidden)
        attn_value = Variable(encoder_outputs.data.new(self.input_max_len, self.batch_size, self.hidden_size))
        for i in range(encoder_outputs.size()[0]):
            attn_value[i] = self.attn_tanh(self.attn_enc_linear(encoder_outputs[i]) + att_dec_tmp)

        # input_len * batch * 1 -->  batch * input_len * 1
        attn_value2 = torch.bmm(attn_value, self.att_weight.expand(self.hidden_size, self.input_max_len).t().unsqueeze(2))
        attention = nn.Softmax()(torch.transpose(attn_value2.squeeze(),0,1))

        #hidden = encoder_outputs
        #for i in range(1, att_enc_tmp.size()[0]):
        #    hidden = hidden + att_enc_tmp[i].attention[i]
        #    (batch_size * hidden_len * input_len ) * (batch_size * input_len) = batch_size * hidden_len

        hidden = torch.bmm(torch.transpose(torch.transpose(encoder_outputs, 0, 1), 1, 2), attention.unsqueeze(2))       
        return hidden.squeeze()



    def encode(self, encoder_inputs):
        weight = next(self.parameters()).data
        init_state = (Variable(weight.new(self.nlayers, self.batch_size, self.hidden_size).zero_()),
                Variable(weight.new(self.nlayers, self.batch_size, self.hidden_size).zero_()))
        embedding = self.enc_embedding(encoder_inputs)
        if self.dropout_p > 0:
            embedding = self.dropout(embedding)
        encoder_outputs, encoder_state = self.encoder(embedding, init_state)
        return encoder_outputs, encoder_state

    def decode(self, encoder_outputs, encoder_state, decoder_inputs, feed_previous):
        pred = []
        state = encoder_state
        if feed_previous:
            if self.batch_first:
                embedding = self.dec_embedding(decoder_inputs[:,0].unsqueeze(1))
            else:
                embedding = self.dec_embedding(decoder_inputs[0].unsqueeze(0))
            for time in range(1, self.output_max_len):
                #state = repackage_state(state)
                # batch_size * 1 * embedding_size
                output, state = self.decoder(embedding, state)
                # print(output.size())

                att_out = self.attention_func(encoder_outputs, output.squeeze())
                softmax = self.predict(output.squeeze(), att_out)
                # feed previous
                decoder_input = softmax.max(1)[1]
                embedding = self.dec_embedding(decoder_input.squeeze().unsqueeze(0))
                if self.batch_first:
                    embedding = torch.transpose(embedding, 0, 1)
                pred.append(softmax)
        else:
            embedding = self.dec_embedding(decoder_inputs)
            if self.dropout_p > 0:
                embedding = self.dropout(embedding)
            outputs, _ = self.decoder(embedding, state)
            # print(outputs.size())

            # if self.batch_first:
                # for batch in range(self.batch_size):
                    # output = outputs[batch,1:,:]
                    # softmax = self.predict(output, encoder_outputs)
                    # pred.append(softmax)
            # else:
            for time_step in range(self.output_max_len - 1):
                if self.batch_first:
                    output = outputs[:,time_step,:]
                else:
                    output = outputs[time_step]
                #softmax = self.predict(output, encoder_outputs)
                att_out = self.attention_func(encoder_outputs, output)
                softmax = self.predict(output, att_out)
                pred.append(softmax)
        return pred

    def predict(self, dec_output, att_out):

        if self.dropout_p > 0:
            dec_output = self.dropout(dec_output)
            att_out = self.dropout(att_out)
        x = self.attn_linear(torch.cat((dec_output,att_out),1))
        linear = self.linear(x)
        softmax = self.softmax(linear)
        return softmax

    #def forward(self, inputs, feed_previous=False):
    def forward(self, encoder_inputs, decoder_inputs, feed_previous=False):
        #encoder_inputs = inputs[0]
        #decoder_inputs = inputs[1]
        # encoding
        encoder_outputs, encoder_state = self.encode(encoder_inputs)

        # decoding
        pred = self.decode(encoder_outputs, encoder_state, decoder_inputs, feed_previous)
        return pred
class ConvLSTMCell(nn.Module):
    '''
    Generate a convolutional LSTM cell
    copied and modified from https://github.com/Atcold/pytorch-CortexNet/blob/master/model/ConvLSTMCell.py
    '''
    def __init__(self,
                 input_size,
                 hidden_size,
                 kernel_size=5,
                 stride=1,
                 padding=2,
                 train_init_state=False,
                 height=None,
                 width=None):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.train_init_state = train_init_state
        self.height = height
        self.width = width

        # lstm gates
        self.gates = nn.Conv2d(input_size + hidden_size,
                               4 * hidden_size,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=padding)

        # initial states
        if self.train_init_state:
            assert self.height and self.width
            self.init_hidden = Parameter(
                torch.zeros(1, self.hidden_size, self.height, self.width))
            self.init_cell = Parameter(
                torch.zeros(1, self.hidden_size, self.height, self.width))

    def init_state(self, batch_size, spatial_size):
        state_size = [batch_size, self.hidden_size] + list(spatial_size)
        if self.train_init_state:
            return (self.init_hidden.expand(state_size),
                    self.init_cell.expand(state_size))
        else:
            weight = next(self.parameters())
            return (weight.new_zeros(state_size), weight.new_zeros(state_size))

    def forward(self, input, prev_state):
        ''' forward stacked rnn one time step
        Input:
            input: batch_size x input_size x height x width
            prev_state: (hidden, cell) of each ConvLSTM
        Output:
            output: hidden (of new_state), batch_size x hidden_size x hidden_height x hidden_width
            new_state: (hidden, cell) of each ConvLSTM
        '''

        # get batch and spatial sizes
        batch_size = input.data.size(0)
        spatial_size = input.data.size()[2:]

        # generate empty prev_state, if None is provided
        if prev_state is None:
            prev_state = self.init_state(batch_size, spatial_size)

        prev_hidden, prev_cell = prev_state

        # data size is [batch, channel, height, width]
        stacked_inputs = torch.cat((input, prev_hidden), 1)
        outputs = self.gates(stacked_inputs)

        # chunk across channel dimension
        in_gate, remember_gate, out_gate, cell_gate = outputs.chunk(4, 1)

        # apply sigmoid non linearity
        in_gate = torch.sigmoid(in_gate)
        remember_gate = torch.sigmoid(remember_gate)
        out_gate = torch.sigmoid(out_gate)

        # apply tanh non linearity
        cell_gate = torch.tanh(cell_gate)

        # compute current cell and hidden state
        cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
        hidden = out_gate * torch.tanh(cell)

        # pack output
        new_state = (hidden, cell)
        output = hidden
        return output, new_state