Exemplo n.º 1
0
 def input_node_func(self, nodes):
     return {'feat': F.silu(self.input(nodes.data['feat']))}
Exemplo n.º 2
0
 def input_edge_func(self, edges):
     return {'d': F.silu(self.edge_input(edges.data['d']))}
Exemplo n.º 3
0
        except TypeError:
            pass

        return x

    return tuple(repeat(x, n_item))


tuple2 = lambda x: ensure_tuple(x, 2)


activations = {
    "identity": lambda x: x,
    "gelu": lambda x: F.gelu(x) * 1.7015043497085571,
    "relu": lambda x: F.relu(x) * 1.7139588594436646,
    "silu": lambda x: F.silu(x) * 1.7881293296813965,
}


class ScaledActivation(nn.Module):
    def __init__(self, activation):
        super().__init__()

        self.name = activation
        self.activation = activations[activation]

    def forward(self, input):
        return self.activation(input)

    def __repr__(self):
        return f"ScaledActivation({self.name})"
Exemplo n.º 4
0
 def forward(self, inputs, levels, *args):
   out = self.project_in(inputs)
   unet = self.out(func.silu(self.unet(out, levels)))
   quadratic = torch.einsum("bchw,bchw->b", unet, inputs)[:, None]
   time = self.time(levels)
   return func.softplus(self.factor(time)) * quadratic + self.Z(time)
Exemplo n.º 5
0
    def forward(self,
                input,
                pos,
                indices=None,
                key_padding_mask=None,
                attn_mask=None,
                mems=None,
                incremental=False,
                incremental_cache=None):

        if indices.size(0) == 1 and len(indices.shape) == 1:
            r_i = torch.index_select(self.r_i, 0, indices).squeeze(0)
            s_i = torch.index_select(self.s_i, 0, indices).squeeze(0)
            r_p = torch.index_select(self.r_p, 0, indices).squeeze(0)
            s_p = torch.index_select(self.s_p, 0, indices).squeeze(0)
            r_o = torch.index_select(self.r_o, 0, indices).squeeze(0)
            s_o = torch.index_select(self.s_o, 0, indices).squeeze(0)
        else:
            print(indices.size(), input.size())
            raise NotImplementedError

        # weight dropout
        in_proj_weight = F.dropout(self.in_proj_weight,
                                   p=self.dropout,
                                   training=self.training)
        pos_proj_weight = F.dropout(self.pos_proj_weight,
                                    p=self.dropout,
                                    training=self.training)
        out_proj_weight = F.dropout(self.out_proj_weight,
                                    p=self.dropout,
                                    training=self.training)

        if self.use_multiplicative:
            rm_i = torch.index_select(self.rm_i, 0, indices).squeeze(0)
            sm_i = torch.index_select(self.sm_i, 0, indices).squeeze(0)
            rm_p = torch.index_select(self.rm_p, 0, indices).squeeze(0)
            sm_p = torch.index_select(self.sm_p, 0, indices).squeeze(0)
            rm_o = torch.index_select(self.rm_o, 0, indices).squeeze(0)
            sm_o = torch.index_select(self.sm_o, 0, indices).squeeze(0)
            # print(rm_i, sm_i)

            in_scale = torch.bmm(rm_i.unsqueeze(-1),
                                 sm_i.unsqueeze(1)).sum(dim=0)
            in_proj_weight = in_proj_weight * in_scale
            pos_proj_weight = pos_proj_weight * torch.bmm(
                rm_p.unsqueeze(-1), sm_p.unsqueeze(1)).sum(dim=0)
            out_proj_weight = out_proj_weight * torch.bmm(
                rm_o.unsqueeze(-1), sm_o.unsqueeze(1)).sum(dim=0)

        in_proj_weight = in_proj_weight + torch.bmm(
            r_i.unsqueeze(-1), s_i.unsqueeze(1)).sum(dim=0)
        pos_proj_weight = pos_proj_weight + torch.bmm(
            r_p.unsqueeze(-1), s_p.unsqueeze(1)).sum(dim=0)
        out_proj_weight = out_proj_weight + torch.bmm(
            r_o.unsqueeze(-1), s_o.unsqueeze(1)).sum(dim=0)

        if self.mfw_activation == "none":
            in_proj_weight = in_proj_weight
        elif self.mfw_activation == "gelu":
            in_proj_weight = F.gelu(in_proj_weight)
            pos_proj_weight = F.gelu(pos_proj_weight)
            out_proj_weight = F.gelu(out_proj_weight)
        elif self.mfw_activation == "silu":
            in_proj_weight = F.silu(in_proj_weight)
            pos_proj_weight = F.silu(pos_proj_weight)
            out_proj_weight = F.silu(out_proj_weight)
        else:
            raise NotImplementedError

        if key_padding_mask is not None:
            assert (
                attn_mask is None
            ), "ERROR attn_mask and key_padding_mask should not be both defined!"
            mask = key_padding_mask
            if len(mask.shape) == 3:
                mask = mask.squeeze(0).transpose(0, 1)
        elif attn_mask is not None:
            mask = attn_mask
            if len(mask.shape) == 3:
                mask = mask.squeeze(-1)
        else:
            mask = None

        is_training = self.training

        outputs, coverage = self.attn_func(
            input, pos, attn_mask is not None, is_training, self.num_heads,
            in_proj_weight.t(), out_proj_weight.t(), pos_proj_weight.t(),
            self.in_proj_bias, self.out_proj_bias, self.pos_proj_bias,
            self.r_w_bias, self.r_r_bias, mask, self.dropout, incremental,
            incremental_cache, False, False)
        # last False is double precision

        return outputs, coverage
Exemplo n.º 6
0
def silu_nf(x):
    return F.silu(x) * 1.7881293296813965
Exemplo n.º 7
0
 def forward(self, x):
     return F.silu(self.alpha * x, inplace=True) / (self.alpha + 1e-10)