예제 #1
0
def reconstruct(model, device, melody_info, out_dirname=None):
    """"Reconstructs melody by given model.

    melody_info is string, containing a comma separated list of midi_path, melody_idx, start_bar"""
    melody, (midi_path, melody_idx,
             start_bar) = get_melody_from_info(melody_info)

    encode_event_fn = get_event_encoder(MIN_PITCH, MAX_PITCH,
                                        NUM_SPECIAL_EVENTS)
    decode_event_fn = get_event_decoder(MIN_PITCH, MAX_PITCH,
                                        NUM_SPECIAL_EVENTS)

    out_melody = reconstruct_melody(model, device, melody, encode_event_fn,
                                    decode_event_fn)

    fig = plt.figure(figsize=(15, 10))
    ax = fig.add_subplot(211)
    pp.plot_pianoroll(ax, melody_lib.melody_to_pianoroll(melody))
    plt.title('Original melody')

    ax = fig.add_subplot(212)
    pp.plot_pianoroll(ax, melody_lib.melody_to_pianoroll(out_melody))
    plt.title('Reconstruction')

    if out_dirname is not None:
        plot_basename = "recon_{}_{}_{}.png".format(
            Path(midi_path).stem, melody_idx, start_bar)
        fig.savefig(str(out_dirname / plot_basename))
    plt.show()
    plt.close()
예제 #2
0
    def log_reconstruction(self,
                           label,
                           output,
                           trg,
                           global_step,
                           max_results=3):
        pred_sequence = output[:max_results].argmax(dim=-1)
        target_sequence = trg[:max_results].cpu()

        n_results = pred_sequence.size(0)
        fig = plt.figure(figsize=(15, 8))
        for i in range(n_results):
            ax = fig.add_subplot(2, n_results, i + 1)
            pp.plot_pianoroll(
                ax,
                melody_lib.melody_to_pianoroll(
                    self.melody_dict.sequence_to_melody(target_sequence[i])))
            plt.title('Original melody')

            ax = fig.add_subplot(2, n_results, n_results + i + 1)
            pp.plot_pianoroll(
                ax,
                melody_lib.melody_to_pianoroll(
                    self.melody_dict.sequence_to_melody(pred_sequence[i])))
            plt.title('Reconstruction')

        self.writer.add_figure(f"{label}.recon", fig, global_step)
        plt.close(fig)
예제 #3
0
def plot_pianoroll_from_midi(midi_path, attr_labels, attr_str, type):
    pr_a = pretty_midi.PrettyMIDI(midi_path)
    piano_roll = pr_a.get_piano_roll().astype('int').T
    beat_resolution = 100
    if len(pr_a.instruments) == 0:
        return
    note_list = pr_a.instruments[0].notes
    num_measures = int(piano_roll.shape[0] / (2 * beat_resolution))
    downbeats = [i * 2 * beat_resolution for i in range(num_measures)]

    shaded_piano_roll = np.zeros_like(piano_roll)
    for i in range(0, shaded_piano_roll.shape[1], 2):
        shaded_piano_roll[:, i] = 30
    for i in range(0, shaded_piano_roll.shape[0], 25):
        shaded_piano_roll[i, :] = 50
    for note in note_list:
        start = int(note.start * beat_resolution)
        pitch = int(note.pitch)
        piano_roll[start:start + 5, pitch - 1:pitch + 2] = 127
    shaded_piano_roll[piano_roll != 0] = piano_roll[piano_roll != 0]

    figsize = (16, 2)
    f, (ax1, ax2) = plt.subplots(1,
                                 2,
                                 figsize=figsize,
                                 gridspec_kw={'width_ratios': [6, 1]})
    pypianoroll.plot_pianoroll(
        ax1,
        shaded_piano_roll,
        downbeats=downbeats,
        beat_resolution=2 * beat_resolution,
        xtick='beat',
    )
    f.set_facecolor('white')
    if type == 'folk':
        ax1.set_ylim(55, 84)
    elif type == 'bach':
        ax1.set_ylim(55, 90)
    ax1.set_ylabel('Pitch')
    ax1.set_yticklabels([])
    ax1.set_yticks([])
    plt.tight_layout()
    ax1.set_xlabel('')
    save_path = os.path.join(
        os.path.dirname(midi_path),
        f'{os.path.splitext(os.path.basename(midi_path))[0]}.png')
    x = [n + 1 for n in range(attr_labels.size)]
    # ax2.bar(x, attr_labels, color='k')
    ax2.plot(x, attr_labels, 'o', color='k', markersize=7)
    ax2.set_ylabel(attr_str)
    # if attr_str == 'contour':
    #     ax2.set_ylim(-0.7, 0.7)
    # else:
    #     ax2.set_ylim(-0.1, 0.5)
    ax2.set_yticklabels([])
    ax2.set_xticks(np.arange(1, num_measures + 1))
    plt.savefig(save_path, dpi=500)
    plt.close()
예제 #4
0
def melody_slideshow(generator):
    plt.ion()
    fig = plt.figure()
    ax = fig.add_subplot(111)

    for i, melody in enumerate(generator):
        pp.plot_pianoroll(ax, melody_lib.melody_to_pianoroll(melody))
        fig.canvas.draw()
        fig.canvas.flush_events()
예제 #5
0
def sample_cmd(model,
               device,
               sequence_length,
               seed_melody_info=None,
               out_dirname=None,
               to_midi=False):

    model.eval()
    with torch.no_grad():
        encode_event_fn = get_event_encoder(MIN_PITCH, MAX_PITCH,
                                            NUM_SPECIAL_EVENTS)
        decode_event_fn = get_event_decoder(MIN_PITCH, MAX_PITCH,
                                            NUM_SPECIAL_EVENTS)

        if seed_melody_info is not None:
            seed_melody, (midi_path, melody_idx,
                          start_bar) = get_melody_from_info(seed_melody_info)
            input_sequence = melody_to_sequence(seed_melody,
                                                encode_event_fn).to(device)

            mu, sigma = model.encode(input_sequence.unsqueeze(dim=0))
            z = model.reparameterize(mu, sigma)
        else:
            z = torch.randn((1, model.decoder.z_dim)).to(device)

        output_sequences, _, _ = model.decode(z,
                                              sequence_length=sequence_length)
        out_melody = sequence_to_melody(
            output_sequences.view(-1, output_sequences.size(-1)),
            decode_event_fn)

        fig = plt.figure(figsize=(20, 10))
        ax = fig.add_subplot(111)
        pp.plot_pianoroll(ax, melody_lib.melody_to_pianoroll(out_melody))
        ax.set_xticks(
            np.arange(MELODY_LENGTH, out_melody.shape[0], MELODY_LENGTH))
        plt.title("Sampeling")

        if out_dirname is not None:
            fileroot = "sample_len{}".format(sequence_length)
            if seed_melody_info is not None:
                fileroot += "_{}_{}_{}".format(
                    Path(midi_path).stem, melody_idx, start_bar)

            if to_midi:
                pm = melody_lib.melody_to_midi(out_melody)
                pm.write(str(out_dirname / (fileroot + ".mid")))

            fig.savefig(str(out_dirname / (fileroot + ".png")))

        plt.show()
        plt.close()
def plot_midifile(filepath, samples_dir, name):
    roll = None
    try:
        roll = pypianoroll.Multitrack(filepath,
                                      beat_resolution=4).tracks[0].pianoroll
    except Exception as _:
        return None
    plt.figure(figsize=(14, 8))
    ax = plt.gca()
    pypianoroll.plot_pianoroll(ax, roll)
    plt.title(name)
    pathtopng = os.path.join(samples_dir, name)
    print('plotting pianoroll to %s' % pathtopng)
    plt.savefig(pathtopng, bbox_inches='tight')
    return True
예제 #7
0
def melody_to_graph(melody, figsize=None, start_bar=0):
    pianoroll = melody_lib.melody_to_pianoroll(melody)

    figsize = figsize or (10, 7)
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    pp.plot_pianoroll(ax, pianoroll)
    ticks = np.arange(0, pianoroll.shape[0], STEPS_PER_BAR * 2)

    ax.set_xticks(ticks)
    ax.set_xticklabels(np.arange(start_bar, start_bar + len(ticks) * 2, 2))
    # ax.set_xticklabels(np.arange(0, len(ticks)*2, 2))
    ax.grid(True, axis='x')

    plt.tight_layout()

    return fig_to_base64(fig)
예제 #8
0
def log_reconstruction(logger, tag, pred_sequence, target_sequence, step):
    pred_sequence = pred_sequence[:MAX_N_RESULTS].cpu()
    target_sequence = target_sequence[:MAX_N_RESULTS].cpu()

    n_results = pred_sequence.size(0)

    recon_sampled_melodies = torch.distributions.categorical.Categorical(
        logits=pred_sequence).sample().numpy()
    recon_argmax_melodies = pred_sequence.argmax(dim=-1).numpy()
    target_melodies = target_sequence.argmax(dim=-1).numpy()

    fig = plt.figure(figsize=(15, 10))
    for i in range(n_results):
        ax = fig.add_subplot(3, n_results, i + 1)
        pp.plot_pianoroll(
            ax,
            melody_lib.melody_to_pianoroll(decode_melody(target_melodies[i])))
        plt.title('Original melody')

        ax = fig.add_subplot(3, n_results, n_results + i + 1)
        pp.plot_pianoroll(
            ax,
            melody_lib.melody_to_pianoroll(
                decode_melody(recon_argmax_melodies[i])))
        plt.title('Reconstruction (argmax)')

        ax = fig.add_subplot(3, n_results, 2 * n_results + i + 1)
        pp.plot_pianoroll(
            ax,
            melody_lib.melody_to_pianoroll(
                decode_melody(recon_sampled_melodies[i])))
        plt.title('Reconstruction (sampled)')
    logger.add_figure(tag, fig, step)
    plt.close(fig)
예제 #9
0
def list_melodies(midi_path, out_dirname=None):
    """Lists extracted melodies of midi file given by path.
    Shows melody idx, length and optionally path to melody plot"""
    melodies = extract_melodies(midi_path)

    pp_multitrack = pp.parse(midi_path)

    for i, melody in enumerate(melodies):
        fig = plt.figure(figsize=(15, 10))
        ax = fig.add_subplot(211)

        pianoroll = melody_lib.melody_to_pianoroll(melody["events"])
        pp.plot_pianoroll(ax, pianoroll)

        ticks = np.arange(0, pianoroll.shape[0], STEPS_PER_BAR * 2)
        ax.set_xticks(ticks)
        ax.set_xticklabels(np.arange(0, len(ticks)) * 2)
        ax.grid(True, axis='x')
        plt.title('Extracted melody')

        ax = fig.add_subplot(212)
        pp.plot_pianoroll(ax,
                          pp_multitrack.tracks[melody["instrument"]].pianoroll)
        plt.title('Original midi track')

        plot_basename = ""
        if out_dirname is not None:
            plot_basename = "lst_{}_{}.png".format(Path(midi_path).stem, i)
            fig.savefig(str(out_dirname / plot_basename))
        plt.show()
        plt.close()

        print("idx: {:2d},\tlength: {:3d}{}".format(
            i,
            len(melody["events"]) // 16,
            "" if out_dirname is None else ",\t" + plot_basename))
def interpolate(sample1_path,
                sample2_path,
                model,
                sample1_bar=0,
                sample2_bar=0,
                temperature=0.5,
                smooth_threshold=0,
                play_loud=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if model.train():
        model.eval()

    with torch.no_grad():
        sample1 = getSlicedPianorollMatrixNp(sample1_path)
        sample1 = transposeNotesHigherLower(sample1)
        sample1 = cutOctaves(sample1[sample1_bar])
        sample2 = getSlicedPianorollMatrixNp(sample2_path)
        sample2 = transposeNotesHigherLower(sample2)
        sample2 = cutOctaves(sample2[sample2_bar])

        #prepare for input
        sample1 = torch.from_numpy(sample1.reshape(1, 1, 96,
                                                   60)).float().to(device)
        sample2 = torch.from_numpy(sample2.reshape(1, 1, 96,
                                                   60)).float().to(device)

        # embed both sequences
        embed1, _ = model.encoder(sample1)
        embed2, _ = model.encoder(sample2)

        # for hamming distance
        recon1 = model.decoder(embed1)
        recon1 = torch.softmax(recon1, dim=3)
        recon1 = recon1.squeeze(0).squeeze(0).cpu().numpy()
        # recon1 /= np.max(np.abs(recon1))
        recon1[recon1 < (1 - temperature)] = 0
        recon1 = debinarizeMidi(recon1, prediction=False)
        recon1 = addCuttedOctaves(recon1)
        recon1[recon1 > 0] = 1
        hamming1 = recon1.flatten()

        recon2 = model.decoder(embed2)
        recon2 = torch.softmax(recon2, dim=3)
        recon2 = recon2.squeeze(0).squeeze(0).cpu().numpy()
        # recon2 /= np.max(np.abs(recon2))
        recon2[recon2 < (1 - temperature)] = 0
        recon2 = debinarizeMidi(recon2, prediction=False)
        recon2 = addCuttedOctaves(recon2)
        recon2[recon2 > 0] = 1
        hamming2 = recon2.flatten()

        hamming_dists1 = []
        hamming_dists2 = []

        for i, a in enumerate(range(0, 11)):
            alpha = a / 10.
            c = (1. - alpha) * embed1 + alpha * embed2

            # decode current interpolation
            recon = model.decoder(c)
            recon = torch.softmax(recon, dim=3)
            recon = recon.squeeze(0).squeeze(0).cpu().numpy()
            # recon /= np.max(np.abs(recon))
            recon[recon < (1 - temperature)] = 0
            recon = debinarizeMidi(recon, prediction=False)
            recon = addCuttedOctaves(recon)
            if smooth_threshold:
                smoother = NoteSmoother(recon, threshold=smooth_threshold)
                recon = smoother.smooth()
            #for current hamming
            recon_hamm = recon.flatten()
            recon_hamm[recon_hamm > 0] = 1
            current_hamming1 = hamming(hamming1, recon_hamm)
            current_hamming2 = hamming(hamming2, recon_hamm)
            hamming_dists1.append(current_hamming1)
            hamming_dists2.append(current_hamming2)

            # plot piano roll
            if i == 0:
                recon_plot = recon
            else:
                recon_plot = np.concatenate((recon_plot, recon), axis=0)

            print("alpha = {}".format(alpha))
            print("Hamming distance to sequence 1 is {}".format(
                current_hamming1))
            print("Hamming distance to sequence 2 is {}".format(
                current_hamming2))
            if play_loud:
                pianorollMatrixToTempMidi(recon,
                                          prediction=True,
                                          show=True,
                                          showPlayer=False,
                                          autoplay=True)

        alphas = np.arange(0, 1.1, 0.1)
        fig, ax = plt.subplots()
        ax.plot(alphas, hamming_dists1)
        ax.plot(alphas, hamming_dists2)
        ax.grid()

        fig2, ax2 = plt.subplots()
        # recon_plot = ppr.Track(recon_plot)
        downbeats = [i * 96 for i in range(11)]
        # recon_plot.plot(ax, downbeats=downbeats)
        ppr.plot_pianoroll(ax2, recon_plot, downbeats=downbeats)
        plt.show()
        if result.shape[1] > truth.shape[1]:
            result = result[:, :truth.shape[1]]
        elif result.shape[1] < truth.shape[1]:
            result = np.pad(result,
                            ((0, 0), (0, truth.shape[1] - result.shape[1])),
                            'constant',
                            constant_values=0)
        p, r, f1 = multipitch_evaluation(result, truth, raw_value=False)
        print("Precision: %.4f, Recall: %.4f, F-score:, %.4f" % (p, r, f1))

        result_proll = np.pad(result.T, ((0, 0), (21, 19)),
                              'constant',
                              constant_values=0)
        truth_proll = np.pad(truth.T, ((0, 0), (21, 19)),
                             'constant',
                             constant_values=0)
        f, axes = plt.subplots(2, 1)
        plot_pianoroll(axes[0], truth_proll)
        axes[0].set_title('ground truth')
        plot_pianoroll(axes[1], result_proll)
        axes[1].set_title('predict')
        f.suptitle(basename)
    else:
        pianoroll = np.pad(result.T, ((0, 0), (21, 19)),
                           'constant',
                           constant_values=0)
        ax = plt.gca()
        plot_pianoroll(ax, pianoroll)
        plt.title('predict')
    plt.show()
예제 #12
0
def interpolate(model,
                device,
                start_melody_info,
                end_melody_info,
                num_steps,
                out_dirname=None,
                to_midi=False):
    """Interpolates from start melody to end melody using (num_steps - 2) in between.

    x_melody_info contains is the tuple (midi_path, melody_idx, start_bar)
    for start respectively end melody."""
    def _slerp(p0, p1, t):
        """Spherical linear interpolation."""
        omega = np.arccos(
            np.dot(np.squeeze(p0 / np.linalg.norm(p0)),
                   np.squeeze(p1 / np.linalg.norm(p1))))
        so = np.sin(omega)
        return np.sin(
            (1.0 - t) * omega) / so * p0 + np.sin(t * omega) / so * p1

    start_melody, (midi_path1, melody_idx1,
                   start_bar1) = get_melody_from_info(start_melody_info)
    end_melody, (midi_path2, melody_idx2,
                 start_bar2) = get_melody_from_info(end_melody_info)

    model.eval()
    with torch.no_grad():
        encode_event_fn = get_event_encoder(MIN_PITCH, MAX_PITCH,
                                            NUM_SPECIAL_EVENTS)
        decode_event_fn = get_event_decoder(MIN_PITCH, MAX_PITCH,
                                            NUM_SPECIAL_EVENTS)

        start_sequence = melody_to_sequence(start_melody, encode_event_fn)
        end_sequence = melody_to_sequence(end_melody, encode_event_fn)
        input_sequences = torch.stack(
            (start_sequence, end_sequence)).to(device)

        mu, sigma = model.encode(input_sequences)
        z = model.reparameterize(mu, sigma)
        z = z.cpu(
        )  # this needs to be done in order to interpolate the way i do it, maybe there is a better way
        interpolated_z = torch.stack([
            _slerp(z[0], z[1], t) for t in np.linspace(0, 1, num_steps)
        ]).to(device)

        output_sequences, _, _ = model.decode(
            interpolated_z, sequence_length=input_sequences.size(1))
        out_melody = sequence_to_melody(
            output_sequences.view(-1, output_sequences.size(-1)),
            decode_event_fn)

        fig = plt.figure(figsize=(20, 10))
        ax = fig.add_subplot(111)
        pp.plot_pianoroll(ax, melody_lib.melody_to_pianoroll(out_melody))
        ax.set_xticks(
            np.arange(MELODY_LENGTH, out_melody.shape[0], MELODY_LENGTH))
        plt.title(
            "Interpolate like there is no tomorrow, but be assured there always is."
        )

        if out_dirname is not None:
            fileroot = "interpolate_{}_{}_{}_to_{}_{}_{}.png".format(
                Path(midi_path1).stem, melody_idx1, start_bar1,
                Path(midi_path2).stem, melody_idx2, start_bar2)

            if to_midi:
                pm = melody_lib.melody_to_midi(out_melody)
                pm.write(str(out_dirname / (fileroot + ".mid")))

            fig.savefig(str(out_dirname / (fileroot + ".png")))
        plt.show()
        plt.close()