Exemple #1
0
def main(args):
    if args.cuda:
        torch.set_default_tensor_type("torch.cuda.FloatTensor")

    logging.info("Loading data")
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info("-" * 40)
    model = models[args.model]
    logging.info("Training {} on {} sequences".format(
        model.__name__, len(data["train"]["sequences"])))
    sequences = data["train"]["sequences"]
    lengths = data["train"]["sequence_lengths"]

    # find all the notes that are present at least once in the training set
    present_notes = (sequences == 1).sum(0).sum(0) > 0
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        poutine.block(model,
                      expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = poutine.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = poutine.trace(
            poutine.replay(poutine.enum(model, first_available_dim),
                           guide_trace)).get_trace(sequences,
                                                   lengths,
                                                   args=args,
                                                   batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    optim = Adam({"lr": args.learning_rate})
    if args.tmc:
        if args.jit:
            raise NotImplementedError(
                "jit support not yet added for TraceTMC_ELBO")
        elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2)
        tmc_model = poutine.infer_config(
            model,
            lambda msg: {
                "num_samples": args.tmc_num_samples,
                "expand": False
            } if msg["infer"].get("enumerate", None) == "parallel" else {},
        )  # noqa: E501
        svi = SVI(tmc_model, guide, optim, elbo)
    else:
        Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
        elbo = Elbo(
            max_plate_nesting=1 if model is model_0 else 2,
            strict_enumeration_warning=(model is not model_7),
            jit_options={"time_compilation": args.time_compilation},
        )
        svi = SVI(model, guide, optim, elbo)

    # We'll train on small minibatches.
    logging.info("Step\tLoss")
    for step in range(args.num_steps):
        loss = svi.step(sequences,
                        lengths,
                        args=args,
                        batch_size=args.batch_size)
        logging.info("{: >5d}\t{}".format(step, loss / num_observations))

    if args.jit and args.time_compilation:
        logging.debug("time to compile: {} s.".format(
            elbo._differentiable_loss.compile_time))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           args,
                           include_prior=False)
    logging.info("training loss = {}".format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info("-" * 40)
    logging.info("Evaluating on {} test sequences".format(
        len(data["test"]["sequences"])))
    sequences = data["test"]["sequences"][..., present_notes]
    lengths = data["test"]["sequence_lengths"]
    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          args=args,
                          include_prior=False)
    logging.info("test loss = {}".format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info("{} capacity = {} parameters".format(model.__name__,
                                                      capacity))
Exemple #2
0
def main(args):
    ## ドレミ
    def easyTones():
        training_seq_lengths = torch.tensor([8]*1)
        training_data_sequences = torch.zeros(1,8,88)
        for i in range(1):
            training_data_sequences[i][0][int(70-i*10)  ] = 1
            training_data_sequences[i][1][int(70-i*10)+2] = 1
            training_data_sequences[i][2][int(70-i*10)+4] = 1
            training_data_sequences[i][3][int(70-i*10)+5] = 1
            training_data_sequences[i][4][int(70-i*10)+7] = 1
            training_data_sequences[i][5][int(70-i*10)+9] = 1
            training_data_sequences[i][6][int(70-i*10)+11] = 1
            training_data_sequences[i][7][int(70-i*10)+12] = 1
        return training_seq_lengths, training_data_sequences

    def superEasyTones():
        training_seq_lengths = torch.tensor([8]*10)
        training_data_sequences = torch.zeros(10,8,88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(30+i*5)] = 1
        return training_seq_lengths, training_data_sequences

    ## ドドド、ドドド、ドドド
    def easiestTones():
        training_seq_lengths = torch.tensor([8]*10)
        training_data_sequences = torch.zeros(10,8,88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(70)] = 1
        return training_seq_lengths, training_data_sequences

    # setup logging
    logging.basicConfig(level=logging.DEBUG, format='%(message)s', filename=args.log, filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    logging.info(args)

    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    training_seq_lengths, training_data_sequences = easiestTones()
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    test_seq_lengths, test_data_sequences = easiestTones()
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    val_seq_lengths, val_data_sequences = easiestTones()
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    logging.info("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d" %
                 (N_train_data, training_seq_lengths.float().mean(), N_mini_batches))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # package repeated copies of val/test data for faster evaluation
    # (i.e. set us up for vectorization)
    def rep(x):
        rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
        repeat_dims = [1] * len(x.size())
        repeat_dims[0] = n_eval_samples
        return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(1, 0).reshape(rep_shape)

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    val_seq_lengths = rep(val_seq_lengths)
    test_seq_lengths = rep(test_seq_lengths)
    val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences),
        val_seq_lengths, cuda=args.cuda)
    test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences),
        test_seq_lengths, cuda=args.cuda)

    # instantiate the dmm
    dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs,
              iaf_dim=args.iaf_dim, use_cuda=args.cuda)

    # setup optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2),
                   "clip_norm": args.clip_norm, "lrd": args.lr_decay,
                   "weight_decay": args.weight_decay}
    adam = ClippedAdam(adam_params)

    # setup inference algorithm
    if args.tmc:
        if args.jit:
            raise NotImplementedError("no JIT support yet for TMC")
        tmc_loss = TraceTMC_ELBO()
        dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False)
        svi = SVI(dmm.model, dmm_guide, adam, loss=tmc_loss)
    elif args.tmcelbo:
        if args.jit:
            raise NotImplementedError("no JIT support yet for TMC ELBO")
        elbo = TraceEnum_ELBO()
        dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False)
        svi = SVI(dmm.model, dmm_guide, adam, loss=elbo)
    else:
        elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
        svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)

    # now we're going to define some functions we need to form the main training loop

    # saves the model and optimizer states to disk
    def save_checkpoint():
        logging.info("saving model to %s..." % args.save_model)
        torch.save(dmm.state_dict(), args.save_model)
        # logging.info("saving optimizer states to %s..." % args.save_opt)
        # adam.save(args.save_opt)
        logging.info("done saving model and optimizer checkpoints to disk.")

    # loads the model and optimizer states from disk
    def load_checkpoint():
        assert exists(args.load_opt) and exists(args.load_model), \
            "--load-model and/or --load-opt misspecified"
        logging.info("loading model from %s..." % args.load_model)
        dmm.load_state_dict(torch.load(args.load_model))
        logging.info("loading optimizer states from %s..." % args.load_opt)
        adam.load(args.load_opt)
        logging.info("done loading model and optimizer states.")

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

    # helper function for doing evaluation
    def do_evaluation():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        dmm.rnn.eval()

        # compute the validation and test loss n_samples many times
        val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask,
                                    val_seq_lengths) / float(torch.sum(val_seq_lengths))
        test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask,
                                     test_seq_lengths) / float(torch.sum(test_seq_lengths))

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        dmm.rnn.train()
        return val_nll, test_nll

    # if checkpoint files provided, load model and optimizer states from disk before we start training
    if args.load_opt != '' and args.load_model != '':
        load_checkpoint()

    #################
    # TRAINING LOOP #
    #################
    times = [time.time()]
    for epoch in range(args.num_epochs):
        # if specified, save model and optimizer states to disk every checkpoint_freq epochs
        if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
            save_checkpoint()

        # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        epoch_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = torch.randperm(N_train_data)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices)

        # report training diagnostics
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        logging.info("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
                     (epoch, epoch_nll / N_train_time_slices, epoch_time))

        # do evaluation on test and validation data and report results
        if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
            val_nll, test_nll = do_evaluation()
            logging.info("[val/test epoch %04d]  %.4f  %.4f" % (epoch, val_nll, test_nll))
Exemple #3
0
def main(args):
    pyro.set_rng_seed(0)
    pyro.clear_param_store()
    pyro.enable_validation(__debug__)

    # load data
    if args.dataset == "dipper":
        capture_history_file = os.path.dirname(
            os.path.abspath(__file__)) + '/dipper_capture_history.csv'
    elif args.dataset == "vole":
        capture_history_file = os.path.dirname(
            os.path.abspath(__file__)) + '/meadow_voles_capture_history.csv'
    else:
        raise ValueError("Available datasets are \'dipper\' and \'vole\'.")

    capture_history = torch.tensor(
        np.genfromtxt(capture_history_file, delimiter=',')).float()[:, 1:]
    N, T = capture_history.shape
    print(
        "Loaded {} capture history for {} individuals collected over {} time periods."
        .format(args.dataset, N, T))

    if args.dataset == "dipper" and args.model in ["4", "5"]:
        sex_file = os.path.dirname(
            os.path.abspath(__file__)) + '/dipper_sex.csv'
        sex = torch.tensor(np.genfromtxt(sex_file, delimiter=',')).float()[:,
                                                                           1]
        print("Loaded dipper sex data.")
    elif args.dataset == "vole" and args.model in ["4", "5"]:
        raise ValueError(
            "Cannot run model_{} on meadow voles data, since we lack sex " +
            "information for these animals.".format(args.model))
    else:
        sex = None

    model = models[args.model]

    # we use poutine.block to only expose the continuous latent variables
    # in the models to AutoDiagonalNormal (all of which begin with 'phi'
    # or 'rho')
    def expose_fn(msg):
        return msg["name"][0:3] in ['phi', 'rho']

    # we use a mean field diagonal normal variational distributions (i.e. guide)
    # for the continuous latent variables.
    guide = AutoDiagonalNormal(poutine.block(model, expose_fn=expose_fn))

    # since we enumerate the discrete random variables,
    # we need to use TraceEnum_ELBO or TraceTMC_ELBO.
    optim = Adam({'lr': args.learning_rate})
    if args.tmc:
        elbo = TraceTMC_ELBO(max_plate_nesting=1)
        tmc_model = poutine.infer_config(model, lambda msg: {
            "num_samples": args.tmc_num_samples,
            "expand": False
        } if msg["infer"].get("enumerate", None) == "parallel" else {}
                                         )  # noqa: E501
        svi = SVI(tmc_model, guide, optim, elbo)
    else:
        elbo = TraceEnum_ELBO(max_plate_nesting=1,
                              num_particles=20,
                              vectorize_particles=True)
        svi = SVI(model, guide, optim, elbo)

    losses = []

    print(
        "Beginning training of model_{} with Stochastic Variational Inference."
        .format(args.model))

    for step in range(args.num_steps):
        loss = svi.step(capture_history, sex)
        losses.append(loss)
        if step % 20 == 0 and step > 0 or step == args.num_steps - 1:
            print("[iteration %03d] loss: %.3f" %
                  (step, np.mean(losses[-20:])))

    # evaluate final trained model
    elbo_eval = TraceEnum_ELBO(max_plate_nesting=1,
                               num_particles=2000,
                               vectorize_particles=True)
    svi_eval = SVI(model, guide, optim, elbo_eval)
    print("Final loss: %.4f" % svi_eval.evaluate_loss(capture_history, sex))