def main(args, logger): logger.debug('load the model') with open(args.model, 'rb') as f: model = pickle.load(f) logger.debug('load the dataset') with open(args.dataset, 'rb') as f: dataset = pickle.load(f) elbo = beer.evidence_lower_bound(datasize=dataset.size) count = 0 for line in sys.stdin: utt = dataset[line.strip().split()[0]] logger.debug(f'processing utterance: {utt.id}') elbo += beer.evidence_lower_bound(model, utt.features, datasize=dataset.size) count += 1 logger.debug('saving the accumulated ELBO...') with open(args.out, 'wb') as f: pickle.dump((elbo, count), f) logger.info( f'accumulated ELBO over {count} utterances: {float(elbo) / (count * dataset.size) :.3f}.' )
def main(args, logger): logger.debug('load the model') with open(args.model, 'rb') as f: model = pickle.load(f) logger.debug('load the dataset') with open(args.dataset, 'rb') as f: dataset = pickle.load(f) alis = None if args.alis: logger.debug('loading alignment graphs') alis = np.load(args.alis, allow_pickle=True) elbo = beer.evidence_lower_bound(datasize=dataset.size) count = 0 if args.uttid == '-': infile = sys.stdin else: with open(args.uttid, 'r') as f: infile = f.readlines() for line in infile: uttid = line.strip().split()[0] try: utt = dataset[uttid] except KeyError: logger.warning(f'no utterance {utt} in {args.dataset}') continue aligraph = None if alis: try: aligraph = alis[uttid][0] except KeyError: logger.warning(f'no alignment graph for utterance "{uttid}"') logger.debug(f'processing utterance: {utt.id}') elbo += beer.evidence_lower_bound(model, utt.features, inference_graph=aligraph, datasize=dataset.size, scale=args.acoustic_scale) logger.debug(f'ELBO: {float(elbo)}') count += 1 logger.debug('saving the accumulated ELBO...') with open(args.out, 'wb') as f: pickle.dump((elbo, count), f) logger.info(f'accumulated ELBO over {count} utterances: {float(elbo) / (count * dataset.size) :.3f}.')
def test_sum(self): for i, model in enumerate(self.models): with self.subTest(model=self.conf_files[i]): optim = beer.BayesianModelCoordinateAscentOptimizer( model.mean_field_groups, lrate=1.) previous = -float('inf') for _ in range(N_EPOCHS): self.seed(1) optim.zero_grad() elbo = beer.evidence_lower_bound(datasize=len(self.data)) for _ in range(N_ITER): elbo += beer.evidence_lower_bound(model, self.data) elbo.natural_backward() optim.step() elbo_val = round( float(elbo) / (len(self.data) * self.dim), 3) self.assertGreaterEqual(elbo_val - previous, -TOLERANCE) previous = elbo_val
def main(args, logger): logger.debug('load the model') with open(args.model, 'rb') as f: model = pickle.load(f) logger.debug('load the dataset') with open(args.dataset, 'rb') as f: dataset = pickle.load(f) logger.debug('create the optimizer') optim = beer.BayesianModelOptimizer(model.mean_field_factorization(), lrate=args.lrate) for epoch in range(1, args.epochs + 1): elbo = beer.evidence_lower_bound(datasize=dataset.size) optim.init_step() for i, utt in enumerate(dataset.utterances(), start=1): logger.debug(f'processing utterance: {utt.id}') elbo += beer.evidence_lower_bound(model, utt.features, datasize=dataset.size) # Update the model after N utterances. if i % args.batch_size == 0: elbo.backward() optim.step() logger.info(f'{"epoch=" + str(epoch):<20} ' \ f'{"batch=" + str(i // args.batch_size) + "/" + str(int(len(dataset) / args.batch_size)):<20} ' \ f'{"ELBO=" + str(round(float(elbo) / (args.batch_size * dataset.size), 3)):<20}') elbo = beer.evidence_lower_bound(datasize=dataset.size) optim.init_step() logger.debug('save the model on disk...') with open(args.out, 'wb') as f: pickle.dump(model, f) logger.info(f'finished training after {args.epochs} epochs. ' \ f'KL(q || p) = {float(model.kl_div_posterior_prior()): .3f}')
def main(): parser = argparse.ArgumentParser() parser.add_argument('lm', help='unigram language model to train') parser.add_argument('data', help='data') parser.add_argument('outlm', help='output model') args = parser.parse_args() # Load the model. with open(args.lm, 'rb') as fh: model = pickle.load(fh) # Load the data for the training. data = np.load(args.data) # Count the number of in the training data. tot_counts = 0 for utt in data: tot_counts += len(data[utt]) # Prepare the optimizer for the training. params = model.mean_field_factorization() optimizer = beer.BayesianModelCoordinateAscentOptimizer(params, lrate=1.) optimizer.zero_grad() # Initialize the objective function. elbo = beer.evidence_lower_bound(datasize=tot_counts) # Re-estimate the LM. for utt in data: ft = torch.from_numpy(data[utt]) elbo += beer.evidence_lower_bound(model, ft, datasize=tot_counts) elbo.backward() optimizer.step() # Save the model. with open(args.outlm, 'wb') as fh: model = pickle.dump(model, fh)
def test_type_switch_double(self): for i, orig_model in enumerate(self.models): model = orig_model.double() with self.subTest(model=self.conf_files[i]): optim = beer.BayesianModelCoordinateAscentOptimizer( model.mean_field_groups, lrate=1.) previous = -float('inf') for _ in range(N_ITER): self.seed(1) optim.zero_grad() elbo = beer.evidence_lower_bound(model, self.data.double()) elbo.natural_backward() optim.step() elbo = round(float(elbo) / (len(self.data) * self.dim), 3) self.assertGreaterEqual(elbo - previous, -TOLERANCE) previous = elbo
def main(): parser = argparse.ArgumentParser() parser.add_argument('--batch-size', type=int, default=-1, help='utterance number in each batch') parser.add_argument('--epochs', type=int, default=1, help='number of epochs to train') parser.add_argument('--fast-eval', action='store_true') parser.add_argument('--kl-weight', type=float, default=1., help='weighting of KL div. of the ELBO') parser.add_argument('--lrate-nnet', type=float, default=1e-3, help='learning rate for the nnet components') parser.add_argument('--lrate', type=float, default=1., help='learning rate') parser.add_argument('--nnet-optim-state', help='file where to load/save state of the nnet ' 'optimizer') parser.add_argument('--use-gpu', action='store_true') parser.add_argument('--verbose', action='store_true') parser.add_argument('model', help='model to train') parser.add_argument('alis', help='alignments') parser.add_argument('feats', help='Feature file') parser.add_argument('feat_stats', help='data statistics') parser.add_argument('out', help='output model') args = parser.parse_args() if args.verbose: logging.getLogger().setLevel(logging.DEBUG) # Load the data. alis = np.load(args.alis) feats = np.load(args.feats) stats = np.load(args.feat_stats) # Load the model and move it to the chosen device (CPU/GPU) with open(args.model, 'rb') as fh: model = pickle.load(fh) if args.use_gpu: device = torch.device('cuda') else: device = torch.device('cpu') model = model.to(device) # NNET optimizer. nnet_optim = torch.optim.Adam(model.modules_parameters(), lr=args.lrate_nnet, eps=1e-3, amsgrad=False, weight_decay=1e-2) if args.nnet_optim_state and os.path.isfile(args.nnet_optim_state): logging.debug('load nnet optimizer state: {}'.format( args.nnet_optim_state)) optim_state = torch.load(args.nnet_optim_state) nnet_optim.load_state_dict(optim_state) # Prepare the optimizer for the training. params = model.mean_field_factorization() optimizer = beer.BayesianModelCoordinateAscentOptimizer( params, lrate=args.lrate, std_optim=nnet_optim) # If no batch_size is specified, use the whole data. batch_size = len(feats.files) if args.batch_size > 0: batch_size = args.batch_size tot_counts = int(stats['nframes']) for epoch in range(1, args.epochs + 1): # Shuffle the order of the utterance. keys = list(feats.keys()) random.shuffle(keys) batches = [ keys[i:i + batch_size] for i in range(0, len(keys), batch_size) ] logging.debug('Data shuffled into {} batches'.format(len(batches))) for batch_no, batch_keys in enumerate(batches, start=1): # Reset the gradients. optimizer.zero_grad() # Load the batch data. ft, labels = load_batch(feats, alis, batch_keys) ft, labels = ft.to(device), labels.to(device) # Compute the objective function. elbo = beer.evidence_lower_bound(model, ft, state_path=labels, kl_weight=args.kl_weight, datasize=tot_counts, fast_eval=args.fast_eval) # Compute the gradient of the model. #elbo.natural_backward() elbo.backward() # Clip the gradient to make avoid explosion. torch.nn.utils.clip_grad_norm_(model.modules_parameters(), 100.0) # Update the parameters. optimizer.step() elbo_value = float(elbo) / tot_counts log_msg = 'epoch={}/{} batch={}/{} elbo={}' logging.info( log_msg.format(epoch, args.epochs, batch_no, len(batches), round(elbo_value, 3))) if args.nnet_optim_state: torch.save(nnet_optim.state_dict(), args.nnet_optim_state) with open(args.out, 'wb') as fh: pickle.dump(model.to(torch.device('cpu')), fh)
def run(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument('--pretrain-epochs', type=int, default=1, help='number of pre-training epochs') parser.add_argument('--epochs', type=int, default=1, help='number of epochs') parser.add_argument('--lrate', type=float, default=1e-1, help='learning rate for the nnet parameters') parser.add_argument('--nnet-lrate', type=float, default=1e-3, help='learning rate for the nnet parameters') parser.add_argument('--update-prior', action='store_true', help='update the prior') parser.add_argument('--weight-decay', type=float, default=1e-2, help='weigth decay for the nnet parameters') parser.add_argument('init_model', help='initial model') parser.add_argument('train_data_dir', help='training data directory') parser.add_argument('test_data_dir', help='test data directory') parser.add_argument('out_model', help='output model') args = parser.parse_args() # Retrieve the list of batches. batches = list(glob.glob(os.path.join(args.train_data_dir, 'batch*npz'))) # Compute the total number of points of the data base tot_counts = 0 for batch in batches: tot_counts += np.load(batch)['features'].shape[0] # Load the model. with open(args.init_model, 'rb') as f: model = pickle.load(f) # Prepare the optimizer. std_optimizer = torch.optim.Adam(model.modules_parameters(), lr=args.nnet_lrate, weight_decay=args.weight_decay) if args.update_prior: params = model.mean_field_factorization() else: params = model.normal.mean_field_factorization() optimizer = beer.BayesianModelOptimizer(params, lrate=args.lrate, std_optim=std_optimizer) # To monitor the convergence. elbos = [] klds = [] log_preds = [] def callback(model, epoch, elbo_value): kld = compute_kl_div(model, args.test_data_dir) elbos.append(elbo_value) klds.append(kld) l_pred = log_pred(model, args.test_data_dir) log_preds.append(l_pred) print(f'epoch={epoch}/{args.epochs} ln p(X) >= {elbo_value:.2f} ' \ f'D(q || p) = {kld:.2f} (nats) ' \ f'ln p(X_test|X_train) = {l_pred:.2f}') epoch = 0 while epoch < args.epochs: # Randomized the order of the batches. batch_list = list(batches) random.shuffle(batch_list) elbo_value = 0. if epoch + 1 <= args.pretrain_epochs: kwargs = {'kl_weight': 0.} else: kwargs = {'kl_weight': 1.} for batch in batch_list: X = torch.from_numpy(np.load(batch)['features']).float() optimizer.init_step() elbo = beer.evidence_lower_bound(model, X, datasize=tot_counts, **kwargs) elbo.backward() optimizer.step() elbo_value += float(elbo) / tot_counts del X, elbo gc.collect() epoch += 1 if callback is not None: callback(model, epoch, elbo_value / len(batches)) with open(args.out_model, 'wb') as f: pickle.dump(model, f)
def main(args, logger): if args.gpu: gpu_idx = beer.utils.reserve_gpu(logger=logger) logger.debug('loading the GSM') with open(args.gsm, 'rb') as f: gsm = pickle.load(f) root_gsm = gsm['root'] gsms_dict = gsm['langs'] for _gsm in gsms_dict.values(): _gsm.transform.root_transform = root_gsm.shared_transform logger.debug('loading the units posterior') with open(args.posts, 'rb') as f: lp_data = pickle.load(f) posts_dict, units_dict, nstates, groupidx = lp_data labels = None logger.debug('loading the subspace phoneloop') phoneloops_dict = {} for line in open(args.sploop_to_lang): _sploop, lang = line.strip().split() with open(_sploop, 'rb') as f: phoneloops_dict[lang] = pickle.load(f) if args.gpu: logger.info(f'using gpu device: {gpu_idx}') for sploop in phoneloops_dict.values(): sploop = sploop.cuda() root_gsm = root_gsm.cuda() for lang in gsms_dict.keys(): gsms_dict[lang] = gsms_dict[lang].cuda() posts_dict[lang] = posts_dict[lang].cuda() logger.debug('loading the units') units_emissions_dict = {} for lang, sploop in phoneloops_dict.items(): units_emissions_dict[lang] = sploop.modelset.original_modelset.modelsets[groupidx] if len(units_emissions_dict[lang])//nstates < 2: units_emissions_dict[lang] = sploop.modelset.original_modelset.modelsets[1 - groupidx] units_dict[lang] = [unit for unit in iterate_units(units_emissions_dict[lang], len(units_dict[lang]), nstates)] logger.debug('building the optimizer') if args.skip_root_subspace: params = [[]] else: params = sum([list(_gsm.conjugate_bayesian_parameters(keepgroups=True)) for _gsm in gsms_dict.values()], []) cjg_optim = beer.VBConjugateOptimizer(params, lrate=args.learning_rate_cjg) params = [] params = list(root_gsm.parameters()) for lang, _gsm in gsms_dict.items(): params += list(_gsm.transform.latent_posterior.parameters()) all_latent_posts = [] for lang, latent_posts in posts_dict.items(): all_latent_posts.append(latent_posts) params += list(latent_posts.parameters()) if not args.use_sgd: std_optim = torch.optim.Adam(params, lr=args.learning_rate_std) else: std_optim = torch.optim.SGD(params, lr=args.learning_rate_std) optim = beer.VBOptimizer(cjg_optim, std_optim) if args.optim_state and os.path.isfile(args.optim_state): logger.debug(f'loading optimizer state from: {args.optim_state}') if args.gpu: maplocation = 'cuda' else: maplocation = 'cpu' state = torch.load(args.optim_state, maplocation) optim.load_state_dict(state) # listify all the dictionaries for training models_and_submodels = [] for lang, _gsm in gsms_dict.items(): lang_units = [unit for i, unit in enumerate(units_dict[lang])] logger.debug(f'{lang}: {len(lang_units)}') models_and_submodels.append([_gsm, lang_units]) kwargs = { 'univ_latent_nsamples': args.lang_latent_nsamples, 'latent_posts': all_latent_posts, 'latent_nsamples': args.unit_latent_nsamples, 'params_nsamples': args.params_nsamples, } for epoch in range(1, args.epochs + 1): optim.init_step() elbo = beer.evidence_lower_bound(root_gsm, models_and_submodels, **kwargs) elbo.backward() if args.skip_root_subspace: root_gsm.zero_grad() if args.skip_language_posterior: for lang, _gsm in gsms_dict.items(): _gsm.transform.latent_posterior.zero_grad() if args.skip_unit_posterior: for lang, latent_posts in posts_dict.items(): latent_posts.zero_grad() if args.clip_grad > 0: for grp in optim.std_optim.param_groups: norm = torch.nn.utils.clip_grad_norm_(grp['params'], args.clip_grad) else: norm = 0 optim.step() if args.logging_rate > 0 and epoch % args.logging_rate == 0: logger.info(f'epoch={epoch:<20} elbo={float(elbo):<20}') logger.debug(f'epoch={epoch:<20} norm={float(norm):<20}') logger.info(f'finished training at epoch={epoch} with elbo={float(elbo)}') if args.gpu: for sploop in phoneloops_dict.values(): sploop = sploop.cpu() root_gsm = root_gsm.cpu() for lang in gsms_dict.keys(): gsms_dict[lang] = gsms_dict[lang].cpu() posts_dict[lang] = posts_dict[lang].cpu() logger.debug('saving the HGSM') gsm['root'] = root_gsm gsm['langs'] = gsms_dict with open(args.out_gsm, 'wb') as f: pickle.dump(gsm, f) logger.debug('saving the units posterior') with open(args.out_posts, 'wb') as f: pickle.dump((posts_dict, units_dict, nstates, groupidx), f) logger.debug('saving the subspace phoneloop') sploop = phoneloops_dict with open(args.out_sploop, 'wb') as f: pickle.dump(sploop, f) for lang, sploop in phoneloops_dict.items(): with open(args.out_sploop + '_' + lang, 'wb') as f: pickle.dump(sploop, f) if args.optim_state: logger.debug(f'saving the optimizer state to: {args.optim_state}') torch.save(optim.state_dict(), args.optim_state)
print(hmm_full) # %% epochs = 2 lrate = 1. X = torch.from_numpy(data).cuda() optim = beer.VBConjugateOptimizer(hmm_full.mean_field_factorization(), lrate) elbos = [] for epoch in range(epochs): optim.init_step() elbo = beer.evidence_lower_bound(hmm_full, X, datasize=len(X), viterbi=False) elbo.backward() elbos.append(float(elbo) / len(X)) optim.step() # %% fig = figure() fig.line(range(len(elbos)), elbos) show(fig) # %% fig = figure(title='hmm_full', width=250, height=250) fig.circle(data[:, 0], data[:, 1], alpha=.1) plotting.plot_hmm(fig, hmm_full.cpu(),
def main(): parser = argparse.ArgumentParser() parser.add_argument('--batch-size', type=int, default=-1, help='utterance number in each batch') parser.add_argument('--epochs', type=int, default=1, help='number of epochs') parser.add_argument('--fast-eval', action='store_true') parser.add_argument('--lrate', type=float, default=1., help='learning rate') parser.add_argument('--use-gpu', action='store_true') parser.add_argument('--verbose', action='store_true') parser.add_argument('hmm', help='hmm model to train') parser.add_argument('alis', help='alignments') parser.add_argument('feats', help='Feature file') parser.add_argument('feat_stats', help='data statistics') parser.add_argument('out', help='output model') args = parser.parse_args() if args.verbose: logging.getLogger().setLevel(logging.DEBUG) # Load the data. alis = np.load(args.alis) feats = np.load(args.feats) stats = np.load(args.feat_stats) # Load the model and move it to the chosen device (CPU/GPU) with open(args.hmm, 'rb') as fh: model = pickle.load(fh) if args.use_gpu: device = torch.device('cuda') else: device = torch.device('cpu') model = model.to(device) # Prepare the optimizer for the training. params = model.mean_field_groups optimizer = beer.BayesianModelCoordinateAscentOptimizer(params, lrate=args.lrate) # If no batch_size is specified, use the whole data. batch_size = len(feats.files) if args.batch_size > 0: batch_size = args.batch_size tot_counts = int(stats['nframes']) for epoch in range(1, args.epochs + 1): # Shuffle the order of the utterance. keys = list(feats.keys()) random.shuffle(keys) batches = [keys[i: i + batch_size] for i in range(0, len(keys), batch_size)] logging.debug('Data shuffled into {} batches'.format(len(batches))) for batch_no, batch_keys in enumerate(batches, start=1): # Reset the gradients. optimizer.zero_grad() elbo = beer.evidence_lower_bound(datasize=tot_counts) for uttid in batch_keys: # Load the batch data. ft = torch.from_numpy(feats[uttid]).float() ali = torch.from_numpy(alis[uttid]).long() ft, ali = ft.to(device), ali.to(device) # Compute the objective function. elbo += beer.evidence_lower_bound(model, ft, state_path=ali, datasize=tot_counts, fast_eval=args.fast_eval) # Compute the gradient of the model. elbo.natural_backward() # Update the parameters. optimizer.step() elbo_value = float(elbo) / tot_counts log_msg = 'epoch={}/{} batch={}/{} elbo={}' logging.info(log_msg.format( epoch, args.epochs, batch_no, len(batches), round(elbo_value, 3)) ) del ft, ali with open(args.out, 'wb') as fh: pickle.dump(model.to(torch.device('cpu')), fh)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--alignments', help='utterance alignemnts') parser.add_argument('--batch-size', type=int, help='utterance number in each batch') parser.add_argument('--epochs', type=int) parser.add_argument('--fast-eval', action='store_true') parser.add_argument('--infer-type', default='viterbi', choices=['baum_welch', 'viterbi'], help='how to compute the state posteriors') parser.add_argument('--lrate', type=float, help='learning rate') parser.add_argument('--tmpdir', help='directory to store intermediary ' \ 'models') parser.add_argument('--use-gpu', action='store_true') parser.add_argument('hmm', help='hmm model to train') parser.add_argument('feats', help='Feature file') parser.add_argument('feat_stats', help='data statistics') parser.add_argument('out', help='output model') args = parser.parse_args() # Load the data for the training. feats = np.load(args.feats) ali = None if args.alignments: ali = np.load(args.alignments) stats = np.load(args.feat_stats) with open(args.hmm, 'rb') as fh: model = pickle.load(fh) if args.use_gpu: device = torch.device('cuda') else: device = torch.device('cpu') model = model.to(device) # Prepare the optimizer for the training. params = model.mean_field_groups optimizer = beer.BayesianModelCoordinateAscentOptimizer(params, lrate=args.lrate) tot_counts = int(stats['nframes']) for epoch in range(1, args.epochs + 1): # Shuffle the order of the utterance. keys = list(feats.keys()) random.shuffle(keys) batches = [ keys[i:i + args.batch_size] for i in range(0, len(keys), args.batch_size) ] logging.debug('Data shuffled into {} batches'.format(len(batches))) # One mini-batch update. for batch_no, batch_keys in enumerate(batches, start=1): # Reset the gradients. optimizer.zero_grad() # Initialize the ELBO. elbo = beer.evidence_lower_bound(datasize=tot_counts) for utt in batch_keys: ft = torch.from_numpy(feats[utt]).float().to(device) # Get the alignment graph if provided. graph = None if ali is not None: graph = ali[utt][0].to(device) elbo += beer.evidence_lower_bound( model, ft, datasize=tot_counts, fast_eval=args.fast_eval, inference_graph=graph, inference_type=args.infer_type) # Compute the gradient of the model. elbo.natural_backward() # Update the parameters. optimizer.step() elbo_value = float(elbo) / (tot_counts * len(batch_keys)) log_msg = 'epoch={}/{} batch={}/{} ELBO={}' logging.info( log_msg.format(epoch, args.epochs, batch_no, len(batches), round(elbo_value, 3))) if args.tmpdir: path = os.path.join(args.tmpdir, str(epoch) + '.mdl') with open(path, 'wb') as fh: pickle.dump(model.to(torch.device('cpu')), fh) with open(args.out, 'wb') as fh: pickle.dump(model.to(torch.device('cpu')), fh)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=1, help='number of epochs to train') parser.add_argument('--fast-eval', action='store_true') parser.add_argument('--lrate', type=float, default=1., help='learning rate') parser.add_argument('--use-gpu', action='store_true') parser.add_argument('--verbose', action='store_true') parser.add_argument('model', help='model to train') parser.add_argument('batches', help='list of batches file') parser.add_argument('feat_stats', help='data statistics') parser.add_argument('out', help='output model') args = parser.parse_args() if args.verbose: logging.getLogger().setLevel(logging.DEBUG) # Load the data. stats = np.load(args.feat_stats) # Load the batches. batches_list = [] with open(args.batches, 'r') as f: for line in f: batches_list.append(line.strip()) # Load the model and move it to the chosen device (CPU/GPU) with open(args.model, 'rb') as fh: model = pickle.load(fh) if args.use_gpu: device = torch.device('cuda') else: device = torch.device('cpu') model = model.to(device) # Prepare the optimizer for the training. params = model.mean_field_groups optimizer = beer.BayesianModelCoordinateAscentOptimizer(params, lrate=args.lrate) tot_counts = int(stats['nframes']) for epoch in range(1, args.epochs + 1): # Shuffle the order of the utterance. random.shuffle(batches_list) for batch_no, path in enumerate(batches_list, start=1): # Reset the gradients. optimizer.zero_grad() # Load the batch data. batch = np.load(path) ft = torch.from_numpy(batch['features']).float() ft = ft.to(device) # Compute the objective function. elbo = beer.evidence_lower_bound(model, ft, datasize=tot_counts, fast_eval=args.fast_eval) # Compute the gradient of the model. elbo.natural_backward() # Update the parameters. optimizer.step() elbo_value = float(elbo) / tot_counts log_msg = 'epoch={}/{} batch={}/{} elbo={}' logging.info( log_msg.format(epoch, args.epochs, batch_no, len(batches_list), round(elbo_value, 3))) with open(args.out, 'wb') as fh: pickle.dump(model.to(torch.device('cpu')), fh)
def main(args, logger): if args.gpu: gpu_idx = beer.utils.reserve_gpu(logger=logger) logger.debug('loading the GSM') with open(args.gsm, 'rb') as f: gsm = pickle.load(f) logger.debug('loading the units posterior') with open(args.posts, 'rb') as f: lp_data = pickle.load(f) if len(lp_data) == 4: latent_posts, nunits, nstates, groupidx = lp_data labels = None else: logger.debug('using labels for training the latent prior') latent_posts, nunits, nstates, groupidx, labels = lp_data logger.debug('loading the subspace phoneloop') with open(args.sploop, 'rb') as f: sploop = pickle.load(f) if args.gpu: logger.info(f'using gpu device: {gpu_idx}') sploop = sploop.cuda() gsm = gsm.cuda() latent_posts = latent_posts.cuda() if labels is not None: labels = labels.cuda() logger.debug('loading the units') units_emissions = sploop.modelset.original_modelset.modelsets[groupidx] units = [unit for unit in iterate_units(units_emissions, nunits, nstates)] logger.debug('building the optimizer') if args.posteriors: params = [[]] else: params = gsm.conjugate_bayesian_parameters(keepgroups=True) cjg_optim = beer.VBConjugateOptimizer(params, lrate=args.learning_rate_cjg) if args.posteriors: params = list(latent_posts.parameters()) else: params = list(latent_posts.parameters()) + list(gsm.parameters()) std_optim = torch.optim.Adam(params, lr=args.learning_rate_std) optim = beer.VBOptimizer(cjg_optim, std_optim) if args.optim_state and os.path.isfile(args.optim_state): logger.debug(f'loading optimizer state from: {args.optim_state}') if args.gpu: maplocation = 'cuda' else: maplocation = 'cpu' state = torch.load(args.optim_state, maplocation) optim.load_state_dict(state) kwargs = { 'latent_posts': latent_posts, 'latent_nsamples': args.latent_nsamples, 'params_nsamples': args.params_nsamples, } if labels is not None: kwargs['labels'] = labels for epoch in range(1, args.epochs + 1): optim.init_step() elbo = beer.evidence_lower_bound(gsm, units, **kwargs) elbo.backward() optim.step() if args.logging_rate > 0 and epoch % args.logging_rate == 0: logger.info(f'epoch={epoch:<20} elbo={float(elbo):<20}') logger.info(f'finished training at epoch={epoch} with elbo={float(elbo)}') if args.gpu: sploop = sploop.cpu() gsm = gsm.cpu() latent_posts = latent_posts.cpu() if labels is not None: labels = labels.cpu() logger.debug('saving the GSM') with open(args.out_gsm, 'wb') as f: pickle.dump(gsm, f) logger.debug('saving the units posterior') with open(args.out_posts, 'wb') as f: if labels is None: pickle.dump((latent_posts, nunits, nstates, groupidx), f) else: pickle.dump((latent_posts, nunits, nstates, groupidx, labels), f) logger.debug('saving the subspace phoneloop') with open(args.out_sploop, 'wb') as f: pickle.dump(sploop, f) if args.optim_state: logger.debug(f'saving the optimizer state to: {args.optim_state}') torch.save(optim.state_dict(), args.optim_state)