def main(args): if args.cuda: torch.set_default_tensor_type("torch.cuda.FloatTensor") logging.info("Loading data") data = poly.load_data(poly.JSB_CHORALES) logging.info("-" * 40) model = models[args.model] logging.info("Training {} on {} sequences".format( model.__name__, len(data["train"]["sequences"]))) sequences = data["train"]["sequences"] lengths = data["train"]["sequence_lengths"] # find all the notes that are present at least once in the training set present_notes = (sequences == 1).sum(0).sum(0) > 0 # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = poutine.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = poutine.trace( poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". optim = Adam({"lr": args.learning_rate}) if args.tmc: if args.jit: raise NotImplementedError( "jit support not yet added for TraceTMC_ELBO") elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2) tmc_model = poutine.infer_config( model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {}, ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo( max_plate_nesting=1 if model is model_0 else 2, strict_enumeration_warning=(model is not model_7), jit_options={"time_compilation": args.time_compilation}, ) svi = SVI(model, guide, optim, elbo) # We'll train on small minibatches. logging.info("Step\tLoss") for step in range(args.num_steps): loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size) logging.info("{: >5d}\t{}".format(step, loss / num_observations)) if args.jit and args.time_compilation: logging.debug("time to compile: {} s.".format( elbo._differentiable_loss.compile_time)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False) logging.info("training loss = {}".format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info("-" * 40) logging.info("Evaluating on {} test sequences".format( len(data["test"]["sequences"]))) sequences = data["test"]["sequences"][..., present_notes] lengths = data["test"]["sequence_lengths"] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False) logging.info("test loss = {}".format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info("{} capacity = {} parameters".format(model.__name__, capacity))
def main(args): ## ドレミ def easyTones(): training_seq_lengths = torch.tensor([8]*1) training_data_sequences = torch.zeros(1,8,88) for i in range(1): training_data_sequences[i][0][int(70-i*10) ] = 1 training_data_sequences[i][1][int(70-i*10)+2] = 1 training_data_sequences[i][2][int(70-i*10)+4] = 1 training_data_sequences[i][3][int(70-i*10)+5] = 1 training_data_sequences[i][4][int(70-i*10)+7] = 1 training_data_sequences[i][5][int(70-i*10)+9] = 1 training_data_sequences[i][6][int(70-i*10)+11] = 1 training_data_sequences[i][7][int(70-i*10)+12] = 1 return training_seq_lengths, training_data_sequences def superEasyTones(): training_seq_lengths = torch.tensor([8]*10) training_data_sequences = torch.zeros(10,8,88) for i in range(10): for j in range(8): training_data_sequences[i][j][int(30+i*5)] = 1 return training_seq_lengths, training_data_sequences ## ドドド、ドドド、ドドド def easiestTones(): training_seq_lengths = torch.tensor([8]*10) training_data_sequences = torch.zeros(10,8,88) for i in range(10): for j in range(8): training_data_sequences[i][j][int(70)] = 1 return training_seq_lengths, training_data_sequences # setup logging logging.basicConfig(level=logging.DEBUG, format='%(message)s', filename=args.log, filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) logging.getLogger('').addHandler(console) logging.info(args) data = poly.load_data(poly.JSB_CHORALES) training_seq_lengths = data['train']['sequence_lengths'] training_data_sequences = data['train']['sequences'] training_seq_lengths, training_data_sequences = easiestTones() test_seq_lengths = data['test']['sequence_lengths'] test_data_sequences = data['test']['sequences'] test_seq_lengths, test_data_sequences = easiestTones() val_seq_lengths = data['valid']['sequence_lengths'] val_data_sequences = data['valid']['sequences'] val_seq_lengths, val_data_sequences = easiestTones() N_train_data = len(training_seq_lengths) N_train_time_slices = float(torch.sum(training_seq_lengths)) N_mini_batches = int(N_train_data / args.mini_batch_size + int(N_train_data % args.mini_batch_size > 0)) logging.info("N_train_data: %d avg. training seq. length: %.2f N_mini_batches: %d" % (N_train_data, training_seq_lengths.float().mean(), N_mini_batches)) # how often we do validation/test evaluation during training val_test_frequency = 50 # the number of samples we use to do the evaluation n_eval_samples = 1 # package repeated copies of val/test data for faster evaluation # (i.e. set us up for vectorization) def rep(x): rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:] repeat_dims = [1] * len(x.size()) repeat_dims[0] = n_eval_samples return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(1, 0).reshape(rep_shape) # get the validation/test data ready for the dmm: pack into sequences, etc. val_seq_lengths = rep(val_seq_lengths) test_seq_lengths = rep(test_seq_lengths) val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch( torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences), val_seq_lengths, cuda=args.cuda) test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch( torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences), test_seq_lengths, cuda=args.cuda) # instantiate the dmm dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs, iaf_dim=args.iaf_dim, use_cuda=args.cuda) # setup optimizer adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2), "clip_norm": args.clip_norm, "lrd": args.lr_decay, "weight_decay": args.weight_decay} adam = ClippedAdam(adam_params) # setup inference algorithm if args.tmc: if args.jit: raise NotImplementedError("no JIT support yet for TMC") tmc_loss = TraceTMC_ELBO() dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False) svi = SVI(dmm.model, dmm_guide, adam, loss=tmc_loss) elif args.tmcelbo: if args.jit: raise NotImplementedError("no JIT support yet for TMC ELBO") elbo = TraceEnum_ELBO() dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False) svi = SVI(dmm.model, dmm_guide, adam, loss=elbo) else: elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(dmm.model, dmm.guide, adam, loss=elbo) # now we're going to define some functions we need to form the main training loop # saves the model and optimizer states to disk def save_checkpoint(): logging.info("saving model to %s..." % args.save_model) torch.save(dmm.state_dict(), args.save_model) # logging.info("saving optimizer states to %s..." % args.save_opt) # adam.save(args.save_opt) logging.info("done saving model and optimizer checkpoints to disk.") # loads the model and optimizer states from disk def load_checkpoint(): assert exists(args.load_opt) and exists(args.load_model), \ "--load-model and/or --load-opt misspecified" logging.info("loading model from %s..." % args.load_model) dmm.load_state_dict(torch.load(args.load_model)) logging.info("loading optimizer states from %s..." % args.load_opt) adam.load(args.load_opt) logging.info("done loading model and optimizer states.") # prepare a mini-batch and take a gradient step to minimize -elbo def process_minibatch(epoch, which_mini_batch, shuffled_indices): if args.annealing_epochs > 0 and epoch < args.annealing_epochs: # compute the KL annealing factor approriate for the current mini-batch in the current epoch min_af = args.minimum_annealing_factor annealing_factor = min_af + (1.0 - min_af) * \ (float(which_mini_batch + epoch * N_mini_batches + 1) / float(args.annealing_epochs * N_mini_batches)) else: # by default the KL annealing factor is unity annealing_factor = 1.0 # compute which sequences in the training set we should grab mini_batch_start = (which_mini_batch * args.mini_batch_size) mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data]) mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end] # grab a fully prepped mini-batch using the helper function in the data loader mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ = poly.get_mini_batch(mini_batch_indices, training_data_sequences, training_seq_lengths, cuda=args.cuda) # do an actual gradient step loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor) # keep track of the training loss return loss # helper function for doing evaluation def do_evaluation(): # put the RNN into evaluation mode (i.e. turn off drop-out if applicable) dmm.rnn.eval() # compute the validation and test loss n_samples many times val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths) / float(torch.sum(val_seq_lengths)) test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths) / float(torch.sum(test_seq_lengths)) # put the RNN back into training mode (i.e. turn on drop-out if applicable) dmm.rnn.train() return val_nll, test_nll # if checkpoint files provided, load model and optimizer states from disk before we start training if args.load_opt != '' and args.load_model != '': load_checkpoint() ################# # TRAINING LOOP # ################# times = [time.time()] for epoch in range(args.num_epochs): # if specified, save model and optimizer states to disk every checkpoint_freq epochs if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0: save_checkpoint() # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch epoch_nll = 0.0 # prepare mini-batch subsampling indices for this epoch shuffled_indices = torch.randperm(N_train_data) # process each mini-batch; this is where we take gradient steps for which_mini_batch in range(N_mini_batches): epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices) # report training diagnostics times.append(time.time()) epoch_time = times[-1] - times[-2] logging.info("[training epoch %04d] %.4f \t\t\t\t(dt = %.3f sec)" % (epoch, epoch_nll / N_train_time_slices, epoch_time)) # do evaluation on test and validation data and report results if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0: val_nll, test_nll = do_evaluation() logging.info("[val/test epoch %04d] %.4f %.4f" % (epoch, val_nll, test_nll))
def main(args): pyro.set_rng_seed(0) pyro.clear_param_store() pyro.enable_validation(__debug__) # load data if args.dataset == "dipper": capture_history_file = os.path.dirname( os.path.abspath(__file__)) + '/dipper_capture_history.csv' elif args.dataset == "vole": capture_history_file = os.path.dirname( os.path.abspath(__file__)) + '/meadow_voles_capture_history.csv' else: raise ValueError("Available datasets are \'dipper\' and \'vole\'.") capture_history = torch.tensor( np.genfromtxt(capture_history_file, delimiter=',')).float()[:, 1:] N, T = capture_history.shape print( "Loaded {} capture history for {} individuals collected over {} time periods." .format(args.dataset, N, T)) if args.dataset == "dipper" and args.model in ["4", "5"]: sex_file = os.path.dirname( os.path.abspath(__file__)) + '/dipper_sex.csv' sex = torch.tensor(np.genfromtxt(sex_file, delimiter=',')).float()[:, 1] print("Loaded dipper sex data.") elif args.dataset == "vole" and args.model in ["4", "5"]: raise ValueError( "Cannot run model_{} on meadow voles data, since we lack sex " + "information for these animals.".format(args.model)) else: sex = None model = models[args.model] # we use poutine.block to only expose the continuous latent variables # in the models to AutoDiagonalNormal (all of which begin with 'phi' # or 'rho') def expose_fn(msg): return msg["name"][0:3] in ['phi', 'rho'] # we use a mean field diagonal normal variational distributions (i.e. guide) # for the continuous latent variables. guide = AutoDiagonalNormal(poutine.block(model, expose_fn=expose_fn)) # since we enumerate the discrete random variables, # we need to use TraceEnum_ELBO or TraceTMC_ELBO. optim = Adam({'lr': args.learning_rate}) if args.tmc: elbo = TraceTMC_ELBO(max_plate_nesting=1) tmc_model = poutine.infer_config(model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {} ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: elbo = TraceEnum_ELBO(max_plate_nesting=1, num_particles=20, vectorize_particles=True) svi = SVI(model, guide, optim, elbo) losses = [] print( "Beginning training of model_{} with Stochastic Variational Inference." .format(args.model)) for step in range(args.num_steps): loss = svi.step(capture_history, sex) losses.append(loss) if step % 20 == 0 and step > 0 or step == args.num_steps - 1: print("[iteration %03d] loss: %.3f" % (step, np.mean(losses[-20:]))) # evaluate final trained model elbo_eval = TraceEnum_ELBO(max_plate_nesting=1, num_particles=2000, vectorize_particles=True) svi_eval = SVI(model, guide, optim, elbo_eval) print("Final loss: %.4f" % svi_eval.evaluate_loss(capture_history, sex))