示例#1
0
 def get_input_size(self, output_size):
     """
     Computes the input size needed for desired output_size.
     Warning!  This function has side effects.
     """
     win_gr = vconv.GridRange((0, int(1e12)), (0, output_size), 1)
     vconv.compute_inputs(self.vc['end_grcc'], win_gr)
     return self.vc['beg'].parent.in_len()
示例#2
0
    def _init_geometry(self, n_win_batch):
        end_gr = vconv.GridRange((0, 100000), (0, n_win_batch), 1)
        end_vc = self.wavenet.vc['end_grcc']
        end_gr_actual = vconv.compute_inputs(end_vc, end_gr)

        mfcc_vc = self.wavenet.vc['beg'].parent
        beg_grcc_vc = self.wavenet.vc['beg_grcc']

        self.enc_in_len = mfcc_vc.in_len()
        self.enc_in_mel_len = self.embed_len = mfcc_vc.child.in_len()
        self.dec_in_len = beg_grcc_vc.in_len()

        di = beg_grcc_vc.input_gr
        wi = mfcc_vc.input_gr

        self.trim_dec_in = torch.tensor(
                [di.sub[0] - wi.sub[0], di.sub[1] - wi.sub[0] ],
                dtype=torch.long)

        # subrange on the wav input which corresponds to the output
        self.trim_dec_out = torch.tensor(
                [end_gr.sub[0] - wi.sub[0], end_gr.sub[1] - wi.sub[0]],
                dtype=torch.long)

        self.wavenet.trim_ups_out = torch.tensor([0, beg_grcc_vc.in_len()],
                dtype=torch.long)

        self.wavenet.post_init(n_win_batch)
示例#3
0
    def post_init(self, n_win_batch):

        one_gr = vconv.GridRange((0, int(1e12)), (0, 1), 1)
        win_gr = vconv.GridRange((0, int(1e12)), (0, n_win_batch), 1)
        vconv.compute_inputs(self.vc['end_grcc'], win_gr)

        di = self.vc['beg_grcc'].input_gr
        wi = self.vc['beg'].parent.input_gr

        self.wav_cond_offset = [
            int(di.sub[0] - wi.sub[0]),
            int(di.sub[1] - wi.sub[0])
        ]

        vconv.compute_inputs(self.vc['end_grcc'], one_gr)

        for layer in self.conv_layers:
            layer.post_init()

        self.base_global_rf = self.conv_layers[0].global_rf
        self.n_win_batch = n_win_batch
示例#4
0
    def sample(self, wav_onehot, lc_sparse, speaker_inds, jitter_index, n_rep):
        """
        Generate n_rep samples, using lc_sparse and speaker_inds for local and global
        conditioning.  

        wav_onehot: full length wav vector
        lc_sparse: full length local conditioning vector derived from full
        wav_onehot
        """
        # initialize model geometry
        mfcc_vc = self.vc['beg'].parent
        up_vc = self.vc['pre_upsample'].child
        beg_grcc_vc = self.vc['beg_grcc']
        end_vc = self.vc['end_grcc']

        # calculate full output range
        wav_gr = vconv.GridRange((0, 1e12), (0, wav_onehot.size()[2]), 1)
        full_out_gr = vconv.output_range(mfcc_vc, end_vc, wav_gr)
        n_ts = full_out_gr.sub_length()

        # calculate starting input range for single timestep
        one_gr = vconv.GridRange((0, 1e12), (0, 1), 1)
        vconv.compute_inputs(end_vc, one_gr)

        # calculate starting position of wav
        wav_beg = int(beg_grcc_vc.input_gr.sub[0] - mfcc_vc.input_gr.sub[0])
        # wav_end = int(beg_grcc_vc.input_gr.sub[1] - mfcc_vc.input_gr.sub[0])
        wav_onehot = wav_onehot[:,:,wav_beg:]

        # !!! hack - I'm not sure why the int() cast is necessary
        n_init_ts = int(beg_grcc_vc.in_len())

        lc_sparse = lc_sparse.repeat(n_rep, 1, 1)
        jitter_index = jitter_index.repeat(n_rep, 1)
        speaker_inds = speaker_inds.repeat(n_rep)

        # precalculate conditioning vector for all timesteps
        D1 = lc_sparse.size()[1]
        lc_jitter = torch.take(lc_sparse,
                jitter_index.unsqueeze(1).expand(-1, D1, -1))
        lc_conv = self.lc_conv(lc_jitter) 
        lc_dense = self.lc_upsample(lc_conv)
        cond = self.cond(lc_dense, speaker_inds)
        n_ts = cond.size()[2]

        
        # cond_loff, cond_roff = vconv.output_offsets(mfcc_vc, up_end_vc)

        # zero out  
        start_pos = 26000
        n_samples = 20000
        end_pos = start_pos + n_samples

        # wav_onehot[...,n_init_ts:] = 0
        wav_onehot = wav_onehot.repeat(n_rep, 1, 1)
        # wav_onehot[...,start_pos:end_pos] = 0

        # assert cond.size()[2] == wav_onehot.size()[2]

        # loop through timesteps
        # inrange = torch.tensor((0, n_init_ts), dtype=torch.int32)
        inrange = torch.tensor((start_pos - n_init_ts, start_pos), dtype=torch.int32)
        # end_ind = torch.tensor([n_ts], dtype=torch.int32)
        end_ind = torch.tensor([end_pos], dtype=torch.int32)

        # inefficient - this recalculates intermediate activations for the
        # entire receptive fields, rather than just the advancing front
        while not torch.equal(inrange[1], end_ind[0]):
        # while inrange[1] != end_ind[0]:
            sig = self.base_layer(wav_onehot[:,:,inrange[0]:inrange[1]]) 
            sig, skp_sum = self.conv_layers[0](sig, cond[:,:,inrange[0]:inrange[1]])
            for layer in self.conv_layers[1:]:
                sig, skp = layer(sig, cond[:,:,inrange[0]:inrange[1]])
                skp_sum += skp

            post1 = self.post1(self.relu(skp_sum))
            quant = self.post2(self.relu(post1))
            cat = dcat.OneHotCategorical(logits=quant.squeeze(2))
            wav_onehot[1:,:,inrange[1]] = cat.sample()[1:,...]
            inrange += 1
            if inrange[0] % 100 == 0:
                print(inrange, end_ind[0])

        
        # convert to value format
        quant_range = wav_onehot.new(list(range(self.n_quant)))
        wav = torch.matmul(wav_onehot.permute(0,2,1), quant_range)
        torch.set_printoptions(threshold=100000)
        pad = 5
        print('padding = {}'.format(pad))
        print('original')
        print(wav[0,start_pos-pad:end_pos+pad])
        print('synth')
        print(wav[1,start_pos-pad:end_pos+pad])

        # print(wav[:,end_pos:end_pos + 10000])
        print('synth range std: {}, baseline std: {}'.format(
            wav[:,start_pos:end_pos].std(), wav[:,end_pos:].std()
            ))

        return wav