class HighwayLSTMCell(Module): def __init__(self, input_size, output_size, name="lstm"): super(HighwayLSTMCell, self).__init__(name=name) self.input_size = input_size self.output_size = output_size with utils.scope(name): self.gates = Affine(input_size + output_size, 5 * output_size, name="gates") self.trans = Affine(input_size, output_size, name="trans") self.reset_parameters() def forward(self, x, state): c, h = state gates = self.gates(torch.cat([x, h], 1)) combined = torch.reshape(gates, [-1, 5, self.output_size]) i, j, f, o, t = torch.unbind(combined, 1) i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o) t = torch.sigmoid(t) new_c = f * c + i * torch.tanh(j) tmp_h = o * torch.tanh(new_c) new_h = t * tmp_h + (1.0 - t) * self.trans(x) return new_h, (new_c, new_h) def init_state(self, batch_size, dtype, device): c = torch.zeros([batch_size, self.output_size], dtype=dtype, device=device) h = torch.zeros([batch_size, self.output_size], dtype=dtype, device=device) return c, h def mask_state(self, state, prev_state, mask): c, h = state prev_c, prev_h = prev_state mask = mask[:, None] new_c = mask * c + (1.0 - mask) * prev_c new_h = mask * h + (1.0 - mask) * prev_h return new_c, new_h def reset_parameters(self, initializer="orthogonal"): if initializer == "uniform_scaling": nn.init.xavier_uniform_(self.gates.weight) nn.init.constant_(self.gates.bias, 0.0) elif initializer == "uniform": nn.init.uniform_(self.gates.weight, -0.04, 0.04) nn.init.uniform_(self.gates.bias, -0.04, 0.04) elif initializer == "orthogonal": self.gates.orthogonal_initialize() self.trans.orthogonal_initialize() else: raise ValueError("Unknown initializer %d" % initializer)
def __init__(self, input_size, output_size, name="lstm"): super(HighwayLSTMCell, self).__init__(name=name) self.input_size = input_size self.output_size = output_size with utils.scope(name): self.gates = Affine(input_size + output_size, 5 * output_size, name="gates") self.trans = Affine(input_size, output_size, name="trans") self.reset_parameters()
def __init__(self, input_size, output_size, k=2, num_cells=4, name="lstm"): super(DynamicLSTMCell, self).__init__(name=name) self.input_size = input_size self.output_size = output_size self.num_cells = num_cells self.k = k with utils.scope(name): self.gates = Affine(input_size + output_size, 4 * output_size * num_cells, name="gates") self.topk_gate = Affine(input_size + output_size, num_cells, name="controller") self.reset_parameters()
def __init__(self, hidden_size, num_heads, dropout=0.0, name="multihead_attention"): super(MultiHeadAttention, self).__init__(name=name) self.num_heads = num_heads self.hidden_size = hidden_size self.dropout = dropout with utils.scope(name): self.qkv_transform = Affine(hidden_size, 3 * hidden_size, name="qkv_transform") self.o_transform = Affine(hidden_size, hidden_size, name="o_transform") self.reset_parameters()
def __init__(self, input_size, output_size, normalization=False, activation=torch.tanh, name="lstm"): super(LSTMCell, self).__init__(name=name) self.input_size = input_size self.output_size = output_size self.activation = activation with utils.scope(name): self.gates = Affine(input_size + output_size, 4 * output_size, name="gates") if normalization: self.layer_norm = LayerNorm([4, output_size]) else: self.layer_norm = None self.reset_parameters()
def __init__(self, input_size, output_size, normalization=False, name="gru"): super(GRUCell, self).__init__(name=name) self.input_size = input_size self.output_size = output_size with utils.scope(name): self.reset_gate = Affine(input_size + output_size, output_size, bias=False, name="reset_gate") self.update_gate = Affine(input_size + output_size, output_size, bias=False, name="update_gate") self.transform = Affine(input_size + output_size, output_size, name="transform")
def __init__(self, input_size, hidden_size, output_size=None, dropout=0.0, name="feed_forward"): super(FeedForward, self).__init__(name=name) self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size or input_size self.dropout = dropout with utils.scope(name): self.input_transform = Affine(input_size, hidden_size, name="input_transform") self.output_transform = Affine(hidden_size, self.output_size, name="output_transform") self.reset_parameters()
class FeedForward(Module): def __init__(self, input_size, hidden_size, output_size=None, dropout=0.0, name="feed_forward"): super(FeedForward, self).__init__(name=name) self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size or input_size self.dropout = dropout with utils.scope(name): self.input_transform = Affine(input_size, hidden_size, name="input_transform") self.output_transform = Affine(hidden_size, self.output_size, name="output_transform") self.reset_parameters() def forward(self, x): h = nn.functional.relu(self.input_transform(x)) h = nn.functional.dropout(h, self.dropout, self.training) return self.output_transform(h) def reset_parameters(self, initializer="orthogonal"): if initializer == "orthogonal": self.input_transform.orthogonal_initialize() self.output_transform.orthogonal_initialize() else: nn.init.xavier_uniform_(self.input_transform.weight) nn.init.xavier_uniform_(self.output_transform.weight) nn.init.constant_(self.input_transform.bias, 0.0) nn.init.constant_(self.output_transform.bias, 0.0)
class MultiHeadAttention(Module): def __init__(self, hidden_size, num_heads, dropout=0.0, name="multihead_attention"): super(MultiHeadAttention, self).__init__(name=name) self.num_heads = num_heads self.hidden_size = hidden_size self.dropout = dropout self.weights = None with utils.scope(name): self.qkv_transform = Affine(hidden_size, 3 * hidden_size, name="qkv_transform") self.o_transform = Affine(hidden_size, hidden_size, name="o_transform") self.reset_parameters() def forward(self, query, bias): qkv = self.qkv_transform(query) q, k, v = torch.split(qkv, self.hidden_size, dim=-1) # split heads qh = self.split_heads(q, self.num_heads) kh = self.split_heads(k, self.num_heads) vh = self.split_heads(v, self.num_heads) # scale query qh = qh * (self.hidden_size // self.num_heads)**-0.5 # dot-product attention kh = torch.transpose(kh, -2, -1) logits = torch.matmul(qh, kh) if bias is not None: logits = logits + bias self.weights = torch.nn.functional.dropout(torch.softmax(logits, dim=-1), p=self.dropout, training=self.training) x = torch.matmul(self.weights, vh) # combine heads output = self.o_transform(self.combine_heads(x)) return output def reset_parameters(self, initializer="orthogonal"): if initializer == "orthogonal": self.qkv_transform.orthogonal_initialize() self.o_transform.orthogonal_initialize() else: # 6 / (4 * hidden_size) -> 6 / (2 * hidden_size) nn.init.xavier_uniform_(self.qkv_transform.weight) nn.init.xavier_uniform_(self.o_transform.weight) nn.init.constant_(self.qkv_transform.bias, 0.0) nn.init.constant_(self.o_transform.bias, 0.0) @staticmethod def split_heads(x, heads): batch = x.shape[0] length = x.shape[1] channels = x.shape[2] y = torch.reshape(x, [batch, length, heads, channels // heads]) return torch.transpose(y, 2, 1) @staticmethod def combine_heads(x): batch = x.shape[0] heads = x.shape[1] length = x.shape[2] channels = x.shape[3] y = torch.transpose(x, 2, 1) return torch.reshape(y, [batch, length, heads * channels])
class LSTMCell(Module): def __init__(self, input_size, output_size, normalization=False, activation=torch.tanh, name="lstm"): super(LSTMCell, self).__init__(name=name) self.input_size = input_size self.output_size = output_size self.activation = activation with utils.scope(name): self.gates = Affine(input_size + output_size, 4 * output_size, name="gates") if normalization: self.layer_norm = LayerNorm([4, output_size]) else: self.layer_norm = None self.reset_parameters() def forward(self, x, state): c, h = state gates = self.gates(torch.cat([x, h], 1)) if self.layer_norm is not None: combined = self.layer_norm( torch.reshape(gates, [-1, 4, self.output_size])) else: combined = torch.reshape(gates, [-1, 4, self.output_size]) i, j, f, o = torch.unbind(combined, 1) i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o) new_c = f * c + i * torch.tanh(j) if self.activation is None: # Do not use tanh activation new_h = o * new_c else: new_h = o * self.activation(new_c) return new_h, (new_c, new_h) def init_state(self, batch_size, dtype, device): c = torch.zeros([batch_size, self.output_size], dtype=dtype, device=device) h = torch.zeros([batch_size, self.output_size], dtype=dtype, device=device) return c, h def mask_state(self, state, prev_state, mask): c, h = state prev_c, prev_h = prev_state mask = mask[:, None] new_c = mask * c + (1.0 - mask) * prev_c new_h = mask * h + (1.0 - mask) * prev_h return new_c, new_h def reset_parameters(self, initializer="orthogonal"): if initializer == "uniform_scaling": nn.init.xavier_uniform_(self.gates.weight) nn.init.constant_(self.gates.bias, 0.0) elif initializer == "uniform": nn.init.uniform_(self.gates.weight, -0.04, 0.04) nn.init.uniform_(self.gates.bias, -0.04, 0.04) elif initializer == "orthogonal": self.gates.orthogonal_initialize() else: raise ValueError("Unknown initializer %d" % initializer)