def train_epoch(cur_epoch, model, dataloader, loss, opt, lr_scheduler=None, print_modulus=1):
    """
    ----------
    Author: Damon Gwinn
    ----------
    Trains a single model epoch
    ----------
    """

    out = -1
    model.train()
    for batch_num, batch in enumerate(dataloader):
        time_before = time.time()

        opt.zero_grad()

        x   = batch[0].to(get_device())
        tgt = batch[1].to(get_device())

        y = model(x)

        y   = y.reshape(y.shape[0] * y.shape[1], -1)
        tgt = tgt.flatten()

        out = loss.forward(y, tgt)

        out.backward()
        opt.step()

        if(lr_scheduler is not None):
            lr_scheduler.step()

        time_after = time.time()
        time_took = time_after - time_before

        if((batch_num+1) % print_modulus == 0):
            print(SEPERATOR)
            print("Epoch", cur_epoch, " Batch", batch_num+1, "/", len(dataloader))
            print("LR:", get_lr(opt))
            print("Train loss:", float(out))
            print("")
            print("Time (s):", time_took)
            print(SEPERATOR)
            print("")

    return
def eval_model(model, dataloader, loss):
    """
    ----------
    Author: Damon Gwinn
    ----------
    Evaluates the model and prints the average loss and accuracy
    ----------
    """

    model.eval()

    avg_acc     = -1
    avg_loss    = -1
    with torch.set_grad_enabled(False):
        n_test      = len(dataloader)
        sum_loss   = 0.0
        sum_acc    = 0.0
        for batch in dataloader:
            x   = batch[0].to(get_device())
            tgt = batch[1].to(get_device())

            y = model(x)

            sum_acc += float(compute_epiano_accuracy(y, tgt))

            y   = y.reshape(y.shape[0] * y.shape[1], -1)
            tgt = tgt.flatten()

            out = loss.forward(y, tgt)

            sum_loss += float(out)

        avg_loss    = sum_loss / n_test
        avg_acc     = sum_acc / n_test

    return avg_loss, avg_acc
def main():
    """
    ----------
    Author: Damon Gwinn
    ----------
    Entry point. Evaluates a model specified by command line arguments
    ----------
    """

    args = parse_eval_args()
    print_eval_args(args)

    if (args.force_cpu):
        use_cuda(False)
        print("WARNING: Forced CPU usage, expect model to perform slower")
        print("")

    # Test dataset
    _, _, test_dataset = create_epiano_datasets(args.dataset_dir,
                                                args.max_sequence)

    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=args.n_workers)

    model = MusicTransformer(n_layers=args.n_layers,
                             num_heads=args.num_heads,
                             d_model=args.d_model,
                             dim_feedforward=args.dim_feedforward,
                             max_sequence=args.max_sequence,
                             rpr=args.rpr).to(get_device())

    model.load_state_dict(torch.load(args.model_weights))

    # No smoothed loss
    loss = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD)

    print("Evaluating:")
    model.eval()

    avg_loss, avg_acc = eval_model(model, test_loader, loss)

    print("Avg loss:", avg_loss)
    print("Avg acc:", avg_acc)
    print(SEPERATOR)
    print("")
Esempio n. 4
0
    def forward(self, x, mask=True):
        """
        ----------
        Author: Damon Gwinn
        ----------
        Takes an input sequence and outputs predictions using a sequence to sequence method.

        A prediction at one index is the "next" prediction given all information seen previously.
        ----------
        """

        if (mask is True):
            mask = self.transformer.generate_square_subsequent_mask(
                x.shape[1]).to(get_device())
        else:
            mask = None

        x = self.embedding(x)

        # Input shape is (max_seq, batch_size, d_model)
        x = x.permute(1, 0, 2)

        x = self.positional_encoding(x)

        # Since there are no true decoder layers, the tgt is unused
        # Pytorch wants src and tgt to have some equal dims however
        x_out = self.transformer(src=x, tgt=x, src_mask=mask)

        # Back to (batch_size, max_seq, d_model)
        x_out = x_out.permute(1, 0, 2)

        y = self.Wout(x_out)
        # y = self.softmax(y)

        del mask

        # They are trained to predict the next note in sequence (we don't need the last one)
        return y
Esempio n. 5
0
def main():
    """
    ----------
    Author: Damon Gwinn
    ----------
    Entry point. Generates music from a model specified by command line arguments
    ----------
    """

    args = parse_generate_args()
    print_generate_args(args)

    if (args.force_cpu):
        use_cuda(False)
        print("WARNING: Forced CPU usage, expect model to perform slower")
        print("")

    os.makedirs(args.output_dir, exist_ok=True)

    # Grabbing dataset if needed
    _, _, dataset = create_epiano_datasets(args.midi_root,
                                           args.num_prime,
                                           random_seq=False)

    # Can be None, an integer index to dataset, or a file path
    if (args.primer_file is None):
        f = str(random.randrange(len(dataset)))
    else:
        f = args.primer_file

    if (f.isdigit()):
        idx = int(f)
        primer, _ = dataset[idx]
        primer = primer.to(get_device())

        print("Using primer index:", idx, "(", dataset.data_files[idx], ")")

    else:
        raw_mid = encode_midi(f)
        if (len(raw_mid) == 0):
            print("Error: No midi messages in primer file:", f)
            return

        primer, _ = process_midi(raw_mid, args.num_prime, random_seq=False)
        primer = torch.tensor(primer,
                              dtype=TORCH_LABEL_TYPE,
                              device=get_device())

        print("Using primer file:", f)

    model = MusicTransformer(n_layers=args.n_layers,
                             num_heads=args.num_heads,
                             d_model=args.d_model,
                             dim_feedforward=args.dim_feedforward,
                             max_sequence=args.max_sequence,
                             rpr=args.rpr).to(get_device())

    model.load_state_dict(torch.load(args.model_weights))

    # Saving primer first
    f_path = os.path.join(args.output_dir, "primer")
    #decode_midi(primer[:args.num_prime].cpu().numpy(), file_path=f_path)
    x = primer[:args.num_prime].cpu().numpy()
    y = x.tolist()
    z = TMIDI.Tegridy_INT_to_TXT_Converter(y)
    SONG = TMIDI.Tegridy_Reduced_TXT_to_Notes_Converter(
        z, has_MIDI_channels=False, has_velocities=False)
    stats = TMIDI.Tegridy_SONG_to_MIDI_Converter(SONG=SONG[0],
                                                 output_file_name=f_path)

    # GENERATION
    model.eval()
    with torch.set_grad_enabled(False):
        if (args.beam > 0):
            print("BEAM:", args.beam)
            beam_seq = model.generate(primer[:args.num_prime],
                                      args.target_seq_length,
                                      beam=args.beam)

            f_path = os.path.join(args.output_dir, "beam")
            decode_midi(beam_seq[0].cpu().numpy(), file_path=f_path)
            x = beam_seq[0].cpu().numpy()
            y = x.tolist()
            z = TMIDI.Tegridy_INT_to_TXT_Converter(y)
            SONG, song_name = TMIDI.Tegridy_Optimus_TXT_to_Notes_Converter(
                z,
                has_MIDI_channels=False,
                simulate_velocity=False,
                char_encoding_offset=33,
                save_only_first_composition=True,
                dataset_MIDI_events_time_denominator=10,
                has_velocities=True)
            stats = TMIDI.Tegridy_SONG_to_MIDI_Converter(
                SONG=SONG, output_file_name=f_path)
            print(stats)

        else:
            print("RAND DIST")
            rand_seq = model.generate(primer[:args.num_prime],
                                      args.target_seq_length,
                                      beam=0)

            f_path = os.path.join(args.output_dir, "rand")
            #decode_midi(rand_seq[0].cpu().numpy(), file_path=f_path)
            #print('Seq =', rand_seq[0].cpu().numpy())
            x = rand_seq[0].cpu().numpy()
            y = x.tolist()
            z = TMIDI.Tegridy_INT_to_TXT_Converter(y)
            #SONG = TMIDI.Tegridy_Reduced_TXT_to_Notes_Converter(z, has_MIDI_channels=False, has_velocities=False)
            SONG, song_name = TMIDI.Tegridy_Optimus_TXT_to_Notes_Converter(
                z,
                has_MIDI_channels=False,
                simulate_velocity=False,
                char_encoding_offset=33,
                save_only_first_composition=True,
                dataset_MIDI_events_time_denominator=10,
                has_velocities=True)
            stats = TMIDI.Tegridy_SONG_to_MIDI_Converter(
                SONG=SONG, output_file_name=f_path)
            print(stats)
Esempio n. 6
0
    def generate(self,
                 primer=None,
                 target_seq_length=1024,
                 beam=0,
                 beam_chance=1.0):
        """
        ----------
        Author: Damon Gwinn
        ----------
        Generates midi given a primer sample. Music can be generated using a probability distribution over
        the softmax probabilities (recommended) or by using a beam search.
        ----------
        """

        assert (not self.training), "Cannot generate while in training mode"

        print("Generating sequence of max length:", target_seq_length)

        gen_seq = torch.full((1, target_seq_length),
                             TOKEN_PAD,
                             dtype=TORCH_LABEL_TYPE,
                             device=get_device())

        num_primer = len(primer)
        gen_seq[..., :num_primer] = primer.type(TORCH_LABEL_TYPE).to(
            get_device())

        # print("primer:",primer)
        # print(gen_seq)
        cur_i = num_primer
        while (cur_i < target_seq_length):
            # gen_seq_batch     = gen_seq.clone()
            y = self.softmax(self.forward(
                gen_seq[..., :cur_i]))[..., :TOKEN_END]
            token_probs = y[:, cur_i - 1, :]

            if (beam == 0):
                beam_ran = 2.0
            else:
                beam_ran = random.uniform(0, 1)

            if (beam_ran <= beam_chance):
                token_probs = token_probs.flatten()
                top_res, top_i = torch.topk(token_probs, beam)

                beam_rows = top_i // VOCAB_SIZE
                beam_cols = top_i % VOCAB_SIZE

                gen_seq = gen_seq[beam_rows, :]
                gen_seq[..., cur_i] = beam_cols

            else:
                distrib = torch.distributions.categorical.Categorical(
                    probs=token_probs)
                next_token = distrib.sample()
                # print("next token:",next_token)
                gen_seq[:, cur_i] = next_token

                # Let the transformer decide to end if it wants to
                if (next_token == TOKEN_END):
                    print("Model called end of sequence at:", cur_i, "/",
                          target_seq_length)
                    break

            cur_i += 1
            if (cur_i % 50 == 0):
                print(cur_i, "/", target_seq_length)

        return gen_seq[:, :cur_i]
Esempio n. 7
0
args.num_prime = 240
args.target_seq_length = 1200
args.rpr = True
args.output_dir = 'output'

print_generate_args(args)

# use prime from dataset
# _, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False)
# idx = round(np.random.rand()*len(dataset))
# primer, _  = dataset[idx]
# primer = primer.to(get_device())

raw_mid = encode_midi(args.primer_file)
primer, _ = process_midi(raw_mid, args.num_prime, random_seq=False)
primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device())

model = MusicTransformer(n_layers=args.n_layers,
                         num_heads=args.num_heads,
                         d_model=args.d_model,
                         dim_feedforward=args.dim_feedforward,
                         max_sequence=args.max_sequence,
                         rpr=args.rpr).to(get_device())

model.load_state_dict(torch.load(args.model_weights))

## optionally save primer

# Saving primer first
f_path = os.path.join(args.output_dir, "primer.mid")
decode_midi(primer[:args.num_prime].cpu().numpy(), file_path=f_path)
def main():
    """
    ----------
    Author: Damon Gwinn
    ----------
    Entry point. Trains a model specified by command line arguments
    ----------
    """

    args = parse_train_args()
    print_train_args(args)

    if (args.force_cpu):
        use_cuda(False)
        print("WARNING: Forced CPU usage, expect model to perform slower")
        print("")

    os.makedirs(args.output_dir, exist_ok=True)

    # Output prep
    params_file = os.path.join(args.output_dir, "model_params.txt")
    write_model_params(args, params_file)

    weights_folder = os.path.join(args.output_dir, "weights")
    os.makedirs(weights_folder, exist_ok=True)

    results_folder = os.path.join(args.output_dir, "results")
    os.makedirs(results_folder, exist_ok=True)

    # Datasets
    train_dataset, val_dataset, test_dataset = create_epiano_datasets(
        args.input_dir, args.max_sequence)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=args.n_workers,
                              shuffle=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            num_workers=args.n_workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=args.n_workers)

    model = MusicTransformer(n_layers=args.n_layers,
                             num_heads=args.num_heads,
                             d_model=args.d_model,
                             dim_feedforward=args.dim_feedforward,
                             dropout=args.dropout,
                             max_sequence=args.max_sequence,
                             rpr=args.rpr).to(get_device())

    # Continuing from previous training session
    start_epoch = 0
    if (args.continue_weights is not None):
        if (args.continue_epoch is None):
            print(
                "ERROR: Need epoch number to continue from (-continue_epoch) when using continue_weights"
            )
            return
        else:
            model.load_state_dict(torch.load(args.continue_weights))
            start_epoch = args.continue_epoch
    elif (args.continue_epoch is not None):
        print(
            "ERROR: Need continue weights (-continue_weights) when using continue_epoch"
        )
        return

    # Lr Scheduler vs static lr
    if (args.lr is None):
        if (args.continue_epoch is None):
            init_step = 0
        else:
            init_step = args.continue_epoch * len(train_loader)

        lr = LR_DEFAULT_START
        lr_stepper = LrStepTracker(args.d_model, SCHEDULER_WARMUP_STEPS,
                                   init_step)
    else:
        lr = args.lr

    # Not smoothing evaluation loss
    eval_loss = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD)

    # SmoothCrossEntropyLoss or CrossEntropyLoss for training
    if (args.ce_smoothing is None):
        train_loss = eval_loss
    else:
        train_loss = SmoothCrossEntropyLoss(args.ce_smoothing,
                                            VOCAB_SIZE,
                                            ignore_index=TOKEN_PAD)

    # Optimizer
    opt = Adam(model.parameters(),
               lr=lr,
               betas=(ADAM_BETA_1, ADAM_BETA_2),
               eps=ADAM_EPSILON)

    if (args.lr is None):
        lr_scheduler = LambdaLR(opt, lr_stepper.step)
    else:
        lr_scheduler = None

    best_acc = 0.0
    best_acc_epoch = -1
    best_loss = float("inf")
    best_loss_epoch = -1

    # TRAIN LOOP
    for epoch in range(start_epoch, args.epochs):
        print(SEPERATOR)
        print("NEW EPOCH:", epoch + 1)
        print(SEPERATOR)
        print("")

        train_epoch(epoch + 1, model, train_loader, train_loss, opt,
                    lr_scheduler)

        print(SEPERATOR)
        print("Evaluating:")

        cur_loss, cur_acc = eval_model(model, test_loader, eval_loss)

        print("Avg loss:", cur_loss)
        print("Avg acc:", cur_acc)
        print(SEPERATOR)
        print("")

        if (cur_acc > best_acc):
            best_acc = cur_acc
            best_acc_epoch = epoch + 1
        if (cur_loss < best_loss):
            best_loss = cur_loss
            best_loss_epoch = epoch + 1

        epoch_str = str(epoch + 1).zfill(PREPEND_ZEROS_WIDTH)

        if ((epoch + 1) % args.weight_modulus == 0):
            path = os.path.join(weights_folder,
                                "epoch_" + epoch_str + ".pickle")
            torch.save(model.state_dict(), path)

        path = os.path.join(results_folder, "epoch_" + epoch_str + ".txt")
        o_stream = open(path, "w")
        o_stream.write(str(cur_acc) + "\n")
        o_stream.write(str(cur_loss) + "\n")
        o_stream.close()

    print(SEPERATOR)
    print("Best acc epoch:", best_acc_epoch)
    print("Best acc:", best_acc)
    print("")
    print("Best loss epoch:", best_loss_epoch)
    print("Best loss:", best_loss)
Esempio n. 9
0
def main():
    """
    ----------
    Author: Damon Gwinn
    ----------
    Entry point. Trains a model specified by command line arguments
    ----------
    """

    args = parse_train_args()
    print_train_args(args)

    if (args.force_cpu):
        use_cuda(False)
        print("WARNING: Forced CPU usage, expect model to perform slower")
        print("")

    os.makedirs(args.output_dir, exist_ok=True)

    ##### Output prep #####
    params_file = os.path.join(args.output_dir, "model_params.txt")
    write_model_params(args, params_file)

    weights_folder = os.path.join(args.output_dir, "weights")
    os.makedirs(weights_folder, exist_ok=True)

    results_folder = os.path.join(args.output_dir, "results")
    os.makedirs(results_folder, exist_ok=True)

    results_file = os.path.join(results_folder, "results.csv")
    best_loss_file = os.path.join(results_folder, "best_loss_weights.pickle")
    best_acc_file = os.path.join(results_folder, "best_acc_weights.pickle")
    best_text = os.path.join(results_folder, "best_epochs.txt")

    ##### Tensorboard #####
    if (args.no_tensorboard):
        tensorboard_summary = None
    else:
        from torch.utils.tensorboard import SummaryWriter

        tensorboad_dir = os.path.join(args.output_dir, "tensorboard")
        tensorboard_summary = SummaryWriter(log_dir=tensorboad_dir)

    ##### Datasets #####
    train_dataset, val_dataset, test_dataset = create_epiano_datasets(
        args.input_dir, args.max_sequence)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=args.n_workers,
                              shuffle=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            num_workers=args.n_workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=args.n_workers)

    model = MusicTransformer(n_layers=args.n_layers,
                             num_heads=args.num_heads,
                             d_model=args.d_model,
                             dim_feedforward=args.dim_feedforward,
                             dropout=args.dropout,
                             max_sequence=args.max_sequence,
                             rpr=args.rpr).to(get_device())

    ##### Continuing from previous training session #####
    start_epoch = BASELINE_EPOCH
    if (args.continue_weights is not None):
        if (args.continue_epoch is None):
            print(
                "ERROR: Need epoch number to continue from (-continue_epoch) when using continue_weights"
            )
            return
        else:
            model.load_state_dict(torch.load(args.continue_weights))
            start_epoch = args.continue_epoch
    elif (args.continue_epoch is not None):
        print(
            "ERROR: Need continue weights (-continue_weights) when using continue_epoch"
        )
        return

    ##### Lr Scheduler vs static lr #####
    if (args.lr is None):
        if (args.continue_epoch is None):
            init_step = 0
        else:
            init_step = args.continue_epoch * len(train_loader)

        lr = LR_DEFAULT_START
        lr_stepper = LrStepTracker(args.d_model, SCHEDULER_WARMUP_STEPS,
                                   init_step)
    else:
        lr = args.lr

    ##### Not smoothing evaluation loss #####
    eval_loss_func = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD)

    ##### SmoothCrossEntropyLoss or CrossEntropyLoss for training #####
    if (args.ce_smoothing is None):
        train_loss_func = eval_loss_func
    else:
        train_loss_func = SmoothCrossEntropyLoss(args.ce_smoothing,
                                                 VOCAB_SIZE,
                                                 ignore_index=TOKEN_PAD)

    ##### Optimizer #####
    opt = Adam(model.parameters(),
               lr=lr,
               betas=(ADAM_BETA_1, ADAM_BETA_2),
               eps=ADAM_EPSILON)

    if (args.lr is None):
        lr_scheduler = LambdaLR(opt, lr_stepper.step)
    else:
        lr_scheduler = None

    ##### Tracking best evaluation accuracy #####
    best_eval_acc = 0.0
    best_eval_acc_epoch = -1
    best_eval_loss = float("inf")
    best_eval_loss_epoch = -1

    ##### Results reporting #####
    if (not os.path.isfile(results_file)):
        with open(results_file, "w", newline="") as o_stream:
            writer = csv.writer(o_stream)
            writer.writerow(CSV_HEADER)

    ##### TRAIN LOOP #####
    for epoch in range(start_epoch, args.epochs):
        # Baseline has no training and acts as a base loss and accuracy (epoch 0 in a sense)
        if (epoch > BASELINE_EPOCH):
            print(SEPERATOR)
            print("NEW EPOCH:", epoch + 1)
            print(SEPERATOR)
            print("")

            # Train
            train_epoch(epoch + 1, model, train_loader, train_loss_func, opt,
                        lr_scheduler, args.print_modulus)

            print(SEPERATOR)
            print("Evaluating:")
        else:
            print(SEPERATOR)
            print("Baseline model evaluation (Epoch 0):")

        # Eval
        train_loss, train_acc = eval_model(model, train_loader,
                                           train_loss_func)
        eval_loss, eval_acc = eval_model(model, test_loader, eval_loss_func)

        # Learn rate
        lr = get_lr(opt)

        print("Epoch:", epoch + 1)
        print("Avg train loss:", train_loss)
        print("Avg train acc:", train_acc)
        print("Avg eval loss:", eval_loss)
        print("Avg eval acc:", eval_acc)
        print(SEPERATOR)
        print("")

        new_best = False

        if (eval_acc > best_eval_acc):
            best_eval_acc = eval_acc
            best_eval_acc_epoch = epoch + 1
            torch.save(model.state_dict(), best_acc_file)
            new_best = True

        if (eval_loss < best_eval_loss):
            best_eval_loss = eval_loss
            best_eval_loss_epoch = epoch + 1
            torch.save(model.state_dict(), best_loss_file)
            new_best = True

        # Writing out new bests
        if (new_best):
            with open(best_text, "w") as o_stream:
                print("Best eval acc epoch:",
                      best_eval_acc_epoch,
                      file=o_stream)
                print("Best eval acc:", best_eval_acc, file=o_stream)
                print("")
                print("Best eval loss epoch:",
                      best_eval_loss_epoch,
                      file=o_stream)
                print("Best eval loss:", best_eval_loss, file=o_stream)

        if (not args.no_tensorboard):
            tensorboard_summary.add_scalar("Avg_CE_loss/train",
                                           train_loss,
                                           global_step=epoch + 1)
            tensorboard_summary.add_scalar("Avg_CE_loss/eval",
                                           eval_loss,
                                           global_step=epoch + 1)
            tensorboard_summary.add_scalar("Accuracy/train",
                                           train_acc,
                                           global_step=epoch + 1)
            tensorboard_summary.add_scalar("Accuracy/eval",
                                           eval_acc,
                                           global_step=epoch + 1)
            tensorboard_summary.add_scalar("Learn_rate/train",
                                           lr,
                                           global_step=epoch + 1)
            tensorboard_summary.flush()

        if ((epoch + 1) % args.weight_modulus == 0):
            epoch_str = str(epoch + 1).zfill(PREPEND_ZEROS_WIDTH)
            path = os.path.join(weights_folder,
                                "epoch_" + epoch_str + ".pickle")
            torch.save(model.state_dict(), path)

        with open(results_file, "a", newline="") as o_stream:
            writer = csv.writer(o_stream)
            writer.writerow(
                [epoch + 1, lr, train_loss, train_acc, eval_loss, eval_acc])

    # Sanity check just to make sure everything is gone
    if (not args.no_tensorboard):
        tensorboard_summary.flush()

    return
Esempio n. 10
0
def main():
    """
    ----------
    Author: Damon Gwinn
    ----------
    Entry point. Generates music from a model specified by command line arguments
    ----------
    """

    args = parse_generate_args()
    print_generate_args(args)

    if (args.force_cpu):
        use_cuda(False)
        print("WARNING: Forced CPU usage, expect model to perform slower")
        print("")

    os.makedirs(args.output_dir, exist_ok=True)

    # Grabbing dataset if needed
    _, _, dataset = create_epiano_datasets(args.midi_root,
                                           args.num_prime,
                                           random_seq=False)

    # Can be None, an integer index to dataset, or a file path
    if (args.primer_file is None):
        f = str(random.randrange(len(dataset)))
    else:
        f = args.primer_file

    if (f.isdigit()):
        idx = int(f)
        primer, _ = dataset[idx]
        primer = primer.to(get_device())

        print("Using primer index:", idx, "(", dataset.data_files[idx], ")")

    else:
        raw_mid = encode_midi(f)
        if (len(raw_mid) == 0):
            print("Error: No midi messages in primer file:", f)
            return

        primer, _ = process_midi(raw_mid, args.num_prime, random_seq=False)
        primer = torch.tensor(primer,
                              dtype=TORCH_LABEL_TYPE,
                              device=get_device())

        print("Using primer file:", f)

    model = MusicTransformer(n_layers=args.n_layers,
                             num_heads=args.num_heads,
                             d_model=args.d_model,
                             dim_feedforward=args.dim_feedforward,
                             max_sequence=args.max_sequence,
                             rpr=args.rpr).to(get_device())

    model.load_state_dict(torch.load(args.model_weights))

    # Saving primer first
    f_path = os.path.join(args.output_dir, "primer.mid")
    decode_midi(primer[:args.num_prime].cpu().numpy(), file_path=f_path)

    # GENERATION
    model.eval()
    with torch.set_grad_enabled(False):
        if (args.beam > 0):
            print("BEAM:", args.beam)
            beam_seq = model.generate(primer[:args.num_prime],
                                      args.target_seq_length,
                                      beam=args.beam)

            f_path = os.path.join(args.output_dir, "beam.mid")
            decode_midi(beam_seq[0].cpu().numpy(), file_path=f_path)
        else:
            print("RAND DIST")
            rand_seq = model.generate(primer[:args.num_prime],
                                      args.target_seq_length,
                                      beam=0)

            f_path = os.path.join(args.output_dir, "rand.mid")
            decode_midi(rand_seq[0].cpu().numpy(), file_path=f_path)