コード例 #1
0
    def ancestral_sample(self, n_samples, z_conds=None, y=None):
        z = t.zeros((n_samples, self.n_ctx), dtype=t.long, device='cuda') + \
            t.arange(0, self.n_ctx, dtype=t.long, device='cuda').view(1, self.n_ctx)

        if z_conds is not None:
            z_cond = z_conds[0]
            assert_shape(z_cond, (n_samples, self.n_ctx // 4))
            assert (z // 4 == repeat(z_cond, 4, 1)).all(
            ), f'z: {z}, z_cond: {z_cond}, diff: {(z // 4) - repeat(z_cond, 4, 1)}'
        return z
コード例 #2
0
    def primed_sample(self, n_samples, z, z_conds=None, y=None):
        prime = z.shape[1]
        assert_shape(z, (n_samples, prime))
        start = z[:, -1:] + 1
        z_rest = (t.arange(0, self.n_ctx - prime, dtype=t.long,
                           device='cuda').view(1, self.n_ctx - prime) +
                  start).view(n_samples, self.n_ctx - prime)
        z = t.cat([z, z_rest], dim=1)

        if z_conds is not None:
            z_cond = z_conds[0]
            assert_shape(z_cond, (n_samples, self.n_ctx // 4))
            assert (z // 4 == repeat(z_cond, 4, 1)).all(
            ), f'z: {z}, z_cond: {z_cond}, diff: {(z // 4) - repeat(z_cond, 4, 1)}'
        return z
コード例 #3
0
    def forward(self, x):
        N, T = x.shape[0], x.shape[-1]
        emb = self.input_emb_width
        assert_shape(x, (N, emb, T))
        xs = []

        # 64, 32, ...
        iterator = zip(list(range(self.levels)), self.downs_t, self.strides_t)
        for level, down_t, stride_t in iterator:
            level_block = self.level_blocks[level]
            x = level_block(x)
            emb, T = self.output_emb_width, T // (stride_t ** down_t)
            assert_shape(x, (N, emb, T))
            xs.append(x)

        return xs
コード例 #4
0
 def get_encoder_kv(self, prime, fp16=False, sample=False):
     if self.n_tokens != 0 and self.use_tokens:
         if sample:
             self.prime_prior.cuda()
         N = prime.shape[0]
         prime_acts = self.prime_prior(prime, None, None, None, fp16=fp16)
         assert_shape(prime_acts,
                      (N, self.prime_loss_dims, self.prime_acts_width))
         assert prime_acts.dtype == t.float, f'Expected t.float, got {prime_acts.dtype}'
         encoder_kv = self.prime_state_ln(self.prime_state_proj(prime_acts))
         assert encoder_kv.dtype == t.float, f'Expected t.float, got {encoder_kv.dtype}'
         if sample:
             self.prime_prior.cpu()
             if fp16:
                 encoder_kv = encoder_kv.half()
     else:
         encoder_kv = None
     return encoder_kv
コード例 #5
0
    def forward(self, xs, all_levels=True):
        if all_levels:
            assert len(xs) == self.levels
        else:
            assert len(xs) == 1
        x = xs[-1]
        N, T = x.shape[0], x.shape[-1]
        emb = self.output_emb_width
        assert_shape(x, (N, emb, T))

        # 32, 64 ...
        iterator = reversed(list(zip(list(range(self.levels)), self.downs_t, self.strides_t)))
        for level, down_t, stride_t in iterator:
            level_block = self.level_blocks[level]
            x = level_block(x)
            emb, T = self.output_emb_width, T * (stride_t ** down_t)
            assert_shape(x, (N, emb, T))
            if level != 0 and all_levels:
                x = x + xs[level - 1]

        x = self.out(x)
        return x
コード例 #6
0
    def prior_preprocess(self, xs, conds):
        N = xs[0].shape[0]
        for i in range(len(xs)):
            x, shape, dims = xs[i], self.prior_shapes[i], self.prior_dims[i]
            bins, bins_shift = int(self.prior_bins[i]), int(
                self.prior_bins_shift[i])
            assert isinstance(x, t.cuda.LongTensor), x
            assert (0 <= x).all() and (x < bins).all()
            #assert_shape(x, (N, *shape))
            xs[i] = (xs[i] + bins_shift).view(N, -1)

        for i in range(len(conds)):
            cond, shape, dims = conds[i], self.prior_shapes[
                i], self.prior_dims[i]
            if cond is not None:
                assert_shape(cond, (N, dims, self.prior_width))
            else:
                conds[i] = t.zeros((N, dims, self.prior_width),
                                   dtype=t.float,
                                   device='cuda')

        return t.cat(xs, dim=1), t.cat(conds, dim=1)
コード例 #7
0
    def forward(self, y):
        assert len(y.shape) == 2, f"Expected shape with 2 dims, got {y.shape}"
        assert y.shape[-1] == 4 + self.max_bow_genre_size, f"Expected shape (N,{4 + self.max_bow_genre_size}), got {y.shape}"
        assert isinstance(y, t.cuda.LongTensor), f"Expected dtype {t.cuda.LongTensor}, got {y.dtype}"
        N = y.shape[0]
        total_length, offset, length, artist, genre = y[:,0:1], y[:,1:2], y[:,2:3], y[:,3:4], y[:,4:]

        # Start embedding of length 1
        artist_emb = self.artist_emb(artist)
        # Empty genre slots are denoted by -1. We mask these out.
        mask = (genre >= 0).float().unsqueeze(2)
        genre_emb = (self.bow_genre_emb(genre.clamp(0)) * mask).sum(dim=1, keepdim=True)
        start_emb = genre_emb + artist_emb
        assert_shape(start_emb, (N, 1, self.out_width))

        # Pos embedding of length n_ctx
        if self.include_time_signal:
            start, end = offset, offset + length
            total_length, start, end = total_length.float(), start.float(), end.float()
            pos_emb = self.total_length_emb(total_length) + self.absolute_pos_emb(start, end) + self.relative_pos_emb(start/total_length, end/total_length)
            assert_shape(pos_emb, (N, self.n_time, self.out_width))
        else:
            pos_emb = None
        return start_emb, pos_emb
コード例 #8
0
    def forward(self, x, x_cond=None):
        N = x.shape[0]
        assert_shape(x, (N, *self.x_shape))
        if x_cond is not None:
            assert_shape(x_cond, (N, *self.x_shape, self.width))
        else:
            x_cond = 0.0
        # Embed x
        x = x.long()
        x = self.x_emb(x)
        assert_shape(x, (N, *self.x_shape, self.width))
        x = x + x_cond

        # Run conditioner
        x = self.preprocess(x)
        x = self.cond(x)
        x = self.postprocess(x)
        x = self.ln(x)
        return x
コード例 #9
0
    def sample(self,
               n_samples,
               z=None,
               z_conds=None,
               y=None,
               fp16=False,
               temp=1.0,
               top_k=0,
               top_p=0.0,
               chunk_size=None,
               sample_tokens=None):
        N = n_samples
        if z is not None:
            assert z.shape[
                0] == N, f"Expected shape ({N},**), got shape {z.shape}"
        if y is not None:
            assert y.shape[
                0] == N, f"Expected shape ({N},**), got shape {y.shape}"
        if z_conds is not None:
            for z_cond in z_conds:
                assert z_cond.shape[
                    0] == N, f"Expected shape ({N},**), got shape {z_cond.shape}"

        no_past_context = (z is None or z.shape[1] == 0)
        if dist.get_rank() == 0:
            name = {True: 'Ancestral', False: 'Primed'}[no_past_context]
            print(
                f"{name} sampling {n_samples} samples with temp={temp}, top_k={top_k}, top_p={top_p}"
            )

        with t.no_grad():
            # Currently x_cond only uses immediately above layer
            x_cond, y_cond, prime = self.get_cond(z_conds, y)
            if self.single_enc_dec:
                # assert chunk_size % self.prime_loss_dims == 0. TODO: Check if needed
                if no_past_context:
                    z, x_cond = self.prior_preprocess([prime], [None, x_cond])
                else:
                    z, x_cond = self.prior_preprocess([prime, z],
                                                      [None, x_cond])
                if sample_tokens is not None:
                    sample_tokens += self.n_tokens
                z = self.prior.primed_sample(n_samples,
                                             z,
                                             x_cond,
                                             y_cond,
                                             fp16=fp16,
                                             temp=temp,
                                             top_k=top_k,
                                             top_p=top_p,
                                             chunk_size=chunk_size,
                                             sample_tokens=sample_tokens)
                z = self.prior_postprocess(z)
            else:
                encoder_kv = self.get_encoder_kv(prime, fp16=fp16, sample=True)
                if no_past_context:
                    z = self.prior.sample(n_samples,
                                          x_cond,
                                          y_cond,
                                          encoder_kv,
                                          fp16=fp16,
                                          temp=temp,
                                          top_k=top_k,
                                          top_p=top_p,
                                          sample_tokens=sample_tokens)
                else:
                    z = self.prior.primed_sample(n_samples,
                                                 z,
                                                 x_cond,
                                                 y_cond,
                                                 encoder_kv,
                                                 fp16=fp16,
                                                 temp=temp,
                                                 top_k=top_k,
                                                 top_p=top_p,
                                                 chunk_size=chunk_size,
                                                 sample_tokens=sample_tokens)
            if sample_tokens is None:
                assert_shape(z, (N, *self.z_shape))
        return z
コード例 #10
0
def get_alignment(x, zs, labels, prior, fp16, hps):
    level = hps.levels - 1 # Top level used
    n_ctx, n_tokens = prior.n_ctx, prior.n_tokens
    z = zs[level]
    bs, total_length = z.shape[0], z.shape[1]
    if total_length < n_ctx:
        padding_length = n_ctx - total_length
        z = t.cat([z, t.zeros(bs, n_ctx - total_length, dtype=z.dtype, device=z.device)], dim=1)
        total_length = z.shape[1]
    else:
        padding_length = 0

    hop_length = int(hps.hop_fraction[level]*prior.n_ctx)
    n_head = prior.prior.transformer.n_head
    alignment_head, alignment_layer = prior.alignment_head, prior.alignment_layer
    attn_layers = set([alignment_layer])
    alignment_hops = {}
    indices_hops = {}

    prior.cuda()
    empty_cache()
    for start in get_starts(total_length, n_ctx, hop_length):
        end = start + n_ctx

        # set y offset, sample_length and lyrics tokens
        y, indices_hop = prior.get_y(labels, start, get_indices=True)
        assert len(indices_hop) == bs
        for indices in indices_hop:
            assert len(indices) == n_tokens

        z_bs = t.chunk(z, bs, dim=0)
        y_bs = t.chunk(y, bs, dim=0)
        w_hops = []
        for z_i, y_i in zip(z_bs, y_bs):
            w_hop = prior.z_forward(z_i[:,start:end], [], y_i, fp16=fp16, get_attn_weights=attn_layers)
            assert len(w_hop) == 1
            w_hops.append(w_hop[0][:, alignment_head])
            del w_hop
        w = t.cat(w_hops, dim=0)
        del w_hops
        assert_shape(w, (bs, n_ctx, n_tokens))
        alignment_hop = w.float().cpu().numpy()
        assert_shape(alignment_hop, (bs, n_ctx, n_tokens))
        del w

        # alignment_hop has shape (bs, n_ctx, n_tokens)
        # indices_hop is a list of len=bs, each entry of len hps.n_tokens
        indices_hops[start] = indices_hop
        alignment_hops[start] = alignment_hop
    prior.cpu()
    empty_cache()

    # Combine attn for each hop into attn for full range
    # Use indices to place them into correct place for corresponding source tokens
    alignments = []
    for item in range(bs):
        # Note each item has different length lyrics
        full_tokens = labels['info'][item]['full_tokens']
        alignment = np.zeros((total_length, len(full_tokens) + 1))
        for start in reversed(get_starts(total_length, n_ctx, hop_length)):
            end = start + n_ctx
            alignment_hop = alignment_hops[start][item]
            indices = indices_hops[start][item]
            assert len(indices) == n_tokens
            assert alignment_hop.shape == (n_ctx, n_tokens)
            alignment[start:end,indices] = alignment_hop
        alignment = alignment[:total_length - padding_length,:-1] # remove token padding, and last lyric index
        alignments.append(alignment)
    return alignments