def test_encoder_cache(normalize_before): adim = 4 idim = 5 encoder = Encoder( idim=idim, attention_dim=adim, linear_units=3, num_blocks=2, normalize_before=normalize_before, dropout_rate=0.0, input_layer="embed") elayer = encoder.encoders[0] x = torch.randn(2, 5, adim) mask = subsequent_mask(x.shape[1]).unsqueeze(0) prev_mask = mask[:, :-1, :-1] encoder.eval() with torch.no_grad(): # layer-level test y = elayer(x, mask, None)[0] cache = elayer(x[:, :-1], prev_mask, None)[0] y_fast = elayer(x, mask, cache=cache)[0] numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=1e-5) # encoder-level test x = torch.randint(0, idim, x.shape[:2]) y = encoder.forward_one_step(x, mask)[0] y_, _, cache = encoder.forward_one_step(x[:, :-1], prev_mask) y_fast, _, _ = encoder.forward_one_step(x, mask, cache=cache) numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=1e-5)
memory = torch.randn(2, 500, adim) mask = subsequent_mask(xlen).unsqueeze(0) result = {"cached": [], "baseline": []} n_avg = 10 for key, value in result.items(): cache = None print(key) for i in range(xlen): x = xs[:, :i + 1] m = mask[:, :i + 1, :i + 1] start = time() for _ in range(n_avg): with torch.no_grad(): if key == "baseline": cache = None if model == "decoder": y, new_cache = decoder.forward_one_step(x, m, memory, cache=cache) else: y, _, new_cache = encoder.forward_one_step(x, m, cache=cache) if key == "cached": cache = new_cache dur = (time() - start) / n_avg value.append(dur) plt.plot(range(xlen), value, label=key) plt.xlabel("hypothesis length") plt.ylabel("average time [sec]") plt.grid() plt.legend() plt.savefig(f"benchmark_{model}.png")
class TransformerLM(nn.Module, LMInterface): """Transformer language model.""" @staticmethod def add_arguments(parser): """Add arguments to command line argument parser.""" parser.add_argument('--layer', type=int, default=4, help='Number of hidden layers') parser.add_argument('--unit', type=int, default=1024, help='Number of hidden units in feedforward layer') parser.add_argument('--att-unit', type=int, default=256, help='Number of hidden units in attention layer') parser.add_argument('--embed-unit', type=int, default=128, help='Number of hidden units in embedding layer') parser.add_argument('--head', type=int, default=2, help='Number of multi head attention') parser.add_argument('--dropout-rate', type=float, default=0.5, help='dropout probability') parser.add_argument('--pos-enc', default="sinusoidal", choices=["sinusoidal", "none"], help='positional encoding') return parser def __init__(self, n_vocab, args): """Initialize class. Args: n_vocab (int): The size of the vocabulary args (argparse.Namespace): configurations. see py:method:`add_arguments` """ nn.Module.__init__(self) if args.pos_enc == "sinusoidal": pos_enc_class = PositionalEncoding elif args.pos_enc == "none": def pos_enc_class(*args, **kwargs): return nn.Sequential() # indentity else: raise ValueError(f"unknown pos-enc option: {args.pos_enc}") self.embed = nn.Embedding(n_vocab, args.embed_unit) self.encoder = Encoder(idim=args.embed_unit, attention_dim=args.att_unit, attention_heads=args.head, linear_units=args.unit, num_blocks=args.layer, dropout_rate=args.dropout_rate, input_layer="linear", pos_enc_class=pos_enc_class) self.decoder = nn.Linear(args.att_unit, n_vocab) def _target_mask(self, ys_in_pad): ys_mask = ys_in_pad != 0 m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) return ys_mask.unsqueeze(-2) & m def forward( self, x: torch.Tensor, t: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute LM loss value from buffer sequences. Args: x (torch.Tensor): Input ids. (batch, len) t (torch.Tensor): Target ids. (batch, len) Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of loss to backward (scalar), negative log-likelihood of t: -log p(t) (scalar) and the number of elements in x (scalar) Notes: The last two return values are used in perplexity: p(t)^{-n} = exp(-log p(t) / n) """ xm = (x != 0) h, _ = self.encoder(self.embed(x), self._target_mask(x)) y = self.decoder(h) loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") mask = xm.to(dtype=loss.dtype) logp = loss * mask.view(-1) logp = logp.sum() count = mask.sum() return logp / count, logp, count def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: """Score new token. Args: y (torch.Tensor): 1D torch.int64 prefix tokens. state: Scorer state for prefix tokens x (torch.Tensor): encoder feature that generates ys. Returns: tuple[torch.Tensor, Any]: Tuple of torch.float32 scores for next token (n_vocab) and next state for ys """ y = y.unsqueeze(0) h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state) h = self.decoder(h[:, -1]) logp = h.log_softmax(dim=-1).squeeze(0) return logp, cache
class TransformerLM(AbsLM): def __init__( self, vocab_size: int, pos_enc: str = None, embed_unit: int = 128, att_unit: int = 256, head: int = 2, unit: int = 1024, layer: int = 4, dropout_rate: float = 0.5, ): super().__init__() if pos_enc == "sinusoidal": pos_enc_class = PositionalEncoding elif pos_enc is None: def pos_enc_class(*args, **kwargs): return nn.Sequential() # indentity else: raise ValueError(f"unknown pos-enc option: {pos_enc}") self.embed = nn.Embedding(vocab_size, embed_unit) self.encoder = Encoder( idim=embed_unit, attention_dim=att_unit, attention_heads=head, linear_units=unit, num_blocks=layer, dropout_rate=dropout_rate, input_layer="linear", pos_enc_class=pos_enc_class, ) self.decoder = nn.Linear(att_unit, vocab_size) def _target_mask(self, ys_in_pad): ys_mask = ys_in_pad != 0 m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) return ys_mask.unsqueeze(-2) & m def forward(self, input: torch.Tensor, hidden: None) -> Tuple[torch.Tensor, None]: """Compute LM loss value from buffer sequences. Args: input (torch.Tensor): Input ids. (batch, len) hidden (torch.Tensor): Target ids. (batch, len) """ x = self.embed(input) mask = self._target_mask(input) h, _ = self.encoder(x, mask) y = self.decoder(h) return y, None def score( self, y: torch.Tensor, state: Any, x: torch.Tensor ) -> Tuple[torch.Tensor, Any]: """Score new token. Args: y (torch.Tensor): 1D torch.int64 prefix tokens. state: Scorer state for prefix tokens x (torch.Tensor): encoder feature that generates ys. Returns: tuple[torch.Tensor, Any]: Tuple of torch.float32 scores for next token (vocab_size) and next state for ys """ y = y.unsqueeze(0) h, _, cache = self.encoder.forward_one_step( self.embed(y), self._target_mask(y), cache=state ) h = self.decoder(h[:, -1]) logp = h.log_softmax(dim=-1).squeeze(0) return logp, cache def batch_score( self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor ) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch. Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, vocab_size)` and next state list for ys. """ # merge states n_batch = len(ys) n_layers = len(self.encoder.encoders) if states[0] is None: batch_state = None else: # transpose state of [batch, layer] into [layer, batch] batch_state = [ torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers) ] # batch decoding h, _, states = self.encoder.forward_one_step( self.embed(ys), self._target_mask(ys), cache=batch_state ) h = self.decoder(h[:, -1]) logp = h.log_softmax(dim=-1) # transpose state of [layer, batch] into [batch, layer] state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] return logp, state_list
class TransformerLM(nn.Module, LMInterface, BatchScorerInterface): """Transformer language model.""" @staticmethod def add_arguments(parser): """Add arguments to command line argument parser.""" parser.add_argument( "--layer", type=int, default=4, help="Number of hidden layers" ) parser.add_argument( "--unit", type=int, default=1024, help="Number of hidden units in feedforward layer", ) parser.add_argument( "--att-unit", type=int, default=256, help="Number of hidden units in attention layer", ) parser.add_argument( "--embed-unit", type=int, default=128, help="Number of hidden units in embedding layer", ) parser.add_argument( "--head", type=int, default=2, help="Number of multi head attention" ) parser.add_argument( "--dropout-rate", type=float, default=0.5, help="dropout probability" ) parser.add_argument( "--pos-enc", default="sinusoidal", choices=["sinusoidal", "none"], help="positional encoding", ) return parser def __init__(self, n_vocab, args): """Initialize class. Args: n_vocab (int): The size of the vocabulary args (argparse.Namespace): configurations. see py:method:`add_arguments` """ nn.Module.__init__(self) if args.pos_enc == "sinusoidal": pos_enc_class = PositionalEncoding elif args.pos_enc == "none": def pos_enc_class(*args, **kwargs): return nn.Sequential() # indentity else: raise ValueError(f"unknown pos-enc option: {args.pos_enc}") self.embed = nn.Embedding(n_vocab, args.embed_unit) self.encoder = Encoder( idim=args.embed_unit, attention_dim=args.att_unit, attention_heads=args.head, linear_units=args.unit, num_blocks=args.layer, dropout_rate=args.dropout_rate, input_layer="linear", pos_enc_class=pos_enc_class, ) self.decoder = nn.Linear(args.att_unit, n_vocab) def _target_mask(self, ys_in_pad): ys_mask = ys_in_pad != 0 m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) return ys_mask.unsqueeze(-2) & m def forward( self, x: torch.Tensor, t: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute LM loss value from buffer sequences. Args: x (torch.Tensor): Input ids. (batch, len) t (torch.Tensor): Target ids. (batch, len) Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of loss to backward (scalar), negative log-likelihood of t: -log p(t) (scalar) and the number of elements in x (scalar) Notes: The last two return values are used in perplexity: p(t)^{-n} = exp(-log p(t) / n) """ xm = x != 0 h, _ = self.encoder(self.embed(x), self._target_mask(x)) y = self.decoder(h) loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") mask = xm.to(dtype=loss.dtype) logp = loss * mask.view(-1) logp = logp.sum() count = mask.sum() return logp / count, logp, count def score( self, y: torch.Tensor, state: Any, x: torch.Tensor ) -> Tuple[torch.Tensor, Any]: """Score new token. Args: y (torch.Tensor): 1D torch.int64 prefix tokens. state: Scorer state for prefix tokens x (torch.Tensor): encoder feature that generates ys. Returns: tuple[torch.Tensor, Any]: Tuple of torch.float32 scores for next token (n_vocab) and next state for ys """ y = y.unsqueeze(0) h, _, cache = self.encoder.forward_one_step( self.embed(y), self._target_mask(y), cache=state ) h = self.decoder(h[:, -1]) logp = h.log_softmax(dim=-1).squeeze(0) return logp, cache # batch beam search API (see BatchScorerInterface) def batch_score( self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor ) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch (required). Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ # merge states n_batch = len(ys) n_layers = len(self.encoder.encoders) if states[0] is None: batch_state = None else: # transpose state of [batch, layer] into [layer, batch] batch_state = [ torch.stack([states[b][l] for b in range(n_batch)]) for l in range(n_layers) ] # batch decoding h, _, states = self.encoder.forward_one_step( self.embed(ys), self._target_mask(ys), cache=batch_state ) h = self.decoder(h[:, -1]) logp = h.log_softmax(dim=-1) # transpose state of [layer, batch] into [batch, layer] state_list = [[states[l][b] for l in range(n_layers)] for b in range(n_batch)] return logp, state_list