Exemplo n.º 1
0
    def __init__(self, n_in, n_ctx, n_head,
                 attn_dropout=0.0, resid_dropout=0.0,
                 afn='quick_gelu', scale=True, mask=False,
                 zero_out=False, init_scale=1.0, res_scale=1.0,
                 m_attn = 0.25, m_mlp = 1.,
                 checkpoint_attn = 0, checkpoint_mlp = 0,
                 attn_func=0, blocks=None, spread=None,
                 encoder_dims=None, prime_len=None):
        super().__init__()
        self.attn = FactoredAttention(n_in=n_in, n_ctx=n_ctx, n_state=int(m_attn * n_in), n_head=n_head,
                                      attn_dropout=attn_dropout, resid_dropout=resid_dropout,
                                      scale=scale, mask=mask,
                                      zero_out=zero_out, init_scale=init_scale,
                                      checkpoint_attn=checkpoint_attn,
                                      attn_func=attn_func, blocks=blocks, spread=spread,
                                      encoder_dims=encoder_dims, prime_len=prime_len)
        self.ln_0 = LayerNorm(n_in)
        self.mlp = MLP(n_in=n_in, n_state=int(m_mlp * n_in),
                       resid_dropout=resid_dropout,
                       afn=afn,
                       zero_out=zero_out, init_scale=init_scale)
        self.ln_1 = LayerNorm(n_in)
        self.res_scale = res_scale

        self.checkpoint_attn = checkpoint_attn
        self.checkpoint_mlp = checkpoint_mlp
        self.n_in = n_in
        self.attn_func = attn_func
Exemplo n.º 2
0
class ResAttnBlock(nn.Module):
    def __init__(self, n_in, n_ctx, n_head,
                 attn_dropout=0.0, resid_dropout=0.0,
                 afn='quick_gelu', scale=True, mask=False,
                 zero_out=False, init_scale=1.0, res_scale=1.0,
                 m_attn = 0.25, m_mlp = 1.,
                 checkpoint_attn = 0, checkpoint_mlp = 0,
                 attn_func=0, blocks=None, spread=None,
                 encoder_dims=None, prime_len=None):
        super().__init__()
        self.attn = FactoredAttention(n_in=n_in, n_ctx=n_ctx, n_state=int(m_attn * n_in), n_head=n_head,
                                      attn_dropout=attn_dropout, resid_dropout=resid_dropout,
                                      scale=scale, mask=mask,
                                      zero_out=zero_out, init_scale=init_scale,
                                      checkpoint_attn=checkpoint_attn,
                                      attn_func=attn_func, blocks=blocks, spread=spread,
                                      encoder_dims=encoder_dims, prime_len=prime_len)
        self.ln_0 = LayerNorm(n_in)
        self.mlp = MLP(n_in=n_in, n_state=int(m_mlp * n_in),
                       resid_dropout=resid_dropout,
                       afn=afn,
                       zero_out=zero_out, init_scale=init_scale)
        self.ln_1 = LayerNorm(n_in)
        self.res_scale = res_scale

        self.checkpoint_attn = checkpoint_attn
        self.checkpoint_mlp = checkpoint_mlp
        self.n_in = n_in
        self.attn_func = attn_func

    def forward(self, x, encoder_kv, sample=False):
        if sample:
            a = self.attn(self.ln_0(x), encoder_kv, sample)
            m = self.mlp(self.ln_1(x + a))
        else:
            if self.attn_func == 6:
                assert encoder_kv is not None
                a = checkpoint(lambda _x,_enc_kv,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s),
                               (x,encoder_kv),
                               (*self.attn.parameters(), *self.ln_0.parameters()),
                               self.checkpoint_attn == 3)  # 2 recomputes after the projections, and 1 recomputes after head splitting.
            else:
                assert encoder_kv is None
                a = checkpoint(lambda _x,_enc_kv=None,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s),
                               (x,),
                               (*self.attn.parameters(), *self.ln_0.parameters()),
                               self.checkpoint_attn == 3)  # 2 recomputes after the projections, and 1 recomputes after head splitting.
            m = checkpoint(lambda _x: self.mlp(self.ln_1(_x)), (x + a,),
                           (*self.mlp.parameters(), *self.ln_1.parameters()),
                           self.checkpoint_mlp == 1)
        if self.res_scale == 1.0:
            h = x + a + m
        else:
            h = x + self.res_scale * (a + m)
        return h
Exemplo n.º 3
0
    def __init__(self, input_shape, bins, down_t, stride_t, out_width, init_scale, zero_out, res_scale, **block_kwargs):
        super().__init__()
        self.x_shape = input_shape

        # Embedding
        self.width = out_width
        self.x_emb = nn.Embedding(bins, out_width)
        nn.init.normal_(self.x_emb.weight, std=0.02 * init_scale)

        # Conditioner
        self.cond = DecoderConvBock(self.width, self.width, down_t, stride_t, **block_kwargs, zero_out=zero_out, res_scale=res_scale)
        self.ln = LayerNorm(self.width)
Exemplo n.º 4
0
    def __init__(self,
                 z_shapes,
                 l_bins,
                 encoder,
                 decoder,
                 level,
                 downs_t,
                 strides_t,
                 labels,
                 prior_kwargs,
                 x_cond_kwargs,
                 y_cond_kwargs,
                 prime_kwargs,
                 copy_input,
                 labels_v3=False,
                 merged_decoder=False,
                 single_enc_dec=False):
        super().__init__()

        self.use_tokens = prime_kwargs.pop('use_tokens')
        self.n_tokens = prime_kwargs.pop('n_tokens')
        self.prime_loss_fraction = prime_kwargs.pop('prime_loss_fraction')

        self.copy_input = copy_input
        if self.copy_input:
            prime_kwargs['bins'] = l_bins

        self.z_shapes = z_shapes
        self.levels = len(self.z_shapes)

        self.z_shape = self.z_shapes[level]

        self.level = level
        assert level < self.levels, f"Total levels {self.levels}, got level {level}"

        self.l_bins = l_bins

        # Passing functions instead of the vqvae module to avoid getting params
        self.encoder = encoder
        self.decoder = decoder

        # X conditioning
        self.x_cond = (level != (self.levels - 1))
        self.cond_level = level + 1

        # Y conditioning
        self.y_cond = labels

        self.single_enc_dec = single_enc_dec
        # X conditioning
        if self.x_cond:
            self.conditioner_blocks = nn.ModuleList()
            conditioner_block = lambda _level: Conditioner(
                input_shape=z_shapes[_level],
                bins=l_bins,
                down_t=downs_t[_level],
                stride_t=strides_t[_level],
                **x_cond_kwargs)
            if dist.get_rank() == 0: print(f"Conditioning on 1 above level(s)")
            self.conditioner_blocks.append(conditioner_block(self.cond_level))

        # Y conditioning
        if self.y_cond:
            self.n_time = self.z_shape[
                0]  # Assuming STFT=TF order and raw=T1 order, so T is first dim
            self.y_emb = LabelConditioner(n_time=self.n_time,
                                          include_time_signal=not self.x_cond,
                                          **y_cond_kwargs)

        # Lyric conditioning
        if single_enc_dec:
            # Single encoder-decoder transformer
            self.prior_shapes = [(self.n_tokens, ),
                                 prior_kwargs.pop('input_shape')]
            self.prior_bins = [prime_kwargs['bins'], prior_kwargs.pop('bins')]
            self.prior_dims = [np.prod(shape) for shape in self.prior_shapes]
            self.prior_bins_shift = np.cumsum([0, *self.prior_bins])[:-1]
            self.prior_width = prior_kwargs['width']
            print_once(
                f'Creating cond. autoregress with prior bins {self.prior_bins}, '
            )
            print_once(f'dims {self.prior_dims}, ')
            print_once(f'shift {self.prior_bins_shift}')
            print_once(f'input shape {sum(self.prior_dims)}')
            print_once(f'input bins {sum(self.prior_bins)}')
            print_once(f'Self copy is {self.copy_input}')

            self.prime_loss_dims, self.gen_loss_dims = self.prior_dims[
                0], self.prior_dims[1]
            self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims
            self.prior = ConditionalAutoregressive2D(
                input_shape=(sum(self.prior_dims), ),
                bins=sum(self.prior_bins),
                x_cond=(self.x_cond or self.y_cond),
                y_cond=True,
                prime_len=self.prime_loss_dims,
                **prior_kwargs)

        else:
            # Separate encoder-decoder transformer
            if self.n_tokens != 0 and self.use_tokens:
                from jukebox.transformer.ops import Conv1D
                prime_input_shape = (self.n_tokens, )
                self.prime_loss_dims = np.prod(prime_input_shape)
                self.prime_acts_width, self.prime_state_width = prime_kwargs[
                    'width'], prior_kwargs['width']
                self.prime_prior = ConditionalAutoregressive2D(
                    input_shape=prime_input_shape,
                    x_cond=False,
                    y_cond=False,
                    only_encode=True,
                    **prime_kwargs)
                self.prime_state_proj = Conv1D(
                    self.prime_acts_width,
                    self.prime_state_width,
                    init_scale=prime_kwargs['init_scale'])
                self.prime_state_ln = LayerNorm(self.prime_state_width)
                self.prime_bins = prime_kwargs['bins']
                self.prime_x_out = nn.Linear(self.prime_state_width,
                                             self.prime_bins,
                                             bias=False)
                nn.init.normal_(self.prime_x_out.weight,
                                std=0.02 * prior_kwargs['init_scale'])
            else:
                self.prime_loss_dims = 0
            self.gen_loss_dims = np.prod(self.z_shape)
            self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims
            self.prior = ConditionalAutoregressive2D(
                x_cond=(self.x_cond or self.y_cond),
                y_cond=self.y_cond,
                encoder_dims=self.prime_loss_dims,
                merged_decoder=merged_decoder,
                **prior_kwargs)

        self.n_ctx = self.gen_loss_dims
        self.downsamples = calculate_strides(strides_t, downs_t)
        self.cond_downsample = self.downsamples[
            level + 1] if level != self.levels - 1 else None
        self.raw_to_tokens = np.prod(self.downsamples[:level + 1])
        self.sample_length = self.n_ctx * self.raw_to_tokens
        if labels:
            self.labels_v3 = labels_v3
            self.labeller = Labeller(self.y_emb.max_bow_genre_size,
                                     self.n_tokens,
                                     self.sample_length,
                                     v3=self.labels_v3)

        print(
            f"Level:{level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample length:{self.sample_length}"
        )