def train(self, index): ss = self.state ss.to(self.device) current_stats = {} # for resuming the learning rate sorted_lr_steps = sorted(self.learning_rates.keys()) lr_index = util.greatest_lower_bound(sorted_lr_steps, ss.step) ss.update_learning_rate(self.learning_rates[sorted_lr_steps[lr_index]]) if ss.model.bn_type != 'none': sorted_as_steps = sorted(self.anneal_schedule.keys()) as_index = util.greatest_lower_bound(sorted_as_steps, ss.step) ss.model.objective.update_anneal_weight( self.anneal_schedule[sorted_as_steps[as_index]]) if ss.model.bn_type in ('vqvae', 'vqvae-ema'): ss.model.init_codebook(self.data_iter, 10000) while ss.step < self.opts.max_steps: if ss.step in self.learning_rates: ss.update_learning_rate(self.learning_rates[ss.step]) if ss.model.bn_type == 'vae' and ss.step in self.anneal_schedule: ss.model.objective.update_anneal_weight( self.anneal_schedule[ss.step]) loss = self.optim_step_fn() if ss.model.bn_type == 'vqvae-ema' and ss.step == 10000: ss.model.bottleneck.update_codebook() if ss.step % self.opts.progress_interval == 0: current_stats.update({ 'step': ss.step, 'loss': loss, 'lrate': ss.optim.param_groups[0]['lr'], 'tprb_m': self.avg_prob_target(), # 'pk_d_m': avg_peak_dist }) current_stats.update(ss.model.objective.metrics) if ss.model.bn_type in ('vae'): current_stats['free_nats'] = ss.model.objective.free_nats current_stats['anneal_weight'] = \ ss.model.objective.anneal_weight.item() if ss.model.bn_type in ('vqvae', 'vqvae-ema', 'ae', 'vae'): current_stats.update(ss.model.encoder.metrics) netmisc.print_metrics(current_stats, index, 100) stderr.flush() if ((ss.step % self.opts.save_interval == 0 and ss.step != self.start_step)): self.save_checkpoint() ss.step += 1
def next_slice(self): """Get a random slice of a file, together with its start position and ID. Populates self.snd_slice, self.mel_slice, and self.mask""" picks = np.random(0, self.n_total_samples, self.batch_size) for vpos, b in enumerate(picks): file_i = util.greatest_lower_bound(self.voffset, vpos) last_in = self.n_snd_elem[file_i] - 1 last_out = self.n_samples[file_i] - 1 sam_i = vpos - self.voffset[file_i] mel_in_b, mel_in_e = rf.get_rfield(self.mel_in, self.dec_out, sam_i, sam_i, last_out) dec_in_b, dec_in_e = rf.get_rfield(self.dec_in, self.dec_out, sam_i, sam_i, last_out) out_b, out_e = rf.get_ifield(self.ae_wav_in, self.dec_out, snd_in_b, snd_in_e, last_in) snd_off = self.snd_offset[file_i] mel_off = self.mel_offset[file_i] self.snd_slice[b] = self.snd_data[snd_off + dec_in_b:snd_off + dec_in_e + 1] self.mel_slice[b] = self.mel_data[mel_off + mel_in_b:mel_off + mel_in_e + 1] self.mask[b].zero_() self.mask[b, sam_i - out_b] = 1 assert self.mask.size()[1] == out_e - out_b
def graph_shadow(in_layer, out_layer, in_b, in_e): out_b, out_e = graph_ifield(in_layer, in_b, in_e) if out_b == out_e: return 0, 0 # search through the in_layer until the matching position is found out_b_pos = out_layer[out_b].position out_e_pos = out_layer[out_e - 1].position positions = list(map(lambda n: n.position, in_layer)) lb_b = util.greatest_lower_bound(positions, out_b_pos) for i in range(lb_b, len(in_layer)): n = in_layer[i] if n.position >= out_b_pos: shadow_b = i break lb_e = util.greatest_lower_bound(positions, out_e_pos) shadow_e = lb_e #for i in range(lb_e, len(in_layer)): # n = in_layer[i] # if n.position <= out_e_pos: # shadow_e = i # break return shadow_b, shadow_e + 1
def gen_fn(): for iter_pos, vind in perm_gen: vpos = offset + vind * self.n_sample_win wav_file_ind = util.greatest_lower_bound(self.vstart, vpos) wav_off = vpos - self.vstart[wav_file_ind] # self.perm_gen_pos gives the position that will be yielded next self.perm_gen_pos = iter_pos + 1 yield wav_file_ind, wav_off, vind, \ self.wav_ids[wav_file_ind], \ self.wav_buf[wav_file_ind][wav_off:wav_off + self.slice_size] # Reset state variables self.perm_gen_pos = 0
def gen_fn(): # random state for self.perm is determined from # self.wav_buffer_rand_state, so don't need to store it. perm_gen = self.perm.permutation_gen_fn( self.perm_gen_pos, int(self.perm.n_items * self.frac_perm_use)) for iter_pos, vind in perm_gen: vpos = self.offset + vind * self.n_sample_win wav_file_ind = util.greatest_lower_bound(self.vstart, vpos) wav_off = vpos - self.vstart[wav_file_ind] # self.perm_gen_pos gives the position that will be yielded next self.perm_gen_pos = iter_pos + 1 yield wav_file_ind, wav_off, vind, \ self.wav_ids[wav_file_ind], \ self.wav_buf[wav_file_ind][wav_off:wav_off + self.slice_size] # We've exhausted the iterator, next position should be zero self.perm_gen_pos = 0
def train(self): hps = self.state.hps ss = self.state current_stats = {} writer_stats = {} # for resuming the learning rate sorted_lr_steps = sorted(self.learning_rates.keys()) lr_index = util.greatest_lower_bound(sorted_lr_steps, ss.data.global_step) ss.update_learning_rate(self.learning_rates[sorted_lr_steps[lr_index]]) if ss.model.bn_type != 'none': sorted_as_steps = sorted(self.anneal_schedule.keys()) as_index = util.greatest_lower_bound(sorted_as_steps, ss.data.global_step) ss.model.objective.update_anneal_weight( self.anneal_schedule[sorted_as_steps[as_index]]) if ss.model.bn_type in ('vqvae', 'vqvae-ema'): ss.model.init_codebook(self.data_iter, 10000) start_time = time.time() for batch_num, batch in enumerate(self.device_loader): wav, mel, voice, jitter, position = batch global_step = len(ss.data.dataset) * position[0] + position[1] # print(f'replica {self.replica_index}, batch {batch_num}', file=stderr) # stderr.flush() if (batch_num % hps.save_interval == 0 and batch_num != 0): self.save_checkpoint(position) if hps.skip_loop_body: continue lr_index = util.greatest_lower_bound(sorted_lr_steps, global_step) ss.update_learning_rate( self.learning_rates[sorted_lr_steps[lr_index]]) # if ss.data.global_step in self.learning_rates: # ss.update_learning_rate(self.learning_rates[ss.data.global_step]) if ss.model.bn_type == 'vae' and ss.step in self.anneal_schedule: ss.model.objective.update_anneal_weight( self.anneal_schedule[ss.data.global_step]) ss.optim.zero_grad() quant, self.target, loss = self.state.model.run( wav, mel, voice, jitter) self.probs = self.softmax(quant) self.mel_enc_input = mel # print(f'after model.run', file=stderr) # stderr.flush() loss.backward() # print(f'after loss.backward()', file=stderr) # stderr.flush() if batch_num % hps.progress_interval == 0: pars_copy = [p.data.clone() for p in ss.model.parameters()] # print(f'after pars_copy', file=stderr) # stderr.flush() if self.is_tpu: xm.optimizer_step(ss.optim) else: ss.optim.step() ss.optim_step += 1 if ss.model.bn_type == 'vqvae-ema' and ss.data.global_step == 10000: ss.model.bottleneck.update_codebook() tprb_m = self.avg_prob_target() if batch_num % hps.progress_interval == 0: iterator = zip(pars_copy, ss.model.named_parameters()) uw_ratio = { np[0]: t.norm(c - np[1].data) / c.norm() for c, np in iterator } writer_stats.update({'uwr': uw_ratio}) if self.is_tpu: count = torch_xla._XLAC._xla_get_replication_devices_count( ) loss_red, tprb_red = xm.all_reduce('sum', [loss, tprb_m], scale=1.0 / count) # loss_red = xm.all_reduce('all_loss', loss, reduce_mean) # tprb_red = xm.all_reduce('all_tprb', tprb_m, reduce_mean) else: loss_red = loss tprb_red = tprb_m writer_stats.update({ 'loss_r': loss_red, 'tprb_r': tprb_red, 'optim_step': ss.optim_step }) current_stats.update({ 'optim_step': ss.optim_step, 'gstep': global_step, # 'gstep': ss.data.global_step, 'epoch': position[0], 'step': position[1], # 'loss': loss, 'lrate': ss.optim.param_groups[0]['lr'], # 'tprb_m': tprb_m, # 'pk_d_m': avg_peak_dist }) current_stats.update(ss.model.objective.metrics) if ss.model.bn_type in ('vae'): current_stats['free_nats'] = ss.model.objective.free_nats current_stats['anneal_weight'] = \ ss.model.objective.anneal_weight.item() if ss.model.bn_type in ('vqvae', 'vqvae-ema', 'ae', 'vae'): current_stats.update(ss.model.encoder.metrics) if self.is_tpu: xm.add_step_closure(self.train_update, args=(writer_stats, current_stats)) else: self.train_update(writer_stats, current_stats) # if not self.is_tpu or xm.is_master_ordinal(): # if batch_num in range(25, 50) or batch_num in range(75, 100): stderr.flush() elapsed = time.time() - start_time
def compute_n_items(cls, requested_n_items): ind = util.greatest_lower_bound(cls.primes, requested_n_items) if ind == -1: raise InvalidArgument return cls.primes[ind]
def main(): if len(sys.argv) == 1 or sys.argv[1] not in ('new', 'resume'): print(parse_tools.top_usage, file=stderr) return mode = sys.argv[1] del sys.argv[1] if mode == 'new': opts = parse_tools.two_stage_parse(parse_tools.cold) elif mode == 'resume': opts = parse_tools.resume.parse_args() opts.device = None if not opts.disable_cuda and torch.cuda.is_available(): opts.device = torch.device('cuda') else: opts.device = torch.device('cpu') ckpt_path = util.CheckpointPath(opts.ckpt_template) # Construct model if mode == 'new': # Initialize model pre_params = parse_tools.get_prefixed_items(vars(opts), 'pre_') enc_params = parse_tools.get_prefixed_items(vars(opts), 'enc_') bn_params = parse_tools.get_prefixed_items(vars(opts), 'bn_') dec_params = parse_tools.get_prefixed_items(vars(opts), 'dec_') # Initialize data sample_catalog = D.parse_sample_catalog(opts.sam_file) data = D.WavSlices(sample_catalog, pre_params['sample_rate'], opts.frac_permutation_use, opts.requested_wav_buf_sz) dec_params['n_speakers'] = data.num_speakers() #with torch.autograd.set_detect_anomaly(True): model = ae.AutoEncoder(pre_params, enc_params, bn_params, dec_params) print('Initializing model parameters', file=stderr) model.initialize_weights() # Construct overall state state = checkpoint.State(0, model, data) else: state = checkpoint.State() state.load(opts.ckpt_file) print('Restored model and data from {}'.format(opts.ckpt_file), file=stderr) state.model.set_geometry(opts.n_sam_per_slice) state.data.set_geometry(opts.n_batch, state.model.input_size, state.model.output_size) state.model.to(device=opts.device) #total_bytes = 0 #for name, par in model.named_parameters(): # n_bytes = par.data.nelement() * par.data.element_size() # total_bytes += n_bytes # print(name, type(par.data), par.size(), n_bytes) #print('total_bytes: ', total_bytes) # Initialize optimizer model_params = state.model.parameters() metrics = ae.Metrics(state.model, None) batch_gen = state.data.batch_slice_gen_fn() #loss_fcn = state.model.loss_factory(state.data.batch_slice_gen_fn()) # Start training print('Starting training...', file=stderr) print("Step\tLoss\tAvgProbTarget\tPeakDist\tAvgMax", file=stderr) stderr.flush() learning_rates = dict(zip(opts.learning_rate_steps, opts.learning_rate_rates)) start_step = state.step if start_step not in learning_rates: ref_step = util.greatest_lower_bound(opts.learning_rate_steps, start_step) metrics.optim = torch.optim.Adam(params=model_params, lr=learning_rates[ref_step]) while state.step < opts.max_steps: if state.step in learning_rates: metrics.optim = torch.optim.Adam(params=model_params, lr=learning_rates[state.step]) # do 'pip install --upgrade scipy' if you get 'FutureWarning: ...' metrics.update(batch_gen) loss = metrics.optim.step(metrics.loss) avg_peak_dist = metrics.peak_dist() avg_max = metrics.avg_max() avg_prob_target = metrics.avg_prob_target() # Progress reporting if state.step % opts.progress_interval == 0: fmt = "{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}" print(fmt.format(state.step, loss, avg_prob_target, avg_peak_dist, avg_max), file=stderr) stderr.flush() # Checkpointing if state.step % opts.save_interval == 0 and state.step != start_step: ckpt_file = ckpt_path.path(state.step) state.save(ckpt_file) print('Saved checkpoint to {}'.format(ckpt_file), file=stderr) state.step += 1