def __init__(self, idim, args): super(Encoder, self).__init__() if args.transformer_input_layer == "linear": self.input_layer = torch.nn.Sequential( torch.nn.Linear(idim, args.adim), torch.nn.LayerNorm(args.adim), torch.nn.Dropout(args.dropout_rate), torch.nn.ReLU(), PositionalEncoding(args.adim, args.dropout_rate)) elif args.transformer_input_layer == "conv2d": self.input_layer = Conv2dSubsampling(idim, args.adim, args.dropout_rate) elif args.transformer_input_layer == "embed": self.input_layer = torch.nn.Sequential( torch.nn.Embedding(idim, args.adim), PositionalEncoding(args.adim, args.dropout_rate)) else: raise ValueError("unknown input_layer: " + args.transformer_input_layer) self.encoders = repeat( args.elayers, lambda: EncoderLayer( args.adim, MultiHeadedAttention(args.aheads, args.adim, args. transformer_attn_dropout_rate), PositionwiseFeedForward(args.adim, args.eunits, args. dropout_rate), args.dropout_rate)) self.norm = LayerNorm(args.adim)
def test_compatibility(): """Regression test for #1121""" x = torch.rand(2, 3, 4) legacy_net = torch.nn.Sequential( LegacyPositionalEncoding(4, 0.0), torch.nn.Linear(4, 2) ) latest_net = torch.nn.Sequential(PositionalEncoding(4, 0.0), torch.nn.Linear(4, 2)) latest_net.load_state_dict(legacy_net.state_dict()) legacy = legacy_net(x) latest = latest_net(x) assert torch.allclose(legacy, latest) legacy_net = torch.nn.Sequential( LegacyScaledPositionalEncoding(4, 0.0), torch.nn.Linear(4, 2) ) latest_net = torch.nn.Sequential( ScaledPositionalEncoding(4, 0.0), torch.nn.Linear(4, 2) ) latest_net.load_state_dict(legacy_net.state_dict()) legacy = legacy_net(x) latest = latest_net(x) assert torch.allclose(legacy, latest)
def __init__(self, n_vocab, args): """Initialize class. Args: n_vocab (int): The size of the vocabulary args (argparse.Namespace): configurations. see py:method:`add_arguments` """ nn.Module.__init__(self) self.model_type = 'Transformer' self.src_mask = None self.encoder = Encoder(n_vocab, args.att_unit, args.head, args.unit, args.layer, args.dropout_rate, args.dropout_rate, args.dropout_rate, input_layer="embed") # reset posenc self.encoder.embed[1] = PositionalEncoding(args.att_unit, args.dropout_rate, args.posenc_len) self.decoder = nn.Linear(args.att_unit, n_vocab)
def __init__(self, idim, odim, dropout_rate): super(Conv2dSubsampling, self).__init__() self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU()) self.out = torch.nn.Sequential( torch.nn.Linear(odim * ((idim - 3) // 4), odim), PositionalEncoding(odim, dropout_rate))
def __init__(self, idim, odim, dropout_rate): """Construct an Conv2dSubsampling object.""" super(Conv2dSubsampling, self).__init__() self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU()) self.out = torch.nn.Sequential( torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), PositionalEncoding(odim, dropout_rate))
def __init__(self, idim, odim, dropout_rate, pos_enc=None): """Construct an Conv2dSubsampling6 object.""" super(Conv2dSubsampling6, self).__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 5, 3), torch.nn.ReLU(), ) self.out = torch.nn.Sequential( torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim), pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), )
def __init__(self, odim, dim): super(DecoderConv1dPosition, self).__init__() self.embed = torch.nn.Embedding(odim, 512) self.position = PositionalEncoding(odim, 1) self.conv1 = torch.nn.Conv1d(512, 512, 3, stride=1) self.norm1 = torch.nn.LayerNorm(512) self.conv2 = torch.nn.Conv1d(512, 512, 3, stride=1) self.norm2 = torch.nn.LayerNorm(512) self.conv3 = torch.nn.Conv1d(512, 512, 3, stride=1) self.norm3 = torch.nn.LayerNorm(512) self.out = torch.nn.Linear(512, dim)
def __init__(self, n_head, d_model, d_head, pos_ff, att_type, dropout, dropatt, pre_lnorm, tgt_len=None, ext_len=0, mem_len=0, future_len=0, rel_pos=True): super(EncoderLayer, self).__init__() self.register_buffer('mems', None) self.n_head = n_head self.d_head = d_head self.d_model = d_model self.mem_len = mem_len self.rel_pos = rel_pos self.future_len = future_len self.tgt_len = tgt_len self.att = MultiHeadedAttention(n_head, d_model, dropatt) if att_type == "mta": self.att = MultiHeadedAttention(n_head, d_model, dropatt) elif att_type == "win": self.att = WinMultiHeadedAttention(n_head, d_model, dropatt) elif att_type == "smooth": self.att = SmoothMultiHeadedAttention(n_head, d_model, dropatt) elif att_type == "rel": self.att = RelMultiHeadedAttention(n_head, d_model, dropatt) else: raise ValueError("unknown attention type: " + att_type) self.layer = CashEncoderLayer(d_model, self.att, pos_ff, dropout, pre_lnorm, concat_after=False) self.drop = nn.Dropout(dropout) self.ext_len = ext_len self.rel_pos = rel_pos if rel_pos: self.re_pos_embed = PositionalEncoding(self.d_model, dropout) else: self.re_pos_embed = None
def __init__(self, odim, args): super(Decoder, self).__init__() self.embed = torch.nn.Sequential( torch.nn.Embedding(odim, args.adim), PositionalEncoding(args.adim, args.dropout_rate) ) self.decoders = repeat( args.dlayers, lambda: DecoderLayer( args.adim, MultiHeadedAttention(args.aheads, args.adim, args.transformer_attn_dropout_rate), MultiHeadedAttention(args.aheads, args.adim, args.transformer_attn_dropout_rate), PositionwiseFeedForward(args.adim, args.dunits, args.dropout_rate), args.dropout_rate ) ) self.output_norm = LayerNorm(args.adim) self.output_layer = torch.nn.Linear(args.adim, odim)
def test_pe_extendable(dtype, device): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("no cuda device is available") dtype = getattr(torch, dtype) dim = 2 pe = PositionalEncoding(dim, 0.0, 3).to(dtype=dtype, device=device) x = torch.rand(2, 3, dim, dtype=dtype, device=device) y = pe(x) init_cache = pe.pe # test not extended from init x = torch.rand(2, 3, dim, dtype=dtype, device=device) y = pe(x) assert pe.pe is init_cache x = torch.rand(2, 5, dim, dtype=dtype, device=device) y = pe(x) sd = pe.state_dict() assert len(sd) == 0, "PositionalEncoding should save nothing" pe2 = PositionalEncoding(dim, 0.0, 3).to(dtype=dtype, device=device) pe2.load_state_dict(sd) y2 = pe2(x) assert torch.allclose(y, y2)