Example #1
0
    def __init__(self, idim, args):
        super(Encoder, self).__init__()
        if args.transformer_input_layer == "linear":
            self.input_layer = torch.nn.Sequential(
                torch.nn.Linear(idim, args.adim),
                torch.nn.LayerNorm(args.adim),
                torch.nn.Dropout(args.dropout_rate), torch.nn.ReLU(),
                PositionalEncoding(args.adim, args.dropout_rate))
        elif args.transformer_input_layer == "conv2d":
            self.input_layer = Conv2dSubsampling(idim, args.adim,
                                                 args.dropout_rate)
        elif args.transformer_input_layer == "embed":
            self.input_layer = torch.nn.Sequential(
                torch.nn.Embedding(idim, args.adim),
                PositionalEncoding(args.adim, args.dropout_rate))
        else:
            raise ValueError("unknown input_layer: " +
                             args.transformer_input_layer)

        self.encoders = repeat(
            args.elayers, lambda: EncoderLayer(
                args.adim,
                MultiHeadedAttention(args.aheads, args.adim, args.
                                     transformer_attn_dropout_rate),
                PositionwiseFeedForward(args.adim, args.eunits, args.
                                        dropout_rate), args.dropout_rate))
        self.norm = LayerNorm(args.adim)
Example #2
0
def test_compatibility():
    """Regression test for #1121"""
    x = torch.rand(2, 3, 4)

    legacy_net = torch.nn.Sequential(
        LegacyPositionalEncoding(4, 0.0), torch.nn.Linear(4, 2)
    )

    latest_net = torch.nn.Sequential(PositionalEncoding(4, 0.0), torch.nn.Linear(4, 2))

    latest_net.load_state_dict(legacy_net.state_dict())
    legacy = legacy_net(x)
    latest = latest_net(x)
    assert torch.allclose(legacy, latest)

    legacy_net = torch.nn.Sequential(
        LegacyScaledPositionalEncoding(4, 0.0), torch.nn.Linear(4, 2)
    )

    latest_net = torch.nn.Sequential(
        ScaledPositionalEncoding(4, 0.0), torch.nn.Linear(4, 2)
    )

    latest_net.load_state_dict(legacy_net.state_dict())
    legacy = legacy_net(x)
    latest = latest_net(x)
    assert torch.allclose(legacy, latest)
Example #3
0
    def __init__(self, n_vocab, args):
        """Initialize class.

        Args:
            n_vocab (int): The size of the vocabulary
            args (argparse.Namespace): configurations. see py:method:`add_arguments`

        """
        nn.Module.__init__(self)
        self.model_type = 'Transformer'
        self.src_mask = None
        self.encoder = Encoder(n_vocab,
                               args.att_unit,
                               args.head,
                               args.unit,
                               args.layer,
                               args.dropout_rate,
                               args.dropout_rate,
                               args.dropout_rate,
                               input_layer="embed")
        # reset posenc
        self.encoder.embed[1] = PositionalEncoding(args.att_unit,
                                                   args.dropout_rate,
                                                   args.posenc_len)
        self.decoder = nn.Linear(args.att_unit, n_vocab)
Example #4
0
 def __init__(self, idim, odim, dropout_rate):
     super(Conv2dSubsampling, self).__init__()
     self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, odim, 3, 2),
                                     torch.nn.ReLU(),
                                     torch.nn.Conv2d(odim, odim, 3, 2),
                                     torch.nn.ReLU())
     self.out = torch.nn.Sequential(
         torch.nn.Linear(odim * ((idim - 3) // 4), odim),
         PositionalEncoding(odim, dropout_rate))
Example #5
0
 def __init__(self, idim, odim, dropout_rate):
     """Construct an Conv2dSubsampling object."""
     super(Conv2dSubsampling, self).__init__()
     self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, odim, 3, 2),
                                     torch.nn.ReLU(),
                                     torch.nn.Conv2d(odim, odim, 3, 2),
                                     torch.nn.ReLU())
     self.out = torch.nn.Sequential(
         torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
         PositionalEncoding(odim, dropout_rate))
Example #6
0
 def __init__(self, idim, odim, dropout_rate, pos_enc=None):
     """Construct an Conv2dSubsampling6 object."""
     super(Conv2dSubsampling6, self).__init__()
     self.conv = torch.nn.Sequential(
         torch.nn.Conv2d(1, odim, 3, 2),
         torch.nn.ReLU(),
         torch.nn.Conv2d(odim, odim, 5, 3),
         torch.nn.ReLU(),
     )
     self.out = torch.nn.Sequential(
         torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
         pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
     )
Example #7
0
    def __init__(self, odim, dim):
        super(DecoderConv1dPosition, self).__init__()

        self.embed = torch.nn.Embedding(odim, 512)
        self.position = PositionalEncoding(odim, 1)
        self.conv1 = torch.nn.Conv1d(512, 512, 3, stride=1)
        self.norm1 = torch.nn.LayerNorm(512)
        self.conv2 = torch.nn.Conv1d(512, 512, 3, stride=1)
        self.norm2 = torch.nn.LayerNorm(512)
        self.conv3 = torch.nn.Conv1d(512, 512, 3, stride=1)
        self.norm3 = torch.nn.LayerNorm(512)

        self.out = torch.nn.Linear(512, dim)
Example #8
0
    def __init__(self,
                 n_head,
                 d_model,
                 d_head,
                 pos_ff,
                 att_type,
                 dropout,
                 dropatt,
                 pre_lnorm,
                 tgt_len=None,
                 ext_len=0,
                 mem_len=0,
                 future_len=0,
                 rel_pos=True):
        super(EncoderLayer, self).__init__()
        self.register_buffer('mems', None)
        self.n_head = n_head
        self.d_head = d_head
        self.d_model = d_model
        self.mem_len = mem_len
        self.rel_pos = rel_pos
        self.future_len = future_len
        self.tgt_len = tgt_len
        self.att = MultiHeadedAttention(n_head, d_model, dropatt)
        if att_type == "mta":
            self.att = MultiHeadedAttention(n_head, d_model, dropatt)
        elif att_type == "win":
            self.att = WinMultiHeadedAttention(n_head, d_model, dropatt)
        elif att_type == "smooth":
            self.att = SmoothMultiHeadedAttention(n_head, d_model, dropatt)
        elif att_type == "rel":
            self.att = RelMultiHeadedAttention(n_head, d_model, dropatt)
        else:
            raise ValueError("unknown attention type: " + att_type)

        self.layer = CashEncoderLayer(d_model,
                                      self.att,
                                      pos_ff,
                                      dropout,
                                      pre_lnorm,
                                      concat_after=False)

        self.drop = nn.Dropout(dropout)
        self.ext_len = ext_len
        self.rel_pos = rel_pos
        if rel_pos:
            self.re_pos_embed = PositionalEncoding(self.d_model, dropout)
        else:
            self.re_pos_embed = None
Example #9
0
 def __init__(self, odim, args):
     super(Decoder, self).__init__()
     self.embed = torch.nn.Sequential(
         torch.nn.Embedding(odim, args.adim),
         PositionalEncoding(args.adim, args.dropout_rate)
     )
     self.decoders = repeat(
         args.dlayers,
         lambda: DecoderLayer(
             args.adim,
             MultiHeadedAttention(args.aheads, args.adim, args.transformer_attn_dropout_rate),
             MultiHeadedAttention(args.aheads, args.adim, args.transformer_attn_dropout_rate),
             PositionwiseFeedForward(args.adim, args.dunits, args.dropout_rate),
             args.dropout_rate
         )
     )
     self.output_norm = LayerNorm(args.adim)
     self.output_layer = torch.nn.Linear(args.adim, odim)
Example #10
0
def test_pe_extendable(dtype, device):
    if device == "cuda" and not torch.cuda.is_available():
        pytest.skip("no cuda device is available")
    dtype = getattr(torch, dtype)
    dim = 2
    pe = PositionalEncoding(dim, 0.0, 3).to(dtype=dtype, device=device)
    x = torch.rand(2, 3, dim, dtype=dtype, device=device)
    y = pe(x)
    init_cache = pe.pe

    # test not extended from init
    x = torch.rand(2, 3, dim, dtype=dtype, device=device)
    y = pe(x)
    assert pe.pe is init_cache

    x = torch.rand(2, 5, dim, dtype=dtype, device=device)
    y = pe(x)

    sd = pe.state_dict()
    assert len(sd) == 0, "PositionalEncoding should save nothing"
    pe2 = PositionalEncoding(dim, 0.0, 3).to(dtype=dtype, device=device)
    pe2.load_state_dict(sd)
    y2 = pe2(x)
    assert torch.allclose(y, y2)