def main(): if len(sys.argv) == 1 or sys.argv[1] not in ('new', 'resume'): print(parse_tools.top_usage, file=stderr) return print('Command line: ', ' '.join(sys.argv), file=stderr) stderr.flush() mode = sys.argv[1] del sys.argv[1] if mode == 'new': cold_parser = parse_tools.cold_parser() opts = parse_tools.two_stage_parse(cold_parser) elif mode == 'resume': resume_parser = parse_tools.resume_parser() opts = resume_parser.parse_args() if opts.hwtype == 'GPU': if not torch.cuda.is_available(): raise RuntimeError('GPU requested but not available') elif opts.hwtype in ('TPU', 'TPU-single'): import torch_xla.distributed.xla_multiprocessing as xmp else: raise RuntimeError(('Invalid device {} requested. ' + 'Must be GPU or TPU').format(opts.hwtype)) print('Using {}'.format(opts.hwtype), file=stderr) stderr.flush() # Start training print('Training parameters used:', file=stderr) pprint(opts, stderr) # set this to zero if you want to print out a logging header in resume mode as well netmisc.set_print_iter(0) if opts.hwtype == 'GPU': ae.Metrics(mode, opts).train(0) elif opts.hwtype == 'TPU': def _mp_fn(index, mode, opts): m = ae.Metrics(mode, opts) m.train(index) xmp.spawn(_mp_fn, args=(mode, opts), nprocs=1, start_method='fork') elif opts.hwtype == 'TPU-single': ae.Metrics(mode, opts).train(0)
def _mp_fn(index, mode, opts): m = ae.Metrics(mode, opts) m.train(index)
def main(): if len(sys.argv) == 1 or sys.argv[1] not in ('new', 'resume'): print(parse_tools.top_usage, file=stderr) return print('Command line: ', ' '.join(sys.argv), file=stderr) stderr.flush() mode = sys.argv[1] del sys.argv[1] if mode == 'new': cold_parser = parse_tools.cold_parser() opts = parse_tools.two_stage_parse(cold_parser) elif mode == 'resume': resume_parser = parse_tools.resume_parser() opts = resume_parser.parse_args() opts.device = None if not opts.disable_cuda and torch.cuda.is_available(): opts.device = torch.device('cuda') print('Using GPU', file=stderr) else: opts.device = torch.device('cpu') print('Using CPU', file=stderr) stderr.flush() ckpt_path = util.CheckpointPath(opts.ckpt_template) learning_rates = dict( zip(opts.learning_rate_steps, opts.learning_rate_rates)) # Construct model if mode == 'new': # Initialize model pre_params = parse_tools.get_prefixed_items(vars(opts), 'pre_') enc_params = parse_tools.get_prefixed_items(vars(opts), 'enc_') bn_params = parse_tools.get_prefixed_items(vars(opts), 'bn_') dec_params = parse_tools.get_prefixed_items(vars(opts), 'dec_') # Initialize data data_source = data.Slice(opts.index_file_prefix, opts.max_gpu_data_bytes, opts.n_batch) dec_params['n_speakers'] = data_source.num_speakers() model = ae.AutoEncoder(pre_params, enc_params, bn_params, dec_params, opts.n_sam_per_slice) optim = torch.optim.Adam(params=model.parameters(), lr=learning_rates[0]) state = checkpoint.State(0, model, data_source, optim) else: state = checkpoint.State() state.load(opts.ckpt_file) state.model.set_slice_size(opts.n_sam_per_slice) print('Restored model, data, and optim from {}'.format(opts.ckpt_file), file=stderr) #print('Data state: {}'.format(state.data), file=stderr) #print('Model state: {}'.format(state.model.checksum())) #print('Optim state: {}'.format(state.optim_checksum())) stderr.flush() start_step = state.step print('Model input size: {}'.format(state.model.input_size), file=stderr) stderr.flush() # set this to zero if you want to print out a logging header in resume mode as well netmisc.set_print_iter(0) state.data.init_geometry(state.model.preprocess.rf, state.model) #state.data.set_geometry(opts.n_batch, state.model.input_size, # state.model.output_size) state.to(device=opts.device) # Initialize optimizer metrics = ae.Metrics(state) batch_gen = state.data.batch_slice_gen_fn() #for p in list(state.model.encoder.parameters()): # with torch.no_grad(): # p *= 1 # Start training print('Training parameters used:', file=stderr) pprint(opts, stderr) state.init_torch_generator() while state.step < opts.max_steps: if state.step in learning_rates: state.update_learning_rate(learning_rates[state.step]) # do 'pip install --upgrade scipy' if you get 'FutureWarning: ...' # print('in main loop') if state.step in (1, 10, 50, 100, 300, 500) and state.model.bn_type == 'vqvae': print('Reinitializing embed with current distribution', file=stderr) stderr.flush() state.model.init_vq_embed(batch_gen) metrics.update(batch_gen) loss = metrics.state.optim.step(metrics.loss) avg_peak_dist = metrics.peak_dist() avg_max = metrics.avg_max() avg_prob_target = metrics.avg_prob_target() if False: for n, p in list(state.model.encoder.named_parameters()): g = p.grad if g is None: print('{:60s}\tNone'.format(n), file=stderr) else: fmt = '{:s}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' print(fmt.format(n, g.max(), g.min(), g.mean(), g.std()), file=stderr) # Progress reporting if state.step % opts.progress_interval == 0: current_stats = { 'step': state.step, 'loss': loss, 'tprb_m': avg_prob_target, 'pk_d_m': avg_peak_dist } #fmt = "M\t{:d}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}" #print(fmt.format(state.step, loss, avg_prob_target, avg_peak_dist, # avg_max), file=stderr) if state.model.bn_type == 'vqvae': current_stats.update(state.model.objective.metrics) netmisc.print_metrics(current_stats, 1000000) stderr.flush() state.step += 1 # Checkpointing if ((state.step % opts.save_interval == 0 and state.step != start_step) or (mode == 'new' and state.step == 1)): ckpt_file = ckpt_path.path(state.step) state.save(ckpt_file) print('Saved checkpoint to {}'.format(ckpt_file), file=stderr) #print('Optim state: {}'.format(state.optim_checksum()), file=stderr) stderr.flush()
def main(): if len(sys.argv) == 1 or sys.argv[1] not in ('new', 'resume'): print(parse_tools.top_usage, file=stderr) return mode = sys.argv[1] del sys.argv[1] if mode == 'new': opts = parse_tools.two_stage_parse(parse_tools.cold) elif mode == 'resume': opts = parse_tools.resume.parse_args() opts.device = None if not opts.disable_cuda and torch.cuda.is_available(): opts.device = torch.device('cuda') else: opts.device = torch.device('cpu') ckpt_path = util.CheckpointPath(opts.ckpt_template) # Construct model if mode == 'new': # Initialize model pre_params = parse_tools.get_prefixed_items(vars(opts), 'pre_') enc_params = parse_tools.get_prefixed_items(vars(opts), 'enc_') bn_params = parse_tools.get_prefixed_items(vars(opts), 'bn_') dec_params = parse_tools.get_prefixed_items(vars(opts), 'dec_') # Initialize data sample_catalog = D.parse_sample_catalog(opts.sam_file) data = D.WavSlices(sample_catalog, pre_params['sample_rate'], opts.frac_permutation_use, opts.requested_wav_buf_sz) dec_params['n_speakers'] = data.num_speakers() #with torch.autograd.set_detect_anomaly(True): model = ae.AutoEncoder(pre_params, enc_params, bn_params, dec_params) print('Initializing model parameters', file=stderr) model.initialize_weights() # Construct overall state state = checkpoint.State(0, model, data) else: state = checkpoint.State() state.load(opts.ckpt_file) print('Restored model and data from {}'.format(opts.ckpt_file), file=stderr) state.model.set_geometry(opts.n_sam_per_slice) state.data.set_geometry(opts.n_batch, state.model.input_size, state.model.output_size) state.model.to(device=opts.device) #total_bytes = 0 #for name, par in model.named_parameters(): # n_bytes = par.data.nelement() * par.data.element_size() # total_bytes += n_bytes # print(name, type(par.data), par.size(), n_bytes) #print('total_bytes: ', total_bytes) # Initialize optimizer model_params = state.model.parameters() metrics = ae.Metrics(state.model, None) batch_gen = state.data.batch_slice_gen_fn() #loss_fcn = state.model.loss_factory(state.data.batch_slice_gen_fn()) # Start training print('Starting training...', file=stderr) print("Step\tLoss\tAvgProbTarget\tPeakDist\tAvgMax", file=stderr) stderr.flush() learning_rates = dict(zip(opts.learning_rate_steps, opts.learning_rate_rates)) start_step = state.step if start_step not in learning_rates: ref_step = util.greatest_lower_bound(opts.learning_rate_steps, start_step) metrics.optim = torch.optim.Adam(params=model_params, lr=learning_rates[ref_step]) while state.step < opts.max_steps: if state.step in learning_rates: metrics.optim = torch.optim.Adam(params=model_params, lr=learning_rates[state.step]) # do 'pip install --upgrade scipy' if you get 'FutureWarning: ...' metrics.update(batch_gen) loss = metrics.optim.step(metrics.loss) avg_peak_dist = metrics.peak_dist() avg_max = metrics.avg_max() avg_prob_target = metrics.avg_prob_target() # Progress reporting if state.step % opts.progress_interval == 0: fmt = "{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}" print(fmt.format(state.step, loss, avg_prob_target, avg_peak_dist, avg_max), file=stderr) stderr.flush() # Checkpointing if state.step % opts.save_interval == 0 and state.step != start_step: ckpt_file = ckpt_path.path(state.step) state.save(ckpt_file) print('Saved checkpoint to {}'.format(ckpt_file), file=stderr) state.step += 1