示例#1
0
class HighwayLSTMCell(Module):
    def __init__(self, input_size, output_size, name="lstm"):
        super(HighwayLSTMCell, self).__init__(name=name)

        self.input_size = input_size
        self.output_size = output_size

        with utils.scope(name):
            self.gates = Affine(input_size + output_size,
                                5 * output_size,
                                name="gates")
            self.trans = Affine(input_size, output_size, name="trans")

        self.reset_parameters()

    def forward(self, x, state):
        c, h = state

        gates = self.gates(torch.cat([x, h], 1))
        combined = torch.reshape(gates, [-1, 5, self.output_size])
        i, j, f, o, t = torch.unbind(combined, 1)
        i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
        t = torch.sigmoid(t)

        new_c = f * c + i * torch.tanh(j)
        tmp_h = o * torch.tanh(new_c)
        new_h = t * tmp_h + (1.0 - t) * self.trans(x)

        return new_h, (new_c, new_h)

    def init_state(self, batch_size, dtype, device):
        c = torch.zeros([batch_size, self.output_size],
                        dtype=dtype,
                        device=device)
        h = torch.zeros([batch_size, self.output_size],
                        dtype=dtype,
                        device=device)
        return c, h

    def mask_state(self, state, prev_state, mask):
        c, h = state
        prev_c, prev_h = prev_state
        mask = mask[:, None]
        new_c = mask * c + (1.0 - mask) * prev_c
        new_h = mask * h + (1.0 - mask) * prev_h
        return new_c, new_h

    def reset_parameters(self, initializer="orthogonal"):
        if initializer == "uniform_scaling":
            nn.init.xavier_uniform_(self.gates.weight)
            nn.init.constant_(self.gates.bias, 0.0)
        elif initializer == "uniform":
            nn.init.uniform_(self.gates.weight, -0.04, 0.04)
            nn.init.uniform_(self.gates.bias, -0.04, 0.04)
        elif initializer == "orthogonal":
            self.gates.orthogonal_initialize()
            self.trans.orthogonal_initialize()
        else:
            raise ValueError("Unknown initializer %d" % initializer)
示例#2
0
    def __init__(self, input_size, output_size, name="lstm"):
        super(HighwayLSTMCell, self).__init__(name=name)

        self.input_size = input_size
        self.output_size = output_size

        with utils.scope(name):
            self.gates = Affine(input_size + output_size,
                                5 * output_size,
                                name="gates")
            self.trans = Affine(input_size, output_size, name="trans")

        self.reset_parameters()
示例#3
0
    def __init__(self, input_size, output_size, k=2, num_cells=4, name="lstm"):
        super(DynamicLSTMCell, self).__init__(name=name)

        self.input_size = input_size
        self.output_size = output_size
        self.num_cells = num_cells
        self.k = k

        with utils.scope(name):
            self.gates = Affine(input_size + output_size,
                                4 * output_size * num_cells,
                                name="gates")
            self.topk_gate = Affine(input_size + output_size,
                                    num_cells,
                                    name="controller")

        self.reset_parameters()
    def __init__(self,
                 hidden_size,
                 num_heads,
                 dropout=0.0,
                 name="multihead_attention"):
        super(MultiHeadAttention, self).__init__(name=name)

        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.dropout = dropout

        with utils.scope(name):
            self.qkv_transform = Affine(hidden_size,
                                        3 * hidden_size,
                                        name="qkv_transform")
            self.o_transform = Affine(hidden_size,
                                      hidden_size,
                                      name="o_transform")

        self.reset_parameters()
示例#5
0
    def __init__(self,
                 input_size,
                 output_size,
                 normalization=False,
                 activation=torch.tanh,
                 name="lstm"):
        super(LSTMCell, self).__init__(name=name)

        self.input_size = input_size
        self.output_size = output_size
        self.activation = activation

        with utils.scope(name):
            self.gates = Affine(input_size + output_size,
                                4 * output_size,
                                name="gates")
            if normalization:
                self.layer_norm = LayerNorm([4, output_size])
            else:
                self.layer_norm = None

        self.reset_parameters()
示例#6
0
    def __init__(self,
                 input_size,
                 output_size,
                 normalization=False,
                 name="gru"):
        super(GRUCell, self).__init__(name=name)

        self.input_size = input_size
        self.output_size = output_size

        with utils.scope(name):
            self.reset_gate = Affine(input_size + output_size,
                                     output_size,
                                     bias=False,
                                     name="reset_gate")
            self.update_gate = Affine(input_size + output_size,
                                      output_size,
                                      bias=False,
                                      name="update_gate")
            self.transform = Affine(input_size + output_size,
                                    output_size,
                                    name="transform")
示例#7
0
    def __init__(self,
                 input_size,
                 hidden_size,
                 output_size=None,
                 dropout=0.0,
                 name="feed_forward"):
        super(FeedForward, self).__init__(name=name)

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size or input_size
        self.dropout = dropout

        with utils.scope(name):
            self.input_transform = Affine(input_size,
                                          hidden_size,
                                          name="input_transform")
            self.output_transform = Affine(hidden_size,
                                           self.output_size,
                                           name="output_transform")

        self.reset_parameters()
示例#8
0
class FeedForward(Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 output_size=None,
                 dropout=0.0,
                 name="feed_forward"):
        super(FeedForward, self).__init__(name=name)

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size or input_size
        self.dropout = dropout

        with utils.scope(name):
            self.input_transform = Affine(input_size,
                                          hidden_size,
                                          name="input_transform")
            self.output_transform = Affine(hidden_size,
                                           self.output_size,
                                           name="output_transform")

        self.reset_parameters()

    def forward(self, x):
        h = nn.functional.relu(self.input_transform(x))
        h = nn.functional.dropout(h, self.dropout, self.training)
        return self.output_transform(h)

    def reset_parameters(self, initializer="orthogonal"):
        if initializer == "orthogonal":
            self.input_transform.orthogonal_initialize()
            self.output_transform.orthogonal_initialize()
        else:
            nn.init.xavier_uniform_(self.input_transform.weight)
            nn.init.xavier_uniform_(self.output_transform.weight)
            nn.init.constant_(self.input_transform.bias, 0.0)
            nn.init.constant_(self.output_transform.bias, 0.0)
class MultiHeadAttention(Module):
    def __init__(self,
                 hidden_size,
                 num_heads,
                 dropout=0.0,
                 name="multihead_attention"):
        super(MultiHeadAttention, self).__init__(name=name)

        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.weights = None

        with utils.scope(name):
            self.qkv_transform = Affine(hidden_size,
                                        3 * hidden_size,
                                        name="qkv_transform")
            self.o_transform = Affine(hidden_size,
                                      hidden_size,
                                      name="o_transform")

        self.reset_parameters()

    def forward(self, query, bias):
        qkv = self.qkv_transform(query)
        q, k, v = torch.split(qkv, self.hidden_size, dim=-1)

        # split heads
        qh = self.split_heads(q, self.num_heads)
        kh = self.split_heads(k, self.num_heads)
        vh = self.split_heads(v, self.num_heads)

        # scale query
        qh = qh * (self.hidden_size // self.num_heads)**-0.5

        # dot-product attention
        kh = torch.transpose(kh, -2, -1)
        logits = torch.matmul(qh, kh)

        if bias is not None:
            logits = logits + bias

        self.weights = torch.nn.functional.dropout(torch.softmax(logits,
                                                                 dim=-1),
                                                   p=self.dropout,
                                                   training=self.training)

        x = torch.matmul(self.weights, vh)

        # combine heads
        output = self.o_transform(self.combine_heads(x))

        return output

    def reset_parameters(self, initializer="orthogonal"):
        if initializer == "orthogonal":
            self.qkv_transform.orthogonal_initialize()
            self.o_transform.orthogonal_initialize()
        else:
            # 6 / (4 * hidden_size) -> 6 / (2 * hidden_size)
            nn.init.xavier_uniform_(self.qkv_transform.weight)
            nn.init.xavier_uniform_(self.o_transform.weight)
            nn.init.constant_(self.qkv_transform.bias, 0.0)
            nn.init.constant_(self.o_transform.bias, 0.0)

    @staticmethod
    def split_heads(x, heads):
        batch = x.shape[0]
        length = x.shape[1]
        channels = x.shape[2]

        y = torch.reshape(x, [batch, length, heads, channels // heads])
        return torch.transpose(y, 2, 1)

    @staticmethod
    def combine_heads(x):
        batch = x.shape[0]
        heads = x.shape[1]
        length = x.shape[2]
        channels = x.shape[3]

        y = torch.transpose(x, 2, 1)

        return torch.reshape(y, [batch, length, heads * channels])
示例#10
0
class LSTMCell(Module):
    def __init__(self,
                 input_size,
                 output_size,
                 normalization=False,
                 activation=torch.tanh,
                 name="lstm"):
        super(LSTMCell, self).__init__(name=name)

        self.input_size = input_size
        self.output_size = output_size
        self.activation = activation

        with utils.scope(name):
            self.gates = Affine(input_size + output_size,
                                4 * output_size,
                                name="gates")
            if normalization:
                self.layer_norm = LayerNorm([4, output_size])
            else:
                self.layer_norm = None

        self.reset_parameters()

    def forward(self, x, state):
        c, h = state

        gates = self.gates(torch.cat([x, h], 1))

        if self.layer_norm is not None:
            combined = self.layer_norm(
                torch.reshape(gates, [-1, 4, self.output_size]))
        else:
            combined = torch.reshape(gates, [-1, 4, self.output_size])

        i, j, f, o = torch.unbind(combined, 1)
        i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)

        new_c = f * c + i * torch.tanh(j)

        if self.activation is None:
            # Do not use tanh activation
            new_h = o * new_c
        else:
            new_h = o * self.activation(new_c)

        return new_h, (new_c, new_h)

    def init_state(self, batch_size, dtype, device):
        c = torch.zeros([batch_size, self.output_size],
                        dtype=dtype,
                        device=device)
        h = torch.zeros([batch_size, self.output_size],
                        dtype=dtype,
                        device=device)
        return c, h

    def mask_state(self, state, prev_state, mask):
        c, h = state
        prev_c, prev_h = prev_state
        mask = mask[:, None]
        new_c = mask * c + (1.0 - mask) * prev_c
        new_h = mask * h + (1.0 - mask) * prev_h
        return new_c, new_h

    def reset_parameters(self, initializer="orthogonal"):
        if initializer == "uniform_scaling":
            nn.init.xavier_uniform_(self.gates.weight)
            nn.init.constant_(self.gates.bias, 0.0)
        elif initializer == "uniform":
            nn.init.uniform_(self.gates.weight, -0.04, 0.04)
            nn.init.uniform_(self.gates.bias, -0.04, 0.04)
        elif initializer == "orthogonal":
            self.gates.orthogonal_initialize()
        else:
            raise ValueError("Unknown initializer %d" % initializer)