def __init__(self, opts, dataset): opts_dict = vars(opts) dec_params = parse_tools.get_prefixed_items(opts_dict, 'dec_') dec_params['n_speakers'] = dataset.num_speakers() mi_params = parse_tools.get_prefixed_items(opts_dict, 'mi_') self.init_args = {'dec_params': dec_params, 'mi_params': mi_params} self._initialize()
def __init__(self, opts, dataset): opts_dict = vars(opts) enc_params = parse_tools.get_prefixed_items(opts_dict, 'enc_') bn_params = parse_tools.get_prefixed_items(opts_dict, 'bn_') dec_params = parse_tools.get_prefixed_items(opts_dict, 'dec_') dec_params['n_speakers'] = dataset.num_speakers() self.init_args = { 'enc_params': enc_params, 'bn_params': bn_params, 'dec_params': dec_params, 'n_mel_chan': dataset.num_mel_chan(), 'training': opts.training } self._initialize()
def __init__(self, opts): opts_dict = vars(opts) pre_pars = parse_tools.get_prefixed_items(opts_dict, 'pre_') self.init_args = { 'batch_size': opts.n_batch, 'window_batch_size': opts.n_win_batch, 'jitter_prob': opts.jitter_prob, 'sample_rate': pre_pars['sample_rate'], 'mfcc_win_sz': pre_pars['mfcc_win_sz'], 'mfcc_hop_sz': pre_pars['mfcc_hop_sz'], 'n_mels': pre_pars['n_mels'], 'n_mfcc': pre_pars['n_mfcc'] } self._initialize()
def main(): if len(sys.argv) == 1 or sys.argv[1] not in ('new', 'resume'): print(parse_tools.top_usage, file=stderr) return print('Command line: ', ' '.join(sys.argv), file=stderr) stderr.flush() mode = sys.argv[1] del sys.argv[1] if mode == 'new': cold_parser = parse_tools.cold_parser() opts = parse_tools.two_stage_parse(cold_parser) elif mode == 'resume': resume_parser = parse_tools.resume_parser() opts = resume_parser.parse_args() opts.device = None if not opts.disable_cuda and torch.cuda.is_available(): opts.device = torch.device('cuda') print('Using GPU', file=stderr) else: opts.device = torch.device('cpu') print('Using CPU', file=stderr) stderr.flush() ckpt_path = util.CheckpointPath(opts.ckpt_template) learning_rates = dict( zip(opts.learning_rate_steps, opts.learning_rate_rates)) # Construct model if mode == 'new': # Initialize model pre_params = parse_tools.get_prefixed_items(vars(opts), 'pre_') enc_params = parse_tools.get_prefixed_items(vars(opts), 'enc_') bn_params = parse_tools.get_prefixed_items(vars(opts), 'bn_') dec_params = parse_tools.get_prefixed_items(vars(opts), 'dec_') # Initialize data data_source = data.Slice(opts.index_file_prefix, opts.max_gpu_data_bytes, opts.n_batch) dec_params['n_speakers'] = data_source.num_speakers() model = ae.AutoEncoder(pre_params, enc_params, bn_params, dec_params, opts.n_sam_per_slice) optim = torch.optim.Adam(params=model.parameters(), lr=learning_rates[0]) state = checkpoint.State(0, model, data_source, optim) else: state = checkpoint.State() state.load(opts.ckpt_file) state.model.set_slice_size(opts.n_sam_per_slice) print('Restored model, data, and optim from {}'.format(opts.ckpt_file), file=stderr) #print('Data state: {}'.format(state.data), file=stderr) #print('Model state: {}'.format(state.model.checksum())) #print('Optim state: {}'.format(state.optim_checksum())) stderr.flush() start_step = state.step print('Model input size: {}'.format(state.model.input_size), file=stderr) stderr.flush() # set this to zero if you want to print out a logging header in resume mode as well netmisc.set_print_iter(0) state.data.init_geometry(state.model.preprocess.rf, state.model) #state.data.set_geometry(opts.n_batch, state.model.input_size, # state.model.output_size) state.to(device=opts.device) # Initialize optimizer metrics = ae.Metrics(state) batch_gen = state.data.batch_slice_gen_fn() #for p in list(state.model.encoder.parameters()): # with torch.no_grad(): # p *= 1 # Start training print('Training parameters used:', file=stderr) pprint(opts, stderr) state.init_torch_generator() while state.step < opts.max_steps: if state.step in learning_rates: state.update_learning_rate(learning_rates[state.step]) # do 'pip install --upgrade scipy' if you get 'FutureWarning: ...' # print('in main loop') if state.step in (1, 10, 50, 100, 300, 500) and state.model.bn_type == 'vqvae': print('Reinitializing embed with current distribution', file=stderr) stderr.flush() state.model.init_vq_embed(batch_gen) metrics.update(batch_gen) loss = metrics.state.optim.step(metrics.loss) avg_peak_dist = metrics.peak_dist() avg_max = metrics.avg_max() avg_prob_target = metrics.avg_prob_target() if False: for n, p in list(state.model.encoder.named_parameters()): g = p.grad if g is None: print('{:60s}\tNone'.format(n), file=stderr) else: fmt = '{:s}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' print(fmt.format(n, g.max(), g.min(), g.mean(), g.std()), file=stderr) # Progress reporting if state.step % opts.progress_interval == 0: current_stats = { 'step': state.step, 'loss': loss, 'tprb_m': avg_prob_target, 'pk_d_m': avg_peak_dist } #fmt = "M\t{:d}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}" #print(fmt.format(state.step, loss, avg_prob_target, avg_peak_dist, # avg_max), file=stderr) if state.model.bn_type == 'vqvae': current_stats.update(state.model.objective.metrics) netmisc.print_metrics(current_stats, 1000000) stderr.flush() state.step += 1 # Checkpointing if ((state.step % opts.save_interval == 0 and state.step != start_step) or (mode == 'new' and state.step == 1)): ckpt_file = ckpt_path.path(state.step) state.save(ckpt_file) print('Saved checkpoint to {}'.format(ckpt_file), file=stderr) #print('Optim state: {}'.format(state.optim_checksum()), file=stderr) stderr.flush()
def __init__(self, mode, opts): print('Initializing model and data source...', end='', file=stderr) stderr.flush() self.learning_rates = dict(zip(opts.learning_rate_steps, opts.learning_rate_rates)) self.opts = opts if mode == 'new': torch.manual_seed(opts.random_seed) pre_par = parse_tools.get_prefixed_items(vars(opts), 'pre_') enc_par = parse_tools.get_prefixed_items(vars(opts), 'enc_') bn_par = parse_tools.get_prefixed_items(vars(opts), 'bn_') dec_par = parse_tools.get_prefixed_items(vars(opts), 'dec_') # Initialize data jprob = dec_par.pop('jitter_prob') dataset = data.Slice(opts.n_batch, opts.n_win_batch, jprob, pre_par['sample_rate'], pre_par['mfcc_win_sz'], pre_par['mfcc_hop_sz'], pre_par['n_mels'], pre_par['n_mfcc']) dataset.load_data(opts.dat_file) dec_par['n_speakers'] = dataset.num_speakers() model = ae.AutoEncoder(pre_par, enc_par, bn_par, dec_par, dataset.num_mel_chan(), training=True) model.post_init(dataset) dataset.post_init(model) optim = torch.optim.Adam(params=model.parameters(), lr=self.learning_rates[0]) self.state = checkpoint.State(0, model, dataset, optim) self.start_step = self.state.step else: self.state = checkpoint.State() self.state.load(opts.ckpt_file, opts.dat_file) self.start_step = self.state.step # print('Restored model, data, and optim from {}'.format(opts.ckpt_file), file=stderr) #print('Data state: {}'.format(state.data), file=stderr) #print('Model state: {}'.format(state.model.checksum())) #print('Optim state: {}'.format(state.optim_checksum())) stderr.flush() self.ckpt_path = util.CheckpointPath(self.opts.ckpt_template) self.quant = None self.target = None self.softmax = torch.nn.Softmax(1) # input to this is (B, Q, N) if self.opts.hwtype == 'GPU': self.device = torch.device('cuda') self.data_loader = self.state.data_loader self.data_loader.set_target_device(self.device) self.optim_step_fn = (lambda: self.state.optim.step(self.loss_fn)) self.data_iter = GPULoaderIter(iter(self.data_loader)) else: import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl self.device = xm.xla_device() self.data_loader = pl.ParallelLoader(self.state.data_loader, [self.device]) self.data_iter = TPULoaderIter(self.data_loader, self.device) self.optim_step_fn = (lambda : xm.optimizer_step(self.state.optim, optimizer_args={'closure': self.loss_fn})) self.state.init_torch_generator() print('Done.', file=stderr) stderr.flush()
def main(): if len(sys.argv) == 1 or sys.argv[1] not in ('new', 'resume'): print(parse_tools.top_usage, file=stderr) return mode = sys.argv[1] del sys.argv[1] if mode == 'new': opts = parse_tools.two_stage_parse(parse_tools.cold) elif mode == 'resume': opts = parse_tools.resume.parse_args() opts.device = None if not opts.disable_cuda and torch.cuda.is_available(): opts.device = torch.device('cuda') else: opts.device = torch.device('cpu') ckpt_path = util.CheckpointPath(opts.ckpt_template) # Construct model if mode == 'new': # Initialize model pre_params = parse_tools.get_prefixed_items(vars(opts), 'pre_') enc_params = parse_tools.get_prefixed_items(vars(opts), 'enc_') bn_params = parse_tools.get_prefixed_items(vars(opts), 'bn_') dec_params = parse_tools.get_prefixed_items(vars(opts), 'dec_') # Initialize data sample_catalog = D.parse_sample_catalog(opts.sam_file) data = D.WavSlices(sample_catalog, pre_params['sample_rate'], opts.frac_permutation_use, opts.requested_wav_buf_sz) dec_params['n_speakers'] = data.num_speakers() #with torch.autograd.set_detect_anomaly(True): model = ae.AutoEncoder(pre_params, enc_params, bn_params, dec_params) print('Initializing model parameters', file=stderr) model.initialize_weights() # Construct overall state state = checkpoint.State(0, model, data) else: state = checkpoint.State() state.load(opts.ckpt_file) print('Restored model and data from {}'.format(opts.ckpt_file), file=stderr) state.model.set_geometry(opts.n_sam_per_slice) state.data.set_geometry(opts.n_batch, state.model.input_size, state.model.output_size) state.model.to(device=opts.device) #total_bytes = 0 #for name, par in model.named_parameters(): # n_bytes = par.data.nelement() * par.data.element_size() # total_bytes += n_bytes # print(name, type(par.data), par.size(), n_bytes) #print('total_bytes: ', total_bytes) # Initialize optimizer model_params = state.model.parameters() metrics = ae.Metrics(state.model, None) batch_gen = state.data.batch_slice_gen_fn() #loss_fcn = state.model.loss_factory(state.data.batch_slice_gen_fn()) # Start training print('Starting training...', file=stderr) print("Step\tLoss\tAvgProbTarget\tPeakDist\tAvgMax", file=stderr) stderr.flush() learning_rates = dict(zip(opts.learning_rate_steps, opts.learning_rate_rates)) start_step = state.step if start_step not in learning_rates: ref_step = util.greatest_lower_bound(opts.learning_rate_steps, start_step) metrics.optim = torch.optim.Adam(params=model_params, lr=learning_rates[ref_step]) while state.step < opts.max_steps: if state.step in learning_rates: metrics.optim = torch.optim.Adam(params=model_params, lr=learning_rates[state.step]) # do 'pip install --upgrade scipy' if you get 'FutureWarning: ...' metrics.update(batch_gen) loss = metrics.optim.step(metrics.loss) avg_peak_dist = metrics.peak_dist() avg_max = metrics.avg_max() avg_prob_target = metrics.avg_prob_target() # Progress reporting if state.step % opts.progress_interval == 0: fmt = "{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}" print(fmt.format(state.step, loss, avg_prob_target, avg_peak_dist, avg_max), file=stderr) stderr.flush() # Checkpointing if state.step % opts.save_interval == 0 and state.step != start_step: ckpt_file = ckpt_path.path(state.step) state.save(ckpt_file) print('Saved checkpoint to {}'.format(ckpt_file), file=stderr) state.step += 1