예제 #1
0
    def forward(self, input, *args):

        if self.optimized == 2 or not input.is_cuda:
            hidden = F.linear(input, self.in_proj_weight, self.in_proj_bias)
            # hidden = F.relu(hidden, inplace=True)
            hidden = self.function(hidden)
            if self.variational:
                hidden = variational_dropout(hidden,
                                             p=self.dropout,
                                             training=self.training)
            else:
                hidden = F.dropout(hidden,
                                   p=self.dropout,
                                   training=self.training)
            hidden = F.linear(hidden, self.out_proj_weight, self.out_proj_bias)
        else:
            # Apex MLP does not support dropout so instead we use dropconnect
            # Theoretically they should yield similar results
            weights = [
                F.dropout(self.in_proj_weight,
                          p=self.dropout,
                          training=self.training), self.out_proj_weight
            ]
            biases = [
                F.dropout(self.in_proj_bias,
                          p=self.dropout,
                          training=self.training), self.out_proj_bias
            ]
            seq_len, bsz, hidden_size = input.size(0), input.size(
                1), input.size(2)
            hidden = self.fast_mlp_func(True, 1, input.view(seq_len * bsz, -1),
                                        *weights, *biases)
            hidden = hidden.view(seq_len, bsz, hidden_size)

        return hidden
예제 #2
0
    def forward(self,
                input,
                pos,
                key_padding_mask=None,
                attn_mask=None,
                incremental=False,
                incremental_cache=None,
                cleaning=False):
        q = self.layer_norm(input)
        attn, coverage = self.attn(q,
                                   pos,
                                   key_padding_mask=key_padding_mask,
                                   attn_mask=attn_mask,
                                   incremental=incremental,
                                   incremental_cache=incremental_cache)

        if not self.variational:
            o = F.dropout(attn,
                          p=self.residual_dropout,
                          training=self.training,
                          inplace=False)
        else:
            o = variational_dropout(attn,
                                    p=self.residual_dropout,
                                    inplace=False,
                                    training=self.training)

        if cleaning:
            del q, attn
        return o, coverage
예제 #3
0
    def forward(self, input, indices=None):

        len_x, bsz = input.size(0), input.size(1)
        ensemble = self.r_in.size(0)

        if self.training:
            with torch.no_grad():
                indices = torch.arange(0, bsz, device=input.device, dtype=torch.long)
                indices = torch.remainder(indices, ensemble)

            r_in = torch.index_select(self.r_in, 0, indices)
            s_in = torch.index_select(self.s_in, 0, indices)
            r_out = torch.index_select(self.r_out, 0, indices)
            s_out = torch.index_select(self.s_out, 0, indices)

            input = torch.mul(input, r_in)
            input = F.linear(input, self.in_proj_weight, self.in_proj_bias)
            input = torch.mul(input, s_in)

            input = F.relu(input)
            if self.variational:
                input = variational_dropout(input, p=self.dropout, training=self.training)
            else:
                input = F.dropout(input, p=self.dropout, training=self.training)

            input = torch.mul(input, r_out)
            input = F.linear(input, self.out_proj_weight, self.out_proj_bias)
            input = torch.mul(input, s_out)

            return input
        else:
            input = input.repeat(1, ensemble, 1).view(len_x, ensemble, bsz, input.size(-1))
            input = torch.mul(input, self.r_in.unsqueeze(1))
            input = F.linear(input, self.in_proj_weight, self.in_proj_bias)
            input = torch.mul(input, self.s_in.unsqueeze(1))

            input = F.relu(input)

            input = torch.mul(input, self.r_out.unsqueeze(1))
            input = F.linear(input, self.out_proj_weight, self.out_proj_bias)
            input = torch.mul(input, self.s_out.unsqueeze(1))

            input = torch.mean(input, dim=1)

            return input
        # hidden = self.input_linear(input, indices)
        # hidden = F.relu(hidden)
        # if self.variational:
        #     hidden = variational_dropout(hidden, p=self.dropout, training=self.training)
        # else:
        #     hidden = F.dropout(hidden, p=self.dropout, training=self.training)
        # hidden = self.output_linear(hidden, indices)

        return hidden
예제 #4
0
    def forward(self, input, factor):

        factor = self.factor_map(factor).squeeze()

        in_proj_weight = torch.mv(self.in_proj_weight.view(-1, self.factor_size), factor)\
            .view(self.in_proj_weight.size(0), self.in_proj_weight.size(1))
        out_proj_weight = torch.mv(self.out_proj_weight.view(-1, self.factor_size), factor)\
            .view(self.out_proj_weight.size(0), self.out_proj_weight.size(1))

        in_proj_bias = torch.mv(self.in_proj_bias, factor)
        out_proj_bias = torch.mv(self.out_proj_bias, factor)

        if self.optimized == 2 or not input.is_cuda:
            hidden = F.linear(input, in_proj_weight, in_proj_bias)
            hidden = torch.relu(hidden)
            if self.variational:
                hidden = variational_dropout(hidden,
                                             p=self.dropout,
                                             training=self.training)
            else:
                hidden = F.dropout(hidden,
                                   p=self.dropout,
                                   training=self.training)
            hidden = F.linear(hidden, out_proj_weight, out_proj_bias)
        else:
            # Here weight dropout has to be done instead of dropout because
            # Apex MLP does not support dropout
            weights = [
                F.dropout(in_proj_weight,
                          p=self.dropout,
                          training=self.training),
                F.dropout(out_proj_weight,
                          p=self.dropout,
                          training=self.training)
            ]
            biases = [
                F.dropout(in_proj_bias, p=self.dropout,
                          training=self.training),
                F.dropout(out_proj_bias,
                          p=self.dropout,
                          training=self.training)
            ]
            seq_len, bsz, hidden_size = input.size(0), input.size(
                1), input.size(2)
            hidden = self.fast_mlp_func(True, 1, input.view(seq_len * bsz, -1),
                                        *weights, *biases)
            hidden = hidden.view(seq_len, bsz, hidden_size)

        return hidden
예제 #5
0
    def forward(self, input, sample=False, calculate_log_probs=False):

        calculate_log_probs = calculate_log_probs or self.training
        sample = sample or self.training
        # (MCMC)
        # Sample the weights from the variational posterior distribution q(w)
        sampled_weights, log_variational_posterior = self.weight.sample(
            sample, calculate_log_probs)

        in_proj_weight, out_proj_weight, in_proj_bias, out_proj_bias = \
            unflatten(sampled_weights, self.indices, self.shapes)

        if self.optimized == 2 or not input.is_cuda:
            hidden = F.linear(input, in_proj_weight, in_proj_bias)
            hidden = F.relu(hidden, inplace=True)
            if self.variational:
                hidden = variational_dropout(hidden,
                                             p=self.dropout,
                                             training=self.training)
            else:
                hidden = F.dropout(hidden,
                                   p=self.dropout,
                                   training=self.training)
            hidden = F.linear(hidden, out_proj_weight, out_proj_bias)
        else:
            # Apex MLP does not support dropout so instead we use dropconnect
            # Theoretically they should be the same ^^
            weights = [in_proj_weight, out_proj_weight]
            biases = [in_proj_bias, out_proj_bias]
            seq_len, bsz, hidden_size = input.size(0), input.size(
                1), input.size(2)
            # True = bias, 1 = relu
            hidden = self.fast_mlp_func(True, 1, input.view(seq_len * bsz, -1),
                                        *weights, *biases)
            hidden = hidden.view(seq_len, bsz, hidden_size)

        if calculate_log_probs:
            # KL Divergence between prior and (variational) posterior
            self.log_variational_posterior = log_variational_posterior

            self.log_prior = self.weight_prior.log_prob(sampled_weights)

        return hidden
예제 #6
0
    def forward(self, input, cleaning=False):

        x_norm = self.layer_norm(input)
        x_ff = self.feedforward(x_norm)

        if not self.variational:
            o = F.dropout(x_ff,
                          p=self.residual_dropout,
                          training=self.training,
                          inplace=False)
        else:
            o = variational_dropout(x_ff,
                                    p=self.residual_dropout,
                                    inplace=False,
                                    training=self.training)

        if cleaning:
            del x_norm, x_ff

        return o
예제 #7
0
    def forward(self,
                input,
                context,
                pos_emb,
                mask_tgt,
                mask_src,
                src_lang=None,
                tgt_lang=None,
                incremental=False,
                incremental_cache=None,
                reuse_source=True,
                mems=None):
        """ Self attention layer
            layernorm > attn > dropout > residual
        """

        if incremental and incremental_cache is None:
            incremental_cache = dict()

        coin = True
        if self.training and self.death_rate > 0:
            coin = (torch.rand(1)[0].item() >= self.death_rate)

        if coin:
            # input and context should be time first ?
            if mems is not None and mems.size(0) > 0:
                mems = self.preprocess_attn(mems)
            else:
                mems = None

            if self.macaron:
                out = self.mcr_feedforward(self.preprocess_mcr_ffn(input),
                                           src_lang)

                if self.training and self.death_rate > 0:
                    out = out / (1 - self.death_rate)

                if not self.variational:
                    out = F.dropout(out,
                                    p=self.dropout,
                                    training=self.training)
                else:
                    out = variational_dropout(out,
                                              p=self.dropout,
                                              training=self.training)

                input = input + self.ffn_scale * out

            query = self.preprocess_attn(input)

            if self.mfw:
                out, _ = self.multihead_tgt(
                    query,
                    pos_emb,
                    tgt_lang,
                    None,
                    mask_tgt,
                    mems=mems,
                    incremental=incremental,
                    incremental_cache=incremental_cache)
            else:
                out, _ = self.multihead_tgt(
                    query,
                    pos_emb,
                    None,
                    mask_tgt,
                    mems=mems,
                    incremental=incremental,
                    incremental_cache=incremental_cache)

            # rescaling before residual
            if self.training and self.death_rate > 0:
                out = out / (1 - self.death_rate)

            input = self.postprocess_attn(out, input)
            """ Context Attention layer 
                layernorm > attn > dropout > residual
            """
            if not self.ignore_source:
                query = self.preprocess_src_attn(input)
                incremental_source = incremental and reuse_source

                if self.mfw:
                    out, coverage = self.multihead_src(
                        query,
                        context,
                        context,
                        src_lang,
                        tgt_lang,
                        mask_src,
                        incremental=incremental_source,
                        incremental_cache=incremental_cache)
                else:
                    out, coverage = self.multihead_src(
                        query,
                        context,
                        context,
                        mask_src,
                        incremental=incremental_source,
                        incremental_cache=incremental_cache)

                # rescaling before residual
                if self.training and self.death_rate > 0:
                    out = out / (1 - self.death_rate)

                input = self.postprocess_src_attn(out, input)
            else:
                coverage = None
            """ Feed forward layer 
                layernorm > ffn > dropout > residual
            """
            out = self.feedforward(self.preprocess_ffn(input), tgt_lang)

            # rescaling before residual
            if self.training and self.death_rate > 0:
                out = out / (1 - self.death_rate)

            if not self.variational:
                out = F.dropout(out, p=self.dropout, training=self.training)
            else:
                out = variational_dropout(out,
                                          p=self.dropout,
                                          training=self.training)

            input = input + self.ffn_scale * out
        else:
            coverage = None

        return input, coverage, incremental_cache
예제 #8
0
    def forward(self, input, *args, **kwargs):

        if self.fused and input.is_cuda and not self.autograd:

            # if autocast is enabled: manually cast the function args into half manually
            # for some reason custom_fwd(...) doesn't work
            with autocast(enabled=False):
                weights = [
                    self.in_proj_weight.half(),
                    self.out_proj_weight.half()
                ]
                biases = [self.in_proj_bias.half(), self.out_proj_bias.half()]

                seq_len, bsz, hidden_size = input.size(0), input.size(
                    1), input.size(2)

                dropout = self.dropout if self.training else 0.0

                if self.fused_dropout_add:
                    res_dropout = self.res_dropout if self.training else 0.0
                    hidden = self.fused_function(
                        dropout, res_dropout,
                        input.half().view(seq_len * bsz, -1), *weights,
                        *biases).type_as(input)
                else:
                    recompute = onmt.constants.recompute
                    hidden = self.fused_function(
                        dropout, recompute,
                        input.half().view(seq_len * bsz, -1), *weights,
                        *biases).type_as(input)
                hidden = hidden.view(seq_len, bsz, hidden_size)

                # verification code (only with dropout = 0.0)
                # with torch.no_grad():
                #     hidden_ = F.linear(self.act(F.linear(input, self.in_proj_weight, self.in_proj_bias)),
                #                        self.out_proj_weight, self.out_proj_bias).type_as(hidden)
                #
                #     if self.fused_dropout_add:
                #         hidden_.add_(input)
                #
                #     comp = torch.allclose(hidden, hidden_, rtol=1e-02, atol=1e-03)
                #     if not comp:
                #         print("Warning! The fused function doesn't match the PyTorch function.")
                #         print(hidden - hidden_)

        else:
            if self.autograd:
                hidden = self.linear_in(input)
            else:
                hidden = F.linear(input, self.in_proj_weight,
                                  self.in_proj_bias)

            if self.glu and self.activation != 'sigmoid':
                hidden, gate = hidden.chunk(2, dim=-1)
                hidden = self.act(hidden) * gate
            else:  # GLU function
                hidden = self.act(hidden)

            if not (not self.glu and self.activation == 'relu'):
                if self.variational:
                    hidden = variational_dropout(
                        hidden,
                        p=self.dropout,
                        training=self.training,
                        inplace=self.activation
                        in ['silu', 'relu', 'swish', 'gelu'])
                else:
                    hidden = F.dropout(hidden,
                                       p=self.dropout,
                                       training=self.training,
                                       inplace=self.activation
                                       in ['silu', 'relu', 'swish', 'gelu'])

            if self.autograd:
                hidden = self.linear_out(hidden)
            else:
                hidden = F.linear(hidden, self.out_proj_weight,
                                  self.out_proj_bias)

        if self.dropout_residual:
            if not self.fused_dropout_add:
                if not self.variational:
                    hidden = F.dropout(hidden,
                                       p=self.res_dropout,
                                       training=self.training) + input
                else:
                    hidden = variational_dropout(
                        hidden, p=self.dropout, training=self.training) + input

        return hidden
예제 #9
0
    def forward(self,
                input,
                pos_emb,
                attn_mask,
                incremental=False,
                incremental_cache=None,
                mems=None,
                src_lang=None):

        assert incremental is False
        assert incremental_cache is None

        coin = True
        if self.training and self.death_rate > 0:
            coin = (torch.rand(1)[0].item() >= self.death_rate)
            ffn_scale = self.ffn_scale / (1 - self.death_rate)

        else:
            ffn_scale = self.ffn_scale

        if coin:
            out = self.mcr_feedforward(self.preprocess_mcr_ffn(input),
                                       src_lang)

            out = out * ffn_scale

            if not self.variational:
                out = F.dropout(out, p=self.dropout, training=self.training)
            else:
                out = variational_dropout(out,
                                          p=self.dropout,
                                          training=self.training)

            input = input + out

            # attention
            attn_input = self.preprocess_attn(input)
            out, _ = self.attn(attn_input, pos_emb, attn_mask, None)

            if self.training and self.death_rate > 0:
                out = out / (1 - self.death_rate)

            input = self.postprocess_attn(out, input)

            # convolution
            conv_input = self.preprocess_conv(input)
            out = self.conv(conv_input)

            if self.training and self.death_rate > 0:
                out = out / (1 - self.death_rate)

            input = self.postprocess_conv(out, input)

            # last ffn
            out = self.feedforward(self.preprocess_ffn(input), src_lang)

            out = out * ffn_scale

            if not self.variational:
                out = F.dropout(out, p=self.dropout, training=self.training)
            else:
                out = variational_dropout(out,
                                          p=self.dropout,
                                          training=self.training)

            input = input + out

            return input

        return input
예제 #10
0
    def forward(self,
                input,
                pos_emb,
                attn_mask,
                incremental=False,
                incremental_cache=None,
                mems=None,
                src_lang=None):

        if incremental and incremental_cache is None:
            incremental_cache = dict()

        coin = True
        if self.training and self.death_rate > 0:
            coin = (torch.rand(1)[0].item() >= self.death_rate)

        if coin:
            if self.macaron:
                out = self.mcr_feedforward(self.preprocess_mcr_ffn(input),
                                           src_lang)

                if self.training and self.death_rate > 0:
                    out = out / (1 - self.death_rate)

                if not self.variational:
                    out = F.dropout(out,
                                    p=self.dropout,
                                    training=self.training)
                else:
                    out = variational_dropout(out,
                                              p=self.dropout,
                                              training=self.training)

                input = input + self.ffn_scale * out

            query = self.preprocess_attn(input)

            if self.mfw:
                out, _ = self.multihead(query,
                                        pos_emb,
                                        src_lang,
                                        attn_mask,
                                        None,
                                        mems=mems,
                                        incremental=incremental,
                                        incremental_cache=incremental_cache)
            else:
                out, _ = self.multihead(query,
                                        pos_emb,
                                        attn_mask,
                                        None,
                                        mems=mems,
                                        incremental=incremental,
                                        incremental_cache=incremental_cache)

            # rescaling before residual
            if self.training and self.death_rate > 0:
                out = out / (1 - self.death_rate)

            input = self.postprocess_attn(out, input)
            """ Feed forward layer 
                layernorm > ffn > dropout > residual
            """
            out = self.feedforward(self.preprocess_ffn(input), src_lang)

            # rescaling before residual
            if self.training and self.death_rate > 0:
                out = out / (1 - self.death_rate)

            if not self.variational:
                out = F.dropout(out, p=self.dropout, training=self.training)
            else:
                out = variational_dropout(out,
                                          p=self.dropout,
                                          training=self.training)
            input = input + self.ffn_scale * out

        if incremental:
            return input, incremental_cache

        return input