Ejemplo n.º 1
0
def process_minibatch(svi, epoch, mini_batch, n_mini_batches, which_mini_batch,
                      shuffled_indices):
    if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
        # compute the KL annealing factor approriate for the current mini-batch in the current epoch
        min_af = args.minimum_annealing_factor
        annealing_factor = min_af + (1.0 - min_af) * \
            (float(which_mini_batch + epoch * n_mini_batches + 1) /
             float(args.annealing_epochs * n_mini_batches))
    else:
        # by default the KL annealing factor is unity
        annealing_factor = 1.0

    # Generate dummy data that we can feed into the below fn.
    training_data_sequences = mini_batch.type(torch.FloatTensor)
    mini_batch_indices = torch.arange(0, training_data_sequences.size(0))
    training_seq_lengths = torch.full(
        (training_data_sequences.size(0), ),
        training_data_sequences.size(1)).type(torch.IntTensor)

    # grab a fully prepped mini-batch using the helper function in the data loader
    mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
        = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                              training_seq_lengths, cuda=args.cuda)

    # do an actual gradient step
    loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                    mini_batch_seq_lengths, annealing_factor)

    # keep track of the training loss
    return loss
Ejemplo n.º 2
0
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size,
                                 N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss
Ejemplo n.º 3
0
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = elbo.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                         mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss
Ejemplo n.º 4
0
def main(args):
    # setup logging
    log = get_logger(args.log)
    log(args)

    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    log("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d" %
        (N_train_data, training_seq_lengths.float().mean(), N_mini_batches))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # package repeated copies of val/test data for faster evaluation
    # (i.e. set us up for vectorization)
    def rep(x):
        rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
        repeat_dims = [1] * len(x.size())
        repeat_dims[0] = n_eval_samples
        return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(1, 0).reshape(rep_shape)

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    val_seq_lengths = rep(val_seq_lengths)
    test_seq_lengths = rep(test_seq_lengths)
    val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences),
        val_seq_lengths, cuda=args.cuda)
    test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences),
        test_seq_lengths, cuda=args.cuda)

    # instantiate the dmm
    dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs,
              iaf_dim=args.iaf_dim, use_cuda=args.cuda)

    # setup optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2),
                   "clip_norm": args.clip_norm, "lrd": args.lr_decay,
                   "weight_decay": args.weight_decay}
    adam = ClippedAdam(adam_params)

    # setup inference algorithm
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)

    # now we're going to define some functions we need to form the main training loop

    # saves the model and optimizer states to disk
    def save_checkpoint():
        log("saving model to %s..." % args.save_model)
        torch.save(dmm.state_dict(), args.save_model)
        log("saving optimizer states to %s..." % args.save_opt)
        adam.save(args.save_opt)
        log("done saving model and optimizer checkpoints to disk.")

    # loads the model and optimizer states from disk
    def load_checkpoint():
        assert exists(args.load_opt) and exists(args.load_model), \
            "--load-model and/or --load-opt misspecified"
        log("loading model from %s..." % args.load_model)
        dmm.load_state_dict(torch.load(args.load_model))
        log("loading optimizer states from %s..." % args.load_opt)
        adam.load(args.load_opt)
        log("done loading model and optimizer states.")

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

    # helper function for doing evaluation
    def do_evaluation():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        dmm.rnn.eval()

        # compute the validation and test loss n_samples many times
        val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask,
                                    val_seq_lengths) / torch.sum(val_seq_lengths)
        test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask,
                                     test_seq_lengths) / torch.sum(test_seq_lengths)

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        dmm.rnn.train()
        return val_nll, test_nll

    # if checkpoint files provided, load model and optimizer states from disk before we start training
    if args.load_opt != '' and args.load_model != '':
        load_checkpoint()

    #################
    # TRAINING LOOP #
    #################
    times = [time.time()]
    for epoch in range(args.num_epochs):
        # if specified, save model and optimizer states to disk every checkpoint_freq epochs
        if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
            save_checkpoint()

        # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        epoch_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = torch.randperm(N_train_data)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices)

        # report training diagnostics
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        log("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
            (epoch, epoch_nll / N_train_time_slices, epoch_time))

        # do evaluation on test and validation data and report results
        if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
            val_nll, test_nll = do_evaluation()
            log("[val/test epoch %04d]  %.4f  %.4f" % (epoch, val_nll, test_nll))
Ejemplo n.º 5
0
def main(args):
    # setup logging
    log = get_logger(args.log)
    log(args)

    jsb_file_loc = "./data/jsb_processed.pkl"
    # ingest training/validation/test data from disk
    data = pickle.load(open(jsb_file_loc, "rb"))
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = np.sum(training_seq_lengths)
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    log("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d" %
        (N_train_data, np.mean(training_seq_lengths), N_mini_batches))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # package repeated copies of val/test data for faster evaluation
    # (i.e. set us up for vectorization)
    def rep(x):
        y = np.repeat(x, n_eval_samples, axis=0)
        return y

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    val_seq_lengths = rep(val_seq_lengths)
    test_seq_lengths = rep(test_seq_lengths)
    val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
        np.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences),
        val_seq_lengths, volatile=True, cuda=args.cuda)
    test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
        np.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences),
        test_seq_lengths, volatile=True, cuda=args.cuda)

    # instantiate the dmm
    dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs,
              iaf_dim=args.iaf_dim, use_cuda=args.cuda)

    # setup optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2),
                   "clip_norm": args.clip_norm, "lrd": args.lr_decay,
                   "weight_decay": args.weight_decay}
    adam = ClippedAdam(adam_params)

    # setup inference algorithm
    elbo = SVI(dmm.model, dmm.guide, adam, "ELBO", trace_graph=False)

    # now we're going to define some functions we need to form the main training loop

    # saves the model and optimizer states to disk
    def save_checkpoint():
        log("saving model to %s..." % args.save_model)
        torch.save(dmm.state_dict(), args.save_model)
        log("saving optimizer states to %s..." % args.save_opt)
        adam.save(args.save_opt)
        log("done saving model and optimizer checkpoints to disk.")

    # loads the model and optimizer states from disk
    def load_checkpoint():
        assert exists(args.load_opt) and exists(args.load_model), \
            "--load-model and/or --load-opt misspecified"
        log("loading model from %s..." % args.load_model)
        dmm.load_state_dict(torch.load(args.load_model))
        log("loading optimizer states from %s..." % args.load_opt)
        adam.load(args.load_opt)
        log("done loading model and optimizer states.")

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = elbo.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                         mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

    # helper function for doing evaluation
    def do_evaluation():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        dmm.rnn.eval()

        # compute the validation and test loss n_samples many times
        val_nll = elbo.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask,
                                     val_seq_lengths) / np.sum(val_seq_lengths)
        test_nll = elbo.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask,
                                      test_seq_lengths) / np.sum(test_seq_lengths)

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        dmm.rnn.train()
        return val_nll, test_nll

    # if checkpoint files provided, load model and optimizer states from disk before we start training
    if args.load_opt != '' and args.load_model != '':
        load_checkpoint()

    #################
    # TRAINING LOOP #
    #################
    times = [time.time()]
    for epoch in range(args.num_epochs):
        # if specified, save model and optimizer states to disk every checkpoint_freq epochs
        if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
            save_checkpoint()

        # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        epoch_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = np.arange(N_train_data)
        np.random.shuffle(shuffled_indices)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices)

        # report training diagnostics
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        log("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
            (epoch, epoch_nll / N_train_time_slices, epoch_time))

        # do evaluation on test and validation data and report results
        if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
            val_nll, test_nll = do_evaluation()
            log("[val/test epoch %04d]  %.4f  %.4f" % (epoch, val_nll, test_nll))
Ejemplo n.º 6
0
def test_minibatch(dmm, mini_batch, args, sample_z=True):

    # Generate data that we can feed into the below fn.
    test_data_sequences = mini_batch.type(torch.FloatTensor)
    mini_batch_indices = torch.arange(0, test_data_sequences.size(0))
    test_seq_lengths = torch.full(
        (test_data_sequences.size(0), ),
        test_data_sequences.size(1)).type(torch.IntTensor)

    # grab a fully prepped mini-batch using the helper function in the data loader
    mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
        = poly.get_mini_batch(mini_batch_indices, test_data_sequences,
                              test_seq_lengths, cuda=args.cuda)

    # Get the initial RNN state.
    h_0 = dmm.h_0
    h_0_contig = h_0.expand(1, mini_batch.size(0),
                            dmm.rnn.hidden_size).contiguous()

    # Feed the test sequence into the RNN.
    rnn_output, rnn_hidden_state = dmm.rnn(mini_batch_reversed, h_0_contig)

    # Reverse the time ordering of the hidden state and unpack it.
    rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
    # print(rnn_output)
    # print(rnn_output.shape)

    # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
    z_prev = dmm.z_q_0.expand(mini_batch.size(0), dmm.z_q_0.size(0))

    # sample the latents z one time step at a time
    T_max = mini_batch.size(1)
    sequence_z = []
    sequence_output = []
    for t in range(1, T_max + 1):
        # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
        z_loc, z_scale = dmm.combiner(z_prev, rnn_output[:, t - 1, :])

        if sample_z:
            # if we are using normalizing flows, we apply the sequence of transformations
            # parameterized by self.iafs to the base distribution defined in the previous line
            # to yield a transformed distribution that we use for q(z_t|...)
            if len(dmm.iafs) > 0:
                z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale),
                                                 dmm.iafs)
            else:
                z_dist = dist.Normal(z_loc, z_scale)
            assert z_dist.event_shape == ()
            assert z_dist.batch_shape == (len(mini_batch), dmm.z_q_0.size(0))

            # sample z_t from the distribution z_dist
            annealing_factor = 1.0
            with pyro.poutine.scale(scale=annealing_factor):
                z_t = pyro.sample(
                    "z_%d" % t,
                    z_dist.mask(mini_batch_mask[:, t - 1:t]).to_event(1))
        else:
            z_t = z_loc

        z_t_np = z_t.detach().numpy()
        z_t_np = z_t_np[:, np.newaxis, :]
        sequence_z.append(z_t_np)

        # print("z_{}:".format(t), z_t)
        # print(z_t.shape)

        # compute the probabilities that parameterize the bernoulli likelihood
        emission_probs_t = dmm.emitter(z_t)

        emission_probs_t_np = emission_probs_t.detach().numpy()
        sequence_output.append(emission_probs_t_np)

        # print("x_{}:".format(t), emission_probs_t)
        # print(emission_probs_t.shape)

        # the latent sampled at this time step will be conditioned upon in the next time step
        # so keep track of it
        z_prev = z_t

    # Run the model another few steps.
    n_extra_steps = 100
    for t in range(1, n_extra_steps + 1):
        # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
        z_loc, z_scale = dmm.trans(z_prev)

        # then sample z_t according to dist.Normal(z_loc, z_scale)
        # note that we use the reshape method so that the univariate Normal distribution
        # is treated as a multivariate Normal distribution with a diagonal covariance.
        annealing_factor = 1.0
        with poutine.scale(scale=annealing_factor):
            z_t = pyro.sample(
                "z_%d" % t,
                dist.Normal(z_loc, z_scale)
                # .mask(mini_batch_mask[:, t - 1:t])
                .to_event(1))

        z_t_np = z_t.detach().numpy()
        z_t_np = z_t_np[:, np.newaxis, :]
        sequence_z.append(z_t_np)

        # compute the probabilities that parameterize the bernoulli likelihood
        emission_probs_t = dmm.emitter(z_t)

        emission_probs_t_np = emission_probs_t.detach().numpy()
        sequence_output.append(emission_probs_t_np)

        # the latent sampled at this time step will be conditioned upon
        # in the next time step so keep track of it
        z_prev = z_t

    sequence_z = np.concatenate(sequence_z, axis=1)
    sequence_output = np.concatenate(sequence_output, axis=1)
    # print(sequence_output.shape)

    # n_plots = 5
    # fig, axes = plt.subplots(nrows=n_plots, ncols=1)
    # x = range(sequence_output.shape[1])
    # for i in range(n_plots):
    #     input = mini_batch[i, :].numpy().squeeze()
    #     output = sequence_output[i, :]
    #     axes[i].plot(range(input.shape[0]), input)
    #     axes[i].plot(range(len(output)), output)
    #     axes[i].grid()

    return mini_batch, sequence_z, sequence_output  #fig
Ejemplo n.º 7
0
def main(args):
    # setup logging
    log = get_logger(args.log)
    log(args)

    jsb_file_loc = "./data/jsb_processed.pkl"
    # ingest training/validation/test data from disk
    data = pickle.load(open(jsb_file_loc, "rb"))
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(np.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    log("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d"
        % (N_train_data, np.mean(training_seq_lengths), N_mini_batches))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # package repeated copies of val/test data for faster evaluation
    # (i.e. set us up for vectorization)
    def rep(x):
        y = np.repeat(x, n_eval_samples, axis=0)
        return y

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    val_seq_lengths = rep(val_seq_lengths)
    test_seq_lengths = rep(test_seq_lengths)
    val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
        np.arange(n_eval_samples * val_data_sequences.shape[0]),
        rep(val_data_sequences),
        val_seq_lengths,
        cuda=args.cuda)
    test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
        np.arange(n_eval_samples * test_data_sequences.shape[0]),
        rep(test_data_sequences),
        test_seq_lengths,
        cuda=args.cuda)

    # instantiate the dmm
    dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate,
              num_iafs=args.num_iafs,
              iaf_dim=args.iaf_dim,
              use_cuda=args.cuda)

    # setup optimizer
    adam_params = {
        "lr": args.learning_rate,
        "betas": (args.beta1, args.beta2),
        "clip_norm": args.clip_norm,
        "lrd": args.lr_decay,
        "weight_decay": args.weight_decay
    }
    adam = ClippedAdam(adam_params)

    # setup inference algorithm
    elbo = SVI(dmm.model, dmm.guide, adam, Trace_ELBO())

    # now we're going to define some functions we need to form the main training loop

    # saves the model and optimizer states to disk
    def save_checkpoint():
        log("saving model to %s..." % args.save_model)
        torch.save(dmm.state_dict(),
                   os.path.join('.', 'checkpoints', 'dmm_model.pth'))
        log("saving optimizer states to %s..." % args.save_opt)
        adam.save(os.path.join('.', 'checkpoints', 'dmm_opt.pth'))
        log("done saving model and optimizer checkpoints to disk.")

    # loads the model and optimizer states from disk
    def load_checkpoint():
        assert exists(args.load_opt) and exists(args.load_model), \
            "--load-model and/or --load-opt misspecified"
        log("loading model from %s..." % args.load_model)
        dmm.load_state_dict(
            torch.load(os.path.join('.', 'checkpoints', 'dmm_model.pth')))
        log("loading optimizer states from %s..." % args.load_opt)
        adam.load(os.path.join('.', 'checkpoints', 'dmm_opt.pth'))
        log("done loading model and optimizer states.")

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size,
                                 N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]

        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = elbo.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                         mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

    # helper function for doing evaluation
    def do_evaluation():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        dmm.rnn.eval()

        # compute the validation and test loss n_samples many times
        val_nll = elbo.evaluate_loss(val_batch, val_batch_reversed,
                                     val_batch_mask,
                                     val_seq_lengths) / np.sum(val_seq_lengths)
        test_nll = elbo.evaluate_loss(
            test_batch, test_batch_reversed, test_batch_mask,
            test_seq_lengths) / np.sum(test_seq_lengths)

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        dmm.rnn.train()
        return val_nll, test_nll

    # if checkpoint files provided, load model and optimizer states from disk before we start training
    if args.load_opt != '' and args.load_model != '':
        load_checkpoint()

    #################
    # TRAINING LOOP #
    #################
    desc = "Epoch %4d | Loss %.4f | CUDA enabled" if args.cuda else "Epoch %4d | Loss %.4f"
    pbar = trange(args.num_epochs, desc=desc, unit="epoch")
    # for epoch in range(args.num_epochs):
    for epoch in pbar:
        # if specified, save model and optimizer states to disk every checkpoint_freq epochs
        if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
            save_checkpoint()

        # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        epoch_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = np.arange(N_train_data)
        np.random.shuffle(shuffled_indices)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_nll += process_minibatch(epoch, which_mini_batch,
                                           shuffled_indices)

        # report training diagnostics
        pbar.set_description(desc % (epoch, epoch_nll / N_train_time_slices))

        if np.isnan(epoch_nll):
            print("Gradient exploded. Exiting program.")
            quit()

        # do evaluation on test and validation data and report results
        if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
            val_nll, test_nll = do_evaluation()
            pbar.write("Epoch %04d | Validation Loss %.4f, Test Loss %.4f" %
                       (epoch, val_nll, test_nll))
Ejemplo n.º 8
0
    def test_minibatch(which_mini_batch, shuffled_indices):

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size,
                                 N_test_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, test_data_sequences,
                                  test_seq_lengths, cuda=args.cuda)

        # Get the initial RNN state.
        h_0 = dmm.h_0
        h_0_contig = h_0.expand(1, mini_batch.size(0),
                                dmm.rnn.hidden_size).contiguous()

        # Feed the test sequence into the RNN.
        rnn_output, rnn_hidden_state = dmm.rnn(mini_batch_reversed, h_0_contig)

        # Reverse the time ordering of the hidden state and unpack it.
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        print(rnn_output)
        print(rnn_output.shape)

        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = dmm.z_q_0.expand(mini_batch.size(0), dmm.z_q_0.size(0))

        # sample the latents z one time step at a time
        T_max = mini_batch.size(1)
        sequence_output = []
        for t in range(1, T_max + 1):
            # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
            z_loc, z_scale = dmm.combiner(z_prev, rnn_output[:, t - 1, :])

            # if we are using normalizing flows, we apply the sequence of transformations
            # parameterized by self.iafs to the base distribution defined in the previous line
            # to yield a transformed distribution that we use for q(z_t|...)
            if len(dmm.iafs) > 0:
                z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale),
                                                 dmm.iafs)
            else:
                z_dist = dist.Normal(z_loc, z_scale)
            assert z_dist.event_shape == ()
            assert z_dist.batch_shape == (len(mini_batch), dmm.z_q_0.size(0))

            # sample z_t from the distribution z_dist
            annealing_factor = 1.0
            with pyro.poutine.scale(scale=annealing_factor):
                z_t = pyro.sample(
                    "z_%d" % t,
                    z_dist.mask(mini_batch_mask[:, t - 1:t]).to_event(1))

            print("z_{}:".format(t), z_t)
            print(z_t.shape)

            # compute the probabilities that parameterize the bernoulli likelihood
            emission_probs_t = dmm.emitter(z_t)

            emission_probs_t_np = emission_probs_t.detach().numpy()
            sequence_output.append(emission_probs_t_np)

            print("x_{}:".format(t), emission_probs_t)
            print(emission_probs_t.shape)

            # the latent sampled at this time step will be conditioned upon in the next time step
            # so keep track of it
            z_prev = z_t

        # Run the model another few steps.
        n_steps = 100
        for t in range(1, n_steps + 1):
            # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
            z_loc, z_scale = dmm.trans(z_prev)

            # then sample z_t according to dist.Normal(z_loc, z_scale)
            # note that we use the reshape method so that the univariate Normal distribution
            # is treated as a multivariate Normal distribution with a diagonal covariance.
            with poutine.scale(scale=annealing_factor):
                z_t = pyro.sample(
                    "z_%d" % t,
                    dist.Normal(z_loc, z_scale).mask(
                        mini_batch_mask[:, t - 1:t]).to_event(1))

            # compute the probabilities that parameterize the bernoulli likelihood
            emission_probs_t = dmm.emitter(z_t)

            emission_probs_t_np = emission_probs_t.detach().numpy()
            sequence_output.append(emission_probs_t_np)

            # # the next statement instructs pyro to observe x_t according to the
            # # bernoulli distribution p(x_t|z_t)
            # pyro.sample("obs_x_%d" % t,
            #             # dist.Bernoulli(emission_probs_t)
            #             dist.Normal(emission_probs_t, 0.5)
            #             .mask(mini_batch_mask[:, t - 1:t])
            #             .to_event(1),
            #             obs=mini_batch[:, t - 1, :])

            # the latent sampled at this time step will be conditioned upon
            # in the next time step so keep track of it
            z_prev = z_t

        sequence_output = np.concatenate(sequence_output, axis=1)
        print(sequence_output.shape)

        n_plots = 5
        fig, axes = plt.subplots(nrows=n_plots, ncols=1)
        x = range(sequence_output.shape[1])
        for i in range(n_plots):
            input = mini_batch[i, :].numpy().squeeze()
            output = sequence_output[i, :]
            axes[i].plot(range(input.shape[0]), input)
            axes[i].plot(range(len(output)), output)
            axes[i].grid()

        # plt.plot(sequence_output[0, :])
        plt.show()
Ejemplo n.º 9
0
def main(args):
    # setup logging
    log = get_logger(args.log)
    log(args)

    if 0:
        data = generate_sine_wave_data()
        input_dim = 1
    elif 1:
        data = generate_returns_data()
        input_dim = 1
        # return
    else:
        data = poly.load_data(poly.JSB_CHORALES)
        input_dim = 88
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    N_test_data = len(test_seq_lengths)

    log("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d"
        % (N_train_data, training_seq_lengths.float().mean(), N_mini_batches))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # package repeated copies of val/test data for faster evaluation
    # (i.e. set us up for vectorization)
    def rep(x):
        rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
        repeat_dims = [1] * len(x.size())
        repeat_dims[0] = n_eval_samples
        return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(
            1, 0).reshape(rep_shape)

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    val_seq_lengths = rep(val_seq_lengths)
    test_seq_lengths = rep(test_seq_lengths)
    val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * val_data_sequences.shape[0]),
        rep(val_data_sequences),
        val_seq_lengths,
        cuda=args.cuda)
    test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * test_data_sequences.shape[0]),
        rep(test_data_sequences),
        test_seq_lengths,
        cuda=args.cuda)

    # instantiate the dmm
    dmm = DMM(input_dim=input_dim,
              rnn_dropout_rate=args.rnn_dropout_rate,
              num_iafs=args.num_iafs,
              iaf_dim=args.iaf_dim,
              use_cuda=args.cuda)

    # setup optimizer
    adam_params = {
        "lr": args.learning_rate,
        "betas": (args.beta1, args.beta2),
        "clip_norm": args.clip_norm,
        "lrd": args.lr_decay,
        "weight_decay": args.weight_decay
    }
    adam = ClippedAdam(adam_params)

    # setup inference algorithm
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)

    # now we're going to define some functions we need to form the main training loop

    # saves the model and optimizer states to disk
    def save_checkpoint():
        log("saving model to %s..." % args.save_model)
        torch.save(dmm.state_dict(), args.save_model)
        log("saving optimizer states to %s..." % args.save_opt)
        adam.save(args.save_opt)
        log("done saving model and optimizer checkpoints to disk.")

    # loads the model and optimizer states from disk
    def load_checkpoint():
        assert exists(args.load_opt) and exists(args.load_model), \
            "--load-model and/or --load-opt misspecified"
        log("loading model from %s..." % args.load_model)
        dmm.load_state_dict(torch.load(args.load_model))
        log("loading optimizer states from %s..." % args.load_opt)
        adam.load(args.load_opt)
        log("done loading model and optimizer states.")

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size,
                                 N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def test_minibatch(which_mini_batch, shuffled_indices):

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size,
                                 N_test_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, test_data_sequences,
                                  test_seq_lengths, cuda=args.cuda)

        # Get the initial RNN state.
        h_0 = dmm.h_0
        h_0_contig = h_0.expand(1, mini_batch.size(0),
                                dmm.rnn.hidden_size).contiguous()

        # Feed the test sequence into the RNN.
        rnn_output, rnn_hidden_state = dmm.rnn(mini_batch_reversed, h_0_contig)

        # Reverse the time ordering of the hidden state and unpack it.
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        print(rnn_output)
        print(rnn_output.shape)

        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = dmm.z_q_0.expand(mini_batch.size(0), dmm.z_q_0.size(0))

        # sample the latents z one time step at a time
        T_max = mini_batch.size(1)
        sequence_output = []
        for t in range(1, T_max + 1):
            # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
            z_loc, z_scale = dmm.combiner(z_prev, rnn_output[:, t - 1, :])

            # if we are using normalizing flows, we apply the sequence of transformations
            # parameterized by self.iafs to the base distribution defined in the previous line
            # to yield a transformed distribution that we use for q(z_t|...)
            if len(dmm.iafs) > 0:
                z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale),
                                                 dmm.iafs)
            else:
                z_dist = dist.Normal(z_loc, z_scale)
            assert z_dist.event_shape == ()
            assert z_dist.batch_shape == (len(mini_batch), dmm.z_q_0.size(0))

            # sample z_t from the distribution z_dist
            annealing_factor = 1.0
            with pyro.poutine.scale(scale=annealing_factor):
                z_t = pyro.sample(
                    "z_%d" % t,
                    z_dist.mask(mini_batch_mask[:, t - 1:t]).to_event(1))

            print("z_{}:".format(t), z_t)
            print(z_t.shape)

            # compute the probabilities that parameterize the bernoulli likelihood
            emission_probs_t = dmm.emitter(z_t)

            emission_probs_t_np = emission_probs_t.detach().numpy()
            sequence_output.append(emission_probs_t_np)

            print("x_{}:".format(t), emission_probs_t)
            print(emission_probs_t.shape)

            # the latent sampled at this time step will be conditioned upon in the next time step
            # so keep track of it
            z_prev = z_t

        # Run the model another few steps.
        n_steps = 100
        for t in range(1, n_steps + 1):
            # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
            z_loc, z_scale = dmm.trans(z_prev)

            # then sample z_t according to dist.Normal(z_loc, z_scale)
            # note that we use the reshape method so that the univariate Normal distribution
            # is treated as a multivariate Normal distribution with a diagonal covariance.
            with poutine.scale(scale=annealing_factor):
                z_t = pyro.sample(
                    "z_%d" % t,
                    dist.Normal(z_loc, z_scale).mask(
                        mini_batch_mask[:, t - 1:t]).to_event(1))

            # compute the probabilities that parameterize the bernoulli likelihood
            emission_probs_t = dmm.emitter(z_t)

            emission_probs_t_np = emission_probs_t.detach().numpy()
            sequence_output.append(emission_probs_t_np)

            # # the next statement instructs pyro to observe x_t according to the
            # # bernoulli distribution p(x_t|z_t)
            # pyro.sample("obs_x_%d" % t,
            #             # dist.Bernoulli(emission_probs_t)
            #             dist.Normal(emission_probs_t, 0.5)
            #             .mask(mini_batch_mask[:, t - 1:t])
            #             .to_event(1),
            #             obs=mini_batch[:, t - 1, :])

            # the latent sampled at this time step will be conditioned upon
            # in the next time step so keep track of it
            z_prev = z_t

        sequence_output = np.concatenate(sequence_output, axis=1)
        print(sequence_output.shape)

        n_plots = 5
        fig, axes = plt.subplots(nrows=n_plots, ncols=1)
        x = range(sequence_output.shape[1])
        for i in range(n_plots):
            input = mini_batch[i, :].numpy().squeeze()
            output = sequence_output[i, :]
            axes[i].plot(range(input.shape[0]), input)
            axes[i].plot(range(len(output)), output)
            axes[i].grid()

        # plt.plot(sequence_output[0, :])
        plt.show()

    # helper function for doing evaluation
    def do_evaluation():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        dmm.rnn.eval()

        # compute the validation and test loss n_samples many times
        val_nll = svi.evaluate_loss(
            val_batch, val_batch_reversed, val_batch_mask,
            val_seq_lengths) / torch.sum(val_seq_lengths)
        test_nll = svi.evaluate_loss(
            test_batch, test_batch_reversed, test_batch_mask,
            test_seq_lengths) / torch.sum(test_seq_lengths)

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        dmm.rnn.train()
        return val_nll, test_nll

    # if checkpoint files provided, load model and optimizer states from disk before we start training
    if args.load_opt != '' and args.load_model != '':
        load_checkpoint()

    #################
    # TRAINING LOOP #
    #################
    times = [time.time()]
    for epoch in range(args.num_epochs):
        # if specified, save model and optimizer states to disk every checkpoint_freq epochs
        if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
            save_checkpoint()

        # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        epoch_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = torch.randperm(N_train_data)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_nll += process_minibatch(epoch, which_mini_batch,
                                           shuffled_indices)

        # report training diagnostics
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        log("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
            (epoch, epoch_nll / N_train_time_slices, epoch_time))

        # do evaluation on test and validation data and report results
        if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
            val_nll, test_nll = do_evaluation()
            log("[val/test epoch %04d]  %.4f  %.4f" %
                (epoch, val_nll, test_nll))

    # Testing.
    print("Testing")
    shuffled_indices = torch.randperm(N_test_data)
    which_mini_batch = 0
    test_minibatch(which_mini_batch, shuffled_indices)
Ejemplo n.º 10
0
def main(args):
    # setup logging
    log = get_logger(args.log)
    log(args)

    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    log("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d"
        % (N_train_data, training_seq_lengths.float().mean(), N_mini_batches))

    # how often we do validation/test evaluation during training
    val_test_frequency = 5
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # package repeated copies of val/test data for faster evaluation
    # (i.e. set us up for vectorization)
    def rep(x):
        rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
        repeat_dims = [1] * len(x.size())
        repeat_dims[0] = n_eval_samples
        return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(
            1, 0).reshape(rep_shape)

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    val_seq_lengths = rep(val_seq_lengths)
    test_seq_lengths = rep(test_seq_lengths)
    val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * val_data_sequences.shape[0]),
        rep(val_data_sequences),
        val_seq_lengths,
        cuda=args.cuda)
    test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * test_data_sequences.shape[0]),
        rep(test_data_sequences),
        test_seq_lengths,
        cuda=args.cuda)

    # instantiate the dmm
    dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, use_cuda=args.cuda)

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                               (float(which_mini_batch + epoch * N_mini_batches + 1) /
                                float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size,
                                 N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss, loss_AT = dmm.train_ae(mini_batch, mini_batch_reversed,
                                     mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

    # if checkpoint files provided, load model and optimizer states from disk before we start training
    if args.load_opt != '' and args.load_model != '':
        load_checkpoint(dmm, log)

    #################
    # TRAINING LOOP #
    #################
    times = [time.time()]
    for epoch in range(args.num_epochs):
        # if specified, save model and optimizer states to disk every checkpoint_freq epochs
        if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
            save_checkpoint(dmm, log)

        # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        epoch_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = torch.randperm(N_train_data)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_nll += process_minibatch(epoch, which_mini_batch,
                                           shuffled_indices)

        # report training diagnostics
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        log("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
            (epoch, epoch_nll / N_train_time_slices, epoch_time))

        # do evaluation on test and validation data and report results
        if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
            # sample = dmm.generate(n=n_eval_samples, T_max=torch.max(training_seq_lengths).item())
            val_loss = dmm.build_loss(val_batch, val_batch_reversed,
                                      val_seq_lengths)
            test_loss = dmm.build_loss(test_batch, test_batch_reversed,
                                       test_seq_lengths)
            log("[val/test epoch %04d]  %.4f  %.4f" %
                (epoch, val_loss[0], test_loss[0]))