Ejemplo n.º 1
0
    def train(self, index):
        ss = self.state
        ss.to(self.device)
        current_stats = {}

        # for resuming the learning rate
        sorted_lr_steps = sorted(self.learning_rates.keys())
        lr_index = util.greatest_lower_bound(sorted_lr_steps, ss.step)
        ss.update_learning_rate(self.learning_rates[sorted_lr_steps[lr_index]])

        if ss.model.bn_type != 'none':
            sorted_as_steps = sorted(self.anneal_schedule.keys())
            as_index = util.greatest_lower_bound(sorted_as_steps, ss.step)
            ss.model.objective.update_anneal_weight(
                self.anneal_schedule[sorted_as_steps[as_index]])

        if ss.model.bn_type in ('vqvae', 'vqvae-ema'):
            ss.model.init_codebook(self.data_iter, 10000)

        while ss.step < self.opts.max_steps:
            if ss.step in self.learning_rates:
                ss.update_learning_rate(self.learning_rates[ss.step])

            if ss.model.bn_type == 'vae' and ss.step in self.anneal_schedule:
                ss.model.objective.update_anneal_weight(
                    self.anneal_schedule[ss.step])

            loss = self.optim_step_fn()

            if ss.model.bn_type == 'vqvae-ema' and ss.step == 10000:
                ss.model.bottleneck.update_codebook()

            if ss.step % self.opts.progress_interval == 0:
                current_stats.update({
                    'step': ss.step,
                    'loss': loss,
                    'lrate': ss.optim.param_groups[0]['lr'],
                    'tprb_m': self.avg_prob_target(),
                    # 'pk_d_m': avg_peak_dist
                })
                current_stats.update(ss.model.objective.metrics)

                if ss.model.bn_type in ('vae'):
                    current_stats['free_nats'] = ss.model.objective.free_nats
                    current_stats['anneal_weight'] = \
                            ss.model.objective.anneal_weight.item()

                if ss.model.bn_type in ('vqvae', 'vqvae-ema', 'ae', 'vae'):
                    current_stats.update(ss.model.encoder.metrics)

                netmisc.print_metrics(current_stats, index, 100)
                stderr.flush()

            if ((ss.step % self.opts.save_interval == 0
                 and ss.step != self.start_step)):
                self.save_checkpoint()
            ss.step += 1
Ejemplo n.º 2
0
    def next_slice(self):
        """Get a random slice of a file, together with its start position
        and ID.  Populates self.snd_slice, self.mel_slice, and self.mask"""
        picks = np.random(0, self.n_total_samples, self.batch_size)
        for vpos, b in enumerate(picks):
            file_i = util.greatest_lower_bound(self.voffset, vpos)
            last_in = self.n_snd_elem[file_i] - 1
            last_out = self.n_samples[file_i] - 1
            sam_i = vpos - self.voffset[file_i]
            mel_in_b, mel_in_e = rf.get_rfield(self.mel_in, self.dec_out,
                                               sam_i, sam_i, last_out)
            dec_in_b, dec_in_e = rf.get_rfield(self.dec_in, self.dec_out,
                                               sam_i, sam_i, last_out)
            out_b, out_e = rf.get_ifield(self.ae_wav_in, self.dec_out,
                                         snd_in_b, snd_in_e, last_in)

            snd_off = self.snd_offset[file_i]
            mel_off = self.mel_offset[file_i]
            self.snd_slice[b] = self.snd_data[snd_off + dec_in_b:snd_off +
                                              dec_in_e + 1]
            self.mel_slice[b] = self.mel_data[mel_off + mel_in_b:mel_off +
                                              mel_in_e + 1]
            self.mask[b].zero_()
            self.mask[b, sam_i - out_b] = 1
            assert self.mask.size()[1] == out_e - out_b
Ejemplo n.º 3
0
def graph_shadow(in_layer, out_layer, in_b, in_e):
    out_b, out_e = graph_ifield(in_layer, in_b, in_e)
    if out_b == out_e:
        return 0, 0
    # search through the in_layer until the matching position is found
    out_b_pos = out_layer[out_b].position
    out_e_pos = out_layer[out_e - 1].position
    positions = list(map(lambda n: n.position, in_layer))
    lb_b = util.greatest_lower_bound(positions, out_b_pos)
    for i in range(lb_b, len(in_layer)):
        n = in_layer[i]
        if n.position >= out_b_pos:
            shadow_b = i
            break
    lb_e = util.greatest_lower_bound(positions, out_e_pos)
    shadow_e = lb_e
    #for i in range(lb_e, len(in_layer)):
    #    n = in_layer[i]
    #    if n.position <= out_e_pos:
    #        shadow_e = i
    #        break
    return shadow_b, shadow_e + 1
Ejemplo n.º 4
0
        def gen_fn():
            for iter_pos, vind in perm_gen:
                vpos = offset + vind * self.n_sample_win
                wav_file_ind = util.greatest_lower_bound(self.vstart, vpos)
                wav_off = vpos - self.vstart[wav_file_ind]

                # self.perm_gen_pos gives the position that will be yielded next
                self.perm_gen_pos = iter_pos + 1
                yield wav_file_ind, wav_off, vind, \
                        self.wav_ids[wav_file_ind], \
                        self.wav_buf[wav_file_ind][wav_off:wav_off + self.slice_size]
            # Reset state variables
            self.perm_gen_pos = 0
Ejemplo n.º 5
0
        def gen_fn():
            # random state for self.perm is determined from
            # self.wav_buffer_rand_state, so don't need to store it.
            perm_gen = self.perm.permutation_gen_fn(
                self.perm_gen_pos, int(self.perm.n_items * self.frac_perm_use))
            for iter_pos, vind in perm_gen:
                vpos = self.offset + vind * self.n_sample_win
                wav_file_ind = util.greatest_lower_bound(self.vstart, vpos)
                wav_off = vpos - self.vstart[wav_file_ind]

                # self.perm_gen_pos gives the position that will be yielded next
                self.perm_gen_pos = iter_pos + 1
                yield wav_file_ind, wav_off, vind, \
                        self.wav_ids[wav_file_ind], \
                        self.wav_buf[wav_file_ind][wav_off:wav_off + self.slice_size]

            # We've exhausted the iterator, next position should be zero
            self.perm_gen_pos = 0
Ejemplo n.º 6
0
    def train(self):
        hps = self.state.hps
        ss = self.state
        current_stats = {}
        writer_stats = {}

        # for resuming the learning rate
        sorted_lr_steps = sorted(self.learning_rates.keys())
        lr_index = util.greatest_lower_bound(sorted_lr_steps,
                                             ss.data.global_step)
        ss.update_learning_rate(self.learning_rates[sorted_lr_steps[lr_index]])

        if ss.model.bn_type != 'none':
            sorted_as_steps = sorted(self.anneal_schedule.keys())
            as_index = util.greatest_lower_bound(sorted_as_steps,
                                                 ss.data.global_step)
            ss.model.objective.update_anneal_weight(
                self.anneal_schedule[sorted_as_steps[as_index]])

        if ss.model.bn_type in ('vqvae', 'vqvae-ema'):
            ss.model.init_codebook(self.data_iter, 10000)

        start_time = time.time()

        for batch_num, batch in enumerate(self.device_loader):
            wav, mel, voice, jitter, position = batch
            global_step = len(ss.data.dataset) * position[0] + position[1]

            # print(f'replica {self.replica_index}, batch {batch_num}', file=stderr)
            # stderr.flush()
            if (batch_num % hps.save_interval == 0 and batch_num != 0):
                self.save_checkpoint(position)

            if hps.skip_loop_body:
                continue

            lr_index = util.greatest_lower_bound(sorted_lr_steps, global_step)
            ss.update_learning_rate(
                self.learning_rates[sorted_lr_steps[lr_index]])
            # if ss.data.global_step in self.learning_rates:
            # ss.update_learning_rate(self.learning_rates[ss.data.global_step])

            if ss.model.bn_type == 'vae' and ss.step in self.anneal_schedule:
                ss.model.objective.update_anneal_weight(
                    self.anneal_schedule[ss.data.global_step])

            ss.optim.zero_grad()
            quant, self.target, loss = self.state.model.run(
                wav, mel, voice, jitter)
            self.probs = self.softmax(quant)
            self.mel_enc_input = mel
            # print(f'after model.run', file=stderr)
            # stderr.flush()
            loss.backward()

            # print(f'after loss.backward()', file=stderr)
            # stderr.flush()

            if batch_num % hps.progress_interval == 0:
                pars_copy = [p.data.clone() for p in ss.model.parameters()]

            # print(f'after pars_copy', file=stderr)
            # stderr.flush()

            if self.is_tpu:
                xm.optimizer_step(ss.optim)
            else:
                ss.optim.step()

            ss.optim_step += 1

            if ss.model.bn_type == 'vqvae-ema' and ss.data.global_step == 10000:
                ss.model.bottleneck.update_codebook()

            tprb_m = self.avg_prob_target()

            if batch_num % hps.progress_interval == 0:
                iterator = zip(pars_copy, ss.model.named_parameters())
                uw_ratio = {
                    np[0]: t.norm(c - np[1].data) / c.norm()
                    for c, np in iterator
                }

                writer_stats.update({'uwr': uw_ratio})

                if self.is_tpu:
                    count = torch_xla._XLAC._xla_get_replication_devices_count(
                    )
                    loss_red, tprb_red = xm.all_reduce('sum', [loss, tprb_m],
                                                       scale=1.0 / count)
                    # loss_red = xm.all_reduce('all_loss', loss, reduce_mean)
                    # tprb_red = xm.all_reduce('all_tprb', tprb_m, reduce_mean)
                else:
                    loss_red = loss
                    tprb_red = tprb_m

                writer_stats.update({
                    'loss_r': loss_red,
                    'tprb_r': tprb_red,
                    'optim_step': ss.optim_step
                })

                current_stats.update({
                    'optim_step': ss.optim_step,
                    'gstep': global_step,
                    # 'gstep': ss.data.global_step,
                    'epoch': position[0],
                    'step': position[1],
                    # 'loss': loss,
                    'lrate': ss.optim.param_groups[0]['lr'],
                    # 'tprb_m': tprb_m,
                    # 'pk_d_m': avg_peak_dist
                })
                current_stats.update(ss.model.objective.metrics)

                if ss.model.bn_type in ('vae'):
                    current_stats['free_nats'] = ss.model.objective.free_nats
                    current_stats['anneal_weight'] = \
                            ss.model.objective.anneal_weight.item()

                if ss.model.bn_type in ('vqvae', 'vqvae-ema', 'ae', 'vae'):
                    current_stats.update(ss.model.encoder.metrics)

                if self.is_tpu:
                    xm.add_step_closure(self.train_update,
                                        args=(writer_stats, current_stats))
                else:
                    self.train_update(writer_stats, current_stats)

                # if not self.is_tpu or xm.is_master_ordinal():
                # if batch_num in range(25, 50) or batch_num in range(75, 100):
                stderr.flush()
                elapsed = time.time() - start_time
Ejemplo n.º 7
0
 def compute_n_items(cls, requested_n_items):
     ind = util.greatest_lower_bound(cls.primes, requested_n_items)
     if ind == -1:
         raise InvalidArgument
     return cls.primes[ind]
Ejemplo n.º 8
0
def main():
    if len(sys.argv) == 1 or sys.argv[1] not in ('new', 'resume'):
        print(parse_tools.top_usage, file=stderr)
        return

    mode = sys.argv[1]
    del sys.argv[1]
    if mode == 'new':
        opts = parse_tools.two_stage_parse(parse_tools.cold)
    elif mode == 'resume':
        opts = parse_tools.resume.parse_args()  

    opts.device = None
    if not opts.disable_cuda and torch.cuda.is_available():
        opts.device = torch.device('cuda')
    else:
        opts.device = torch.device('cpu') 

    ckpt_path = util.CheckpointPath(opts.ckpt_template)

    # Construct model
    if mode == 'new':
        # Initialize model
        pre_params = parse_tools.get_prefixed_items(vars(opts), 'pre_')
        enc_params = parse_tools.get_prefixed_items(vars(opts), 'enc_')
        bn_params = parse_tools.get_prefixed_items(vars(opts), 'bn_')
        dec_params = parse_tools.get_prefixed_items(vars(opts), 'dec_')

        # Initialize data
        sample_catalog = D.parse_sample_catalog(opts.sam_file)
        data = D.WavSlices(sample_catalog, pre_params['sample_rate'],
                opts.frac_permutation_use, opts.requested_wav_buf_sz)
        dec_params['n_speakers'] = data.num_speakers()

        #with torch.autograd.set_detect_anomaly(True):
        model = ae.AutoEncoder(pre_params, enc_params, bn_params, dec_params)
        print('Initializing model parameters', file=stderr)
        model.initialize_weights()

        # Construct overall state
        state = checkpoint.State(0, model, data)

    else:
        state = checkpoint.State()
        state.load(opts.ckpt_file)
        print('Restored model and data from {}'.format(opts.ckpt_file), file=stderr)

    state.model.set_geometry(opts.n_sam_per_slice)

    state.data.set_geometry(opts.n_batch, state.model.input_size,
            state.model.output_size)

    state.model.to(device=opts.device)

    #total_bytes = 0
    #for name, par in model.named_parameters():
    #    n_bytes = par.data.nelement() * par.data.element_size()
    #    total_bytes += n_bytes
    #    print(name, type(par.data), par.size(), n_bytes)
    #print('total_bytes: ', total_bytes)

    # Initialize optimizer
    model_params = state.model.parameters()
    metrics = ae.Metrics(state.model, None)
    batch_gen = state.data.batch_slice_gen_fn()

    #loss_fcn = state.model.loss_factory(state.data.batch_slice_gen_fn())

    # Start training
    print('Starting training...', file=stderr)
    print("Step\tLoss\tAvgProbTarget\tPeakDist\tAvgMax", file=stderr)
    stderr.flush()

    learning_rates = dict(zip(opts.learning_rate_steps, opts.learning_rate_rates))
    start_step = state.step
    if start_step not in learning_rates:
        ref_step = util.greatest_lower_bound(opts.learning_rate_steps, start_step)
        metrics.optim = torch.optim.Adam(params=model_params,
                lr=learning_rates[ref_step])

    while state.step < opts.max_steps:
        if state.step in learning_rates:
            metrics.optim = torch.optim.Adam(params=model_params,
                    lr=learning_rates[state.step])
        # do 'pip install --upgrade scipy' if you get 'FutureWarning: ...'
        metrics.update(batch_gen)
        loss = metrics.optim.step(metrics.loss)
        avg_peak_dist = metrics.peak_dist()
        avg_max = metrics.avg_max()
        avg_prob_target = metrics.avg_prob_target()

        # Progress reporting
        if state.step % opts.progress_interval == 0:
            fmt = "{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}"
            print(fmt.format(state.step, loss, avg_prob_target, avg_peak_dist,
                avg_max), file=stderr)
            stderr.flush()

        # Checkpointing
        if state.step % opts.save_interval == 0 and state.step != start_step:
            ckpt_file = ckpt_path.path(state.step)
            state.save(ckpt_file)
            print('Saved checkpoint to {}'.format(ckpt_file), file=stderr)

        state.step += 1