class Agent(nn.Module):
    def __init__(self, score_method, agent_state_size, answer_node_emb_size):
        super(Agent, self).__init__()
        self.policy_network = Attn(score_method, answer_node_emb_size,
                                   agent_state_size)
        self.value_network = nn.Sequential(
            nn.Linear(agent_state_size, agent_state_size // 2), nn.Tanh(),
            nn.Linear(agent_state_size // 2, 1))
        # two vector to represent the fraud and non-fraud actions
        self.fraud_embed = Parameter(torch.Tensor(answer_node_emb_size))
        self.fraud_embed.data.uniform_(-1, 1)
        self.non_fraud_embed = Parameter(torch.Tensor(answer_node_emb_size))
        self.non_fraud_embed.data.uniform_(-1, 1)

    def forward(self, agent_state, answer_nodes, graph_node_embedding):
        """
        :param agent_state: (batch_size, agent_state_size)
        :param answer_nodes: (batch_size, answer_node_num)
        :param graph_node_embedding: (batch_size, node_num, node_feature_size)
        :return:
        values: (batch_size,)
        logits: (batch, answer_node_num + 2)
        """
        values = self.value_network(agent_state).squeeze(-1)

        batch_size = agent_state.shape[0]
        answer_node_embedding = batch_embedding_lookup(graph_node_embedding,
                                                       answer_nodes)
        actions_embedding = torch.cat(
            (answer_node_embedding, self.fraud_embed.repeat(batch_size, 1, 1),
             self.non_fraud_embed.repeat(batch_size, 1, 1)),
            dim=1)
        logits = self.policy_network(actions_embedding, agent_state)

        return values, logits
class Manager(nn.Module):
    def __init__(self, score_method, manager_state_size, worker_state_size):
        super(Manager, self).__init__()
        self.policy_network = Attn(score_method, worker_state_size,
                                   manager_state_size)
        self.value_network = nn.Sequential(
            nn.Linear(manager_state_size, manager_state_size // 2), nn.Tanh(),
            nn.Linear(manager_state_size // 2, 1))

        # two vector to represent the fraud and non-fraud actions
        self.fraud_embed = Parameter(torch.Tensor(worker_state_size))
        self.fraud_embed.data.uniform_(-1, 1)
        self.non_fraud_embed = Parameter(torch.Tensor(worker_state_size))
        self.non_fraud_embed.data.uniform_(-1, 1)

    def forward(self, manager_state, workers_state):
        """
        :param manager_state: (batch_size, manager_state_size)
        :param workers_state: (batch_size, personal_node_num, worker_sate_size)
        :return:
        values: (batch_size,)
        logits: (batch_size, personal_node_num + 2)
        """
        values = self.value_network(manager_state).squeeze(-1)

        batch_size = manager_state.shape[0]
        actions_embedding = torch.cat(
            (workers_state, self.fraud_embed.repeat(batch_size, 1, 1),
             self.non_fraud_embed.repeat(batch_size, 1, 1)),
            dim=1)
        logits = self.policy_network(actions_embedding, manager_state)

        return values, logits
Esempio n. 3
0
class PathAttention(nn.Module):
    def __init__(self):
        super(PathAttention, self).__init__()
        self.w1 = Parameter(torch.FloatTensor(embed_size * 4, embed_size * 2))
        self.w2 = Parameter(torch.FloatTensor(head, embed_size * 4))
        self.norm_layer = nn.Softmax(dim=2)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.w1.size(1))
        self.w1.data.uniform_(-stdv, stdv)
        stdv = 1. / math.sqrt(self.w2.size(1))
        self.w2.data.uniform_(-stdv, stdv)

    def forward(self, value, mask):
        #         print(value.shape, self.w1.repeat(value.shape[0], 1, 1).shape, value.transpose(1,2).shape)
        # print(value.transpose(1,2).shape)
        try:
            x = F.relu(
                torch.bmm(self.w1.repeat(value.shape[0], 1, 1),
                          value.transpose(1, 2)))
        except:
            print(value)
        x = torch.bmm(self.w2.repeat(x.shape[0], 1, 1), x)
        mask = torch.unsqueeze(mask, 1).repeat(1, head, 1)
        # print(mask.shape, x.shape)
        x = -1e9 * mask + x
        x = self.norm_layer(x)
        x = torch.bmm(x, value)
        return x
Esempio n. 4
0
class LSTMForwardEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, train_init_state=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.LSTMCell(input_size, hidden_size)
        self.train_init_state=train_init_state

        if self.train_init_state:
            self.init_hidden = Parameter(torch.zeros(1, self.hidden_size))
            self.init_cell = Parameter(torch.zeros(1, self.hidden_size))

    def init_state(self, batch_size):
        if self.train_init_state:
            return (self.init_hidden.repeat(batch_size,1),
                    self.init_cell.repeat(batch_size,1))
        else:
            weight = next(self.parameters())
            return (weight.new_zeros(batch_size, self.hidden_size),
                    weight.new_zeros(batch_size, self.hidden_size))

    def forward(self, sequence):
        seq_len = len(sequence)
        batch_size = sequence[0].size(0)
        state_fwd = self.init_state(batch_size)

        outputs_fwd = []
        for t in range(seq_len):
            state_fwd = self.rnn(sequence[t], state_fwd)
            outputs_fwd += [state_fwd[0]]

        return outputs_fwd
Esempio n. 5
0
class InstanceNorm2d(Module):
    def __init__(self, num_features):
        """ only support batch_size 1 in training phase, but no this limit in testing phase"""
        super(InstanceNorm2d, self).__init__()
        self.num_features = num_features
        self.weight = Parameter(torch.ones(num_features))
        self.bias = Parameter(torch.zeros(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.reset_parameters()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.weight.data.fill_(1)
        self.bias.data.zero_()

    def forward(self, input):
        B, C, H, W = input.size()
        input_view = input.view(1, B * C, H, W)
        output_view = F.batch_norm(input_view,
                                   self.running_mean.repeat(B),
                                   self.running_var.repeat(B),
                                   self.weight.repeat(B),
                                   self.bias.repeat(B),
                                   training=True,
                                   momentum=0.,
                                   eps=1e-5)
        output = output_view.view(B, C, H, W)
        return output
class Workers(nn.Module):
    def __init__(self, score_method, worker_state_size, answer_node_emb_size):
        super(Workers, self).__init__()
        self.policy_networks = Attn(score_method, answer_node_emb_size,
                                    worker_state_size)
        self.value_network = nn.Sequential(
            nn.Linear(worker_state_size, worker_state_size // 2), nn.Tanh(),
            nn.Linear(worker_state_size // 2, 1))

        # two vector to represent the fraud and non-fraud actions
        self.fraud_embed = Parameter(torch.Tensor(answer_node_emb_size))
        self.fraud_embed.data.uniform_(-1, 1)
        self.non_fraud_embed = Parameter(torch.Tensor(answer_node_emb_size))
        self.non_fraud_embed.data.uniform_(-1, 1)

    def forward(self, workers_state, answer_nodes, graph_node_embedding):
        """
        :param workers_state: (batch_size, personal_node_num, worker_state_size)
        :param answer_nodes: (batch_size, personal_node_num, answer_node_num)
        :param graph_node_embedding: (batch_size, node_num, node_feature_size)
        :return:
        values: (batch_size, personal_node_num)
        logits: (batch_size, personal_node_num, answer_node_num + 2)
        """
        values = self.value_network(workers_state).squeeze(-1)

        batch_size = answer_nodes.shape[0]
        personal_node_num = answer_nodes.shape[1]
        answer_node_num = answer_nodes.shape[2]

        answer_nodes = answer_nodes.reshape(batch_size, -1)
        answer_node_embedding = batch_embedding_lookup(graph_node_embedding,
                                                       answer_nodes)
        answer_node_embedding = answer_node_embedding.reshape(
            batch_size, personal_node_num, answer_node_num, -1)
        actions_embedding = torch.cat(
            (answer_node_embedding,
             self.fraud_embed.repeat(batch_size, personal_node_num, 1, 1),
             self.non_fraud_embed.repeat(batch_size, personal_node_num, 1, 1)),
            dim=2)

        workers_state = workers_state.reshape(batch_size * personal_node_num,
                                              -1)
        actions_embedding = actions_embedding.reshape(
            batch_size * personal_node_num, answer_node_num + 2, -1)

        logits = self.policy_networks(actions_embedding, workers_state)
        logits = logits.reshape(batch_size, personal_node_num, -1)
        return values, logits
Esempio n. 7
0
def get_sinusoid_encoding_table_vocab(n_src_vocab, d_hid, padding_idx=None):
    ''' Sinusoid position encoding table '''
    def cal_angle(position, hid_idx):
        return 1 / np.power(10000, 2 * (hid_idx // 2) / n_src_vocab)

    def get_posi_angle_vec(position):
        return [cal_angle(position, hid_j) for hid_j in range(n_src_vocab)]

    # sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_src_vocab)])
    sinusoid_table = np.array([get_posi_angle_vec(1)])
    sinusoid_table = torch.FloatTensor(sinusoid_table)

    enc_output_phase = Parameter(sinusoid_table, requires_grad=True)  #虚部向量

    sinusoid_table = enc_output_phase.repeat(d_hid, 1)
    sinusoid_table = sinusoid_table.reshape(n_src_vocab, d_hid)

    # sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    # sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    # if padding_idx is not None:
    #     # zero vector for padding dimension
    #     sinusoid_table[padding_idx] = 1.

    return torch.FloatTensor(sinusoid_table)
Esempio n. 8
0
class SequenceBias(nn.Module):
    """ Adds one bias element to the end of the sequence
    Args:
        embed_dim: Embedding dimension

    Shape:
        - Input: (L, N, E), where
            L - sequence length, N - batch size, E - embedding dimension
        - Output: (L+1, N, E), where
            L - sequence length, N - batch size, E - embedding dimension

    Attributes:
        bias:   the learnable bias of the module of shape (E),
            where E - embedding dimension

    Examples::

        >>> m = SequenceBias(16)
        >>> input = torch.randn(20, 4, 16)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([21, 4, 16])
    """
    def __init__(self, embed_dim):
        super(SequenceBias, self).__init__()

        self.bias = Parameter(torch.empty(embed_dim))
        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.normal_(self.bias)

    def forward(self, x):
        _, bsz, _ = x.shape
        return torch.cat([x, self.bias.repeat(1, bsz, 1)])
Esempio n. 9
0
class TransformerNet(M.Model):
    def initialize(self,
                   num_enc,
                   num_heads,
                   dim_per_head,
                   latent_token=True,
                   drop=0.2):
        self.latent_token = latent_token
        self.posemb = PositionalEmbedding()
        self.trans_blocks = nn.ModuleList()
        for i in range(num_enc):
            self.trans_blocks.append(
                Transformer(num_heads, dim_per_head, drop=drop))

    def build(self, *inputs):
        indim = inputs[0].shape[2]
        if self.latent_token:
            self.token = Parameter(torch.zeros(1, 1, indim))

    def forward(self, x):
        if self.latent_token:
            token = self.token.repeat(x.shape[0], 1, 1)
            x = torch.cat([token, x], dim=1)
        x = self.posemb(x)
        for trans in self.trans_blocks:
            x = trans(x)
        return x
Esempio n. 10
0
class BN2d_slow(nn.Module):
    def __init__(self, num_features, momentum=0.01):
        super(BN2d_slow, self).__init__()
        self.num_features = num_features
        self.weight = Parameter(torch.Tensor(num_features))
        self.bias = Parameter(torch.Tensor(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.zeros(num_features))
        self.eps = 1e-5
        self.momentum = momentum

        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.weight.data.uniform_()
        self.bias.data.zero_()

    def forward(self, x):
        nB = x.data.size(0)
        nC = x.data.size(1)
        nH = x.data.size(2)
        nW = x.data.size(3)
        samples = nB * nH * nW
        y = x.view(nB, nC, nH * nW).transpose(1, 2).contiguous().view(-1, nC)
        if self.training:
            print('forward in training mode on autograd')
            m = Variable(y.mean(0).data, requires_grad=False)
            v = Variable(y.var(0).data, requires_grad=False)
            self.running_mean = (
                1 - self.momentum
            ) * self.running_mean + self.momentum * m.data.view(-1)
            self.running_var = (
                1 - self.momentum
            ) * self.running_var + self.momentum * v.data.view(-1)
            m = m.repeat(samples, 1)
            v = v.repeat(samples, 1) * (samples - 1.0) / samples
        else:
            m = Variable(self.running_mean.repeat(samples, 1),
                         requires_grad=False)
            v = Variable(self.running_var.repeat(samples, 1),
                         requires_grad=False)
        w = self.weight.repeat(samples, 1)
        b = self.bias.repeat(samples, 1)
        y = (y - m) / (v + self.eps).sqrt() * w + b
        y = y.view(nB, nH * nW,
                   nC).transpose(1, 2).contiguous().view(nB, nC, nH, nW)
        return y
Esempio n. 11
0
class MaskLayer(nn.Module):
    def __init__(self):
        super(MaskLayer, self).__init__()

        self.mask = Parameter(torch.ones(1, 128))

    def forward(self, x):
        return x * self.mask.repeat(x.size(0), 1)
Esempio n. 12
0
class GeneralizedMeanPoolingManyP(GeneralizedMeanPooling):
  """ One p for each filter
  """
  def __init__(self, n_filters, norm=3, output_size=1, eps=1e-6):
    super().__init__(norm, output_size, eps)
    self.p = Parameter(torch.ones(n_filters) * norm)

  def forward(self, x):
    p = self.p.repeat(x.size(0), 1).unsqueeze(-1).unsqueeze(-1)
    x = x.clamp(min=self.eps).pow(p)
    return F.adaptive_avg_pool2d(x, self.output_size).pow(1. / p)
class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, out_features, bias=None):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features)).cuda()
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        init.xavier_uniform(self.weight.data, gain=1)
        self.weight.data.uniform_(-stdv, stdv)  #随机
        if self.bias is not None:
            init.xavier_uniform(self.bias.data, gain=1)
            self.bias.data.uniform_(-stdv, stdv)  #随机

    def forward(self, input, adj):
        #   []
        weight_matrix = self.weight.repeat(input.shape[0], 1, 1)
        support = torch.bmm(input, weight_matrix)
        #print(adj.shape)
        #print(type(adj))
        output = SparseMM(adj)(support)
        if self.bias is not None:
            return output + self.bias.repeat(output.size(0))
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Esempio n. 14
0
class SequenceBias(nn.Module):
    r"""
    Adds one bias element to the end of the sequence.
    so if the input has a shape ``(L, N, E)``, where
    ``L`` is the sequence length, ``N`` is the batch size, and ``E`` is
    the embedding dimension, the output will have a shape
    ``(L+1, N, E)``.

    Attributes
    ------------
    bias: :class:`torch.nn.parameter.Parameter`
        the learnable bias of the module of shape ``(E)``,
        where ``E`` is the embedding dimension.

    Example
    -------
        >>> m = SequenceBias(16)
        >>> input = torch.randn(20, 4, 16)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([21, 4, 16])
    """

    def __init__(self, embed_dim):
        r"""
        Parameters
        ----------
        embed_dim: int
            Embedding dimension
        """
        super(SequenceBias, self).__init__()

        self.bias = Parameter(torch.empty(embed_dim))
        self._reset_parameters()

    def _reset_parameters(self):
        r"""
        assing's Normally distributed random values to bias.
        """
        nn.init.normal_(self.bias)

    def forward(self, x):
        _, bsz, _ = x.shape
        return torch.cat([x, self.bias.repeat(1, bsz, 1)])
Esempio n. 15
0
class Gaussian3DBase(nn.Module):
    def __init__(self, w, n_gaussian, depth):
        super(Gaussian3DBase, self).__init__()
        assert 1 == w % 2, "'w' must be even 3,5,7,9,11 etc."
        assert 1 == depth % 2, "'depth' must be even 3,5,7,9,11 etc."

        self.xes = torch.FloatTensor(range(int(-w / 2),
                                           int(w / 2) + 1)).unsqueeze(-1)**2
        self.xes = self.xes.repeat(depth, self.xes.size(0), 1, n_gaussian)
        self.yes = self.xes.transpose(1, 2)
        self.zes = torch.FloatTensor(range(
            int(-depth / 2),
            int(depth / 2) + 1)).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)**2

        self.xes = Parameter(self.xes, requires_grad=False)
        self.yes = Parameter(self.yes, requires_grad=False)
        self.zes = Parameter(self.zes.repeat(1, self.xes.size(1),
                                             self.xes.size(2), n_gaussian),
                             requires_grad=False)
Esempio n. 16
0
File: network.py Progetto: zlapp/ec
class Network(nn.Module):
    """
    Todo:
    - Beam search
    - check if this is right? attend during P->FC rather than during softmax->P?
    - allow length 0 inputs/targets
    - give n_examples as input to FC
    - Initialise new weights randomly, rather than as zeroes
    """
    def __init__(self,
                 input_vocabulary,
                 target_vocabulary,
                 hidden_size=512,
                 embedding_size=128,
                 cell_type="LSTM"):
        """
        :param list input_vocabulary: list of possible inputs
        :param list target_vocabulary: list of possible targets
        """
        super(Network, self).__init__()
        self.h_input_encoder_size = hidden_size
        self.h_output_encoder_size = hidden_size
        self.h_decoder_size = hidden_size
        self.embedding_size = embedding_size
        self.input_vocabulary = input_vocabulary
        self.target_vocabulary = target_vocabulary
        # Number of tokens in input vocabulary
        self.v_input = len(input_vocabulary)
        # Number of tokens in target vocabulary
        self.v_target = len(target_vocabulary)

        self.cell_type = cell_type
        if cell_type == 'GRU':
            self.input_encoder_cell = nn.GRUCell(
                input_size=self.v_input + 1,
                hidden_size=self.h_input_encoder_size,
                bias=True)
            self.input_encoder_init = Parameter(
                torch.rand(1, self.h_input_encoder_size))
            self.output_encoder_cell = nn.GRUCell(
                input_size=self.v_input + 1 + self.h_input_encoder_size,
                hidden_size=self.h_output_encoder_size,
                bias=True)
            self.decoder_cell = nn.GRUCell(input_size=self.v_target + 1,
                                           hidden_size=self.h_decoder_size,
                                           bias=True)
        if cell_type == 'LSTM':
            self.input_encoder_cell = nn.LSTMCell(
                input_size=self.v_input + 1,
                hidden_size=self.h_input_encoder_size,
                bias=True)
            self.input_encoder_init = nn.ParameterList([
                Parameter(torch.rand(1, self.h_input_encoder_size)),
                Parameter(torch.rand(1, self.h_input_encoder_size))
            ])
            self.output_encoder_cell = nn.LSTMCell(
                input_size=self.v_input + 1 + self.h_input_encoder_size,
                hidden_size=self.h_output_encoder_size,
                bias=True)
            self.output_encoder_init_c = Parameter(
                torch.rand(1, self.h_output_encoder_size))
            self.decoder_cell = nn.LSTMCell(input_size=self.v_target + 1,
                                            hidden_size=self.h_decoder_size,
                                            bias=True)
            self.decoder_init_c = Parameter(torch.rand(1, self.h_decoder_size))

        self.W = nn.Linear(self.h_output_encoder_size + self.h_decoder_size,
                           self.embedding_size)
        self.V = nn.Linear(self.embedding_size, self.v_target + 1)
        self.input_A = nn.Bilinear(self.h_input_encoder_size,
                                   self.h_output_encoder_size,
                                   1,
                                   bias=False)
        self.output_A = nn.Bilinear(self.h_output_encoder_size,
                                    self.h_decoder_size,
                                    1,
                                    bias=False)
        self.input_EOS = torch.zeros(1, self.v_input + 1)
        self.input_EOS[:, -1] = 1
        self.input_EOS = Parameter(self.input_EOS)
        self.output_EOS = torch.zeros(1, self.v_input + 1)
        self.output_EOS[:, -1] = 1
        self.output_EOS = Parameter(self.output_EOS)
        self.target_EOS = torch.zeros(1, self.v_target + 1)
        self.target_EOS[:, -1] = 1
        self.target_EOS = Parameter(self.target_EOS)

    def __getstate__(self):
        if hasattr(self, 'opt'):
            return dict([(k, v)
                         for k, v in self.__dict__.items() if k is not 'opt'] +
                        [('optstate', self.opt.state_dict())])
            # return {**{k:v for k,v in self.__dict__.items() if k is not 'opt'},
            #         'optstate': self.opt.state_dict()}
        else:
            return self.__dict__

    def __setstate__(self, state):
        self.__dict__.update(state)
        # Legacy:
        if isinstance(self.input_encoder_init, tuple):
            self.input_encoder_init = nn.ParameterList(
                list(self.input_encoder_init))

    def clear_optimiser(self):
        if hasattr(self, 'opt'):
            del self.opt
        if hasattr(self, 'optstate'):
            del self.optstate

    def get_optimiser(self):
        self.opt = torch.optim.Adam(self.parameters(), lr=0.001)
        if hasattr(self, 'optstate'):
            self.opt.load_state_dict(self.optstate)

    def optimiser_step(self, inputs, outputs, target):
        if not hasattr(self, 'opt'):
            self.get_optimiser()
        score = self.score(inputs, outputs, target, autograd=True).mean()
        (-score).backward()
        self.opt.step()
        self.opt.zero_grad()
        return score.data[0]

    def set_target_vocabulary(self, target_vocabulary):
        if target_vocabulary == self.target_vocabulary:
            return

        V_weight = []
        V_bias = []
        decoder_ih = []

        for i in range(len(target_vocabulary)):
            if target_vocabulary[i] in self.target_vocabulary:
                j = self.target_vocabulary.index(target_vocabulary[i])
                V_weight.append(self.V.weight.data[j:j + 1])
                V_bias.append(self.V.bias.data[j:j + 1])
                decoder_ih.append(self.decoder_cell.weight_ih.data[:, j:j + 1])
            else:
                V_weight.append(torch.zeros(1, self.V.weight.size(1)))
                V_bias.append(torch.ones(1) * -10)
                decoder_ih.append(
                    torch.zeros(self.decoder_cell.weight_ih.data.size(0), 1))

        V_weight.append(self.V.weight.data[-1:])
        V_bias.append(self.V.bias.data[-1:])
        decoder_ih.append(self.decoder_cell.weight_ih.data[:, -1:])

        self.target_vocabulary = target_vocabulary
        self.v_target = len(target_vocabulary)
        self.target_EOS.data = torch.zeros(1, self.v_target + 1)
        self.target_EOS.data[:, -1] = 1

        self.V.weight.data = torch.cat(V_weight, dim=0)
        self.V.bias.data = torch.cat(V_bias, dim=0)
        self.V.out_features = self.V.bias.data.size(0)

        self.decoder_cell.weight_ih.data = torch.cat(decoder_ih, dim=1)
        self.decoder_cell.input_size = self.decoder_cell.weight_ih.data.size(1)

        self.clear_optimiser()

    def input_encoder_get_init(self, batch_size):
        if self.cell_type == "GRU":
            return self.input_encoder_init.repeat(batch_size, 1)
        if self.cell_type == "LSTM":
            return tuple(
                x.repeat(batch_size, 1) for x in self.input_encoder_init)

    def output_encoder_get_init(self, input_encoder_h):
        if self.cell_type == "GRU":
            return input_encoder_h
        if self.cell_type == "LSTM":
            return (input_encoder_h,
                    self.output_encoder_init_c.repeat(input_encoder_h.size(0),
                                                      1))

    def decoder_get_init(self, output_encoder_h):
        if self.cell_type == "GRU":
            return output_encoder_h
        if self.cell_type == "LSTM":
            return (output_encoder_h,
                    self.decoder_init_c.repeat(output_encoder_h.size(0), 1))

    def cell_get_h(self, cell_state):
        if self.cell_type == "GRU":
            return cell_state
        if self.cell_type == "LSTM":
            return cell_state[0]

    def score(self, inputs, outputs, target, autograd=False):
        inputs = self.inputsToTensors(inputs)
        outputs = self.inputsToTensors(outputs)
        target = self.targetToTensor(target)
        target, score = self.run(inputs, outputs, target=target, mode="score")
        # target = self.tensorToOutput(target)
        if autograd:
            return score
        else:
            return score.data

    def sample(self, inputs, outputs):
        inputs = self.inputsToTensors(inputs)
        outputs = self.inputsToTensors(outputs)
        target, score = self.run(inputs, outputs, mode="sample")
        target = self.tensorToOutput(target)
        return target

    def sampleAndScore(self, inputs, outputs, nRepeats=None):
        inputs = self.inputsToTensors(inputs)
        outputs = self.inputsToTensors(outputs)
        if nRepeats is None:
            target, score = self.run(inputs, outputs, mode="sample")
            target = self.tensorToOutput(target)
            return target, score.data
        else:
            target = []
            score = []
            for i in range(nRepeats):
                # print("repeat %d" % i)
                t, s = self.run(inputs, outputs, mode="sample")
                t = self.tensorToOutput(t)
                target.extend(t)
                score.extend(list(s.data))
            return target, score

    def run(self, inputs, outputs, target=None, mode="sample"):
        """
        :param mode: "score" returns log p(target|input), "sample" returns target ~ p(-|input)
        :param List[LongTensor] inputs: n_examples * (max_length_input * batch_size)
        :param List[LongTensor] target: max_length_target * batch_size
        """
        assert ((mode == "score" and target is not None) or mode == "sample")

        n_examples = len(inputs)
        max_length_input = [inputs[j].size(0) for j in range(n_examples)]
        max_length_output = [outputs[j].size(0) for j in range(n_examples)]
        max_length_target = target.size(0) if target is not None else 10
        batch_size = inputs[0].size(1)

        score = Variable(torch.zeros(batch_size))
        inputs_scatter = [
            Variable(
                torch.zeros(max_length_input[j], batch_size,
                            self.v_input + 1).scatter_(2, inputs[j][:, :,
                                                                    None], 1))
            for j in range(n_examples)
        ]  # n_examples * (max_length_input * batch_size * v_input+1)
        outputs_scatter = [
            Variable(
                torch.zeros(max_length_output[j], batch_size,
                            self.v_input + 1).scatter_(2, outputs[j][:, :,
                                                                     None], 1))
            for j in range(n_examples)
        ]  # n_examples * (max_length_output * batch_size * v_input+1)
        if target is not None:
            target_scatter = Variable(
                torch.zeros(
                    max_length_target, batch_size, self.v_target + 1).scatter_(
                        2, target[:, :, None],
                        1))  # max_length_target * batch_size * v_target+1

        # -------------- Input Encoder -------------

        # n_examples * (max_length_input * batch_size * h_encoder_size)
        input_H = []
        input_embeddings = []  # h for example at INPUT_EOS
        # 0 until (and including) INPUT_EOS, then -inf
        input_attention_mask = []
        for j in range(n_examples):
            active = torch.Tensor(max_length_input[j], batch_size).byte()
            active[0, :] = 1
            state = self.input_encoder_get_init(batch_size)
            hs = []
            for i in range(max_length_input[j]):
                state = self.input_encoder_cell(inputs_scatter[j][i, :, :],
                                                state)
                if i + 1 < max_length_input[j]:
                    active[i + 1, :] = active[i, :] * \
                        (inputs[j][i, :] != self.v_input)
                h = self.cell_get_h(state)
                hs.append(h[None, :, :])
            input_H.append(torch.cat(hs, 0))
            embedding_idx = active.sum(0).long() - 1
            embedding = input_H[j].gather(
                0,
                Variable(embedding_idx[None, :, None].repeat(
                    1, 1, self.h_input_encoder_size)))[0]
            input_embeddings.append(embedding)
            input_attention_mask.append(Variable(active.float().log()))

        # -------------- Output Encoder -------------

        def input_attend(j, h_out):
            """
            'general' attention from https://arxiv.org/pdf/1508.04025.pdf
            :param j: Index of example
            :param h_out: batch_size * h_output_encoder_size
            """
            scores = self.input_A(
                input_H[j].view(max_length_input[j] * batch_size,
                                self.h_input_encoder_size),
                h_out.view(batch_size, self.h_output_encoder_size).repeat(
                    max_length_input[j],
                    1)).view(max_length_input[j],
                             batch_size) + input_attention_mask[j]
            c = (F.softmax(scores[:, :, None], dim=0) * input_H[j]).sum(0)
            return c

        # n_examples * (max_length_input * batch_size * h_encoder_size)
        output_H = []
        output_embeddings = []  # h for example at INPUT_EOS
        # 0 until (and including) INPUT_EOS, then -inf
        output_attention_mask = []
        for j in range(n_examples):
            active = torch.Tensor(max_length_output[j], batch_size).byte()
            active[0, :] = 1
            state = self.output_encoder_get_init(input_embeddings[j])
            hs = []
            h = self.cell_get_h(state)
            for i in range(max_length_output[j]):
                state = self.output_encoder_cell(
                    torch.cat(
                        [outputs_scatter[j][i, :, :],
                         input_attend(j, h)], 1), state)
                if i + 1 < max_length_output[j]:
                    active[i + 1, :] = active[i, :] * \
                        (outputs[j][i, :] != self.v_input)
                h = self.cell_get_h(state)
                hs.append(h[None, :, :])
            output_H.append(torch.cat(hs, 0))
            embedding_idx = active.sum(0).long() - 1
            embedding = output_H[j].gather(
                0,
                Variable(embedding_idx[None, :, None].repeat(
                    1, 1, self.h_output_encoder_size)))[0]
            output_embeddings.append(embedding)
            output_attention_mask.append(Variable(active.float().log()))

        # ------------------ Decoder -----------------

        def output_attend(j, h_dec):
            """
            'general' attention from https://arxiv.org/pdf/1508.04025.pdf
            :param j: Index of example
            :param h_dec: batch_size * h_decoder_size
            """
            scores = self.output_A(
                output_H[j].view(max_length_output[j] * batch_size,
                                 self.h_output_encoder_size),
                h_dec.view(batch_size, self.h_decoder_size).repeat(
                    max_length_output[j],
                    1)).view(max_length_output[j],
                             batch_size) + output_attention_mask[j]
            c = (F.softmax(scores[:, :, None], dim=0) * output_H[j]).sum(0)
            return c

        # Multi-example pooling: Figure 3, https://arxiv.org/pdf/1703.07469.pdf
        target = target if mode == "score" else torch.zeros(
            max_length_target, batch_size).long()
        decoder_states = [
            self.decoder_get_init(output_embeddings[j])
            for j in range(n_examples)
        ]  # P
        active = torch.ones(batch_size).byte()
        for i in range(max_length_target):
            FC = []
            for j in range(n_examples):
                h = self.cell_get_h(decoder_states[j])
                p_aug = torch.cat([h, output_attend(j, h)], 1)
                FC.append(F.tanh(self.W(p_aug)[None, :, :]))
            # batch_size * embedding_size
            m = torch.max(torch.cat(FC, 0), 0)[0]
            logsoftmax = F.log_softmax(self.V(m), dim=1)
            if mode == "sample":
                target[i, :] = torch.multinomial(logsoftmax.data.exp(), 1)[:,
                                                                           0]
            score = score + \
                choose(logsoftmax, target[i, :]) * Variable(active.float())
            active *= (target[i, :] != self.v_target)
            for j in range(n_examples):
                if mode == "score":
                    target_char_scatter = target_scatter[i, :, :]
                elif mode == "sample":
                    target_char_scatter = Variable(
                        torch.zeros(batch_size, self.v_target + 1).scatter_(
                            1, target[i, :, None], 1))
                decoder_states[j] = self.decoder_cell(target_char_scatter,
                                                      decoder_states[j])
        return target, score

    def inputsToTensors(self, inputss):
        """
        :param inputss: size = nBatch * nExamples
        """
        tensors = []
        for j in range(len(inputss[0])):
            inputs = [x[j] for x in inputss]
            maxlen = max(len(s) for s in inputs)
            t = torch.ones(1 if maxlen == 0 else maxlen + 1,
                           len(inputs)).long() * self.v_input
            for i in range(len(inputs)):
                s = inputs[i]
                if len(s) > 0:
                    t[:len(s), i] = torch.LongTensor(
                        [self.input_vocabulary.index(x) for x in s])
            tensors.append(t)
        return tensors

    def targetToTensor(self, targets):
        """
        :param targets:
        """
        maxlen = max(len(s) for s in targets)
        t = torch.ones(1 if maxlen == 0 else maxlen + 1,
                       len(targets)).long() * self.v_target
        for i in range(len(targets)):
            s = targets[i]
            if len(s) > 0:
                t[:len(s), i] = torch.LongTensor(
                    [self.target_vocabulary.index(x) for x in s])
        return t

    def tensorToOutput(self, tensor):
        """
        :param tensor: max_length * batch_size
        """
        out = []
        for i in range(tensor.size(1)):
            l = tensor[:, i].tolist()
            if l[0] == self.v_target:
                out.append([])
            elif self.v_target in l:
                final = tensor[:, i].tolist().index(self.v_target)
                out.append(
                    [self.target_vocabulary[x] for x in tensor[:final, i]])
            else:
                out.append([self.target_vocabulary[x] for x in tensor[:, i]])
        return out
Esempio n. 17
0
class LinearGroupNJ(Module):
    """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):

        super(LinearGroupNJ, self).__init__()
        self.cuda = cuda
        self.in_features = in_features
        self.out_features = out_features
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference
        # trainable params according to Eq.(6)
        # dropout params
        self.z_mu = Parameter(torch.Tensor(in_features))
        self.z_logvar = Parameter(torch.Tensor(in_features))  # = z_mu^2 * alpha
        # weight params
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))

        self.bias_mu = Parameter(torch.Tensor(out_features))
        self.bias_logvar = Parameter(torch.Tensor(out_features))

        # init params either random or with pretrained net
        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        stdv = 1. / math.sqrt(self.weight_mu.size(1))

        self.z_mu.data.normal_(1, 1e-2)

        if init_weight is not None:
            self.weight_mu.data = torch.Tensor(init_weight)
        else:
            self.weight_mu.data.normal_(0, stdv)

        if init_bias is not None:
            self.bias_mu.data = torch.Tensor(init_bias)
        else:
            self.bias_mu.data.fill_(0)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        return self.post_weight_mu, self.post_weight_var

    def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.linear(x, self.post_weight_mu, self.bias_mu)

        batch_size = x.size()[0]
        # compute z  
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training)

        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        xz = x * z
        mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
        var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())

        return reparametrize(mu_activations, var_activations.log(), sampling=self.training)

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Esempio n. 18
0
class EBP_binaryNet(nn.Module):
    def __init__(self, H, drop_prb):
        super(EBP_binaryNet, self).__init__()
        self.drop_prob = drop_prb
        self.sq2pi = 0.797884560

        self.hidden = H
        self.D_out = 10
        self.w0 = Parameter(torch.Tensor(28 * 28,  self.hidden))
        stdv = 1. / math.sqrt(self.w0.data.size(1))
        self.w0.data= self.w0.data.uniform_(-stdv, stdv)
        self.w1 = Parameter(torch.Tensor(self.hidden, self.hidden))
        stdv = 1. / math.sqrt(self.w1.data.size(1))
        self.w1.data= self.w1.data.uniform_(-stdv, stdv)

        self.w2 = Parameter(torch.Tensor(self.hidden, self.hidden))
        stdv = 1. / math.sqrt(self.w2.data.size(1))
        self.w2.data= self.w2.data.uniform_(-stdv, stdv)

        self.w3 = Parameter(torch.Tensor(self.hidden, self.hidden))
        stdv = 1. / math.sqrt(self.w3.data.size(1))
        self.w3.data= self.w3.data.uniform_(-stdv, stdv)

        self.wlast = Parameter(torch.Tensor( self.hidden, self.D_out))
        stdv = 1. / math.sqrt(self.wlast.data.size(1))
        self.wlast.data= self.wlast.data.uniform_(-stdv, stdv)

        self.th0 = Parameter(torch.zeros(1,self.hidden))
        self.th1 = Parameter(torch.zeros(1,self.hidden))
        self.th2 = Parameter(torch.zeros(1,self.hidden))
        self.thlast = Parameter(torch.zeros(1,self.D_out))


    def EBP_layer(self, xbar, xcov,m,th):
        #recieves neuron means and covariance, returns next layer means and covariances
        bn = nn.BatchNorm1d(xbar.size()[1], affine=False)
        bn.cuda()
        M, H= xbar.size()
        sigma = torch.diag(torch.sum(1 - m ** 2, 1)).repeat(M, 1, 1)
        tem = sigma.clone().resize(M, H * H)
        diagsig2 = tem[:, ::(H + 1)]

        hbar = xbar.mm(m) + th.repeat(xbar.size()[0], 1)  # numerator of input to sigmoid non-linearity
        h = self.sq2pi * hbar / torch.sqrt(diagsig2)
        xbar_next = torch.tanh(bn(h))  # this is equal to 2*torch.sigmoid(2*h1)-1 - NEED THE 2 in the argument!

        # x covariance across layer 2
        xc2 = (1 - xbar_next ** 2)
        xcov_next = Variable(torch.eye(H).cuda())[None, :, :] * (1 - xbar_next[:, None, :] ** 2)

        return xbar_next, xcov_next

    def expected_loss(self, target, forward_result):
        (a2, logprobs_out) = forward_result
        return F.nll_loss(logprobs_out, target)


    def forward(self, x, target):
        m0 = 2 * F.sigmoid(self.w0) - 1
        m1 = 2 * torch.sigmoid(self.w1) - 1
        m2 = 2 * torch.sigmoid(self.w2) - 1
        m3 = 2 * torch.sigmoid(self.w3) - 1
        mlast = 2 * torch.sigmoid(self.wlast) - 1
        sq2pi = 0.797884560
        dtype = torch.FloatTensor

        H = self.hidden
        D_out = self.D_out
        x = x.view(-1, 28 * 28)
        y = target[:, None]
        M = x.size()[0]
        M_double = M*1.0
        #x0_do = do0(x)
        x0_do = F.dropout(x,p=0.2, training=self.training)
        #diagsig1 = Variable(torch.cuda.FloatTensor(M, H))
        sigma_1 = torch.diag(torch.sum(1 - m0 ** 2, 0)).repeat(M, 1, 1)  # + sigma_1[:,:,m]
        #for m in range(M):
        #    diagsig1[m, :] = torch.diag(sigma_1[m, :, :])
        tem = sigma_1.clone().resize(M, H * H)
        diagsig1 = tem[:, ::(H + 1)]
        h1bar = x0_do.mm(m0) + self.th0.repeat(x.size()[0], 1)  # numerator of input to sigmoid non-linearity
        h1 = sq2pi * h1bar / torch.sqrt(diagsig1)  #

        bn = nn.BatchNorm1d(h1.size()[1], affine=False)
        bn.cuda()
        x1bar = torch.tanh(bn(h1))  #
        ey =  Variable(torch.eye(H).cuda())
        xcov_1 = ey[None, :, :] * ( 1 - x1bar[:, None, :] ** 2)  # diagonal of the layer covariance - ie. the var of neuron i

        '''NEW LAYER FUNCTION'''
        #x2bar, xcov_2 = self.EBP_layer(do1(x1bar), xcov_1,m1)
        x2bar, xcov_2 = self.EBP_layer(F.dropout(x1bar,p =  self.drop_prob,training=self.training), xcov_1,m1,self.th1)
        #x3bar, xcov_3 = self.EBP_layer(do2(x2bar), xcov_2,m2)
        #x4bar, xcov_4 = self.EBP_layer(do3(x2bar), xcov_2,m3)
        x4bar, xcov_4 = self.EBP_layer(F.dropout(x2bar, p =  self.drop_prob,training=self.training), xcov_2,m3, self.th1)

        hlastbar = (x4bar.mm(mlast) + self.thlast.repeat(x1bar.size()[0], 1))

        y_temp = torch.FloatTensor(M, 10)
        y_temp.zero_()

        y_onehot = Variable(y_temp.scatter_(1, y.data.cpu(), 1))
        print(hlastbar)
        logprobs_out = F.log_softmax(hlastbar)
        val, ind = torch.max(hlastbar, 1)
        tem = y.type(dtype) - ind.type(dtype)[:, None]
        fraction_correct = (M_double - torch.sum((tem != 0)).type(dtype)) / M_double
        expected_loss =  self.expected_loss(target, (hlastbar, logprobs_out))
        return ((hlastbar, logprobs_out)), expected_loss, fraction_correct
class CompProbModel(torch.nn.Module):
    def __init__(self,
                 a_max=7.25,
                 s_max=9.25,
                 avg_ball_speed=20.0,
                 tti_sigma=0.5,
                 tti_lambda_off=1.0,
                 tti_lambda_def=1.0,
                 ppc_alpha=1.0,
                 tuning=None,
                 use_ppc=False):
        super().__init__()
        # define self.tuning
        self.tuning = tuning

        # define parameters and whether or not to optimize
        self.tti_sigma = Parameter(
            torch.tensor([tti_sigma]),
            requires_grad=(self.tuning == TuningParam.sigma)).float()
        self.tti_lambda_off = Parameter(
            torch.tensor([tti_lambda_off]),
            requires_grad=(self.tuning == TuningParam.lamb)).float()
        self.tti_lambda_def = Parameter(
            torch.tensor([tti_lambda_def]),
            requires_grad=(self.tuning == TuningParam.lamb)).float()
        self.ppc_alpha = Parameter(
            torch.tensor([ppc_alpha]),
            requires_grad=(self.tuning == TuningParam.alpha)).float()
        self.a_max = Parameter(
            torch.tensor([a_max]),
            requires_grad=(self.tuning == TuningParam.av)).float()
        self.s_max = Parameter(
            torch.tensor([s_max]),
            requires_grad=(self.tuning == TuningParam.av)).float()
        self.reax_t = Parameter(torch.tensor([0.2])).float()
        self.avg_ball_speed = Parameter(torch.tensor([avg_ball_speed]),
                                        requires_grad=False).float()
        self.g = Parameter(torch.tensor([10.72468]),
                           requires_grad=False)  #y/s/s
        self.z_max = Parameter(torch.tensor([3.]), requires_grad=False)
        self.z_min = Parameter(torch.tensor([0.]), requires_grad=False)
        self.use_ppc = use_ppc
        self.zero_cuda = Parameter(torch.tensor([0.0], dtype=torch.float32),
                                   requires_grad=False)

        # define field grid
        self.x = torch.linspace(0.5, 119.5, 120).float()
        self.y = torch.linspace(-0.5, 53.5, 55).float()
        self.y[0] = -0.2
        self.yy, self.xx = torch.meshgrid(self.y, self.x)
        self.field_locs = Parameter(torch.flatten(torch.stack(
            (self.xx, self.yy), dim=-1),
                                                  end_dim=-2),
                                    requires_grad=False)  # (F, 2)
        self.T = Parameter(torch.linspace(0.1, 4, 40),
                           requires_grad=False)  # (T,)

        # for hist trans prob
        self.hist_x_min, self.hist_x_max = -9, 70
        self.hist_y_min, self.hist_y_max = -39, 40
        self.hist_t_min, self.hist_t_max = 10, 63
        self.T_given_Ls_df = pd.read_pickle('in/T_given_L.pkl')

    def get_hist_trans_prob(self, frame):
        B = len(frame)
        """ P(L|t) """
        ball_start = frame[:, 0, 8:10]  # (B, 2)
        ball_start_ind = torch.round(ball_start).long()
        reach_vecs = self.field_locs.unsqueeze(0) - ball_start.unsqueeze(
            1)  # (B, F, 2)
        # mask for zeroing out parts of the field that are too far to be thrown to per the L_given_t model
        L_t_mask = torch.zeros(B, *self.xx.shape)  # (B, Y, X)
        b_zeros = torch.zeros(ball_start_ind.shape[0])
        b_ones = torch.ones(ball_start_ind.shape[0])
        for bb in range(B):
            L_t_mask[bb, max(0, ball_start_ind[bb,1]+self.hist_y_min):\
                        min(len(self.y)-1, ball_start_ind[bb,1]+self.hist_y_max),\
                     max(0, ball_start_ind[bb,0]+self.hist_x_min):\
                        min(len(self.x)-1, ball_start_ind[bb,0]+self.hist_x_max)] = 1.
        L_t_mask = L_t_mask.flatten(1)  # (B, F)
        L_given_t = L_t_mask  #changed L_given_t to uniform after discussion
        # renormalize since part of L|t may have been off field
        L_given_t /= L_given_t.sum(1, keepdim=True)  # (B, F)
        """ P(T|L) """
        # we find T|L for sufficiently close spots (1 < L <= 60)
        reach_dist_int = torch.round(torch.linalg.norm(
            reach_vecs, dim=-1)).long()  # (B, F)
        reach_dist_in_bounds_idx = (reach_dist_int > 1) & (reach_dist_int <=
                                                           60)
        reach_dist_in_bounds = reach_dist_int[
            reach_dist_in_bounds_idx]  # 1d tensor
        T_given_L_subset = torch.from_numpy(self.T_given_Ls_df.set_index('pass_dist').loc[reach_dist_in_bounds, 'p'].to_numpy()).float()\
            .reshape(-1, len(self.T))  # (BF~, T) ; BF~ is subset of B*F that is in [1, 60] yds from ball
        T_given_L = torch.zeros(B * len(self.field_locs),
                                len(self.T))  # (B, F, T)
        # fill in the subset of values computed above
        T_given_L[reach_dist_in_bounds_idx.flatten()] = T_given_L_subset
        T_given_L = T_given_L.reshape(B, len(self.field_locs), -1)  # (B, F, T)

        L_T_given_t = L_given_t[..., None] * T_given_L  # (B, F, T)
        L_T_given_t /= L_T_given_t.sum(
            (1, 2), keepdim=True
        )  # normalize all passes after some have been chopped off
        return L_T_given_t  # (B, F, T)

    def get_ppc_off(self, frame, p_int):
        assert self.use_ppc, 'Call made to get_ppc_off while use_ppc setting is False'
        B = frame.shape[0]
        J = p_int.shape[-1]
        ball_start = frame[:, 0, 8:10]  # (B, 2)
        player_teams = frame[:, :, 7]  # (B, J)
        reach_vecs = self.field_locs.unsqueeze(0) - ball_start.unsqueeze(
            1)  # B, F, 2
        # trajectory integration
        dx = reach_vecs[:, :, 0]  #B, F
        dy = reach_vecs[:, :, 1]  #B, F
        vx = dx[:, :, None] / self.T[None, None, :]  #F, T
        vy = dy[:, :, None] / self.T[None, None, :]  #F, T
        vz_0 = (self.T * self.g) / 2  #T

        # note that idx (i, j, k) into below arrays is invalid when j < k
        traj_ts = self.T.repeat(len(self.field_locs), len(self.T),
                                1)  #(F, T, T)
        traj_locs_x_idx = torch.round(
            torch.clip((ball_start[:, 0, None, None, None] +
                        vx.unsqueeze(-1) * self.T), 0,
                       len(self.x) - 1)).int()  # B, F, T, T
        traj_locs_y_idx = torch.round(
            torch.clip((ball_start[:, 1, None, None, None] +
                        vy.unsqueeze(-1) * self.T), 0,
                       len(self.y) - 1)).int()  # B, F, T, T
        traj_locs_z = 2.0 + vz_0.view(
            1, -1, 1) * traj_ts - 0.5 * self.g * traj_ts * traj_ts  #F, T, T
        lambda_z = torch.where(
            (traj_locs_z < self.z_max) & (traj_locs_z > self.z_min), 1,
            0)  #F, T, T
        path_idxs = (traj_locs_y_idx * self.x.shape[0] +
                     traj_locs_x_idx).long().reshape(B, -1)  # (B, F*T*T)
        # 10*traj_ts - 1 converts the times into indices - hacky
        traj_t_idxs = (10 * traj_ts - 1).long().repeat(B, 1, 1, 1).reshape(
            B, -1)  # (B, F*T*T)
        p_int_traj = torch.stack([p_int[bb, path_idxs[bb], traj_t_idxs[bb], :] for bb in range(B)])\
                        .reshape(*traj_locs_x_idx.shape, -1) * lambda_z.unsqueeze(-1)  # B, F, T, T, J
        p_int_traj_sum = p_int_traj.sum(dim=-1, keepdim=True)  # B, F, T, T, J
        norm_factor = torch.maximum(torch.ones_like(p_int_traj_sum),
                                    p_int_traj_sum)  # B, F, T, T
        p_int_traj_norm = p_int_traj / norm_factor  # B, F, T, T, J

        # independent int probs at each point on trajectory
        all_p_int_traj = torch.sum(p_int_traj_norm, dim=-1)  # B, F, T, T
        # off_p_int_traj = torch.sum((player_teams == 1)[:,None,None,None] * p_int_traj_norm, dim=-1)  # B, F, T, T
        # def_p_int_traj = torch.sum((player_teams == 0)[:,None,None,None] * p_int_traj_norm, dim=-1)  # B, F, T, T
        ind_p_int_traj = p_int_traj_norm  #use for analyzing specific players; # B, F, T, T, J

        # calc decaying residual probs after you take away p_int on earlier times in the traj
        compl_all_p_int_traj = 1 - all_p_int_traj  # B, F, T, T
        remaining_compl_p_int_traj = torch.cumprod(compl_all_p_int_traj,
                                                   dim=-1)  # B, F, T, T
        # maximum 0 because if it goes negative the pass has been caught by then and theres no residual probability
        shift_compl_cumsum = torch.roll(remaining_compl_p_int_traj, 1,
                                        dims=-1)  # B, F, T, T
        shift_compl_cumsum[:, :, :, 0] = 1

        # multiply residual prob by p_int at that location and lambda
        lambda_all = self.tti_lambda_off * player_teams + self.tti_lambda_def * (
            1 - player_teams)  # B, J
        # off_completion_prob_dt = shift_compl_cumsum * off_p_int_traj  # B, F, T, T
        # def_completion_prob_dt = shift_compl_cumsum * def_p_int_traj  # B, F, T, T
        # all_completion_prob_dt = off_completion_prob_dt + def_completion_prob_dt  # B, F, T, T
        ind_completion_prob_dt = shift_compl_cumsum.unsqueeze(
            -1) * ind_p_int_traj  # F, T, T, J

        # now accumulate values over total traj for each team and take at T=t
        # all_completion_prob = torch.cumsum(all_completion_prob_dt, dim=-1)  # B, F, T, T
        # off_completion_prob = torch.cumsum(off_completion_prob_dt, dim=-1)  # B, F, T, T
        # def_completion_prob = torch.cumsum(def_completion_prob_dt, dim=-1)  # B, F, T, T
        ind_completion_prob = torch.cumsum(ind_completion_prob_dt,
                                           dim=-2)  # B, F, T, T, J

        # this einsum takes the diagonal values over the last two axes where T = t
        # this takes care of the t > T issue.
        # ppc_all = torch.einsum('...ii->...i', all_completion_prob)  # B, F, T
        # ppc_off = torch.einsum('...ii->...i', off_completion_prob)  # B, F, T
        # ppc_def = torch.einsum('...ii->...i', def_completion_prob)  # B, F, T
        ppc_ind = torch.einsum('...iij->...ij',
                               ind_completion_prob)  # B, F, T, J
        ppc_ind *= lambda_all[:, None, None, :]
        # no_p_int_pass = 1-ppc_all  # B, F, T

        ppc_off = torch.sum(ppc_ind * player_teams[:, None, None, :],
                            dim=-1)  # B, F, T
        ppc_def = torch.sum(ppc_ind * (1 - player_teams)[:, None, None, :],
                            dim=-1)  # B, F, T

        # assert torch.allclose(all_p_int_pass, off_p_int_pass + def_p_int_pass, atol=0.01)
        # assert torch.allclose(all_p_int_pass, ind_p_int_pass.sum(-1), atol=0.01)
        # return off_p_int_pass, def_p_int_pass, ind_p_int_pass
        return ppc_off, ppc_def, ppc_ind

    def forward(self, frame):
        v_x_r = frame[:, :, 5] * self.reax_t + frame[:, :, 3]
        v_y_r = frame[:, :, 6] * self.reax_t + frame[:, :, 4]
        v_r_mag = torch.norm(torch.stack([v_x_r, v_y_r], dim=-1), dim=-1)
        v_r_theta = torch.atan2(v_y_r, v_x_r)

        x_r = frame[:, :,
                    1] + frame[:, :,
                               3] * self.reax_t + 0.5 * frame[:, :,
                                                              5] * self.reax_t**2
        y_r = frame[:, :,
                    2] + frame[:, :,
                               4] * self.reax_t + 0.5 * frame[:, :,
                                                              6] * self.reax_t**2

        # get each player's team, location, and velocity
        player_teams = frame[:, :, 7]  # B, J
        reaction_player_locs = torch.stack([x_r, y_r], dim=-1)  # (J, 2)
        reaction_player_vels = torch.stack([v_x_r, v_y_r], dim=-1)  #(J, 2)

        # calculate each player's distance from each field location
        int_d_vec = self.field_locs.unsqueeze(1).unsqueeze(
            0) - reaction_player_locs.unsqueeze(1)  #F, J, 2
        int_d_mag = torch.norm(int_d_vec, dim=-1)  # F, J
        int_d_theta = torch.atan2(int_d_vec[..., 1], int_d_vec[..., 0])  # F, J

        # take dot product of velocity and direction
        int_s0 = torch.clamp(
            torch.sum(int_d_vec * reaction_player_vels.unsqueeze(1), dim=-1) /
            int_d_mag, -1 * self.s_max.item(), self.s_max.item())  #F, J

        # calculate time it takes for each player to reach each field position accounting for their current velocity and acceleration
        t_lt_smax = (self.s_max - int_s0) / self.a_max  #F, J,
        d_lt_smax = t_lt_smax * ((int_s0 + self.s_max) / 2)  #F, J,

        # if accelerating would overshoot, then t = -v0/a + sqrt(v0^2/a^2 + 2x/a) (from kinematics)
        t_lt_smax = torch.where(d_lt_smax > int_d_mag, -int_s0 / self.a_max + \
                torch.sqrt((int_s0 / self.a_max) ** 2 + 2 * int_d_mag / self.a_max), t_lt_smax) # F, J
        d_lt_smax = torch.max(torch.min(d_lt_smax, int_d_mag),
                              torch.zeros_like(d_lt_smax))  # F, J

        d_at_smax = int_d_mag - d_lt_smax  #F, J,
        t_at_smax = d_at_smax / self.s_max  #F, J,
        t_tot = self.reax_t + t_lt_smax + t_at_smax  # F, J,

        # get true pass (tof and ball_end) to tune on (subtract 1 from tof, add 1 to y for correct indexing)
        tof = torch.round(frame[:, 0, -1]).long().view(-1, 1, 1, 1).repeat(
            1, t_tot.size(1), 1, t_tot.size(-1)) - 1

        # ball ind
        ball_end_x = frame[:, 0, -3].int()
        ball_end_y = frame[:, 0, -2].int() + 1
        ball_field_ind = (ball_end_y * self.x.shape[0] +
                          ball_end_x).long().view(-1, 1, 1).repeat(
                              1, 1, t_tot.size(-1))

        if self.tuning == TuningParam.av:
            # collapse extra dims
            tof = self.T[tof[:, 0, 0, 0]].float()

            # select field in for all the position and velocity values calculated previously
            t_lt_smax = torch.gather(t_lt_smax, 1,
                                     ball_field_ind).squeeze()  # J,
            d_lt_smax = torch.gather(d_lt_smax, 1,
                                     ball_field_ind).squeeze()  # J,
            d_at_smax = torch.gather(d_at_smax, 1, ball_field_ind).squeeze()
            t_at_smax = torch.gather(t_at_smax, 1, ball_field_ind).squeeze()
            t_tot = torch.gather(t_tot, 1, ball_field_ind).squeeze()
            int_s0 = torch.gather(int_s0, 1, ball_field_ind).squeeze()

            int_d_theta = torch.gather(int_d_theta, 1,
                                       ball_field_ind).squeeze()
            int_d_mag = torch.gather(int_d_mag, 1, ball_field_ind).squeeze()

            # projected locations at t = tof, f = ball_field_ind
            d_proj = torch.where(tof.unsqueeze(-1) <= self.reax_t, self.zero_cuda,
                    torch.where(tof.unsqueeze(-1) <= (t_lt_smax + self.reax_t),
                    (int_s0 * (tof.unsqueeze(-1) - self.reax_t)) + 0.5 * self.a_max \
                            * (tof.unsqueeze(-1) - self.reax_t) ** 2,
                    torch.where(tof.unsqueeze(-1) <= (t_lt_smax + t_at_smax + self.reax_t),
                    (d_lt_smax + (d_at_smax * (tof.unsqueeze(-1) - t_lt_smax - self.reax_t))),
                    int_d_mag))) # J,

            d_proj = torch.minimum(d_proj, int_d_mag)

            x_proj = reaction_player_locs[..., 0] + d_proj * torch.cos(
                int_d_theta)  # J
            y_proj = reaction_player_locs[..., 1] + d_proj * torch.sin(
                int_d_theta)  # J

            # mask x_proj and y_proj (only want loss on closest off and def players)
            player_mask = frame[:, :, -4]
            masked_x = player_mask * x_proj
            masked_y = player_mask * y_proj

            return torch.stack([masked_x, masked_y], dim=-1)  # J, 2

        # subtract the arrival time (t_tot) from time of flight of ball
        int_dT = self.T.view(1, 1, -1, 1) - t_tot.unsqueeze(2)  #F, T, J

        # calculate interception probability for each player, field loc, time of flight (logistic function)
        p_int = torch.sigmoid(
            (3.14 / (1.732 * self.tti_sigma)) * int_dT)  # (B, F, T, J)

        if self.tuning == TuningParam.sigma:
            p_int = torch.gather(p_int, 2, tof).squeeze()
            p_int = torch.gather(p_int, 1, ball_field_ind).squeeze()
            return p_int

        elif self.tuning == TuningParam.alpha:
            h_trans_prob = self.get_hist_trans_prob(frame)  # (B, F, T)
            if self.use_ppc:
                ppc_off, *_ = self.get_ppc_off(frame, p_int)
                trans_prob = h_trans_prob * torch.pow(
                    ppc_off, self.ppc_alpha)  # (B, F, T)
            else:
                # p_int summed over all offensive players
                p_int_off = torch.sum(p_int * (player_teams == 1),
                                      dim=-1)  # (B, F, T)
                trans_prob = h_trans_prob * torch.pow(p_int_off,
                                                      self.ppc_alpha)  # (B,)
            trans_prob /= trans_prob.sum(dim=(1, 2), keepdim=True)  # (B, F, T)
            # index into true pass. [...,0] necessary on indices because no J dimension
            trans_prob_throw = torch.gather(trans_prob, 2, tof[...,
                                                               0]).squeeze()
            trans_prob_throw = torch.gather(
                trans_prob_throw, 1, ball_field_ind[..., 0]).squeeze()  # (B,)
            return trans_prob_throw

        elif self.tuning == TuningParam.lamb:
            assert self.use_ppc, 'need to use ppc to tune lambda'
            *_, ppc_ind = self.get_ppc_off(frame,
                                           p_int)  # ppc_ind: (B, F, T, J)
            ppc_ind_throw = torch.gather(ppc_ind, 2, tof).squeeze()  # B, F, J
            ppc_ind_throw = torch.gather(ppc_ind_throw, 1,
                                         ball_field_ind).squeeze()  # B, J
            return ppc_ind_throw
Esempio n. 20
0
class MyTransformer(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.backbone = MyResnet(args)
        pthfile = args.pretrained_dir
        state_dict = torch.load(pthfile)
        state_dict.pop('fc.weight')
        state_dict.pop('fc.bias')
        self.backbone.load_state_dict(state_dict, False)

        self.transformer = build_transformer(args)
        self.class_embed = nn.Linear(args.hidden_dim, args.class_number)
        # self.position_embedding = build_position_encoding(args.d_model, args.src_len, args.pos_emb_type)
        self.position_embedding = BartLearnedPositionalEmbedding(args.max_position_embedding, args.d_model)
        self.Embedding = nn.Embedding(args.class_number, args.d_model)
        self.d_model = args.d_model
        self.src_len = args.src_len
        self.hidden_dim = args.hidden_dim
        self.start = Parameter(torch.rand(self.hidden_dim), requires_grad=False)

    def pack(self, embedding, target, target_embedding):
        bs, spb, hd = embedding.size()
        start = self.start.repeat(bs, 1, 1).to(target.device)  # bs x 1 x d_model
        target_key_embedding = torch.cat((start, target_embedding), dim=1)  # bs x src_len x d_model
        # pos_embed = self.position_embedding(embedding)
        # target_pos_embed = self.position_embedding(target_key_embedding)
        pos_embed = self.position_embedding(embedding.size()[:2]).unsqueeze(0).repeat(bs, 1, 1)
        target_pos_embed = self.position_embedding(target_key_embedding.size()[:2]).unsqueeze(0).repeat(bs, 1, 1)
        return embedding, target_embedding, target_key_embedding, pos_embed, target_pos_embed

    def permute(self, embedding, target_embedding, target_key_embedding, pos_embed, target_pos_embed):
        embedding = embedding.permute(1, 0, 2)
        target_embedding = target_embedding.permute(1, 0, 2)
        target_key_embedding = target_key_embedding.permute(1, 0, 2)
        pos_embed = pos_embed.permute(1, 0, 2)
        target_pos_embed = target_pos_embed.permute(1, 0, 2)
        return embedding, target_embedding, target_key_embedding, pos_embed, target_pos_embed

    def forward(self, x: Tensor, target, test_mode_target=None):
        """
        :param x: bs x src_len x C x H x W
        :param target: bs x src_len
        :param test_mode_target: batch_size x src_len
        :return:
        """
        embedding = self.backbone(x)  # bs x src_len x d_model
        bs, src_len, d_model = embedding.size()
        target = target[:, :-1]  # target 右移一位

        if isinstance(test_mode_target, torch.Tensor):
            target_embedding = torch.zeros((bs, src_len - 1, d_model)).to(embedding.device)
            for i in range(test_mode_target.size(0)):
                for j in range(test_mode_target.size(1)):
                    target_embedding[i, j, :] = get_onehot_label(test_mode_target[i, j], self.d_model)
        else:
            target_embedding = get_onehot_label(target, self.d_model)  # bs x src-1 x d_model
        embedding, target_embedding, target_key_embedding, pos_embed, target_pos_embed = self.pack(embedding, target,
                                                                                                   target_embedding)
        embedding, target_embedding, target_key_embedding, pos_embed, target_pos_embed = self.permute(embedding,
                                                                                                      target_embedding,
                                                                                                      target_key_embedding,
                                                                                                      pos_embed,
                                                                                                      target_pos_embed)
        spb, bs, hidden_dim = embedding.size()
        tgt_mask = torch.triu(torch.ones(spb, spb).to(embedding.device).to(torch.bool), 1)
        src_mask = torch.triu(torch.ones(spb, spb).to(embedding.device).to(torch.bool), 1)
        memory_mask = torch.triu(torch.ones(spb, spb).to(embedding.device).to(torch.bool), 1)

        output = self.transformer(encoder_src=embedding,
                                  decoder_src=target_key_embedding,
                                  decoder_key=target_key_embedding,
                                  pos_embed=pos_embed,
                                  target_pos_embed=target_pos_embed,
                                  target_key_pos_embed=target_pos_embed,
                                  tgt_mask=tgt_mask,
                                  src_mask=src_mask,
                                  memory_mask=memory_mask)  # decoder_number, *,*,*
        # 输出为decoder_number, bs, src_len, d_model, 取最后一个decoder输出
        output = output[-1]
        # output = output.permute(1, 0, 2)
        outputs_class = self.class_embed(output)
        return outputs_class
Esempio n. 21
0
class Image_RobustFill(nn.Module):
    def __init__(self,
                 target_vocabulary,
                 hidden_size=512,
                 embedding_size=128,
                 cell_type="LSTM",
                 input_size=(3, 256, 256)):
        """
        :param: input_vocabularies: List containing a vocabulary list for each input. E.g. if learning a function f:A->B from (a,b) pairs, input_vocabularies has length 2
        :param: target_vocabulary: Vocabulary list for output
        """
        super(Image_RobustFill, self).__init__()
        self.n_encoders = 1

        self.hidden_size = hidden_size
        self.embedding_size = embedding_size
        self.input_vocabularies = [None]  #input_vocabularies
        self.target_vocabulary = target_vocabulary
        self._refreshVocabularyIndex()
        self.v_inputs = None  #[len(x) for x in input_vocabularies] # Number of tokens in input vocabularies
        self.v_target = len(
            target_vocabulary)  # Number of tokens in target vocabulary

        self.no_inputs = len(self.input_vocabularies) == 0

        self.cell_type = cell_type
        if cell_type == 'GRU':
            self.encoder_init_h = Parameter(torch.rand(1, self.hidden_size))
            # self.encoder_cells = nn.ModuleList(
            #     [nn.GRUCell(input_size=self.v_inputs[0]+1, hidden_size=self.hidden_size, bias=True)] +
            #     [nn.GRUCell(input_size=self.v_inputs[i]+1+self.hidden_size, hidden_size=self.hidden_size, bias=True) for i in range(1, self.n_encoders)]
            # )
            self.decoder_cell = nn.GRUCell(input_size=self.v_target + 1,
                                           hidden_size=self.hidden_size,
                                           bias=True)
        if cell_type == 'LSTM':
            self.encoder_init_h = Parameter(
                torch.rand(1, self.hidden_size
                           ))  #Also used for decoder if self.no_inputs=True
            # self.encoder_init_cs = nn.ParameterList(
            #     [Parameter(torch.rand(1, self.hidden_size)) for i in range(len(self.v_inputs))]
            # )

            # self.encoder_cells = nn.ModuleList()
            # for i in range(self.n_encoders):
            #     input_size = self.v_inputs[i] + 1 + (self.hidden_size if i>0 else 0)
            #     self.encoder_cells.append(nn.LSTMCell(input_size=input_size, hidden_size=self.hidden_size, bias=True))
            self.decoder_cell = nn.LSTMCell(input_size=self.v_target + 1,
                                            hidden_size=self.hidden_size,
                                            bias=True)
            self.decoder_init_c = Parameter(torch.rand(1, self.hidden_size))

        self.W = nn.Linear(
            self.hidden_size if self.no_inputs else 2 * self.hidden_size,
            self.embedding_size)
        self.V = nn.Linear(self.embedding_size, self.v_target + 1)

        #self.As = nn.ModuleList([nn.Bilinear(self.hidden_size, self.hidden_size, 1, bias=False) for i in range(self.n_encoders)])

        #image encoder:
        self.conv1 = nn.Conv2d(3,
                               8,
                               kernel_size=(3, 3),
                               padding=(3, 3),
                               stride=(1, 1))
        self.conv2 = nn.Conv2d(8,
                               16,
                               kernel_size=(3, 3),
                               padding=(1, 1),
                               stride=(1, 1))
        self.conv3 = nn.Conv2d(16,
                               16,
                               kernel_size=(3, 3),
                               padding=(1, 1),
                               stride=(1, 1))
        self.conv4 = nn.Conv2d(16,
                               16,
                               kernel_size=(3, 3),
                               padding=(1, 1),
                               stride=(1, 1))
        #self.conv4 = nn.Conv2d(256, 512, kernel_size=(3, 3),
        #                        padding=(1, 1), stride=(1, 1))
        self.batch_norm1 = nn.BatchNorm2d(8)
        self.batch_norm2 = nn.BatchNorm2d(16)

        self.img_feat_to_embedding = nn.Sequential(
            nn.Linear(16 * 16 * 16, 64), nn.ReLU(), nn.Linear(64, 64),
            nn.ReLU(), nn.Linear(64, self.hidden_size))

        #attention params:
        self.h_to_32_linear = nn.Linear(self.hidden_size, 32)
        self.img_to_32 = nn.Linear(16 * 16 * 16, 32)

        self.fc_loc = nn.Linear(32 + 32, 3 * 2)
        self.fc_loc.weight.data.zero_()
        self.fc_loc.bias.data.copy_(
            torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
        self.img_feat_to_context = nn.Sequential(
            nn.Linear(16 * 16 * 16, 128), nn.ReLU(), nn.Linear(128, 128),
            nn.ReLU(), nn.Linear(128, self.hidden_size))

    def with_target_vocabulary(self, target_vocabulary):
        """
        Returns a new network which modifies this one by changing the target vocabulary
        """
        if target_vocabulary == self.target_vocabulary:
            return self

        V_weight = []
        V_bias = []
        decoder_ih = []

        for i in range(len(target_vocabulary)):
            if target_vocabulary[i] in self.target_vocabulary:
                j = self.target_vocabulary.index(target_vocabulary[i])
                V_weight.append(self.V.weight.data[j:j + 1])
                V_bias.append(self.V.bias.data[j:j + 1])
                decoder_ih.append(self.decoder_cell.weight_ih.data[:, j:j + 1])
            else:
                V_weight.append(self._zeros(1, self.V.weight.size(1)))
                V_bias.append(self._ones(1) * -10)
                decoder_ih.append(
                    self._zeros(self.decoder_cell.weight_ih.data.size(0), 1))

        V_weight.append(self.V.weight.data[-1:])
        V_bias.append(self.V.bias.data[-1:])
        decoder_ih.append(self.decoder_cell.weight_ih.data[:, -1:])

        self.target_vocabulary = target_vocabulary
        self.v_target = len(target_vocabulary)

        self.V.weight.data = torch.cat(V_weight, dim=0)
        self.V.bias.data = torch.cat(V_bias, dim=0)
        self.V.out_features = self.V.bias.data.size(0)

        self.decoder_cell.weight_ih.data = torch.cat(decoder_ih, dim=1)
        self.decoder_cell.input_size = self.decoder_cell.weight_ih.data.size(1)

        self._clear_optimiser()
        self._refreshVocabularyIndex()
        return copy.deepcopy(self)

    def optimiser_step(self, batch_inputs, batch_target):
        """
        Perform a single step of SGD
        """
        if not hasattr(self, 'opt'): self._get_optimiser()
        self.opt.zero_grad()
        score = self.score(batch_inputs, batch_target, autograd=True).mean()
        (-score).backward()
        self.opt.step()

        return score.data.item()

    def score(self, batch_inputs, batch_target, autograd=False):
        #inputs = self._inputsToTensors(batch_inputs)
        inputs = torch.stack(tuple(torch.tensor(b) for b in batch_inputs),
                             dim=0).float()
        if next(self.parameters()).is_cuda:
            inputs = inputs.cuda()
        inputs = [[inputs]]
        #print("INPUTS SHAPE", inputs[0][0].shape)
        target = self._targetToTensor(batch_target)
        _, score = self._run(inputs, target=target, mode="score")
        if autograd:
            return score
        else:
            return score.data

    def sample(self, batch_inputs=None, n_samples=None):
        assert batch_inputs is not None or n_samples is not None
        #inputs = self._inputsToTensors(batch_inputs)
        inputs = torch.stack(tuple(torch.tensor(b) for b in batch_inputs),
                             dim=0).float()
        if next(self.parameters()).is_cuda:
            inputs = inputs.cuda()
        inputs = [[inputs]]

        target, score = self._run(inputs, mode="sample", n_samples=n_samples)
        target = self._tensorToOutput(target)
        return target

    def sampleAndScore(self, batch_inputs=None, n_samples=None, nRepeats=None):
        assert batch_inputs is not None or n_samples is not None
        #inputs = self._inputsToTensors(batch_inputs)
        inputs = [[batch_inputs]]
        if nRepeats is None:
            target, score = self._run(inputs,
                                      mode="sample",
                                      n_samples=n_samples)
            target = self._tensorToOutput(target)
            return target, score.data
        else:
            target = []
            score = []
            for i in range(nRepeats):
                t, s = self._run(inputs, mode="sample", n_samples=n_samples)
                t = self._tensorToOutput(t)
                target.extend(t)
                score.extend(list(s.data))
            return target, score

    def _refreshVocabularyIndex(self):
        # self.input_vocabularies_index = [
        #     {self.input_vocabularies[i][j]: j for j in range(len(self.input_vocabularies[i]))}
        #     for i in range(len(self.input_vocabularies))
        # ]
        self.target_vocabulary_index = {
            self.target_vocabulary[j]: j
            for j in range(len(self.target_vocabulary))
        }

    def __getstate__(self):
        if hasattr(self, 'opt'):
            return dict([(k, v)
                         for k, v in self.__dict__.items() if k is not 'opt'] +
                        [('optstate', self.opt.state_dict())])
        else:
            return self.__dict__

    def __setstate__(self, state):
        self.__dict__.update(state)
        if hasattr(self, 'optstate'): self._fix_optstate()

    def _ones(self, *args, **kwargs):
        if next(self.parameters()).is_cuda:
            return torch.ones(*args, **kwargs).cuda()
        else:
            return torch.ones(*args, **kwargs)

    def _zeros(self, *args, **kwargs):
        if next(self.parameters()).is_cuda:
            return torch.zeros(*args, **kwargs).cuda()
        else:
            return torch.zeros(*args, **kwargs)

    def _clear_optimiser(self):
        if hasattr(self, 'opt'): del self.opt
        if hasattr(self, 'optstate'): del self.optstate

    def _get_optimiser(self):
        self.opt = torch.optim.Adam(self.parameters(), lr=0.001)
        if hasattr(self, 'optstate'): self.opt.load_state_dict(self.optstate)

    def _fix_optstate(
        self
    ):  #make sure that we don't have optstate on as tensor but params as cuda tensor, or vice versa
        is_cuda = next(self.parameters()).is_cuda
        for state in self.optstate['state'].values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda() if is_cuda else v.cpu()

    def cuda(self, *args, **kwargs):
        if hasattr(self, 'opt'): del self.opt
        if hasattr(self, 'optstate'): self._fix_optstate()
        super(Image_RobustFill, self).cuda(*args, **kwargs)

    def cpu(self, *args, **kwargs):
        if hasattr(self, 'opt'): del self.opt
        if hasattr(self, 'optstate'): self._fix_optstate()
        super(Image_RobustFill, self).cpu(*args, **kwargs)

    def _encoder_get_init(self, encoder_idx, h=None, batch_size=None):
        if h is None: h = self.encoder_init_h.repeat(batch_size, 1)
        if self.cell_type == "GRU": return h
        if self.cell_type == "LSTM":
            return (h, self.encoder_init_cs[encoder_idx].repeat(batch_size, 1))

    def _decoder_get_init(self, h=None, batch_size=None):
        if h is None:
            assert self.no_inputs
            h = self.encoder_init_h.repeat(batch_size, 1)

        if self.cell_type == "GRU": return h
        if self.cell_type == "LSTM":
            return (h, self.decoder_init_c.repeat(h.size(0), 1))

    def _cell_get_h(self, cell_state):
        if self.cell_type == "GRU": return cell_state
        if self.cell_type == "LSTM": return cell_state[0]

    def _run(self, inputs, target=None, mode="sample", n_samples=None):
        """
        :param mode: "score" or "sample"
        :param list[list[LongTensor]] inputs: n_encoders * n_examples * (max length * batch_size) - change last part to batch_size * 1 x 28 x 28 or whatever it is
        :param list[LongTensor] target: max length * batch_size
        Returns output and score
        """
        assert ((mode == "score" and target is not None) or mode == "sample")

        input_width = inputs[0][0].shape[-1]

        if self.no_inputs:
            batch_size = target.size(1) if mode == "score" else n_samples
        else:
            batch_size = inputs[0][0].size(0)  # will reformulate this
            n_examples = len(inputs[0])

            #max_length_inputs = [[inputs[i][j].size(0) for j in range(n_examples)] for i in range(self.n_encoders)]
            max_length_inputs = [[3 * 3 for j in range(n_examples)]
                                 for i in range(self.n_encoders)]  #TODO

            # inputs_scatter = [
            #     [   Variable(self._zeros(max_length_inputs[i][j], batch_size, self.v_inputs[i]+1).scatter_(2, inputs[i][j][:, :, None], 1))
            #         for j in range(n_examples)
            #     ] for i in range(self.n_encoders)
            # ]  # n_encoders * n_examples * (max_length_input * batch_size * v_input+1)

        max_length_target = target.size(
            0) if target is not None else 50  #CHANGED
        score = Variable(self._zeros(batch_size))
        if target is not None:
            target_scatter = Variable(
                self._zeros(
                    max_length_target, batch_size, self.v_target + 1).scatter_(
                        2, target[:, :, None],
                        1))  # max_length_target * batch_size * v_target+1

        H = [
        ]  # n_encoders * n_examples * (max_length_input * batch_size * h_encoder_size)
        embeddings = []  # n_encoders * (h for example at INPUT_EOS)

        #attention_mask = [] # n_encoders * (0 until (and including) INPUT_EOS, then -inf)

        # def attend(i, j, h):
        #     """
        #     'general' attention from https://arxiv.org/pdf/1508.04025.pdf
        #     :param i: which encoder is doing the attending (or self.n_encoders for the decoder)
        #     :param j: Index of example
        #     :param h: batch_size * hidden_size
        #     """
        #     assert(i != 0)
        #     scores = self.As[i-1](
        #         H[i-1][j].view(max_length_inputs[i-1][j] * batch_size, self.hidden_size),
        #         h.view(batch_size, self.hidden_size).repeat(max_length_inputs[i-1][j], 1)
        #     ).view(max_length_inputs[i-1][j], batch_size)
        #     c = (F.softmax(scores[:, :, None], dim=0) * H[i-1][j]).sum(0)
        #     return c

        def attend(i, j, h):
            """
            spatial transformer attn.
            H[i-1][j] should be the image itself 
            """
            assert (i != 0)
            img = H[i - 1][j]
            linear_img = img.view(-1, img.size(1) * img.size(2) * img.size(3))
            theta = torch.cat((F.relu(
                self.h_to_32_linear(h)), F.relu(self.img_to_32(linear_img))),
                              1)  #right
            theta = self.fc_loc(theta)

            #make affine transform with
            #sample affine grid with theta and img

            theta = theta.view(-1, 2, 3)
            grid = F.affine_grid(theta, img.size())
            transformed_img = F.grid_sample(img, grid)

            linear_transformed_img = transformed_img.view(
                -1,
                transformed_img.size(1) * transformed_img.size(2) *
                transformed_img.size(3))

            c = self.img_feat_to_context(
                linear_transformed_img
            )  # we will do b x 16 x 16 x 16 to 64 to 512
            return c

        # -------------- Image Encoders -------------
        #assume one input image:
        ii = 0

        for j in range(n_examples):
            assert j == 0
            _H = []
            _embeddings = []
            num_attention = 32

            x = inputs[ii][j]

            out = F.relu(self.batch_norm1(self.conv1(x)))
            out = F.max_pool2d(out, 2)  #b x 8 x 16 x 16

            out = F.relu(self.conv2(out))  #b x 16 x 16 x 16
            out = F.max_pool2d(out, 2)
            out = F.relu(self.conv3(out))
            if input_width == 256: out = F.max_pool2d(out, 2)
            out = F.relu(self.batch_norm2(self.conv4(out)))  #b x 16 x 16 x 16
            if input_width == 256: out = F.max_pool2d(out, 2)

            #out = F.max_pool2d(out, 2) #b x 128 x 7 x 7
            #out = F.max_pool2d(out,2) #b x 256 x 3 x 3 or 4 x 4
            #out = F.relu(self.batch_norm2(self.conv4(out))) #b x 512 x 3 x 3 or 4 x 4
            #todo: make embedding (b x 512) from b x 16 x 16 x 16

            #out = out.view(out.size(0), self.hidden_size, -1) #b x 512 x (16x16)

            #out = F.relu(self.fc1(out))
            #out = F.relu(self.fc2(out))
            #out = F.relu(self.fc3(out))

            #out = out.permute(2,0,1).contiguous()
            #out = torch.zeros(out.size()).cuda()

            _H.append(out)

            lin_out = out.view(-1, out.size(1) * out.size(2) * out.size(3))

            embedding = self.img_feat_to_embedding(
                lin_out)  #or something like that
            _embeddings.append(embedding)

        H.append(_H)
        embeddings.append(_embeddings)

        # -------------- Encoders -------------
        # for i in range(len(self.input_vocabularies)):
        #     H.append([])
        #     embeddings.append([])
        #     attention_mask.append([])

        #     for j in range(n_examples):
        #         active = self._ones(max_length_inputs[i][j], batch_size).byte()
        #         state = self._encoder_get_init(i, batch_size=batch_size, h=embeddings[i-1][j] if i>0 else None)
        #         hs = []
        #         h = self._cell_get_h(state)
        #         for k in range(max_length_inputs[i][j]):
        #             if i==0:
        #                 state = self.encoder_cells[i](inputs_scatter[i][j][k, :, :], state)
        #             else:
        #                 state = self.encoder_cells[i](torch.cat([inputs_scatter[i][j][k, :, :], attend(i, j, h)], 1), state)
        #             if k+1 < max_length_inputs[i][j]: active[k+1, :] = active[k, :] * (inputs[i][j][k, :] != self.v_inputs[i])
        #             h = self._cell_get_h(state)
        #             hs.append(h[None, :, :])
        #         H[i].append(torch.cat(hs, 0))
        #         embedding_idx = active.sum(0).long() - 1
        #         embedding = H[i][j].gather(0, Variable(embedding_idx[None, :, None].repeat(1, 1, self.hidden_size)))[0]
        #         embeddings[i].append(embedding)
        #         #embedding.size() == batchsize x hidden_size
        #         attention_mask[i].append(Variable(active.float().log()))

        # ------------------ Decoder -----------------
        # Multi-example pooling: Figure 3, https://arxiv.org/pdf/1703.07469.pdf
        target = target if mode == "score" else self._zeros(
            max_length_target, batch_size).long()
        if self.no_inputs:
            decoder_states = [self._decoder_get_init(batch_size=batch_size)]
        else:
            decoder_states = [
                self._decoder_get_init(embeddings[self.n_encoders - 1][j])
                for j in range(n_examples)
            ]  #P
        active = self._ones(batch_size).byte()
        for k in range(max_length_target):
            FC = []
            for j in range(1 if self.no_inputs else n_examples):
                h = self._cell_get_h(decoder_states[j])
                p_aug = h if self.no_inputs else torch.cat(
                    [h, attend(self.n_encoders, j, h)], 1)
                FC.append(F.tanh(self.W(p_aug)[None, :, :]))
            m = torch.max(torch.cat(FC, 0),
                          0)[0]  # batch_size * embedding_size
            logsoftmax = F.log_softmax(self.V(m), dim=1)
            if mode == "sample":
                target[k, :] = torch.multinomial(logsoftmax.data.exp(), 1)[:,
                                                                           0]
            score = score + choose(logsoftmax, target[k, :]) * Variable(
                active.float())
            active *= (target[k, :] != self.v_target)
            for j in range(1 if self.no_inputs else n_examples):
                if mode == "score":
                    target_char_scatter = target_scatter[k, :, :]
                elif mode == "sample":
                    target_char_scatter = Variable(
                        self._zeros(batch_size, self.v_target + 1).scatter_(
                            1, target[k, :, None], 1))
                decoder_states[j] = self.decoder_cell(target_char_scatter,
                                                      decoder_states[j])
        return target, score

    def _inputsToTensors(self, inputsss):
        """
        :param inputs: size = nBatch * nExamples * nEncoders (or nBatch*nExamples is n_encoders=1)
        Returns nEncoders * nExamples tensors of size nBatch * max_len
        """
        #print("WARNING: you have hit a depricated function, _inputsToTensors")

        # if self.n_encoders == 0: return []
        # tensors = []
        # for i in range(self.n_encoders):
        #     tensors.append([])
        #     for j in range(len(inputsss[0])):
        #         if self.n_encoders == 1: inputs = [x[j] for x in inputsss]
        #         else: inputs = [x[j][i] for x in inputsss]

        #         maxlen = max(len(s) for s in inputs)
        #         t = self._ones(maxlen+1, len(inputs)).long()*self.v_inputs[i]
        #         for k in range(len(inputs)):
        #             s = inputs[k]
        #             if len(s)>0: t[:len(s), k] = torch.LongTensor([self.input_vocabularies_index[i][x] for x in s])
        #         tensors[i].append(t)

        #assert inputsss.shape == bx500
        return [inputsss]

    def _targetToTensor(self, targets):
        """
        :param targets: 
        """
        maxlen = max(len(s) for s in targets)
        t = self._ones(maxlen + 1, len(targets)).long() * self.v_target
        for i in range(len(targets)):
            s = targets[i]
            if len(s) > 0:
                t[:len(s), i] = torch.LongTensor(
                    [self.target_vocabulary_index[x] for x in s])
        return t

    def _tensorToOutput(self, tensor):
        """
        :param tensor: max_length * batch_size
        """
        out = []
        for i in range(tensor.size(1)):
            l = tensor[:, i].tolist()
            if l[0] == self.v_target:
                out.append(tuple())
            elif self.v_target in l:
                final = tensor[:, i].tolist().index(self.v_target)
                out.append(
                    tuple(self.target_vocabulary[x]
                          for x in tensor[:final, i]))
            else:
                out.append(
                    tuple(self.target_vocabulary[x] for x in tensor[:, i]))
        return out
class LinearGroupNJ(BayesianLayers):
    """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):

        super(LinearGroupNJ, self).__init__()
        self.cuda = cuda
        self.in_features = in_features
        self.out_features = out_features
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference
        # trainable params according to Eq.(6)
        # dropout params
        self.z_mu = Parameter(torch.Tensor(in_features))
        self.z_logvar = Parameter(torch.Tensor(in_features))  # = z_mu^2 * alpha
        # weight params
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))

        self.bias_mu = Parameter(torch.Tensor(out_features))
        self.bias_logvar = Parameter(torch.Tensor(out_features))

        # init params either random or with pretrained net
        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        stdv = 1. / math.sqrt(self.weight_mu.size(1))

        self.z_mu.data.normal_(1, 1e-2)

        if init_weight is not None:
            self.weight_mu.data = torch.Tensor(init_weight)
        else:
            self.weight_mu.data.normal_(0, stdv)

        if init_bias is not None:
            self.bias_mu.data = torch.Tensor(init_bias)
        else:
            self.bias_mu.data.fill_(0)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        # print("self.z_mu.pow(2): ", self.z_mu.pow(2).size())
        # print("weight_var: ", weight_var.size())
        # print("z_var: ", z_var.size())
        # print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size())
        # print("weight_var: ", weight_var.size())
        # print("post_weight_mu: ", self.post_weight_mu.size())
        # print("post_weight_var: ", self.post_weight_var.size())
        return self.post_weight_mu, self.post_weight_var

    def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.linear(x, self.post_weight_mu, self.bias_mu)

        batch_size = x.size()[0]
        # compute z
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training,
                          cuda=self.cuda)

        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        xz = x * z
        mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
        var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())

        return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda)

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Esempio n. 23
0
class Deasfn(nn.Module):
    def __init__(self):
        super(Deasfn, self).__init__()
        self.in_channels = 30

        self.FreqEncoder = nn.ModuleList()
        for i in range(1, 6):
            self.FreqEncoder.append(
                nn.Sequential(
                    nn.ConstantPad3d((0, 0, 0, 0, 5 * i - 1, 0), 0),
                    nn.Conv3d(30, 64, kernel_size=(5 * i, 1, 1), padding=0),
                    nn.BatchNorm3d(64), nn.ReLU(inplace=True),
                    nn.Conv3d(64, 128, kernel_size=(25, 1, 1)),
                    nn.BatchNorm3d(128), nn.ReLU(inplace=True)))

        self.SpatialEncoder = nn.Sequential(
            self.make_layer(ResidualBlock3D, 128, 4), nn.AvgPool3d((25, 1, 1)))

        self.hidden = Parameter(
            torch.randn(1, 1, (len(self.FreqEncoder) + 1) * 128))
        self.rnn = torch.nn.GRU(input_size=(len(self.FreqEncoder) + 1) * 128,
                                hidden_size=(len(self.FreqEncoder) + 1) * 128)
        self.attention_block = AttentionLayer(
            (len(self.FreqEncoder) + 1) * 128)

        self.decoder = nn.Sequential(
            ConvBlock((len(self.FreqEncoder) + 1) * 128, 79, 3, 2),
            ConvBlock(79, 38, 3, 1),
            ConvBlock(38, 19, 3, 1),
        )
        self.DecoderJH = nn.Sequential(
            ConvBlock(19, 19, 3, 1), ConvBlock(19, 19, 3, 1),
            nn.Conv2d(19, 19, kernel_size=3, stride=1, padding=1, bias=False))
        self.DecoderPAF = nn.Sequential(
            ConvBlock(19, 38, 3, 1), ConvBlock(38, 38, 3, 1),
            nn.Conv2d(38, 38, kernel_size=3, stride=1, padding=1, bias=False))

    def weights_init(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear, nn.Conv3d)):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm3d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels !=
                             out_channels * block.expansion):
            downsample = nn.Sequential(
                conv3x3x3(self.in_channels,
                          out_channels * block.expansion,
                          stride=stride),
                nn.BatchNorm3d(out_channels * block.expansion))
        layers = []
        layers.append(block(self.in_channels, out_channels, stride,
                            downsample))
        self.in_channels = out_channels * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        b, t, c, f, h, w = x.size()
        x = x.permute(1, 0, 2, 3, 4, 5).reshape(-1, c, f, h, w)

        # Spatial Encoder
        SpatialFeat = self.SpatialEncoder(x).squeeze(2)

        # Frequency Encoder
        FreqFeat = []
        for i in range(len(self.FreqEncoder)):
            FreqFeat.append(self.FreqEncoder[i](x).squeeze(2))
        FreqFeat = torch.cat(FreqFeat, dim=1)

        # Concatenate all the feature
        DualFeat = torch.cat([SpatialFeat, FreqFeat], dim=1)

        # Evolving attention module
        AttentionMask = []
        for i in range(t):
            AttentionMask.append(
                self.attention_block(DualFeat[i * b:(i + 1) * b]).view(b, -1))
        attention = torch.stack(AttentionMask, dim=0)
        h0 = self.hidden.repeat(1, b, 1)
        attention = self.rnn(attention, h0)[0]
        attention = attention.view(
            b * t, -1).unsqueeze(-1).unsqueeze(-1).expand_as(DualFeat)
        EvolvingFeat = DualFeat * attention

        # Decoder
        EvolvingFeat = F.interpolate(EvolvingFeat,
                                     size=[92, 124],
                                     mode='bilinear',
                                     align_corners=False)
        EvolvingFeat = self.decoder(EvolvingFeat)
        JH = self.DecoderJH(EvolvingFeat)
        PAF = self.DecoderPAF(EvolvingFeat)

        return JH, PAF
Esempio n. 24
0
class RobustFill(nn.Module):
    def __init__(self,
                 input_vocabularies,
                 target_vocabulary,
                 hidden_size=512,
                 embedding_size=128,
                 cell_type="LSTM"):
        """
        :param: input_vocabularies: List containing a vocabulary list for each input. E.g. if learning a function f:A->B from (a,b) pairs, input_vocabularies has length 2
        :param: target_vocabulary: Vocabulary list for output
        """
        super(RobustFill, self).__init__()
        self.n_encoders = len(input_vocabularies)

        self.hidden_size = hidden_size
        self.embedding_size = embedding_size
        self.input_vocabularies = input_vocabularies
        self.target_vocabulary = target_vocabulary
        self._refreshVocabularyIndex()
        self.v_inputs = [len(x) for x in input_vocabularies
                         ]  # Number of tokens in input vocabularies
        self.v_target = len(
            target_vocabulary)  # Number of tokens in target vocabulary

        self.cell_type = cell_type
        if cell_type == 'GRU':
            self.encoder_init_h = Parameter(torch.rand(1, self.hidden_size))
            self.encoder_cells = nn.ModuleList([
                nn.GRUCell(input_size=self.v_inputs[0] + 1,
                           hidden_size=self.hidden_size,
                           bias=True)
            ] + [
                nn.GRUCell(input_size=self.v_inputs[i] + 1 + self.hidden_size,
                           hidden_size=self.hidden_size,
                           bias=True) for i in range(1, self.n_encoders)
            ])
            self.decoder_cell = nn.GRUCell(input_size=self.v_target + 1,
                                           hidden_size=self.hidden_size,
                                           bias=True)
        if cell_type == 'LSTM':
            self.encoder_init_h = Parameter(torch.rand(1, self.hidden_size))
            self.encoder_init_cs = nn.ParameterList([
                Parameter(torch.rand(1, self.hidden_size))
                for i in range(len(self.v_inputs))
            ])
            self.encoder_cells = nn.ModuleList([
                nn.LSTMCell(input_size=self.v_inputs[0] + 1,
                            hidden_size=self.hidden_size,
                            bias=True)
            ] + [
                nn.LSTMCell(input_size=self.v_inputs[i] + 1 + self.hidden_size,
                            hidden_size=self.hidden_size,
                            bias=True) for i in range(1, self.n_encoders)
            ])
            self.decoder_cell = nn.LSTMCell(input_size=self.v_target + 1,
                                            hidden_size=self.hidden_size,
                                            bias=True)
            self.decoder_init_c = Parameter(torch.rand(1, self.hidden_size))

        self.W = nn.Linear(self.hidden_size + self.hidden_size,
                           self.embedding_size)
        self.V = nn.Linear(self.embedding_size, self.v_target + 1)

        self.As = nn.ModuleList([
            nn.Bilinear(self.hidden_size, self.hidden_size, 1, bias=False)
            for i in range(self.n_encoders)
        ])

    def with_target_vocabulary(self, target_vocabulary):
        """
        Returns a new network which modifies this one by changing the target vocabulary
        """
        if target_vocabulary == self.target_vocabulary:
            return self

        V_weight = []
        V_bias = []
        decoder_ih = []

        for i in range(len(target_vocabulary)):
            if target_vocabulary[i] in self.target_vocabulary:
                j = self.target_vocabulary.index(target_vocabulary[i])
                V_weight.append(self.V.weight.data[j:j + 1])
                V_bias.append(self.V.bias.data[j:j + 1])
                decoder_ih.append(self.decoder_cell.weight_ih.data[:, j:j + 1])
            else:
                V_weight.append(self._zeros(1, self.V.weight.size(1)))
                V_bias.append(self._ones(1) * -10)
                decoder_ih.append(
                    self._zeros(self.decoder_cell.weight_ih.data.size(0), 1))

        V_weight.append(self.V.weight.data[-1:])
        V_bias.append(self.V.bias.data[-1:])
        decoder_ih.append(self.decoder_cell.weight_ih.data[:, -1:])

        self.target_vocabulary = target_vocabulary
        self.v_target = len(target_vocabulary)

        self.V.weight.data = torch.cat(V_weight, dim=0)
        self.V.bias.data = torch.cat(V_bias, dim=0)
        self.V.out_features = self.V.bias.data.size(0)

        self.decoder_cell.weight_ih.data = torch.cat(decoder_ih, dim=1)
        self.decoder_cell.input_size = self.decoder_cell.weight_ih.data.size(1)

        self._clear_optimiser()
        self._refreshVocabularyIndex()
        return copy.deepcopy(self)

    def optimiser_step(self, batch_inputs, batch_target):
        """
        Perform a single step of SGD
        """
        if not hasattr(self, 'opt'): self._get_optimiser()
        self.opt.zero_grad()
        score = self.score(batch_inputs, batch_target, autograd=True).mean()
        (-score).backward()
        self.opt.step()

        return score.data[0]

    def score(self, batch_inputs, batch_target, autograd=False):
        inputs = self._inputsToTensors(batch_inputs)
        target = self._targetToTensor(batch_target)
        _, score = self._run(inputs, target=target, mode="score")
        if autograd:
            return score
        else:
            return score.data

    def sample(self, batch_inputs):
        inputs = self._inputsToTensors(batch_inputs)
        target, score = self._run(inputs, mode="sample")
        target = self._tensorToOutput(target)
        return target

    def sampleAndScore(self, batch_inputs, nRepeats=None):
        inputs = self._inputsToTensors(batch_inputs)
        if nRepeats is None:
            target, score = self._run(inputs, mode="sample")
            target = self._tensorToOutput(target)
            return target, score.data
        else:
            target = []
            score = []
            for i in range(nRepeats):
                # print("repeat %d" % i)
                t, s = self._run(batch_inputs, mode="sample")
                t = self._tensorToOutput(t)
                target.extend(t)
                score.extend(list(s.data))
            return target, score

    def _refreshVocabularyIndex(self):
        self.input_vocabularies_index = [{
            self.input_vocabularies[i][j]: j
            for j in range(len(self.input_vocabularies[i]))
        } for i in range(len(self.input_vocabularies))]
        self.target_vocabulary_index = {
            self.target_vocabulary[j]: j
            for j in range(len(self.target_vocabulary))
        }

    def __getstate__(self):
        if hasattr(self, 'opt'):
            return dict([(k, v)
                         for k, v in self.__dict__.items() if k is not 'opt'] +
                        [('optstate', self.opt.state_dict())])
        else:
            return self.__dict__

    def __setstate__(self, state):
        self.__dict__.update(state)
        if hasattr(self, 'optstate'): self._fix_optstate()

    def _ones(self, *args, **kwargs):
        if next(self.parameters()).is_cuda:
            return torch.ones(*args, **kwargs).cuda()
        else:
            return torch.ones(*args, **kwargs)

    def _zeros(self, *args, **kwargs):
        if next(self.parameters()).is_cuda:
            return torch.zeros(*args, **kwargs).cuda()
        else:
            return torch.zeros(*args, **kwargs)

    def _clear_optimiser(self):
        if hasattr(self, 'opt'): del self.opt
        if hasattr(self, 'optstate'): del self.optstate

    def _get_optimiser(self):
        self.opt = torch.optim.Adam(self.parameters(), lr=0.001)
        if hasattr(self, 'optstate'): self.opt.load_state_dict(self.optstate)

    def _fix_optstate(
        self
    ):  #make sure that we don't have optstate on as tensor but params as cuda tensor, or vice versa
        is_cuda = next(self.parameters()).is_cuda
        for state in self.optstate['state'].values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda() if is_cuda else v.cpu()

    def cuda(self, *args, **kwargs):
        if hasattr(self, 'opt'): del self.opt
        if hasattr(self, 'optstate'): self._fix_optstate()
        super(RobustFill, self).cuda(*args, **kwargs)

    def cpu(self, *args, **kwargs):
        if hasattr(self, 'opt'): del self.opt
        if hasattr(self, 'optstate'): self._fix_optstate()
        super(RobustFill, self).cpu(*args, **kwargs)

    def _encoder_get_init(self, encoder_idx, h=None, batch_size=None):
        if h is None: h = self.encoder_init_h.repeat(batch_size, 1)
        if self.cell_type == "GRU": return h
        if self.cell_type == "LSTM":
            return (h, self.encoder_init_cs[encoder_idx].repeat(batch_size, 1))

    def _decoder_get_init(self, h):
        if self.cell_type == "GRU": return h
        if self.cell_type == "LSTM":
            return (h, self.decoder_init_c.repeat(h.size(0), 1))

    def _cell_get_h(self, cell_state):
        if self.cell_type == "GRU": return cell_state
        if self.cell_type == "LSTM": return cell_state[0]

    def _run(self, inputs, target=None, mode="sample"):
        """
        :param mode: "score" or "sample"
        :param list[list[LongTensor]] inputs: n_encoders * n_examples * (max length * batch_size)
        :param list[LongTensor] target: max length * batch_size
        Returns output and score
        """
        assert ((mode == "score" and target is not None) or mode == "sample")

        n_examples = len(inputs[0])
        max_length_inputs = [[inputs[i][j].size(0) for j in range(n_examples)]
                             for i in range(self.n_encoders)]
        max_length_target = target.size(0) if target is not None else 10
        batch_size = inputs[0][0].size(1)

        score = Variable(self._zeros(batch_size))
        inputs_scatter = [
            [
                Variable(
                    self._zeros(max_length_inputs[i][j], batch_size,
                                self.v_inputs[i] + 1).scatter_(
                                    2, inputs[i][j][:, :, None], 1))
                for j in range(n_examples)
            ] for i in range(self.n_encoders)
        ]  # n_encoders * n_examples * (max_length_input * batch_size * v_input+1)
        if target is not None:
            target_scatter = Variable(
                self._zeros(
                    max_length_target, batch_size, self.v_target + 1).scatter_(
                        2, target[:, :, None],
                        1))  # max_length_target * batch_size * v_target+1

        H = [
        ]  # n_encoders * n_examples * (max_length_input * batch_size * h_encoder_size)
        embeddings = []  # n_encoders * (h for example at INPUT_EOS)
        attention_mask = [
        ]  # n_encoders * (0 until (and including) INPUT_EOS, then -inf)

        def attend(i, j, h):
            """
            'general' attention from https://arxiv.org/pdf/1508.04025.pdf
            :param i: which encoder is doing the attending (or self.n_encoders for the decoder)
            :param j: Index of example
            :param h: batch_size * hidden_size
            """
            assert (i != 0)
            scores = self.As[i - 1](H[i - 1][j].view(
                max_length_inputs[i - 1][j] * batch_size,
                self.hidden_size), h.view(batch_size, self.hidden_size).repeat(
                    max_length_inputs[i - 1][j], 1)).view(
                        max_length_inputs[i - 1][j],
                        batch_size) + attention_mask[i - 1][j]
            c = (F.softmax(scores[:, :, None], dim=0) * H[i - 1][j]).sum(0)
            return c

        # -------------- Encoders -------------
        for i in range(len(self.input_vocabularies)):
            H.append([])
            embeddings.append([])
            attention_mask.append([])

            for j in range(n_examples):
                active = self._ones(max_length_inputs[i][j], batch_size).byte()
                state = self._encoder_get_init(
                    i,
                    batch_size=batch_size,
                    h=embeddings[i - 1][j] if i > 0 else None)
                hs = []
                h = self._cell_get_h(state)
                for k in range(max_length_inputs[i][j]):
                    if i == 0:
                        state = self.encoder_cells[i](
                            inputs_scatter[i][j][k, :, :], state)
                    else:
                        state = self.encoder_cells[i](torch.cat(
                            [inputs_scatter[i][j][k, :, :],
                             attend(i, j, h)], 1), state)
                    if k + 1 < max_length_inputs[i][j]:
                        active[k + 1, :] = active[k, :] * (inputs[i][j][k, :]
                                                           != self.v_inputs[i])
                    h = self._cell_get_h(state)
                    hs.append(h[None, :, :])
                H[i].append(torch.cat(hs, 0))
                embedding_idx = active.sum(0).long() - 1
                embedding = H[i][j].gather(
                    0,
                    Variable(embedding_idx[None, :, None].repeat(
                        1, 1, self.hidden_size)))[0]
                embeddings[i].append(embedding)
                attention_mask[i].append(Variable(active.float().log()))

        # ------------------ Decoder -----------------
        # Multi-example pooling: Figure 3, https://arxiv.org/pdf/1703.07469.pdf
        target = target if mode == "score" else self._zeros(
            max_length_target, batch_size).long()
        decoder_states = [
            self._decoder_get_init(embeddings[self.n_encoders - 1][j])
            for j in range(n_examples)
        ]  #P
        active = self._ones(batch_size).byte()
        for k in range(max_length_target):
            FC = []
            for j in range(n_examples):
                h = self._cell_get_h(decoder_states[j])
                p_aug = torch.cat([h, attend(self.n_encoders, j, h)], 1)
                FC.append(F.tanh(self.W(p_aug)[None, :, :]))
            m = torch.max(torch.cat(FC, 0),
                          0)[0]  # batch_size * embedding_size
            logsoftmax = F.log_softmax(self.V(m), dim=1)
            if mode == "sample":
                target[k, :] = torch.multinomial(logsoftmax.data.exp(), 1)[:,
                                                                           0]
            score = score + choose(logsoftmax, target[k, :]) * Variable(
                active.float())
            active *= (target[k, :] != self.v_target)
            for j in range(n_examples):
                if mode == "score":
                    target_char_scatter = target_scatter[k, :, :]
                elif mode == "sample":
                    target_char_scatter = Variable(
                        self._zeros(batch_size, self.v_target + 1).scatter_(
                            1, target[k, :, None], 1))
                decoder_states[j] = self.decoder_cell(target_char_scatter,
                                                      decoder_states[j])
        return target, score

    def _inputsToTensors(self, inputsss):
        """
        :param inputs: size = nBatch * nExamples * nEncoders
        Returns nEncoders * nExamples tensors of size nBatch * max_len
        """
        tensors = []
        for i in range(self.n_encoders):
            tensors.append([])
            for j in range(len(inputsss[0])):
                inputs = [x[j][i] for x in inputsss]
                maxlen = max(len(s) for s in inputs)
                t = self._ones(maxlen + 1,
                               len(inputs)).long() * self.v_inputs[i]
                for k in range(len(inputs)):
                    s = inputs[k]
                    if len(s) > 0:
                        t[:len(s), k] = torch.LongTensor(
                            [self.input_vocabularies_index[i][x] for x in s])
                tensors[i].append(t)
        return tensors

    def _targetToTensor(self, targets):
        """
        :param targets: 
        """
        maxlen = max(len(s) for s in targets)
        t = self._ones(maxlen + 1, len(targets)).long() * self.v_target
        for i in range(len(targets)):
            s = targets[i]
            if len(s) > 0:
                t[:len(s), i] = torch.LongTensor(
                    [self.target_vocabulary_index[x] for x in s])
        return t

    def _tensorToOutput(self, tensor):
        """
        :param tensor: max_length * batch_size
        """
        out = []
        for i in range(tensor.size(1)):
            l = tensor[:, i].tolist()
            if l[0] == self.v_target:
                out.append(tuple())
            elif self.v_target in l:
                final = tensor[:, i].tolist().index(self.v_target)
                out.append(
                    tuple(self.target_vocabulary[x]
                          for x in tensor[:final, i]))
            else:
                out.append(
                    tuple(self.target_vocabulary[x] for x in tensor[:, i]))
        return out
Esempio n. 25
0
class MultiheadAttention(Module):
    r"""Allows the model to jointly attend to information
    from different representation subspaces.
    See reference: Attention Is All You Need

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
        \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

    Args:
        embed_dim: total dimension of the model.
        num_heads: parallel attention heads.
        add_bias_kv: add bias to the key and value sequences at dim=0.
        add_zero_attn: add a new batch of zeros to the key and 
                       value sequences at dim=1.

    Examples::

        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
    """
    def __init__(self,
                 embed_dim,
                 num_heads,
                 dropout=0.,
                 bias=True,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim**-0.5

        self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = Linear(embed_dim, embed_dim, bias=bias)

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
            self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):
        xavier_uniform_(self.in_proj_weight[:self.embed_dim, :])
        xavier_uniform_(self.in_proj_weight[self.embed_dim:(self.embed_dim *
                                                            2), :])
        xavier_uniform_(self.in_proj_weight[(self.embed_dim * 2):, :])

        xavier_uniform_(self.out_proj.weight)
        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    @weak_script_method
    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                incremental_state=None,
                need_weights=True,
                static_kv=False,
                attn_mask=None):
        r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output. 
            See "Attention Is All You Need" for more details. 
        key_padding_mask: if provided, specified padding elements in the key will 
            be ignored by the attention.
        incremental_state: if provided, previous time steps are cached.
        need_weights: output attn_output_weights.
        static_kv: if true, key and value are static. The key and value in previous 
            states will be used.
        attn_mask: mask that prevents attention to certain positions.

    Shape:
        - Inputs:

        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 
          the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 
          the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 
          the embedding dimension.
        - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
        - incremental_state: a dictionary used for storing states.
        - attn_mask: :math:`(L, L)` where L is the target sequence length.

        - Outputs:

        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 
          E is the embedding dimension.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
        """
        qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
        kv_same = key.data_ptr() == value.data_ptr()

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        assert key.size() == value.size()

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert kv_same and not qkv_same
                    key = value = None
        else:
            saved_state = None

        if qkv_same:
            # self-attention
            q, k, v = self._in_proj_qkv(query)
        elif kv_same:
            # encoder-decoder attention
            q = self._in_proj_q(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k, v = self._in_proj_kv(key)
        else:
            q = self._in_proj_q(query)
            k = self._in_proj_k(key)
            v = self._in_proj_v(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    key_padding_mask.new_zeros(key_padding_mask.size(0), 1)
                ],
                                             dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)

            self._set_input_buffer(incremental_state, saved_state)

        src_len = k.size(1)

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    torch.zeros(key_padding_mask.size(0),
                                1).type_as(key_padding_mask)
                ],
                                             dim=1)

        attn_output_weights = torch.bmm(q, k.transpose(1, 2))
        assert list(attn_output_weights.size()) == [
            bsz * self.num_heads, tgt_len, src_len
        ]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            attn_output_weights += attn_mask

        if key_padding_mask is not None:
            attn_output_weights = attn_output_weights.view(
                bsz, self.num_heads, tgt_len, src_len)
            attn_output_weights = attn_output_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2),
                float('-inf'),
            )
            attn_output_weights = attn_output_weights.view(
                bsz * self.num_heads, tgt_len, src_len)

        attn_output_weights = F.softmax(
            attn_output_weights.float(),
            dim=-1,
            dtype=torch.float32 if attn_output_weights.dtype == torch.float16
            else attn_output_weights.dtype)
        attn_output_weights = F.dropout(attn_output_weights,
                                        p=self.dropout,
                                        training=self.training)

        attn_output = torch.bmm(attn_output_weights, v)
        assert list(attn_output.size()) == [
            bsz * self.num_heads, tgt_len, self.head_dim
        ]
        attn_output = attn_output.transpose(0, 1).contiguous().view(
            tgt_len, bsz, embed_dim)
        attn_output = self.out_proj(attn_output)

        if need_weights:
            # average attention weights over heads
            attn_output_weights = attn_output_weights.view(
                bsz, self.num_heads, tgt_len, src_len)
            attn_output_weights = attn_output_weights.sum(
                dim=1) / self.num_heads
        else:
            attn_output_weights = None

        return attn_output, attn_output_weights

    def _in_proj_qkv(self, query):
        return self._in_proj(query).chunk(3, dim=-1)

    def _in_proj_kv(self, key):
        return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)

    def _in_proj_q(self, query):
        return self._in_proj(query, end=self.embed_dim)

    def _in_proj_k(self, key):
        return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)

    def _in_proj_v(self, value):
        return self._in_proj(value, start=2 * self.embed_dim)

    def _in_proj(self, input, start=0, end=None):
        weight = self.in_proj_weight
        bias = self.in_proj_bias
        weight = weight[start:end, :]
        if bias is not None:
            bias = bias[start:end]
        return F.linear(input, weight, bias)
Esempio n. 26
0
class MVG_binaryNet(nn.Module):
    def __init__(self, H1, H2, dropprob, scale):
        super(MVG_binaryNet, self).__init__()

        self.sq2pi = 0.797884560
        self.drop_prob = dropprob
        self.hidden1 = H1
        self.hidden2 = H2
        self.D_out = 1
        self.scale = scale
        self.w0 = Parameter(torch.Tensor(28 * 28, self.hidden1))
        stdv = 1. / math.sqrt(self.w0.data.size(1))
        self.w0.data = self.scale * self.w0.data.uniform_(-stdv, stdv)

        self.w1 = Parameter(torch.Tensor(self.hidden1, self.hidden2))
        stdv = 1. / math.sqrt(self.w1.data.size(1))
        self.w1.data = self.scale * self.w1.data.uniform_(-stdv, stdv)

        self.w2 = Parameter(torch.Tensor(self.hidden2, self.hidden2))
        stdv = 1. / math.sqrt(self.w2.data.size(1))
        self.w2.data = self.scale * self.w1.data.uniform_(-stdv, stdv)

        self.w3 = Parameter(torch.Tensor(self.hidden2, self.hidden2))
        stdv = 1. / math.sqrt(self.w3.data.size(1))
        self.w3.data = self.scale * self.w3.data.uniform_(-stdv, stdv)

        self.w4 = Parameter(torch.Tensor(self.hidden2, self.hidden2))
        stdv = 1. / math.sqrt(self.w4.data.size(1))
        self.w4.data = self.scale * self.w4.data.uniform_(-stdv, stdv)

        self.wlast = Parameter(torch.Tensor(self.hidden2, self.D_out))
        stdv = 1. / math.sqrt(self.wlast.data.size(1))
        self.wlast.data = self.scale * self.wlast.data.uniform_(-stdv, stdv)

        self.th0 = Parameter(torch.zeros(1, self.hidden1))
        self.th1 = Parameter(torch.zeros(1, self.hidden2))
        self.th2 = Parameter(torch.zeros(1, self.hidden2))
        self.th3 = Parameter(torch.zeros(1, self.hidden2))
        self.th4 = Parameter(torch.zeros(1, self.hidden2))
        self.thlast = Parameter(torch.zeros(1, self.D_out))

    def MVG_layer(self, xbar, xcov, m, th):
        # recieves neuron means and covariance, returns next layer means and covariances
        M = xbar.size()[0]
        H, H2 = m.size()
        #bn = nn.BatchNorm1d(H2, affine=False)
        sigma = torch.t(m)[None, :, :].repeat(M, 1, 1).bmm(xcov.clone().bmm(
            m.repeat(M, 1, 1))) + torch.diag(torch.sum(1 - m**2, 0)).repeat(
                M, 1, 1)
        tem = sigma.clone().resize(M, H2 * H2)
        diagsig2 = tem[:, ::(H2 + 1)]

        hbar = xbar.mm(m) + th.repeat(
            xbar.size()[0], 1)  # numerator of input to sigmoid non-linearity
        h = self.sq2pi * hbar / torch.sqrt(diagsig2)
        xbar_next = torch.tanh(
            h
        )  # this is equal to 2*torch.sigmoid(2*h1)-1 - NEED THE 2 in the argument!

        # x covariance across layer 2
        ey = Variable(torch.eye(H2).cuda())
        #xc2 = (1 - xbar_next ** 2)
        #xcov_next = self.sq2pi * sigma * xc2[:, :, None] * xc2[:, None, :] / torch.sqrt(diagsig2[:, :, None] * diagsig2[:, None, :]) + ey[None, :, :] * (
        #    1 - xbar_next[:, None, :] ** 2)

        xc2cop = (1 - xbar_next**2) / torch.sqrt(diagsig2)
        xcov_next = self.sq2pi * sigma * xc2cop[:, :,
                                                None] * xc2cop[:, None, :] + ey[
                                                    None, :, :] * (
                                                        1 -
                                                        xbar_next[:,
                                                                  None, :]**2)

        return xbar_next, xcov_next

    def forward(self, x, target):
        m0 = 2 * F.sigmoid(self.w0) - 1
        m1 = 2 * torch.sigmoid(self.w1) - 1
        m2 = 2 * torch.sigmoid(self.w2) - 1
        m3 = 2 * torch.sigmoid(self.w3) - 1
        m4 = 2 * torch.sigmoid(self.w4) - 1
        mlast = 2 * torch.sigmoid(self.wlast) - 1
        sq2pi = 0.797884560
        dtype = torch.FloatTensor

        H = self.hidden1
        D_out = self.D_out
        x = x.view(-1, 28 * 28)
        y = target[:, None]
        M = x.size()[0]
        M_double = M * 1.0
        #bn0 = nn.BatchNorm1d(x.size()[1], affine=False)
        x0_do = F.dropout(x, p=self.drop_prob, training=self.training)

        sigma_1 = torch.diag(torch.sum(1 - m0**2,
                                       0)).repeat(M, 1, 1)  # + sigma_1[:,:,m]
        tem = sigma_1.clone().resize(M, H * H)
        diagsig1 = tem[:, ::(H + 1)]
        h1bar = x0_do.mm(m0) + self.th0.repeat(
            x.size()[0], 1)  # numerator of input to sigmoid non-linearity
        h1 = sq2pi * h1bar / torch.sqrt(diagsig1)  #

        #bn1 = nn.BatchNorm1d(h1.size()[1], affine=False)

        x1bar = torch.tanh(h1)
        x1bar_d0 = F.dropout(x1bar, p=self.drop_prob, training=self.training)

        ey = Variable(torch.eye(H).cuda())
        xcov_1 = ey[None, :, :] * (1 - x1bar_d0[:, None, :]**2
                                   )  # diag cov neurons layer 1
        '''NEW LAYER FUNCTION'''
        x2bar, xcov_2 = self.MVG_layer(x1bar_d0, xcov_1, m1, self.th1)
        x3bar, xcov_3 = self.MVG_layer(
            F.dropout(x2bar, p=self.drop_prob, training=self.training), xcov_2,
            m2, self.th2)
        x4bar, xcov_4 = self.MVG_layer(
            F.dropout(x3bar, p=self.drop_prob, training=self.training), xcov_2,
            m3, self.th3)
        #x5bar, xcov_5 = self.MVG_layer(F.dropout(x4bar, p=self.drop_prob, training=self.training), xcov_4, m4, self.th4)

        H, H2 = mlast.size()
        sigmalast = torch.t(mlast)[None, :, :].repeat(M, 1, 1).bmm(
            xcov_4.clone().bmm(mlast.repeat(M, 1, 1))) + torch.diag(
                torch.sum(1 - mlast**2, 0)).repeat(M, 1, 1)
        tem = sigmalast.clone().resize(M, H2 * H2)
        diagsiglast = tem[:, ::(H2 + 1)]

        hlastbar = x4bar.mm(mlast) + self.thlast.repeat(x1bar.size()[0], 1)
        hlast = sq2pi * hlastbar / torch.sqrt(diagsiglast)

        loss_binary = nn.BCELoss()
        expected_loss = loss_binary(torch.squeeze(F.sigmoid(hlast)),
                                    torch.squeeze(y[:, 0]).type(dtype).cuda())
        logprobs_out = torch.log(F.sigmoid(hlast))
        pred = (torch.sigmoid(hlast) > 0.5).type(dtype)
        a = torch.abs((pred - y.type(dtype)))
        fraction_correct = (M_double - torch.sum(a)) / M_double

        return ((hlastbar, logprobs_out,
                 xcov_4)), expected_loss, fraction_correct
Esempio n. 27
0
class GCNmfConv(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 data,
                 n_components,
                 dropout,
                 bias=True):
        super(GCNmfConv, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.n_components = n_components
        self.dropout = dropout
        self.features = data.features.numpy()
        self.logp = Parameter(torch.FloatTensor(n_components))
        self.means = Parameter(torch.FloatTensor(n_components, in_features))
        self.logvars = Parameter(torch.FloatTensor(n_components, in_features))
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.adj2 = torch.mul(data.adj, data.adj).to(device)
        self.gmm = None
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight.data, gain=1.414)
        if self.bias is not None:
            self.bias.data.fill_(0)
        self.gmm = init_gmm(self.features, self.n_components)
        self.logp.data = torch.FloatTensor(np.log(
            self.gmm.weights_)).to(device)
        self.means.data = torch.FloatTensor(self.gmm.means_).to(device)
        self.logvars.data = torch.FloatTensor(np.log(
            self.gmm.covariances_)).to(device)

    def calc_responsibility(self, mean_mat, variances):
        dim = self.in_features
        log_n = (- 1 / 2) *\
            torch.sum(torch.pow(mean_mat - self.means.unsqueeze(1), 2) / variances.unsqueeze(1), 2)\
            - (dim / 2) * np.log(2 * np.pi) - (1 / 2) * torch.sum(self.logvars)
        log_prob = self.logp.unsqueeze(1) + log_n
        return torch.softmax(log_prob, dim=0)

    def forward(self, x, adj):
        x_imp = x.repeat(self.n_components, 1, 1)
        x_isnan = torch.isnan(x_imp)
        variances = torch.exp(self.logvars)
        mean_mat = torch.where(
            x_isnan,
            self.means.repeat((x.size(0), 1, 1)).permute(1, 0, 2), x_imp)
        var_mat = torch.where(
            x_isnan,
            variances.repeat((x.size(0), 1, 1)).permute(1, 0, 2),
            torch.zeros(size=x_imp.size(), device=device, requires_grad=True))

        # dropout
        dropmat = F.dropout(torch.ones_like(mean_mat),
                            self.dropout,
                            training=self.training)
        mean_mat = mean_mat * dropmat
        var_mat = var_mat * dropmat

        transform_x = torch.matmul(mean_mat, self.weight)
        if self.bias is not None:
            transform_x = torch.add(transform_x, self.bias)
        transform_covs = torch.matmul(var_mat, self.weight * self.weight)
        conv_x = []
        conv_covs = []
        for component_x in transform_x:
            conv_x.append(torch.spmm(adj, component_x))
        for component_covs in transform_covs:
            conv_covs.append(torch.spmm(self.adj2, component_covs))
        transform_x = torch.stack(conv_x, dim=0)
        transform_covs = torch.stack(conv_covs, dim=0)
        expected_x = ex_relu(transform_x, transform_covs)

        # calculate responsibility
        gamma = self.calc_responsibility(mean_mat, variances)
        expected_x = torch.sum(expected_x * gamma.unsqueeze(2), dim=0)
        return expected_x
Esempio n. 28
0
class EBP_binaryNetRelaxed(nn.Module):
    def __init__(self, H, drop_prb, scale):
        super(EBP_binaryNetRelaxed, self).__init__()

        self.drop_prob = drop_prb
        self.pion2 = 1.570796326794
        self.sq2pi = 0.797884560
        self.beta = 1.0
        self.hidden = H
        self.D_out = 1
        self.scale = scale

        self.w0 = Parameter(torch.Tensor(28 * 28, self.hidden))
        stdv = 1. / math.sqrt(self.w0.data.size(1))
        self.w0.data = self.scale * self.w0.data.uniform_(-stdv, stdv)

        self.w1 = Parameter(torch.Tensor(self.hidden, self.hidden))
        stdv = 1. / math.sqrt(self.w1.data.size(1))
        self.w1.data = self.scale * self.w1.data.uniform_(-stdv, stdv)

        self.w2 = Parameter(torch.Tensor(self.hidden, self.hidden))
        stdv = 1. / math.sqrt(self.w2.data.size(1))
        self.w2.data = self.scale * self.w1.data.uniform_(-stdv, stdv)

        self.w3 = Parameter(torch.Tensor(self.hidden, self.hidden))
        stdv = 1. / math.sqrt(self.w3.data.size(1))
        self.w3.data = self.scale * self.w3.data.uniform_(-stdv, stdv)

        self.w4 = Parameter(torch.Tensor(self.hidden, self.hidden))
        stdv = 1. / math.sqrt(self.w4.data.size(1))
        self.w4.data = self.scale * self.w4.data.uniform_(-stdv, stdv)

        self.wlast = Parameter(torch.Tensor(self.hidden, self.D_out))
        stdv = 1. / math.sqrt(self.wlast.data.size(1))
        self.wlast.data = self.scale * self.wlast.data.uniform_(-stdv, stdv)

        self.th0 = Parameter(torch.zeros(1, self.hidden))
        self.th1 = Parameter(torch.zeros(1, self.hidden))
        self.th2 = Parameter(torch.zeros(1, self.hidden))
        self.th3 = Parameter(torch.zeros(1, self.hidden))
        self.th4 = Parameter(torch.zeros(1, self.hidden))
        self.thlast = Parameter(torch.zeros(1, self.D_out))

    def EBP_layer(self, xbar, xcov, m, th):
        #recieves neuron means and covariance, returns next layer means and covariances
        M = xbar.size()[0]
        H, H2 = m.size()
        #bn = nn.BatchNorm1d(xbar.size()[1], affine=False)
        sigma = torch.t(m)[None, :, :].repeat(M, 1, 1).bmm(xcov.clone().bmm(
            m.repeat(M, 1, 1))) + torch.diag(torch.sum(1 - m**2, 0)).repeat(
                M, 1, 1)
        tem = sigma.clone().resize(M, H2 * H2)
        diagsig2 = tem[:, ::(H2 + 1)]

        hbar = xbar.mm(m) + th.repeat(
            xbar.size()[0], 1)  # numerator of input to sigmoid non-linearity
        h = self.beta * hbar / torch.sqrt(1 +
                                          self.pion2 * self.beta**2 * diagsig2)
        xbar_next = torch.tanh(
            h
        )  # this is equal to 2*torch.sigmoid(2*h1)-1 - NEED THE 2 in the argument!

        # x covariance across layer 2
        xc2 = (1 - xbar_next**2)
        xcov_next = Variable(
            torch.eye(H))[None, :, :] * (1 - xbar_next[:, None, :]**2)

        return xbar_next, xcov_next

    def expected_loss(self, target, forward_result):
        (a2, logprobs_out) = forward_result
        return F.nll_loss(logprobs_out, target)

    def forward(self, x, target):
        m0 = 2 * F.sigmoid(self.w0) - 1
        m1 = 2 * torch.sigmoid(self.w1) - 1
        m2 = 2 * torch.sigmoid(self.w2) - 1
        m3 = 2 * torch.sigmoid(self.w3) - 1
        m4 = 2 * torch.sigmoid(self.w4) - 1
        mlast = 2 * torch.sigmoid(self.wlast) - 1
        sq2pi = 0.797884560
        dtype = torch.FloatTensor

        H = self.hidden
        D_out = self.D_out
        x = x.view(-1, 28 * 28)
        y = target[:, None]
        M = x.size()[0]
        M_double = M * 1.0
        x0_do = F.dropout(x, p=self.drop_prob, training=self.training)
        sigma_1 = torch.diag(torch.sum(1 - m0**2,
                                       0)).repeat(M, 1, 1)  # + sigma_1[:,:,m]

        tem = sigma_1.clone().resize(M, H * H)
        diagsig1 = tem[:, ::(H + 1)]
        h1bar = x0_do.mm(m0) + self.th0.repeat(
            x.size()[0], 1)  # numerator of input to sigmoid non-linearity
        h1 = self.beta * h1bar / torch.sqrt(
            1 + self.pion2 * self.beta**2 * diagsig1)  #

        #bn = nn.BatchNorm1d(h1.size()[1], affine=False)
        x1bar = torch.tanh(h1)  #
        ey = Variable(torch.eye(H))
        x1bar_d0 = F.dropout(x1bar, p=self.drop_prob, training=self.training)

        xcov_1 = ey[None, :, :] * (
            1 - x1bar_d0[:, None, :]**2
        )  # diagonal of the layer covariance - ie. the var of neuron i
        '''NEW LAYER FUNCTION'''
        x2bar, xcov_2 = self.EBP_layer(x1bar_d0, xcov_1, m1, self.th1)
        x3bar, xcov_3 = self.EBP_layer(
            F.dropout(x2bar, p=self.drop_prob, training=self.training), xcov_2,
            m2, self.th2)
        x4bar, xcov_4 = self.EBP_layer(
            F.dropout(x3bar, p=self.drop_prob, training=self.training), xcov_2,
            m3, self.th3)
        #x5bar, xcov_5 = self.EBP_layer(F.dropout(x4bar, p=self.drop_prob, training=self.training), xcov_4, m4, self.th4)
        hlastbar = (x4bar.mm(mlast) + self.thlast.repeat(x1bar.size()[0], 1))

        logprobs_out = F.log_softmax(hlastbar)
        val, ind = torch.max(hlastbar, 1)
        tem = y.type(dtype) - ind.type(dtype)[:, None]
        fraction_correct = (M_double - torch.sum(
            (tem != 0)).type(dtype)) / M_double
        expected_loss = self.expected_loss(target, (hlastbar, logprobs_out))
        return ((hlastbar, logprobs_out,
                 xcov_4)), expected_loss, fraction_correct
Esempio n. 29
0
class gaussianMixtureModel(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = Ax + b`
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias.
            Default: ``True``
    Shape:
        - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
          additional dimensions
        - Output: :math:`(N, *, out\_features)` where all but the last dimension
          are the same shape as the input.
    Attributes:
        weight: the learnable weights of the module of shape
            `(out_features x in_features)`
        bias:   the learnable bias of the module of shape `(out_features)`
    Examples::
        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
    """
    def __init__(self,
                 latent_dim,
                 cluster_num,
                 batch_size,
                 opt=None,
                 bias=False):
        super(gaussianMixtureModel, self).__init__()
        self.latent_dim = latent_dim
        self.cluster_num = cluster_num
        self.batch_size = batch_size
        if opt == None:
            self.opt = None
        self.opt = opt
        self.is_first_ff = True
        #self.input_data_dim = input_data_dim
        #self.alpha = alpha
        #self.target_dict_size = target_dict_size
        #weight4loss = torch.ones(target_dict_size)
        #self.cross_entropy_loss = nn.NLLLoss(weight, size_average=False)
        self.cluster_mean = Parameter(torch.Tensor(latent_dim, cluster_num))
        self.cluster_variance_sq_unnorm = Parameter(
            torch.Tensor(latent_dim, cluster_num))
        self.cluster_prior = Parameter(torch.Tensor(cluster_num))
        if bias:
            self.cluster_bias = Parameter(torch.Tensor(cluster_num))
        else:
            self.register_parameter('cluster_bias', None)
        self.reset_parameters()
        print("init cluster_prior:", self.cluster_prior)

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.cluster_num)
        #torch.nn.init.constant_(self.cluster_mean, 0)
        torch.nn.init.constant_(self.cluster_variance_sq_unnorm, 1.0)
        torch.nn.init.constant_(self.cluster_prior, 1.0 / self.cluster_num)
        print("in reset_parameters init cluster_prior:", self.cluster_prior)
        #self.cluster_mean.data.constant_(0)
        self.cluster_mean.data.uniform_(-stdv, stdv)
        #self.cluster_variance_sq.data.constant_(1.0)
        #self.cluster_prior.data.uniform_(-stdv, stdv)
        #self.cluster_prior.data.constant_(1.0 / self.cluster_num)
        if self.cluster_bias is not None:
            self.cluster_bias.data.uniform_(-stdv, stdv)

    def forward(self, z_mean, z_log_variance_sq, z):
        if self.is_first_ff:
            inverse_sigmoid = lambda x: 0 - numpy.log(1 / x - 1)
            #torch.nn.init.constant_(self.cluster_mean, 0)
            stdv = 1. / math.sqrt(self.cluster_num)
            self.cluster_mean.data.uniform_(-stdv, stdv)
            if self.opt != None and self.opt.use_normalize_in_gmm:
                torch.nn.init.constant_(self.cluster_variance_sq_unnorm,
                                        inverse_sigmoid(0.99))
                torch.nn.init.constant_(
                    self.cluster_prior,
                    inverse_sigmoid(1.0 / self.cluster_num))
            else:
                torch.nn.init.constant_(self.cluster_variance_sq_unnorm, stdv)
                #torch.nn.init.constant_(self.cluster_variance_sq_unnorm, 1)
                torch.nn.init.constant_(self.cluster_prior,
                                        1.0 / self.cluster_num)
            self.is_first_ff = False
        if self.opt != None and self.opt.debug_mode >= 4:
            print("cluster_prior:", self.cluster_prior)
            print("cluster_mean:", self.cluster_mean)
            print("cluster_variance_sq_unnorm:",
                  self.cluster_variance_sq_unnorm)
        #print("z_mean:", z_mean.size(), "z_log_variance_sq:", z_log_variance_sq.size(), "z:", z.size())
        # shape
        self.batch_size = z_mean.size()[0]
        cluster_mean_duplicate = self.cluster_mean.repeat(
            self.batch_size, 1, 1)
        if self.opt != None and self.opt.use_normalize_in_gmm:
            #cluster_prior_prob = torch.nn.functional.sigmoid(self.cluster_prior)
            cluster_prior_prob = torch.nn.functional.softmax(
                self.cluster_prior)
            cluster_variance_sq = torch.nn.functional.sigmoid(
                self.cluster_variance_sq_unnorm)
            #cluster_variance_sq = torch.nn.functional.relu(self.cluster_variance_sq_unnorm) # soft relu is the best
        else:
            cluster_prior_prob = self.cluster_prior
            cluster_variance_sq = self.cluster_variance_sq_unnorm
        cluster_variance_sq_duplicate = cluster_variance_sq.repeat(
            self.batch_size, 1, 1)
        cluster_prior_duplicate = cluster_prior_prob.repeat(
            self.latent_dim, 1).repeat(self.batch_size, 1, 1)
        cluster_prior_duplicate_2D = cluster_prior_prob.repeat(
            self.batch_size, 1)

        z_mean_duplicate = z_mean.repeat(self.cluster_num, 1,
                                         1).permute(1, 2, 0)
        z_log_variance_sq_duplicate = z_log_variance_sq.repeat(
            self.cluster_num, 1, 1).permute(1, 2, 0)
        z_duplicate = z.repeat(self.cluster_num, 1, 1).permute(1, 2, 0)
        # prob
        #print("z_duplicate:", z_duplicate)
        #print("cluster_mean_duplicate:", cluster_mean_duplicate)
        if self.opt != None and self.opt.debug_mode >= 3:
            print("z size:", z.size())
            print("z_mean size:", z_mean.size())
            print("z_log_variance_sq size:", z_log_variance_sq.size())
            print("z_duplicate size:", z_duplicate.size())
            print("z_mean_duplicate size:", z_mean_duplicate.size())
            print("z_log_variance_sq_duplicate size:",
                  z_log_variance_sq_duplicate.size())

            print("cluster_mean_duplicate size:",
                  cluster_mean_duplicate.size())
            print("cluster_variance_sq_duplicate size:",
                  cluster_variance_sq_duplicate.size())
            print("cluster_prior_duplicate size:",
                  cluster_prior_duplicate.size())
            print("cluster_prior_duplicate_2D size:",
                  cluster_prior_duplicate_2D.size())
        if self.opt != None and self.opt.debug_mode >= 4:
            print("z:", z)
            print("z_mean:", z_mean)
            print("z_log_variance_sq:", z_log_variance_sq)
            print("z_duplicate:", z_duplicate)
            print("z_mean_duplicate:", z_mean_duplicate)
            print("z_log_variance_sq_duplicate:", z_log_variance_sq_duplicate)

            print("cluster_mean_duplicate:", cluster_mean_duplicate)
            print("cluster_variance_sq_duplicate:",
                  cluster_variance_sq_duplicate)
            print("cluster_prior:", self.cluster_prior)
            print("cluster_prior_prob:", cluster_prior_prob)
            print("cluster_prior_duplicate:", cluster_prior_duplicate)
            print("cluster_prior_duplicate_2D:", cluster_prior_duplicate_2D)
        #tmpa = cluster_mean_duplicate - z_log_variance_sq_duplicate.cuda()
        tmpa = z_duplicate - cluster_mean_duplicate
        tmpb = tmpa * tmpa
        terms = torch.log(cluster_prior_duplicate) \
            - 0.5 * torch.log(2 * math.pi * cluster_variance_sq_duplicate) \
            - tmpb / (2 * cluster_variance_sq_duplicate)
        P_c_given_x_unnorm = torch.exp(Utils.sum_with_axis(terms, [1])) + 1e-10
        #print(P_c_given_x_unnorm)
        #print(sum_with_axis(P_c_given_x_unnorm, [-1]))
        P_c_given_x = Utils.myMatrixDivVector(P_c_given_x_unnorm, \
            Utils.sum_with_axis(P_c_given_x_unnorm, [-1]))

        # loss
        P_c_given_x_duplicate = P_c_given_x.repeat(self.latent_dim, 1,
                                                   1).permute(1, 0, 2)
        #cross_entropy_loss = alpha * self.input_data_dim * self.cross_entropy_loss()
        factor1 = 0.5 * P_c_given_x_duplicate
        #tmp1 = self.latent_dim * math.log(math.pi * 2)
        tmp1 = 0
        tmp2 = torch.log(cluster_variance_sq_duplicate)
        tmp3 = torch.exp(
            z_log_variance_sq_duplicate) / cluster_variance_sq_duplicate
        tmp4 = z_mean_duplicate - cluster_mean_duplicate
        tmp5 = tmp4 * tmp4 / cluster_variance_sq_duplicate
        #tmp111 = tmp1 + tmp2
        #tmp112 = tmp111 + tmp3
        #tmp113 = tmp112 + tmp5
        #second_term = sum_with_axis(tmp113, [1, 2])
        second_term_unfold = factor1 * (tmp1 + tmp2 + tmp3 + tmp5)
        second_term = Utils.sum_with_axis(second_term_unfold, [1, 2])
        tmp6 = Utils.sum_with_axis(P_c_given_x * torch.log(P_c_given_x), [1])
        tmp7 = Utils.sum_with_axis(
            P_c_given_x * torch.log(cluster_prior_duplicate_2D), [1])
        third_term_KL_div = tmp7 - tmp6
        #third_term_KL_div = tmp6 - tmp7
        forth_term = 0.5 * Utils.sum_with_axis(z_log_variance_sq + 1, [1])
        #loss_without_reconstruct = 0 - second_term + third_term_KL_div * self.latent_dim / 2 + forth_term
        #loss_without_reconstruct = 0 - second_term + forth_term
        loss_without_reconstruct = 0 - second_term + third_term_KL_div + forth_term
        #tmp212 = tmp211 + forth_term
        #loss_without_reconstruct = tmp212
        #loss_without_reconstruct = 0 - second_term + third_term_KL_div + forth_term

        nagetive_loss_without_reconstruct = 0 - loss_without_reconstruct

        if self.opt != None and self.opt.debug_mode >= 3:
            print("size terms:", terms.size(), "P_c_given_x_duplicate:",
                  P_c_given_x_duplicate.size(), "P_c_given_x_unnorm:",
                  P_c_given_x_unnorm.size(), "P_c_given_x", P_c_given_x.size(),
                  "second_term:", second_term.size(), "third_term_KL_div:",
                  third_term_KL_div.size(), "forth_term:", forth_term.size(),
                  "nagetive_loss_without_reconstruct:",
                  nagetive_loss_without_reconstruct.size())
            print("tmp2:", tmp2.size(), "tmp3:", tmp3.size(), "tmp4:",
                  tmp4.size(), "tmp5:", tmp5.size(), "tmp6:",
                  tmp6.size(), "tmp7:", tmp7.size(), "second_term:",
                  second_term.size(), "third_term_KL_div:",
                  third_term_KL_div.size(), "z_log_variance_sq:",
                  z_log_variance_sq.size())
            #print("tmp211:", tmp211.size(), "forth_term:", forth_term.size())
        if self.opt != None and self.opt.debug_mode >= 5:
            print("sum_with_axis(terms, [1]):",
                  Utils.sum_with_axis(terms, [1]))
            print("tmpa:", tmpa)
            print("tmpb:", tmpb)
            print("tmpb / (2 * cluster_variance_sq_duplicate):",
                  tmpb / (2 * cluster_variance_sq_duplicate))
            print("torch.log(cluster_prior_duplicate):",
                  torch.log(cluster_prior_duplicate))
            print(
                "0.5 * torch.log(2 * math.pi * cluster_variance_sq_duplicate):",
                0.5 * torch.log(2 * math.pi * cluster_variance_sq_duplicate))
        if self.opt != None and self.opt.debug_mode >= 4:
            print("terms:", terms, "P_c_given_x_duplicate:",
                  P_c_given_x_duplicate, "P_c_given_x_unnorm:",
                  P_c_given_x_unnorm, "P_c_given_x", P_c_given_x,
                  "second_term:", second_term, "third_term_KL_div:",
                  third_term_KL_div, "forth_term:", forth_term,
                  "nagetive_loss_without_reconstruct:",
                  nagetive_loss_without_reconstruct)
            print("tmp1:", tmp1, "tmp2:", tmp2, "tmp3:", tmp3, "tmp4:", tmp4,
                  "tmp5:", tmp5, "tmp6:", tmp6, "tmp7:", tmp7, "second_term:",
                  second_term, "third_term_KL_div:", third_term_KL_div,
                  "z_log_variance_sq:", z_log_variance_sq)
            #print("tmp211:", tmp211, "forth_term:", forth_term)
        return P_c_given_x, nagetive_loss_without_reconstruct

    '''def extra_repr(self):
Esempio n. 30
0
class CoCoAttention(Module):
    r"""Allows the model to jointly attend to information
    from different representation subspaces.
    See reference: Attention Is All You Need

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
        \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

    Args:
        embed_dim: total dimension of the model
        num_heads: parallel attention layers, or heads

    Examples::

        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
    """
    def __init__(self,
                 embed_dim,
                 num_heads,
                 dropout=0.,
                 bias=True,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super(CoCoAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        #self.head_dim = embed_dim // num_heads
        self.head_dim = 32 // num_heads
        #assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim**-0.5

        #self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
        self.in_proj_weight = Parameter(torch.empty(3 * 32, embed_dim))
        if bias:
            #self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
            self.in_proj_bias = Parameter(torch.empty(3 * 32))
        else:
            self.register_parameter('in_proj_bias', None)
        #self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = Linear(32, 1, bias=bias)

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
            self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):
        xavier_uniform_(self.in_proj_weight[:self.embed_dim, :])
        xavier_uniform_(self.in_proj_weight[self.embed_dim:(self.embed_dim *
                                                            2), :])
        xavier_uniform_(self.in_proj_weight[(self.embed_dim * 2):, :])

        xavier_uniform_(self.out_proj.weight)
        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                incremental_state=None,
                need_weights=True,
                static_kv=False,
                attn_mask=None):
        """
        Inputs of forward function
            query: [target length, batch size, embed dim]
            key: [sequence length, batch size, embed dim]
            value: [sequence length, batch size, embed dim]
            key_padding_mask: if True, mask padding based on batch size
            incremental_state: if provided, previous time steps are cashed
            need_weights: output attn_output_weights
            static_kv: key and value are static

        Outputs of forward function
            attn_output: [target length, batch size, embed dim]
            attn_output_weights: [batch size, target length, sequence length]
        """
        qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
        kv_same = key.data_ptr() == value.data_ptr()

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        assert key.size() == value.size()

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert kv_same and not qkv_same
                    key = value = None
        else:
            saved_state = None

        if qkv_same:
            # self-attention
            q, k, v = self._in_proj_qkv(query)
        elif kv_same:
            # encoder-decoder attention
            q = self._in_proj_q(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k, v = self._in_proj_kv(key)
        else:
            q = self._in_proj_q(query)
            k = self._in_proj_k(key)
            v = self._in_proj_v(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    key_padding_mask.new_zeros(key_padding_mask.size(0), 1)
                ],
                                             dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)

            self._set_input_buffer(incremental_state, saved_state)

        src_len = k.size(1)

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    torch.zeros(key_padding_mask.size(0),
                                1).type_as(key_padding_mask)
                ],
                                             dim=1)

        attn_output_weights = torch.bmm(q, k.transpose(1, 2))
        assert list(attn_output_weights.size()) == [
            bsz * self.num_heads, tgt_len, src_len
        ]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            attn_output_weights += attn_mask

        if key_padding_mask is not None:
            attn_output_weights = attn_output_weights.view(
                bsz, self.num_heads, tgt_len, src_len)
            attn_output_weights = attn_output_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2),
                float('-inf'),
            )
            attn_output_weights = attn_output_weights.view(
                bsz * self.num_heads, tgt_len, src_len)

        attn_output_weights = F.softmax(
            attn_output_weights.float(),
            dim=-1,
            dtype=torch.float32 if attn_output_weights.dtype == torch.float16
            else attn_output_weights.dtype)
        attn_output_weights = F.dropout(attn_output_weights,
                                        p=self.dropout,
                                        training=self.training)

        attn_output = torch.bmm(attn_output_weights, v)
        assert list(attn_output.size()) == [
            bsz * self.num_heads, tgt_len, self.head_dim
        ]
        #attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn_output = attn_output.transpose(0, 1).contiguous().view(
            tgt_len, bsz, 32)
        attn_output = self.out_proj(attn_output)

        if need_weights:
            # average attention weights over heads
            attn_output_weights = attn_output_weights.view(
                bsz, self.num_heads, tgt_len, src_len)
            attn_output_weights = attn_output_weights.sum(
                dim=1) / self.num_heads
        else:
            attn_output_weights = None

        return attn_output, attn_output_weights

    def _in_proj_qkv(self, query):
        return self._in_proj(query).chunk(3, dim=-1)

    def _in_proj_kv(self, key):
        #return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
        return self._in_proj(key, start=32).chunk(2, dim=-1)

    def _in_proj_q(self, query):
        #return self._in_proj(query, end=self.embed_dim)
        return self._in_proj(query, end=32)

    def _in_proj_k(self, key):
        #return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
        return self._in_proj(key, start=32, end=2 * 32)

    def _in_proj_v(self, value):
        #return self._in_proj(value, start=2 * self.embed_dim)
        return self._in_proj(value, start=2 * 32)

    def _in_proj(self, input, start=0, end=None):
        weight = self.in_proj_weight
        bias = self.in_proj_bias
        weight = weight[start:end, :]
        if bias is not None:
            bias = bias[start:end]
        return F.linear(input, weight, bias)
Esempio n. 31
0
class MVG_binaryNet(nn.Module):
    def __init__(self, H1, H2, dropprob, scale):
        super(MVG_binaryNet, self).__init__()

        self.sq2pi = 0.797884560
        self.drop_prob = dropprob
        self.hidden1 = H1
        self.hidden2 = H2
        self.D_out = 10
        self.scale = scale
        self.w0 = Parameter(torch.Tensor(28 * 28, self.hidden1))
        stdv = 1. / math.sqrt(self.w0.data.size(1))
        self.w0.data = self.scale*self.w0.data.uniform_(-stdv, stdv)

        self.w1 = Parameter(torch.Tensor(self.hidden1, self.hidden2))
        stdv = 1. / math.sqrt(self.w1.data.size(1))
        self.w1.data =  self.scale*self.w1.data.uniform_(-stdv, stdv)

        self.w2 = Parameter(torch.Tensor(self.hidden2, self.hidden2))
        stdv = 1. / math.sqrt(self.w2.data.size(1))
        self.w2.data =  self.scale*self.w1.data.uniform_(-stdv, stdv)

        self.wlast = Parameter(torch.Tensor(self.hidden2, self.D_out))
        stdv = 1. / math.sqrt(self.wlast.data.size(1))
        self.wlast.data =  self.scale*self.wlast.data.uniform_(-stdv, stdv)

        self.th0 = Parameter(torch.zeros(1, self.hidden1))
        self.th1 = Parameter(torch.zeros(1, self.hidden2))
        self.th2 = Parameter(torch.zeros(1,self.hidden2))
        self.thlast = Parameter(torch.zeros(1, self.D_out))

    def expected_loss(self, target, forward_result):
        (a2, logprobs_out) = forward_result
        return F.nll_loss(logprobs_out, target)

    def MVG_layer(self, xbar, xcov, m, th):
        # recieves neuron means and covariance, returns next layer means and covariances
        M = xbar.size()[0]
        H, H2 = m.size()
        #bn = nn.BatchNorm1d(H2, affine=False)
        sigma = torch.t(m)[None, :, :].repeat(M, 1, 1).bmm(xcov.clone().bmm(m.repeat(M, 1, 1))) + torch.diag(
            torch.sum(1 - m ** 2, 0)).repeat(M, 1, 1)
        tem = sigma.clone().resize(M, H2 * H2)
        diagsig2 = tem[:, ::(H2 + 1)]

        hbar = xbar.mm(m) + th.repeat(xbar.size()[0], 1)  # numerator of input to sigmoid non-linearity
        h = self.sq2pi * hbar / torch.sqrt(diagsig2)
        xbar_next = torch.tanh(h)  # this is equal to 2*torch.sigmoid(2*h1)-1 - NEED THE 2 in the argument!

        ey = Variable(torch.eye(H2))
        xc2cop = (1 - xbar_next ** 2) / torch.sqrt(diagsig2)
        xcov_next = self.sq2pi * sigma *xc2cop[:,:,None]*xc2cop[:,None,:]+ey[None,:,:]*(1-xbar_next[:,None, :] ** 2)

        return xbar_next, xcov_next

    def forward(self, x, target):
        m0 = 2 * F.sigmoid(self.w0) - 1
        m1 = 2 * torch.sigmoid(self.w1) - 1
        m2 = 2 * torch.sigmoid(self.w2) - 1
        mlast = 2 * torch.sigmoid(self.wlast) - 1
        sq2pi = 0.797884560
        dtype = torch.FloatTensor

        H = self.hidden1
        D_out = self.D_out
        x = x.view(-1, 28 * 28)
        y = target[:, None]
        M = x.size()[0]
        M_double = M * 1.0
        #bn0 = nn.BatchNorm1d(x.size()[1], affine=False)
        x0_do = F.dropout(x, p=self.drop_prob, training=self.training)

        sigma_1 = torch.diag(torch.sum(1 - m0 ** 2, 0)).repeat(M, 1, 1)  # + sigma_1[:,:,m]
        tem = sigma_1.clone().resize(M, H * H)
        diagsig1 = tem[:, ::(H + 1)]
        h1bar = x0_do.mm(m0) + self.th0.repeat(x.size()[0], 1)  # numerator of input to sigmoid non-linearity
        h1 = sq2pi * h1bar / torch.sqrt(diagsig1)  #

        #bn1 = nn.BatchNorm1d(h1.size()[1], affine=False)

        x1bar = torch.tanh(h1)
        x1bar_d0 = F.dropout(x1bar, p=self.drop_prob, training=self.training)

        ey = Variable(torch.eye(H))
        xcov_1 = ey[None, :, :]*(1 - x1bar_d0[:, None, :] ** 2) # diag cov neurons layer 1

        '''NEW LAYER FUNCTION'''
        x2bar, xcov_2 = self.MVG_layer(x1bar_d0, xcov_1, m1, self.th1)
        x3bar, xcov_3 = self.MVG_layer(F.dropout(x2bar, p=self.drop_prob, training=self.training), xcov_2, m2, self.th2)
        hlastbar = x3bar.mm(mlast) + self.thlast.repeat(x1bar.size()[0], 1)

        logprobs_out = F.log_softmax(hlastbar)
        val, ind = torch.max(hlastbar, 1)
        tem = y.type(dtype) - ind.type(dtype)[:, None]
        fraction_correct = (M_double - torch.sum((tem != 0)).type(dtype)) / M_double
        expected_loss = self.expected_loss(target, (hlastbar, logprobs_out))
        return ((hlastbar, logprobs_out, xcov_3)), expected_loss, fraction_correct