Ejemplo n.º 1
0
    def set_geometry(self, n_sam_per_slice_req):
        '''Compute the relationship between the encoder input, decoder input,
        and input to the loss function'''
        self.rf.gen_stats(n_sam_per_slice_req)

        enc_input = self.preprocess.rf.src
        dec_input = self.decoder.last_upsample_rf.dst
        loss_input = self.rf.dst

        dec_off = rfield.offsets(enc_input, dec_input)
        pred_off = rfield.offsets(dec_input, loss_input)
        self.preprocess.set_geometry(dec_off, pred_off)

        self.input_size = enc_input.nv
        self.output_size = loss_input.nv
Ejemplo n.º 2
0
 def set_geometry(self, beg_rf, end_rf):
     '''
     Constructs the transpose convolution which mimics the usage pattern
     of WaveNet's local conditioning vectors and output.
     '''
     self.rf = rfield.condensed(beg_rf, end_rf, self.name)
     self.rf.gen_stats(self.rf)
     self.rf.init_nv(1)
     stride = self.rf.stride_ratio.denominator
     l_off, r_off = rfield.offsets(self.rf, self.rf)
     filter_sz = l_off - r_off + 1
     # pad_add = kernel_size - 1 - pad_arg (see torch.nn.ConvTranspose1d)
     # => pad_arg = kernel_size - 1 - pad_add
     pad_add = max(self.rf.l_pad, self.rf.r_pad)
     self.l_trim = pad_add - self.rf.l_pad
     self.r_trim = pad_add - self.rf.r_pad
     pad_arg = filter_sz - 1 - pad_add
     self.tconv = nn.ConvTranspose1d(1,
                                     1,
                                     filter_sz,
                                     stride,
                                     pad_arg,
                                     bias=False)
     self.tconv.weight.requires_grad = False
     nn.init.constant_(self.tconv.weight, 1.0 / self.rf.src.nv)
Ejemplo n.º 3
0
 def skip_lead(self):
     '''distance from start of this input to start of the final
     stack output'''
     if self.end_rf is None:
         raise RuntimeError('Must call init_end_rf() first')
     l_off, __ = rfield.offsets(self.rf.dst, self.end_rf.dst)
     return l_off
Ejemplo n.º 4
0
    def set_geometry(self):
        '''Compute the timestep offsets between the window boundaries of the
        encoder input wav, decoder input wav, and supervising wav input to the
        loss function'''
        self.rf.gen_stats(self.preprocess.rf)
        if self.bn_type in ('vae', 'vqvae'):
            self.objective.set_geometry(self.decoder.pre_upsample_rf,
                                        self.decoder.last_grcc_rf)

        # timestep offsets between input and output of the encoder
        enc_off = rfield.offsets(self.preprocess.rf,
                                 self.decoder.last_upsample_rf)

        # timestep offsets between wav input and output of decoder
        # NOTE: this starts from after the upsampling, because it is concerned
        # with the wav input, not conditioning vectors
        dec_off = rfield.offsets(self.decoder.last_upsample_rf.next(),
                                 self.decoder.rf)
        self.preprocess.set_geometry(enc_off, dec_off)
Ejemplo n.º 5
0
    def skip_lead(self):
        '''distance from start of this *output* to start of the final stack
        output.  Note that the skip information is the *output* of self.rf, not
        the input.
        '''
        if self.end_rf is None:
            raise RuntimeError('Must call init_bound_rfs() first')
        if self.rf == self.end_rf:
            return 0

        l_off, __ = rfield.offsets(self.rf.next(), self.end_rf)
        return l_off
Ejemplo n.º 6
0
 def forward(self, x):
     '''
     B, C, T = n_batch, n_in_chan, n_win
     x: (B, C, T)
     '''
     assert self.rf.src.nv == x.shape[2]
     out = self.conv(x)
     out = self.relu(out)
     if (self.do_res):
         l_off, r_off = rfield.offsets(self.rf.src, self.rf.dst)
         out += x[:, :, l_off:r_off or None]
     assert self.rf.dst.nv == out.shape[2]
     return out
Ejemplo n.º 7
0
 def cond_lead(self):
     '''distance from start of the overall stack input to
     the start of this convolution'''
     l_off, __ = rfield.offsets(self.beg_rf, self.rf)
     return l_off