コード例 #1
0
    def __init__(self, config, dataset):
        ''' Initialize the Transformer '''
        super(NewTransformer, self).__init__()

        self.dataset = dataset
        self.embedding = TokenEmbedding(
            dataset.vocab_size,
            config.embedding_size,
            padding_idx=self.padding_idx
        )
        self.position_embedding = PositionEmbedding(config.embedding_size)
        self.dropout = nn.Dropout(config.dropout_p, inplace=True)

        # Uniq attn attributes
        self.attn_ofs_uniq = list(set(
            config.enc_attn_offset + config.dec_attn_offset + config.enc_dec_attn_offset))
        self.attn_std_uniq = list(set(
            config.enc_attn_std + config.dec_attn_std + config.enc_dec_attn_std))

        # Allow for overriding the encoders and decoders in dervied classes
        self.encoders = self.create_encoders(config)
        self.decoders = self.create_decoders(config)

        self.label_smoothing = LabelSmoothingLoss(
            config.label_smoothing or 0,
            ignore_index=self.padding_idx,
            reduction='none'
        )
        self.cross_entropy = nn.CrossEntropyLoss(
            ignore_index=self.padding_idx,
            reduction='none'
        )
コード例 #2
0
    def __init__(self, config, dataset):
        ''' Initialize the Transformer '''
        super(ProbeNewTransformer, self).__init__()

        self.dataset = dataset
        self.span = config.span
        self.embedding = TokenEmbedding(
            dataset.vocab_size,
            config.embedding_size,
            padding_idx=self.padding_idx
        )
        self.position_embedding = PositionEmbedding(config.embedding_size)
        self.dropout = nn.Dropout(config.dropout_p, inplace=True)

        # Allow for overriding the encoders and decoders in dervied classes
        self.encoders = type(self).create_encoders(config)
        self.decoders = self.create_decoders(config)

        self.label_smoothing = LabelSmoothingLoss(
            config.label_smoothing or 0,
            ignore_index=self.padding_idx,
            reduction='none'
        )
        self.cross_entropy = nn.CrossEntropyLoss(
            ignore_index=self.padding_idx,
            reduction='none'
        )
コード例 #3
0
    def __init__(self, config, dataset):
        ''' Initialize'''
        super(NPLM, self).__init__()

        self.dataset = dataset
        
        self.adaptive = config.adaptive
        # ngm: n tokens that concat with full emb
        # wsz: window size to average for long term context
        self.ngm, self.wsz = config.context_config                  
        self.long_term_block = 0 if self.ngm > 0 and self.wsz == -1 else \
                                    (config.batch_length - self.ngm) // self.wsz

        self.dim_concat_embs = self.ngm * config.embedding_size + self.long_term_block * config.embedding_size

        self.embedding = TokenEmbedding(
                dataset.vocab_size,
                config.embedding_size,
                config.model_size, 
                config.cutoffs,
                emb_std=config.emb_std,
                proj_std = config.proj_std,
                div_val=config.div_val,
                padding_idx=self.padding_idx,
                do_proj=config.do_proj
            )

        if self.adaptive:
            self.adaptive_softmax = AdaptiveSoftmax(self.dataset.vocab_size, config.embedding_size, config.embedding_size, 
                                                    config.cutoffs, div_val=config.div_val)

            self.tie_weights = config.tie_weights
            self.tie_projs = config.tie_projs

            if self.tie_weights:
                for i in range(len(self.adaptive_softmax.out_layers)):
                    self.adaptive_softmax.out_layers[i].weight = self.embedding.emb_layers[i].weight

            if self.tie_projs:
                for i in range(1, len(self.adaptive_softmax.out_projs)):
                    if config.div_val == 1 and config.model_size != config.embedding_size:
                        self.adaptive_softmax.out_projs[i] = self.embedding.emb_projs[0]
                    elif config.div_val != 1:
                        self.adaptive_softmax.out_projs[i] = self.embedding.emb_projs[i]

        self.layers = self.create_layers(config)
        self.position_embedding = PositionEmbedding(config.model_size) # only used in transformer-N
        self.label_smoothing = LabelSmoothingLoss(
            config.label_smoothing or 0,
            ignore_index=self.padding_idx,
            reduction='none'
        )
        self.cross_entropy = nn.CrossEntropyLoss(
            ignore_index=self.padding_idx,
            reduction='none'
        )

        self.dropout = nn.Dropout(config.dropout_p, inplace=True)

        self.config = config
コード例 #4
0
ファイル: transformer.py プロジェクト: SimengSun/revisit-nplm
    def __init__(self, config, dataset):
        ''' Initialize the Transformer '''
        super(Transformer, self).__init__()

        self.dataset = dataset
        self.config = config

        self.adaptive = config.adaptive

        self.embedding = TokenEmbedding(dataset.vocab_size,
                                        config.embedding_size,
                                        config.model_size,
                                        config.cutoffs,
                                        emb_std=config.emb_std,
                                        proj_std=config.proj_std,
                                        div_val=config.div_val,
                                        padding_idx=self.padding_idx,
                                        do_proj=config.do_proj)

        if self.adaptive:
            self.adaptive_softmax = AdaptiveSoftmax(self.dataset.vocab_size,
                                                    config.embedding_size,
                                                    config.model_size,
                                                    config.cutoffs,
                                                    div_val=config.div_val)

            self.tie_weights = config.tie_weights
            self.tie_projs = config.tie_projs

            if self.tie_weights:
                for i in range(len(self.adaptive_softmax.out_layers)):
                    self.adaptive_softmax.out_layers[
                        i].weight = self.embedding.emb_layers[i].weight

            if self.tie_projs:
                for i in range(1, len(self.adaptive_softmax.out_projs)):
                    if config.div_val == 1 and config.model_size != config.embedding_size:
                        self.adaptive_softmax.out_projs[
                            i] = self.embedding.emb_projs[0]
                    elif config.div_val != 1:
                        self.adaptive_softmax.out_projs[
                            i] = self.embedding.emb_projs[i]

        self.position_embedding = PositionEmbedding(config.embedding_size)
        self.dropout = nn.Dropout(config.dropout_p, inplace=True)

        if len(config.no_attention) == 1:
            config.no_attention = config.no_attention * config.num_layers
        assert len(config.no_attention) == config.num_layers

        self.layers = self.create_layers(config)

        self.label_smoothing = LabelSmoothingLoss(
            config.label_smoothing or 0,
            ignore_index=self.padding_idx,
            reduction='none')
        self.cross_entropy = nn.CrossEntropyLoss(ignore_index=self.padding_idx,
                                                 reduction='none')
    def __init__(self, config, dataset):
        ''' Initialize the Transformer '''
        super(InterleaveFixedPosEmbEncoderOnlyTransformer, self).__init__()

        self.dataset = dataset
        self.embedding = TokenEmbedding(dataset.vocab_size,
                                        config.embedding_size,
                                        padding_idx=self.padding_idx)
        self.position_embedding = PositionEmbedding(config.embedding_size)
        self.num_layers = config.num_layers

        encoder_positional_embedding_list = []
        for i in range(self.num_layers // 2):
            position_embedding_encoder = LearnedPositionalEmbedding(
                dataset.max_input_length, config.embedding_size,
                self.padding_idx)
            nn.init.normal_(position_embedding_encoder.weight,
                            mean=0,
                            std=config.embedding_size**-0.5)
            if self.padding_idx is not None:
                nn.init.constant_(
                    position_embedding_encoder.weight[self.padding_idx], 0)
            encoder_positional_embedding_list.append(
                position_embedding_encoder)

        self.encoder_positional_embeddings = nn.ModuleList(
            encoder_positional_embedding_list)

        self.position_embedding_decoder = LearnedPositionalEmbedding(
            dataset.max_target_length, config.embedding_size, self.padding_idx)
        nn.init.normal_(self.position_embedding_decoder.weight,
                        mean=0,
                        std=config.embedding_size**-0.5)
        if self.padding_idx is not None:
            nn.init.constant_(
                self.position_embedding_decoder.weight[self.padding_idx], 0)

        self.dropout = nn.Dropout(config.dropout_p, inplace=True)

        # Uniq attn attributes
        self.attn_ofs_uniq = list(
            set(config.enc_attn_offset + config.dec_attn_offset +
                config.enc_dec_attn_offset))
        self.attn_std_uniq = list(
            set(config.enc_attn_std + config.dec_attn_std +
                config.enc_dec_attn_std))

        # Allow for overriding the encoders and decoders in dervied classes
        self.encoders = self.create_encoders(config)
        self.decoders = self.create_decoders(config)

        self.label_smoothing = LabelSmoothingLoss(
            config.label_smoothing or 0,
            ignore_index=self.padding_idx,
            reduction='none')
        self.cross_entropy = nn.CrossEntropyLoss(ignore_index=self.padding_idx,
                                                 reduction='none')
コード例 #6
0
    def __init__(self, config, dataset):
        ''' Initialize the ParseTransformer '''
        super(ParseTransformer, self).__init__(config, dataset)

        self.span = 1
        args = [config.num_heads, config.embedding_size, config.hidden_dim]
        self.annotation_decoders = nn.ModuleList([
            TransformerDecoderLayer(*args, dropout_p=config.dropout_p)
            for _ in range(config.parse_num_layers)
        ])

        self.annotation_embedding = TokenEmbedding(
            dataset.annotation_vocab_size,
            config.embedding_size,
            padding_idx=self.annotation_padding_idx)
        self.annotation_cross_entropy = nn.CrossEntropyLoss(
            ignore_index=self.annotation_padding_idx, reduction='none')
コード例 #7
0
class Transformer(nn.Module):
    ''' The Transformer module '''
    def __init__(self, config, dataset):
        ''' Initialize the Transformer '''
        super(Transformer, self).__init__()

        self.dataset = dataset
        self.span = config.span
        self.embedding = TokenEmbedding(dataset.vocab_size,
                                        config.embedding_size,
                                        padding_idx=self.padding_idx)
        self.position_embedding = PositionEmbedding(config.embedding_size)
        self.dropout = nn.Dropout(config.dropout_p, inplace=True)

        # Allow for overriding the encoders and decoders in dervied classes
        self.encoders = type(self).create_encoders(config)
        self.decoders = type(self).create_decoders(config)

        self.label_smoothing = LabelSmoothingLoss(
            config.label_smoothing or 0,
            ignore_index=self.padding_idx,
            reduction='none')
        self.cross_entropy = nn.CrossEntropyLoss(ignore_index=self.padding_idx,
                                                 reduction='none')

    @classmethod
    def create_encoders(cls, config):
        ''' Create the transformer encoders '''
        kwargs = {'dropout_p': config.dropout_p}
        args = [config.num_heads, config.embedding_size, config.hidden_dim]
        return nn.ModuleList([
            TransformerEncoderLayer(*args, **kwargs)
            for _ in range(config.num_layers)
        ])

    @classmethod
    def create_decoders(cls, config):
        ''' Create the transformer decoders '''
        kwargs = {'dropout_p': config.dropout_p, 'span': config.span}
        args = [config.num_heads, config.embedding_size, config.hidden_dim]
        return nn.ModuleList([
            TransformerDecoderLayer(*args, **kwargs)
            for _ in range(config.num_layers)
        ])

    @property
    def sos_idx(self):
        ''' Return the sos index '''
        return self.dataset.sos_idx

    @property
    def padding_idx(self):
        ''' Return the padding index '''
        return self.dataset.padding_idx

    def translator(self, config):
        ''' Get a translator for this model '''
        return Translator(config, self, self.dataset)

    def reset_named_parameters(self, modules):
        ''' Get a translator for this model '''
        if 'encoder' in modules:
            for encoder in self.encoders:
                encoder.reset_parameters()
        if 'decoder' in modules:
            for decoder in self.decoders:
                decoder.reset_parameters()
        if 'embeddings' in modules:
            self.embedding.reset_parameters()

    def forward(self, batch):  # pylint:disable=arguments-differ
        ''' A batch of inputs and targets '''
        decoded = self.decode(
            self.encode(batch['inputs']),
            right_shift(right_shift(batch['targets']),
                        shift=self.span - 1,
                        fill=self.sos_idx),
        )

        logits = decoded['logits']
        dims = list(range(1, logits.dim()))
        targets = left_shift(batch['targets'])
        nll = self.cross_entropy(logits, targets).sum(dims[:-1])
        smoothed_nll = self.label_smoothing(logits, targets).sum(dims)

        return smoothed_nll, nll

    def encode(self, inputs):
        ''' Encode the inputs '''
        encoded = {
            'state': self.embed(inputs, self.embedding),
            'mask': inputs.eq(self.padding_idx)
        }
        for encoder in self.encoders:
            encoded = encoder(encoded)

        return encoded

    def decode(self,
               encoded,
               targets,
               decoders=None,
               embedding=None,
               cache=None,
               mask=None):
        ''' Decode the encoded sequence to the targets '''
        if decoders is None:
            decoders = self.decoders

        if embedding is None:
            embedding = self.embedding

        decoded = {
            'cache': cache,
            'state': self.embed(targets, embedding),
            'mask': targets.eq(self.padding_idx) if mask is None else mask
        }
        for decoder in decoders:
            decoded = decoder(decoded, encoded)

        # compute projection to the vocabulary
        state = decoded['state']
        if cache is not None:
            state = state[:, -self.span:]

        return {
            'cache': decoded.get('cache'),
            'logits':
            embedding(state,
                      transpose=True).transpose(2,
                                                1),  # transpose to B x C x ...
        }

    def embed(self, inputs, token_embedding):
        ''' Embed the given inputs '''
        return self.dropout(
            token_embedding(inputs) + self.position_embedding(inputs))
コード例 #8
0
class NewTransformer(nn.Module):
    ''' The New Transformer module '''
    def __init__(self, config, dataset):
        ''' Initialize the Transformer '''
        super(NewTransformer, self).__init__()

        self.dataset = dataset
        self.embedding = TokenEmbedding(
            dataset.vocab_size,
            config.embedding_size,
            padding_idx=self.padding_idx
        )
        self.position_embedding = PositionEmbedding(config.embedding_size)
        self.dropout = nn.Dropout(config.dropout_p, inplace=True)

        # Uniq attn attributes
        self.attn_ofs_uniq = list(set(
            config.enc_attn_offset + config.dec_attn_offset + config.enc_dec_attn_offset))
        self.attn_std_uniq = list(set(
            config.enc_attn_std + config.dec_attn_std + config.enc_dec_attn_std))

        # Allow for overriding the encoders and decoders in dervied classes
        self.encoders = self.create_encoders(config)
        self.decoders = self.create_decoders(config)

        self.label_smoothing = LabelSmoothingLoss(
            config.label_smoothing or 0,
            ignore_index=self.padding_idx,
            reduction='none'
        )
        self.cross_entropy = nn.CrossEntropyLoss(
            ignore_index=self.padding_idx,
            reduction='none'
        )

    def create_encoders(self, config):
        ''' Create the transformer encoders '''
        kwargs = {'dropout_p': config.dropout_p}

        if config.ffn_layer == -1:
            config.ffn_layer = [1] * config.num_layers
        assert len(config.ffn_layer) == config.num_layers

        attn_config = {'attn_type': config.enc_attn_type,
                       'attn_std': config.enc_attn_std,
                       'attn_offset': config.enc_attn_offset,
                       'num_layers': config.num_layers,
                       'num_heads': config.num_heads,
                       'which_attn': 'encoder',
                       'attn_threshold': config.enc_attn_threshold,
                       'attn_window': config.enc_attn_window,
                       'attn_impl': config.enc_attn_impl,
                       'ffn_layer': config.ffn_layer,
                       'attn_ofs_uniq': self.attn_ofs_uniq,
                       'attn_std_uniq': self.attn_std_uniq}
        args = [attn_config, config.num_heads, config.embedding_size, config.hidden_dim]
        encoders = nn.ModuleList([
            TransformerEncoderLayer(*args, layer_i, **kwargs)
            for layer_i in range(config.num_layers)
        ])

        return encoders

    def create_decoders(self, config):
        ''' Create the transformer decoders '''
        kwargs = {'dropout_p': config.dropout_p}

        if config.ffn_layer == -1:
            config.ffn_layer = [1] * config.num_layers
        assert len(config.ffn_layer) == config.num_layers

        dec_attn_config = {'attn_type': config.dec_attn_type,
                           'attn_std': config.dec_attn_std,
                           'attn_offset': config.dec_attn_offset,
                           'num_layers': config.num_layers,
                           'num_heads': config.num_heads,
                           'which_attn': 'decoder',
                           'attn_threshold': config.dec_attn_threshold,
                           'attn_window': config.dec_attn_window,
                           'attn_impl': config.dec_attn_impl,
                           'ffn_layer': config.ffn_layer,
                           'attn_ofs_uniq': self.attn_ofs_uniq,
                           'attn_std_uniq': self.attn_std_uniq
                           }
        enc_dec_attn_config = {'attn_type': config.enc_dec_attn_type,
                               'attn_std': config.enc_dec_attn_std,
                               'attn_offset': config.enc_dec_attn_offset,
                               'num_layers': config.num_layers,
                               'num_heads': config.num_heads,
                               'word_count_ratio': self.dataset.word_count_ratio,
                               'which_attn': 'source',
                               'enc_dec_attn_layer': config.enc_dec_attn_layer,
                               'enc_dec_attn_num_heads': config.enc_dec_attn_num_heads,
                               'attn_threshold': config.enc_dec_attn_threshold,
                               'attn_window': config.enc_dec_attn_window,
                               'attn_impl': config.enc_dec_attn_impl,
                               'ffn_layer': config.ffn_layer,
                               'attn_ofs_uniq': self.attn_ofs_uniq,
                               'attn_std_uniq': self.attn_std_uniq
                               }
        args = [dec_attn_config, enc_dec_attn_config, config.num_heads, config.embedding_size, config.hidden_dim]
        decoders = nn.ModuleList([
            TransformerDecoderLayer(*args, layer_i, **kwargs)
            for layer_i in range(config.num_layers)
        ])

        return decoders


    @property
    def sos_idx(self):
        ''' Return the sos index '''
        return self.dataset.sos_idx

    @property
    def padding_idx(self):
        ''' Return the padding index '''
        return self.dataset.padding_idx

    def translator(self, config):
        ''' Get a translator for this model '''
        return Translator(config, self, self.dataset)

    def reset_named_parameters(self, modules):
        ''' Get a translator for this model '''
        if 'encoder' in modules:
            for encoder in self.encoders:
                encoder.reset_parameters()
        if 'decoder' in modules:
            for decoder in self.decoders:
                decoder.reset_parameters()
        if 'embeddings' in modules:
            self.embedding.reset_parameters()

    def forward(self, batch): # pylint:disable=arguments-differ
        ''' A batch of inputs and targets '''
        decoded = self.decode(
            self.encode(batch['inputs']),
            right_shift(batch['targets']),
            input_lens=batch['input_lens']
        )

        logits = decoded['logits']
        dims = list(range(1, logits.dim()))
        targets = left_shift(batch['targets'])
        nll = self.cross_entropy(logits, targets).sum(dims[:-1])
        smoothed_nll = self.label_smoothing(logits, targets).sum(dims)
        return smoothed_nll, nll

    def encode(self, inputs):
        ''' Encode the inputs '''
        word_embedding = self.embed(inputs, self.embedding)
        encoded = {
            'state': word_embedding,
            'mask': inputs.eq(self.padding_idx)
        }
        for i, encoder in enumerate(self.encoders):
            encoded = encoder(encoded, i)

        return encoded

    def decode(self, encoded, targets, decoders=None, embedding=None, cache=None, mask=None, input_lens=None):
        ''' Decode the encoded sequence to the targets '''
        if decoders is None:
            decoders = self.decoders

        if embedding is None:
            embedding = self.embedding

        word_embedding = self.embed(targets, embedding)

        decoded = {
            'cache': cache,
            'state': word_embedding,
            'mask': targets.eq(self.padding_idx) if mask is None else mask
        }
        for i, decoder in enumerate(decoders):
            # print("i", i)
            decoded = decoder(decoded, encoded, i)

        # compute projection to the vocabulary
        state = decoded['state']
        if cache is not None:
            state = state[:, -1:]

        return {
            'cache': decoded.get('cache'),
            'logits': embedding(state, transpose=True).transpose(2, 1),  # transpose to B x C x ...
        }

    def embed(self, inputs, token_embedding):
        ''' Embed the given inputs '''
        return self.dropout(token_embedding(inputs) + self.position_embedding(inputs))
コード例 #9
0
class NPLM(nn.Module):
    ''' The neural proababilistic LM module '''
    def __init__(self, config, dataset):
        ''' Initialize'''
        super(NPLM, self).__init__()

        self.dataset = dataset
        
        self.adaptive = config.adaptive
        # ngm: n tokens that concat with full emb
        # wsz: window size to average for long term context
        self.ngm, self.wsz = config.context_config                  
        self.long_term_block = 0 if self.ngm > 0 and self.wsz == -1 else \
                                    (config.batch_length - self.ngm) // self.wsz

        self.dim_concat_embs = self.ngm * config.embedding_size + self.long_term_block * config.embedding_size

        self.embedding = TokenEmbedding(
                dataset.vocab_size,
                config.embedding_size,
                config.model_size, 
                config.cutoffs,
                emb_std=config.emb_std,
                proj_std = config.proj_std,
                div_val=config.div_val,
                padding_idx=self.padding_idx,
                do_proj=config.do_proj
            )

        if self.adaptive:
            self.adaptive_softmax = AdaptiveSoftmax(self.dataset.vocab_size, config.embedding_size, config.embedding_size, 
                                                    config.cutoffs, div_val=config.div_val)

            self.tie_weights = config.tie_weights
            self.tie_projs = config.tie_projs

            if self.tie_weights:
                for i in range(len(self.adaptive_softmax.out_layers)):
                    self.adaptive_softmax.out_layers[i].weight = self.embedding.emb_layers[i].weight

            if self.tie_projs:
                for i in range(1, len(self.adaptive_softmax.out_projs)):
                    if config.div_val == 1 and config.model_size != config.embedding_size:
                        self.adaptive_softmax.out_projs[i] = self.embedding.emb_projs[0]
                    elif config.div_val != 1:
                        self.adaptive_softmax.out_projs[i] = self.embedding.emb_projs[i]

        self.layers = self.create_layers(config)
        self.position_embedding = PositionEmbedding(config.model_size) # only used in transformer-N
        self.label_smoothing = LabelSmoothingLoss(
            config.label_smoothing or 0,
            ignore_index=self.padding_idx,
            reduction='none'
        )
        self.cross_entropy = nn.CrossEntropyLoss(
            ignore_index=self.padding_idx,
            reduction='none'
        )

        self.dropout = nn.Dropout(config.dropout_p, inplace=True)

        self.config = config


    @classmethod
    def create_layers(self, config):
        ''' Create the NPLM decoders '''
        kwargs = {'dropout_p': config.dropout_p}                    # sublayer kwargs

        args = [config, config.num_heads, config.embedding_size, config.hidden_dim]

        layers = nn.ModuleList([
            NPLMLayer(*args, layer_i, **kwargs)
            for layer_i in range(config.num_layers)
        ])

        return layers

    @property
    def padding_idx(self):
        return self.dataset.padding_idx

    @property
    def eos_idx(self):
        return  self.dataset.eos_idx

    def reset_named_parameters(self, modules):

        if 'layers' in modules:
            for layer in self.layers:
                layer.reset_parameters()

        if 'embeddings' in modules:
            self.embedding.reset_parameters()

    def forward(self, batch): # pylint:disable=arguments-differ

        batch = batch.t()
        targets = left_shift(batch)
        decoded = self.decode(right_shift(batch))

        state = decoded['state']

        if not self.adaptive:
            logits = self.embedding(state, reverse=True).transpose(2, 1)
            dims = list(range(1, logits.dim()))
            nll = self.cross_entropy(logits, targets).view(-1)
            smoothed_nll = self.label_smoothing(logits, targets).sum(dims)

            if not self.config.return_rank:
                return smoothed_nll, nll

            else:
                logits = logits.transpose(2, 1)
                assert targets.shape[0] == 1
                targets = targets.squeeze(0)
                target_logits = logits[:, range(targets.shape[0]), targets]
                rank = (logits > target_logits.unsqueeze(-1)).sum(dim=-1)
                return rank, nll

        else:
            state = state.view(-1, state.shape[-1]) # (bsz*L, embed_dim)
            targets = targets.contiguous().view(-1) # (bsz*L, )

            if not self.config.return_rank:
                nll = self.adaptive_softmax(state, targets, keep_order=True)
                smoothed_nll = nll
                return smoothed_nll, nll

            else:
                nll, rank = self.adaptive_softmax(state, targets, keep_order=True, return_rank=True)
                return rank, nll

        return smoothed_nll, nll

    def decode(self, batch, cache=None):
        ''' if targest is not None,  '''
        word_embedding = self.embed(batch, self.embedding)

        decoded = {
            'cache': cache,
            'state': word_embedding,
        }

        # concat layer
        decoded = self.layers[0](decoded, layer_i=0)
        global_mem = self.layers[0].global_mem

        # regular layers
        for i, decoder in enumerate(self.layers[1:]):
            decoded = decoder(decoded, layer_i=i+1, global_mem=global_mem)

        # compute projection to the vocabulary
        state = decoded['state']
        if cache is not None:
            state = state[:, -1:]       # fetch newly generated tok

        return {
            'cache': decoded.get('cache'),
            'state': state,            # bs x L x dim_emb or bs x L x hidden_dim
        }

    def embed(self, inputs, token_embedding):
        ''' Embed the given inputs, no position embedding '''
        if self.config.TFN:
            return self.dropout(token_embedding(inputs) + self.position_embedding(inputs))
        else:
            return self.dropout(token_embedding(inputs))
コード例 #10
0
class ProbeNewTransformer(nn.Module):
    ''' The New Transformer module '''
    def __init__(self, config, dataset):
        ''' Initialize the Transformer '''
        super(ProbeNewTransformer, self).__init__()

        self.dataset = dataset
        self.span = config.span
        self.embedding = TokenEmbedding(
            dataset.vocab_size,
            config.embedding_size,
            padding_idx=self.padding_idx
        )
        self.position_embedding = PositionEmbedding(config.embedding_size)
        self.dropout = nn.Dropout(config.dropout_p, inplace=True)

        # Allow for overriding the encoders and decoders in dervied classes
        self.encoders = type(self).create_encoders(config)
        self.decoders = self.create_decoders(config)

        self.label_smoothing = LabelSmoothingLoss(
            config.label_smoothing or 0,
            ignore_index=self.padding_idx,
            reduction='none'
        )
        self.cross_entropy = nn.CrossEntropyLoss(
            ignore_index=self.padding_idx,
            reduction='none'
        )

    @classmethod
    def create_encoders(cls, config):
        ''' Create the transformer encoders '''
        kwargs = {'dropout_p': config.dropout_p}
        attn_config = {'attn_type': config.attn_type,
                       'attn_position': config.attn_position,
                       'attn_param': config.attn_param,
                       'attn_displacement': config.attn_displacement,
                       'num_layers': config.num_layers,
                       'num_heads': config.num_heads,
                       'attn_concat': config.attn_concat,
                       'which_attn': 'encoder',
                       'attn_weights': config.attn_weights,
                       'attn_score': config.attn_score,
                       'attn_bins': config.attn_bins,
                       'attn_threshold': config.attn_threshold,
                       'attn_window': config.attn_window}
        args = [attn_config, config.num_heads, config.embedding_size, config.hidden_dim]
        return nn.ModuleList([
            TransformerEncoderLayer(*args, **kwargs)
            for _ in range(config.num_layers)
        ])

    # @classmethod
    def create_decoders(self, config):
        ''' Create the transformer decoders '''
        kwargs = {'dropout_p': config.dropout_p, 'span': config.span}
        dec_attn_config = {'attn_type': config.dec_attn_type,
                           'attn_position': config.dec_attn_position,
                           'attn_param': config.dec_attn_param,
                           'attn_displacement': config.dec_attn_displacement,
                           'num_layers': config.num_layers,
                           'num_heads': config.num_heads,
                           'attn_concat': config.dec_attn_concat,
                           'which_attn': 'decoder',
                           'attn_weights': config.dec_attn_weights,
                           'attn_score': config.dec_attn_score,
                           'attn_bins': config.dec_attn_bins,
                           'attn_threshold': config.dec_attn_threshold,
                           'attn_window': config.dec_attn_window}
        enc_dec_attn_config = {'attn_type': config.enc_dec_attn_type,
                               'attn_position': config.enc_dec_attn_position,
                               'attn_param': config.enc_dec_attn_param,
                               'attn_displacement': config.enc_dec_attn_displacement,
                               'num_layers': config.num_layers,
                               'num_heads': config.num_heads,
                               'word_count_ratio': self.dataset.word_count_ratio,
                               'attn_concat': config.enc_dec_attn_concat,
                               'which_attn': 'source',
                               'attn_weights': config.enc_dec_attn_weights,
                               'attn_score': config.enc_dec_attn_score,
                               'attn_bins': config.enc_dec_attn_bins,
                               'enc_dec_attn_layer': config.enc_dec_attn_layer,
                               'enc_dec_attn_num_heads': config.enc_dec_attn_num_heads,
                               'attn_threshold': config.enc_dec_attn_threshold,
                               'attn_window': config.enc_dec_attn_window
                               }
        args = [dec_attn_config, enc_dec_attn_config, config.num_heads, config.embedding_size, config.hidden_dim]
        return nn.ModuleList([
            TransformerDecoderLayer(*args, layer_i, **kwargs)
            for layer_i in range(config.num_layers)
        ])

    @property
    def sos_idx(self):
        ''' Return the sos index '''
        return self.dataset.sos_idx

    @property
    def padding_idx(self):
        ''' Return the padding index '''
        return self.dataset.padding_idx

    def translator(self, config):
        ''' Get a translator for this model '''
        return ProbeNewTranslator(config, self, self.dataset)

    def reset_named_parameters(self, modules):
        ''' Get a translator for this model '''
        if 'encoder' in modules:
            for encoder in self.encoders:
                encoder.reset_parameters()
        if 'decoder' in modules:
            for decoder in self.decoders:
                decoder.reset_parameters()
        if 'embeddings' in modules:
            self.embedding.reset_parameters()

    def forward(self, batch): # pylint:disable=arguments-differ
        ''' A batch of inputs and targets '''

        encoded, encoder_attn_weights_tensor = self.encode(batch['inputs'])

        decoded = self.decode(
            encoded,
            right_shift(right_shift(batch['targets']), shift=self.span - 1, fill=self.sos_idx),
            input_lens=batch['input_lens']
        )

        logits = decoded['logits']
        dims = list(range(1, logits.dim()))
        targets = left_shift(batch['targets'])
        nll = self.cross_entropy(logits, targets).sum(dims[:-1])
        smoothed_nll = self.label_smoothing(logits, targets).sum(dims)

        return {'smoothed_nll': smoothed_nll,
            'nll': nll,
            'encoder_attn_weights_tensor': encoder_attn_weights_tensor,
            'decoder_attn_weights_tensor': decoded['decoder_attn_weights_tensor'],
            'enc_dec_attn_weights_tensor': decoded['enc_dec_attn_weights_tensor']}

    def encode(self, inputs):
        ''' Encode the inputs '''
        word_embedding = self.embed(inputs, self.embedding)
        encoded = {
            'state': word_embedding,
            'mask': inputs.eq(self.padding_idx)
        }
        encoder_attn_weights_list = []
        for i, encoder in enumerate(self.encoders):
            encoded = encoder(encoded, i, word_embedding)
            encoder_attn_weights_list.append(encoded['encoder_attn_weights'])

        encoder_attn_weights_tensor = torch.stack(encoder_attn_weights_list)
            
        return encoded, encoder_attn_weights_tensor

    def decode(self, encoded, targets, decoders=None, embedding=None, cache=None, mask=None, input_lens=None):
        ''' Decode the encoded sequence to the targets '''
        if decoders is None:
            decoders = self.decoders

        if embedding is None:
            embedding = self.embedding

        word_embedding = self.embed(targets, embedding)

        decoded = {
            'cache': cache,
            'state': word_embedding,
            'mask': targets.eq(self.padding_idx) if mask is None else mask,
            'input_lens': input_lens
        }
        decoder_attn_weights_list = []
        enc_dec_attn_weights_list = []
        for i, decoder in enumerate(decoders):
            # print("i", i)
            decoded = decoder(decoded, encoded, i, word_embedding)
            if 'enc_dec_attn_weights' not in decoded:
                decoder_attn_weights_list.append(decoded['decoder_attn_weights'])
                # enc_dec_attn_weights_list.append([])
            else:
                decoder_attn_weights_list.append(decoded['decoder_attn_weights'])
                enc_dec_attn_weights_list.append(decoded['enc_dec_attn_weights'])

        decoder_attn_weights_tensor = torch.stack(decoder_attn_weights_list)
        enc_dec_attn_weights_tensor = torch.stack(enc_dec_attn_weights_list)

        # compute projection to the vocabulary
        state = decoded['state']
        if cache is not None:
            state = state[:, -self.span:]

        return {
            'cache': decoded.get('cache'),
            'logits': embedding(state, transpose=True).transpose(2, 1),  # transpose to B x C x ...
            'decoder_attn_weights_tensor': decoder_attn_weights_tensor,
            'enc_dec_attn_weights_tensor': enc_dec_attn_weights_tensor
        }

    def embed(self, inputs, token_embedding):
        ''' Embed the given inputs '''
        return self.dropout(token_embedding(inputs) + self.position_embedding(inputs))
コード例 #11
0
ファイル: transformer.py プロジェクト: SimengSun/revisit-nplm
class Transformer(nn.Module):
    ''' The Transformer LM module '''
    def __init__(self, config, dataset):
        ''' Initialize the Transformer '''
        super(Transformer, self).__init__()

        self.dataset = dataset
        self.config = config

        self.adaptive = config.adaptive

        self.embedding = TokenEmbedding(dataset.vocab_size,
                                        config.embedding_size,
                                        config.model_size,
                                        config.cutoffs,
                                        emb_std=config.emb_std,
                                        proj_std=config.proj_std,
                                        div_val=config.div_val,
                                        padding_idx=self.padding_idx,
                                        do_proj=config.do_proj)

        if self.adaptive:
            self.adaptive_softmax = AdaptiveSoftmax(self.dataset.vocab_size,
                                                    config.embedding_size,
                                                    config.model_size,
                                                    config.cutoffs,
                                                    div_val=config.div_val)

            self.tie_weights = config.tie_weights
            self.tie_projs = config.tie_projs

            if self.tie_weights:
                for i in range(len(self.adaptive_softmax.out_layers)):
                    self.adaptive_softmax.out_layers[
                        i].weight = self.embedding.emb_layers[i].weight

            if self.tie_projs:
                for i in range(1, len(self.adaptive_softmax.out_projs)):
                    if config.div_val == 1 and config.model_size != config.embedding_size:
                        self.adaptive_softmax.out_projs[
                            i] = self.embedding.emb_projs[0]
                    elif config.div_val != 1:
                        self.adaptive_softmax.out_projs[
                            i] = self.embedding.emb_projs[i]

        self.position_embedding = PositionEmbedding(config.embedding_size)
        self.dropout = nn.Dropout(config.dropout_p, inplace=True)

        if len(config.no_attention) == 1:
            config.no_attention = config.no_attention * config.num_layers
        assert len(config.no_attention) == config.num_layers

        self.layers = self.create_layers(config)

        self.label_smoothing = LabelSmoothingLoss(
            config.label_smoothing or 0,
            ignore_index=self.padding_idx,
            reduction='none')
        self.cross_entropy = nn.CrossEntropyLoss(ignore_index=self.padding_idx,
                                                 reduction='none')

    @classmethod
    def create_layers(self, config):
        ''' Create the transformer decoders '''
        kwargs = {'dropout_p': config.dropout_p}  # sublayer kwargs

        args = [config, config.num_heads, config.model_size, config.hidden_dim]

        layers = nn.ModuleList([
            TransformerLayer(*args, layer_i, **kwargs)
            for layer_i in range(config.num_layers)
        ])

        return layers

    @property
    def padding_idx(self):
        return self.dataset.padding_idx

    @property
    def eos_idx(self):
        return self.dataset.eos_idx

    def reset_named_parameters(self, modules):

        if 'layers' in modules:
            for layer in self.layers:
                layer.reset_parameters()

        if 'embeddings' in modules:
            self.embedding.reset_parameters()

    def forward(self, batch, global_mask=None):  # pylint:disable=arguments-differ
        ''' batch: length x bsz'''

        batch = batch.transpose(1, 0)
        targets = left_shift(batch)
        decoded = self.decode(right_shift(batch), global_mask=global_mask)

        state = decoded['state']
        if not self.adaptive:
            logits = self.embedding(state, reverse=True).transpose(2, 1)
            dims = list(range(1, logits.dim()))
            nll = self.cross_entropy(logits, targets).view(-1)
            smoothed_nll = self.label_smoothing(logits, targets).sum(dims)

            if not self.config.return_rank:
                return smoothed_nll, nll

            else:
                logits = logits.transpose(2, 1)
                assert targets.shape[0] == 1
                targets = targets.squeeze(0)
                target_logits = logits[:, range(targets.shape[0]), targets]
                rank = (logits > target_logits.unsqueeze(-1)).sum(dim=-1)
                return rank, nll

        else:
            if self.config.batch_length < state.size(1):
                state = state[:, -self.config.batch_length:].contiguous()
                targets = targets[:, -self.config.batch_length:].contiguous()

            state = state.view(-1, state.shape[-1])  # (bsz*L, embed_dim)
            targets = targets.contiguous().view(-1)  # (bsz*L, )

            if not self.config.return_rank:
                nll = self.adaptive_softmax(state, targets, keep_order=True)
                smoothed_nll = nll
                return smoothed_nll, nll

            else:
                nll, rank = self.adaptive_softmax(state,
                                                  targets,
                                                  keep_order=True,
                                                  return_rank=True)
                return rank, nll

    def decode(self, batch, cache=None, global_mask=None):
        ''' if targest is not None,  '''

        bsz, L = batch.shape
        word_embedding = self.embed(batch, self.embedding)

        decoded = {
            'state': word_embedding,
        }

        decoded['state'][batch == self.padding_idx] = 0

        for i, decoder in enumerate(self.layers):
            decoded = decoder(decoded, layer_i=i, global_mask=global_mask)

        return {
            'state': decoded['state'],  # bs x L x hidden_dim
        }

    def embed(self, inputs, token_embedding):
        ''' Embed the given inputs '''
        return self.dropout(
            token_embedding(inputs) + self.position_embedding(inputs))