Ejemplo n.º 1
0
 def __init__(self, in_size=64, out_size=64, cond_size=None,
              kernel_size=3, dilation=1, padding=None,
              activation=None, norm=None, cond_norm=False):
   super().__init__()
   padding = padding or (kernel_size // 2 * dilation)
   self.cond = None
   self.cond_norm = cond_norm
   norm = norm or (lambda x: nn.Identity())
   self.activation = activation or func.relu
   if cond_size:
     self.cond = nn.Linear(cond_size, out_size)
   self.blocks = nn.ModuleList([
     nn.Conv2d(in_size, out_size, kernel_size, dilation, padding),
     nn.Conv2d(out_size, out_size, kernel_size, dilation, padding)
   ])
   self.project = nn.Conv2d(in_size, out_size, 1, bias=False)
   self.zero = ReZero(out_size)
   if cond_norm:
     self.norms = nn.ModuleList([
       norm(in_size, cond_size),
       norm(out_size, cond_size)
     ])
   else:
     self.norms = nn.ModuleList([
       norm(in_size),
       norm(out_size)
     ])
Ejemplo n.º 2
0
 def __init__(self, size, N=1, drop=None, kernel_size=15, dilation=1):
     super().__init__()
     self.conv = getattr(nn, f"Conv{N}d")
     if drop:
         self.block = nn.Sequential(
             self.conv(size,
                       size // 2,
                       kernel_size,
                       dilation=dilation,
                       padding=(kernel_size // 2) * dilation), nn.ReLU(),
             self.conv(size // 2, size // 2, 1), nn.ReLU(),
             nn.Dropout(drop),
             self.conv(size // 2,
                       size,
                       kernel_size,
                       dilation=dilation,
                       padding=(kernel_size // 2) * dilation))
     else:
         self.block = nn.Sequential(
             self.conv(size,
                       size // 2,
                       kernel_size,
                       dilation=dilation,
                       padding=(kernel_size // 2) * dilation), nn.ReLU(),
             self.conv(size // 2, size // 2, 1), nn.ReLU(),
             self.conv(size // 2,
                       size,
                       kernel_size,
                       dilation=dilation,
                       padding=(kernel_size // 2) * dilation))
     self.zero = ReZero(size)
Ejemplo n.º 3
0
 def __init__(self, depth=4, level_repeat=2, scale=4, base=32, z=32):
     super().__init__()
     self.first = nn.Linear(z, base)
     self.first_mean = nn.Linear(base, z)
     self.first_mean_factor = nn.Parameter(
         torch.zeros(1, z, requires_grad=True))
     self.first_logvar = nn.Linear(base, z)
     self.first_logvar_factor = nn.Parameter(
         torch.zeros(1, z, requires_grad=True))
     self.blocks = nn.ModuleList([
         ResBlock(base, base, 3, depth=2)
         for idx in range(depth * level_repeat)
     ])
     self.modifiers = nn.ModuleList([
         nn.Conv2d(z, base, 1, bias=False)
         for idx in range(depth * level_repeat)
     ])
     self.zeros = nn.ModuleList(
         [ReZero(base) for idx in range(depth * level_repeat)])
     self.mean = nn.ModuleList(
         [z_project(2 * base, z) for idx in range(depth * level_repeat)])
     self.logvar = nn.ModuleList(
         [z_project(2 * base, z) for idx in range(depth * level_repeat)])
     self.mean_factor = nn.ParameterList([
         nn.Parameter(torch.zeros(1, z, 1, 1, requires_grad=True))
         for idx in range(depth * level_repeat)
     ])
     self.logvar_factor = nn.ParameterList([
         nn.Parameter(torch.zeros(1, z, 1, 1, requires_grad=True))
         for idx in range(depth * level_repeat)
     ])
     self.level_repeat = level_repeat
     self.scale = scale
Ejemplo n.º 4
0
 def __init__(self, size=128, heads=8, attention=cross_attention):
     super().__init__()
     self.cross_attention = CrossAttention(size=size,
                                           out_size=2 * size,
                                           heads=heads,
                                           attention=attention)
     self.zero = ReZero(2 * size)
Ejemplo n.º 5
0
 def __init__(self, size=128, heads=8):
     super().__init__()
     self.heads = heads
     self.cross_attention = CrossAttention(size=size,
                                           out_size=2 * size,
                                           heads=heads)
     self.output = nn.Conv1d(size, 2 * size, 1, bias=False)
     self.zero = ReZero(2 * size)
Ejemplo n.º 6
0
 def __init__(self, hole, in_size=64, out_size=64,
              hidden_size=64, kernel_size=3, dilation=1,
              padding=None, depth=3, downscale=2,
              cond_size=None, activation=None):
   super().__init__(hole)
   padding = padding or (kernel_size // 2 * dilation)
   self.activation = activation or func.relu
   self.cond = None
   if cond_size:
     self.cond = nn.Linear(cond_size, hidden_size)
   self.downscale = downscale
   self.into_preprocess = nn.Conv2d(in_size, hidden_size, 1)
   self.into_postprocess = nn.Conv2d(hidden_size, out_size, 1)
   self.into_blocks = nn.ModuleList([
     nn.Conv2d(
       hidden_size,
       hidden_size,
       kernel_size,
       padding=padding
     )
     for idx in range(depth)
   ])
   self.into_zeros = nn.ModuleList([
     ReZero(hidden_size)
     for idx in range(depth)
   ])
   self.outof_preprocess = nn.Conv2d(
     out_size + hidden_size, hidden_size, 1
   )
   self.outof_postprocess = nn.Conv2d(hidden_size, in_size, 1)
   self.outof_blocks = nn.ModuleList([
     nn.Conv2d(
       hidden_size,
       hidden_size,
       kernel_size,
       padding=padding
     )
     for idx in range(depth)
   ])
   self.outof_zeros = nn.ModuleList([
     ReZero(hidden_size)
     for idx in range(depth)
   ])
Ejemplo n.º 7
0
 def __init__(self, in_size, out_size, kernel_size, depth=1):
     super().__init__()
     self.project_in = nn.Conv2d(in_size, in_size // 4, 1, bias=False)
     self.project_out = nn.Conv2d(in_size // 4, out_size, 1, bias=False)
     self.blocks = nn.ModuleList([
         nn.Conv2d(in_size // 4,
                   in_size // 4,
                   kernel_size,
                   padding=kernel_size // 2) for idx in range(depth)
     ])
     self.zero = ReZero(out_size, initial_value=0.1)
Ejemplo n.º 8
0
    def __init__(self,
                 size,
                 n_heads=8,
                 hidden_size=128,
                 attention_size=128,
                 value_size=128,
                 depth=2,
                 dropout=0.1):
        super().__init__()
        self.attention = SequenceMultiHeadAttention(
            size,
            size,
            attention_size=attention_size,
            hidden_size=value_size,
            heads=n_heads)
        self.ff = MLP(size,
                      size,
                      hidden_size=hidden_size,
                      depth=2,
                      batch_norm=False)
        self.rezero = ReZero(size)

        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
Ejemplo n.º 9
0
 def __init__(self, size=128, heads=8):
     super().__init__()
     self.attention = CrossAttention(size=size, out_size=size, heads=heads)
     self.zero = ReZero(size)