def __init__(self, num_shifts):
        self.num_shifts = num_shifts

        utils.reset_directory(autotune_preprocessed_directory, empty=False)
        logger.info("Results will be saved in: {0}".format(autotune_preprocessed_directory))

        # load the performance indices for the datasets: all ids that have an instance in all directories
        performance_list = sorted(list(
            set([f[:-4] for f in os.listdir(pyin_directory) if "npy"]) &
            set([f[:-4] for f in os.listdir(midi_directory) if "npy"]) &
            set([f[:-4] for f in os.listdir(back_chroma_directory) if "npy"])))

        training_list, validation_list, test_list = split_into_training_validation_test(performance_list)

        print("training list", len(training_list), training_list[:5], training_list[-5:])
        training_list = training_list[2500:] + training_list[:2500]
        print("rotated training list", len(training_list), training_list[:5], training_list[-5:])
        logger.info("training list {0} validation list {1}".format(len(training_list), len(validation_list)))
        # build custom datasets
        freeze = False
        self.device = ("cpu")
        logger.info("fixing some of the shifts across all songs: {0}".format(freeze))
        self.training_dataset = get_dataset(data_list=training_list, num_shifts=self.num_shifts,
                                            songs_per_batch=1, mode="training", device=self.device, freeze=freeze)
        self.validation_dataset = get_dataset(data_list=validation_list, num_shifts=self.num_shifts,
                                              songs_per_batch=1, mode="testing", device=self.device, freeze=freeze)
        self.test_dataset = get_dataset(data_list=test_list, num_shifts=self.num_shifts,
                                        songs_per_batch=1, mode="testing", device=self.device, freeze=freeze)
Esempio n. 2
0
def main(args):
    utils.reset_directory(os.path.join(base_directory, analysis_dir))
    utils.reset_directory(differences_dir)
    intonation_hist_path = os.path.join(base_directory, analysis_dir, "intonation_hist.pkl")
    clustering_hist_path = os.path.join(base_directory, analysis_dir, "clustering_hist.pkl")

    plots_dir_intonation = os.path.join(base_directory, "plots/Intonation/")
    plots_dir_clustering = os.path.join(base_directory, "plots/clustering_data_sanna/")

    utils.reset_directory(plots_dir_intonation)
    utils.reset_directory(plots_dir_clustering)

    plt.style.use('classic')
    if args.get_histogram is True:
        # load the files if they exist
        if os.path.exists(intonation_hist_path):
            with open(intonation_hist_path, "rb") as fname:
                intonation_hist = pickle.load(fname)
            with open(clustering_hist_path, "rb") as fname:
                clustering_hist = pickle.load(fname)
        # otherwise, run the analysis
        else:
            print("computing histograms...")
            intonation_hist = get_histogram(intonation_pitch_dir, intonation_midi_dir,
                    plots_dir_intonation, differences_dir, args.max_count)
            clustering_hist = get_histogram(clustering_pitch_dir, clustering_midi_dir,
                    plots_dir_clustering, differences_dir, args.max_count)
            with open(os.path.join(base_directory, analysis_dir, "intonation_hist.pkl"), "wb") as fname:
                pickle.dump(intonation_hist, fname)
            with open(os.path.join(base_directory, analysis_dir, "clustering_hist.pkl"), "wb") as fname:
                pickle.dump(clustering_hist, fname)
        # process and normalize
        print("clustering", clustering_hist.most_common(10))
        print("intonation", intonation_hist.most_common(10))
        intonation_hist = np.array(list(map(list, zip(*sorted(intonation_hist.items())))))
        clustering_hist = np.array(list(map(list, zip(*sorted(clustering_hist.items())))))
        print("sums", np.sum(intonation_hist[1]), np.sum(clustering_hist[1]), )
        normalization = np.sum(intonation_hist[1])/np.sum(clustering_hist[1])
        clustering_hist[1] = (clustering_hist[1] * normalization).astype(int)

        # plot full histograms comparison
        # linear scale
        fig = plt.figure(figsize=(8, 5))
        plt.plot(clustering_hist[0], (clustering_hist[1]),
                label="Remaining clusters", color="red", linestyle="dotted")
        plt.plot(intonation_hist[0], (intonation_hist[1]),
                label="Selected clusters", color="blue", linewidth=0.75, linestyle="solid")
        plt.xlabel("Deviations (cents)", fontsize='large')
        plt.ylabel("Occurrences in 1000s", fontsize='large')
        plt.xlim(-1600, 1600)
        plt.ylim(0, 213000)
        ax = plt.axes()
        ax.xaxis.label.set_size(18)
        ax.yaxis.label.set_size(18)
        ticks_y = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x / 1000))
        ax.yaxis.set_major_formatter(ticks_y)
        # log scale
        ax2 = ax.twinx()
        ax2.set_ylabel("Log of occurrences", fontsize='large')
        ax2.set_ylim(0, 40)
        ax2.yaxis.label.set_size(18)
        ax2.plot(clustering_hist[0], np.log(clustering_hist[1]+1),
                label="Remaining (log)", color="orange", linestyle="dotted")
        ax2.plot(intonation_hist[0], np.log(intonation_hist[1]+1),
                label="Selected (log)", color="green", linewidth=0.75, linestyle="solid")

        lines, labels = ax.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax2.legend(lines + lines2, labels + labels2, loc=0)

        plt.tight_layout()
        fig.savefig(os.path.join(plots_dir_intonation, "full_histograms_comparison.eps"), format="eps")
        fig.show()

        plt.style.use('ggplot')
        # plot full histograms comparison
        # linear scale
        plt.plot(clustering_hist[0], (clustering_hist[1]),
                 label="Remaining clusters", color="red", linewidth=0.75)
        plt.plot(intonation_hist[0], (intonation_hist[1]),
                 label="Selected clusters", color="blue", linewidth=0.75)
        plt.xlabel("Deviations (cents)")
        plt.ylabel("Occurrences in 1000s")
        plt.xlim(-500, 500)
        plt.ylim(0, 25000)
        ax = plt.axes()
        ticks_y = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x / 1000))
        ax.yaxis.set_major_formatter(ticks_y)
        # log scale
        ax2 = ax.twinx()
        ax2.set_ylabel("Log of occurrences")
        ax2.set_ylim(0, 40)
        ax2.plot(clustering_hist[0], np.log(clustering_hist[1] + 1),
                 label="Remaining clusters (log)", color="orange", linewidth=0.75)
        ax2.plot(intonation_hist[0], np.log(intonation_hist[1] + 1),
                 label="Selected clusters (log)", color="green", linewidth=0.75)

        lines, labels = ax.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax2.legend(lines + lines2, labels + labels2, loc=0)

        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir_intonation, "full_histograms_comparison_zoom.eps"), format="eps")
        plt.show()

        rng = 100
        ind_pos = np.arange(rng + 1)  # the x locations for the groups
        width = 0.35  # the width of the bars

        intonation_pos = np.zeros(rng + 1)
        intonation_neg = np.zeros(rng + 1)
        for i in range(1, rng + 1):
            if i in intonation_hist[0]:
                intonation_pos[i] = intonation_hist[1][np.where(intonation_hist[0] == i)[0][0]]
            if -i in clustering_hist[0]:
                intonation_neg[i] = intonation_hist[1][np.where(intonation_hist[0] == -i)[0][0]]

        clustering_pos = np.zeros(rng + 1)
        clustering_neg = np.zeros(rng + 1)
        for i in range(1, rng + 1):
            if i in clustering_hist[0]:
                clustering_pos[i] = clustering_hist[1][np.where(clustering_hist[0] == i)[0][0]]
            if -i in clustering_hist[0]:
                clustering_neg[i] = clustering_hist[1][np.where(clustering_hist[0] == -i)[0][0]]

        fig, ax = plt.subplots(figsize=(6, 4))
        # matplotlib.rcParams.update({'font.size': 18})
        plt.plot(intonation_pos[1:], color='#66b3ff', linestyle=":", label="Selected clusters: Positive")
        plt.plot(intonation_neg[1:], color='#000099', linestyle="-.", label="Selected clusters: Negative")
        plt.plot(clustering_pos[1:], color='#ff751a', linestyle="--", label='Remaining clusters: Positive')
        plt.plot(clustering_neg[1:], color='#cc2900', linestyle="-", label='Remaining clusters: Negative')
        plt.legend(loc="upper right")
        ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
        ticks_y = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x / 1000))
        ax.yaxis.set_major_formatter(ticks_y)
        ax.set_ylabel('Occurrences in 1000s')
        ax.set_xlabel('Deviations (cents)')
        ax.set_xlim(0.5, 100.5)
        ax.set_ylim(0, 250000)
        # for i in range(5, 100, 5):
        #     plt.axvline(x=i, ls='dotted', color="green", linewidth=0.9)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir_intonation, "full_pos_vs_neg_line.eps"), format="eps")
        plt.show()

        width = 0.24
        fig, ax = plt.subplots()
        rects1 = ax.bar(ind_pos - width * 1.5, intonation_pos, width, color='#66b3ff')
        rects2 = ax.bar(ind_pos - width * 0.5, intonation_neg, width, color='#000099')
        rects3 = ax.bar(ind_pos + width * 0.5, clustering_pos, width, color='#ff751a')
        rects4 = ax.bar(ind_pos + width * 1.5, clustering_neg, width, color='#cc2900')
        # add some text for labels, title and axes ticks
        ax.set_ylabel('Occurrences in 1000s')
        ax.set_xlabel('Deviations (cents)')
        ax.legend((rects1[0], rects2[0], rects3[0], rects4[0]), ('Selected clusters: Positive', 'Selected clusters: Negative',
                'Remaining clusters: Positive', 'Remaining clusters: Negative'))
        plt.ylim(0, 160000)
        plt.xlim(0.5, 100.5)
        ticks_x = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x / 1000))
        ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
        ticks_y = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x / 1000))
        ax.yaxis.set_major_formatter(ticks_y)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir_intonation, "full_pos_vs_neg.eps"), format="eps")
        plt.show()


        width = 0.4
        fig, ax = plt.subplots()
        rects3 = ax.bar(ind_pos - width * 0.5, clustering_pos, width, color='#ff751a')
        rects4 = ax.bar(ind_pos + width * 0.5, clustering_neg, width, color='#cc2900')
        # add some text for labels, title and axes ticks
        ax.set_ylabel('Occurrences in 1000s')
        ax.set_xlabel('Deviations (cents)')
        ax.legend((rects3[0], rects4[0]), ('Remaining clusters: Positive', 'Remaining clusters: Negative'))
        plt.ylim(0, 160000)
        plt.xlim(0.5, 60.5)
        ticks_x = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x / 1000))
        ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
        ticks_y = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x / 1000))
        ax.yaxis.set_major_formatter(ticks_y)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir_intonation, "clustering_pos_vs_neg.eps"), format="eps")
        plt.show()

        width = 0.4
        fig, ax = plt.subplots()
        rects1 = ax.bar(ind_pos - width * 0.5, intonation_pos, width, color='#66b3ff')
        rects2 = ax.bar(ind_pos + width * 0.5, intonation_neg, width, color='#000099')
        # add some text for labels, title and axes ticks
        ax.set_ylabel('Occurrences in 1000s')
        ax.set_xlabel('Deviations (cents)')
        ax.legend((rects1[0], rects2[0]), ('Selected clusters: Positive', 'Selected clusters: Negative'))
        plt.ylim(0, 160000)
        plt.xlim(0.5, 60.5)
        ticks_x = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x / 1000))
        ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
        ticks_y = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x / 1000))
        ax.yaxis.set_major_formatter(ticks_y)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir_intonation, "intonation_pos_vs_neg.eps"), format="eps")
        plt.show()

    if args.bootstrap is True:
        bootstrap(intonation_midi_dir, differences_dir, cents_min=args.bootstrap_cents_min,
                  cents_max=args.bootstrap_cents_max)
        bootstrap(clustering_midi_dir, differences_dir, cents_min=args.bootstrap_cents_min,
                  cents_max=args.bootstrap_cents_max)

    plot_pipeline(differences_dir, plots_dir_intonation)
import pickle

import torch
import segmentation_models_pytorch

from dataset import create_segmentation_dataloaders
from utils import reset_directory, train_segmentation_model
from history import plot_segmentation_history

num_epochs = 50
mode = "segmentation"
device = torch.device("cuda")

reset_directory(mode)

dataset_path = "/home/shouki/Desktop/Programming/Python/AI/Datasets/ImageData/Covid19XrayImageSegmentationDataset"
image_size = (224, 224)
batch_size = 16
train_dataloader, validation_dataloader = create_segmentation_dataloaders(dataset_path, image_size, "train", batch_size), create_segmentation_dataloaders(dataset_path, image_size, "validation", batch_size)

model = segmentation_models_pytorch.Unet("resnet101", encoder_weights="imagenet", classes=1, activation=None).to(device)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=3, verbose=True)

history = train_segmentation_model(model, criterion, optimizer, scheduler, num_epochs, train_dataloader, validation_dataloader, device)
plot_segmentation_history(history, num_epochs, mode)

with open("./histories/segmentation/history.pkl", "wb") as f:
  pickle.dump(history, f)
Esempio n. 4
0
    def __init__(self, extension, hidden_size, global_dropout, learning_rate, resume, small_dataset, max_norm,
                 num_layers, num_shifts, epochs, report_step, sandbox):
        self.extension = extension
        self.hidden_size = hidden_size
        self.global_dropout = global_dropout
        self.learning_rate = learning_rate
        self.resume = resume
        self.small_dataset = small_dataset
        self.max_norm = max_norm
        self.num_layers = num_layers
        self.num_shifts = num_shifts
        self.epochs = epochs
        self.report_step = report_step
        self.sandbox = sandbox
        self.is_best = False

        # set paths, make sure all directories exist and delete results from previous runs
        self.results_root = "./results_root" + self.extension  # root directory for plots and results
        self.results_directory = os.path.join(self.results_root, "rnn_results")
        self.plot_directory = os.path.join(self.results_root, "plots")
        self.parameter_plot_directory = os.path.join(self.plot_directory, "parameter_visualization")
        self.layer_plot_directory = os.path.join(self.plot_directory, "layer_visualization")
        self.user_prediction_directory = os.path.join(self.plot_directory, "user_prediction")
        self.test_prediction_directory = os.path.join(self.user_prediction_directory, "test_prediction")
        self.test_results_directory = os.path.join(self.results_directory, "test")
        utils.reset_directory(self.results_root, empty=False)
        utils.reset_directory(self.results_directory, empty=True)
        utils.reset_directory(self.user_prediction_directory, empty=True)
        utils.reset_directory(self.layer_plot_directory, empty=True)
        utils.reset_directory(self.test_prediction_directory, empty=True)
        utils.reset_directory(self.test_results_directory, empty=True)
        utils.reset_directory(pytorch_models_directory, empty=False)
        utils.reset_directory(autotune_preprocessed_directory, empty=False)
        logger.info("preprocessed dir: {0}".format(autotune_preprocessed_directory))
        # pytorch checkpoint directories
        self.resume_file = os.path.join(pytorch_models_directory, 'model_best' + self.extension + '.pth.tar')
        # save latest model parameters to this checkpoint
        self.latest_checkpoint_file = os.path.join(
            pytorch_models_directory, 'checkpoint_rnn' + self.extension + '.pth.tar')
        # save model parameters with best validation loss to this checkpoint
        self.best_checkpoint_file = os.path.join(pytorch_models_directory, 'model_best' + self.extension + '.pth.tar')

        # gpu versus cpu device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # model
        self.model = ConvRNN(hidden_size=self.hidden_size, num_layers=self.num_layers).to(self.device)
        utils.print_param_sizes(self.model)

        # error and loss
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

        # initialize training parameters or load them from checkpoint if the file exists and resume is set to True
        self.best_prec1, self.start_epoch, self.training_losses, self.validation_losses, self.model, self.optimizer = \
            restore_checkpoint(self.resume, self.resume_file, self.device, self.model, self.optimizer)

        # load the performance indices for the datasets: all ids that have an instance in all directories
        performance_list = sorted(list(
            set([f[:-4] for f in os.listdir(pyin_directory) if "npy" in f]) &
            set([f[:-4] for f in os.listdir(back_chroma_directory) if "npy" in f])))

        # leave boundary_arr_id to default if using full dataset
        training_list, validation_list, test_list = split_into_training_validation_test(performance_list,
                boundary_arr_id="3769415_3769415")

        # keep only a subset for development purposes
        if self.small_dataset is True:
            validation_list = validation_list[:1] + training_list[:1]
            training_list = training_list[:1]
            test_list = training_list[:1]
            logger.info("training list {0} validation list {1}".format(training_list, validation_list))
        else:
            logger.info("training list {0} validation list {1}".format(len(training_list), len(validation_list)))
        # build custom datasets
        freeze = True if self.small_dataset is True else False  # freeze some note-wise shifts to check program accuracy
        logger.info("fixing some of the shifts across all songs: {0}".format(freeze))
        self.training_dataset = get_dataset(data_list=training_list, num_shifts=self.num_shifts,
                songs_per_batch=1, mode="training", device=self.device, freeze=freeze)
        self.validation_dataset = get_dataset(data_list=validation_list, num_shifts=self.num_shifts,
                songs_per_batch=1, mode="testing", device=self.device, freeze=freeze)
        self.test_dataset = get_dataset(data_list=test_list, num_shifts=self.num_shifts,
                songs_per_batch=1, mode="testing", device=self.device, freeze=freeze)
Esempio n. 5
0
    def __getitem__(self, idx):
        """
        :param keep: number of songs to keep, less or equal to the number of pitch shifts
        :return: a list of length num_performance_shifts.
        Each item is a tuple with the labels <seq len, shifts>, input <seq len, shifts, fdim>, key (str)
        """
        fpath = os.path.join(autotune_preprocessed_directory, self.performance_list[idx] + ".pkl")
        if not os.path.exists(fpath):
            try:
                pyin = np.load(os.path.join(pyin_directory, self.performance_list[idx] + ".npy"))
                # load stft of vocals, keep complex values in order to use istft later for pitch shifting
                stft_v = dataset_analysis.get_stft(
                    os.path.join(vocals_directory, self.performance_list[idx] + ".wav")).T
                # load cqt of backing track
                cqt_b = np.abs(dataset_analysis.get_cqt(os.path.join(backing_tracks_directory,
                               self.arr_keys[self.performance_list[idx]] + ".wav"))).T
                # truncate pitch features to same length
                frames = min(cqt_b.shape[0], stft_v.shape[0], len(pyin))
                pyin = pyin[:frames]
                stft_v = stft_v[:frames, :]
                cqt_b = cqt_b[:frames, :]
                original_boundaries = np.arange(frames).astype(np.int64)  # store the original indices of the notes here
                # find locations of note onsets using pYIN
                min_note_frames = 24  # half second
                audio_beginnings = np.array([i for i in range(frames - min_note_frames)  # first nonzero frames
                                             if i == 0 and pyin[i] > 0 or i > 0 and pyin[i] > 0 and pyin[i - 1] == 0])
                if self.plot is True:
                    utils.reset_directory("./plots")
                    bplt.output_file(os.path.join("./plots", "note_parse_" + self.performance_list[idx]) + ".html")
                    s1 = bplt.figure(title="note parse")
                    s1.line(np.arange(len(pyin)), pyin)
                    for i, ab in enumerate(audio_beginnings):
                        loc = Span(location=ab, dimension='height', line_color='green')
                        s1.add_layout(loc)
                # discard silent frames
                silent_frames = np.ones(frames)
                silent_frames[np.where(pyin < 1)[0]] *= 0
                pyin = pyin[silent_frames.astype(bool)]
                stft_v = stft_v[silent_frames.astype(bool), :]
                cqt_b = cqt_b[silent_frames.astype(bool), :]
                original_boundaries = original_boundaries[silent_frames.astype(bool)]
                audio_beginnings = [n - np.sum(silent_frames[:n] == 0) for _, n in enumerate(audio_beginnings)]
                frames = len(pyin)
                audio_endings = np.hstack((audio_beginnings[1:], frames - 1))
                # merge notes that are too short
                note_beginnings = []
                note_endings = []
                min_note_frames = 24
                start_note = end_note = 0
                while start_note < len(audio_beginnings):
                    note_beginnings.append(audio_beginnings[start_note])
                    while (audio_endings[end_note] - audio_beginnings[start_note] < min_note_frames and
                           end_note < len(audio_endings) - 1):
                        end_note += 1
                    note_endings.append(audio_endings[end_note])
                    start_note = end_note + 1
                    end_note = start_note
                # check that the last note is long enough
                while note_endings[-1] - note_beginnings[-1] < min_note_frames:
                    del note_beginnings[-1]
                    del note_endings[-2]
                notes = np.array([note_beginnings, note_endings]).T
                # one minor issue
                if notes[-1, 1] > frames - 1:
                    notes[-1, 1] = frames - 1
                if self.plot is True:
                    s2 = bplt.figure(title="note parse of active frames")
                    s2.line(np.arange(len(pyin)), pyin)
                    for i, ab in enumerate(note_beginnings):
                        loc = Span(location=ab, dimension='height', line_color='green')
                        s2.add_layout(loc)
                    for i, ab in enumerate(note_endings):
                        loc = Span(location=ab+1, dimension='height', line_color='red', line_dash='dotted')
                        s2.add_layout(loc)
                    bplt.save(bplt.gridplot([[s1, s2]], toolbar_location=None))
                # store the original indices of the notes
                original_boundaries = np.array([original_boundaries[notes[:, 0]], original_boundaries[notes[:, 1]]]).T
                # compute shifts for every note in every version in the batch (num_shifts)
                note_shifts = np.random.rand(self.num_shifts, notes.shape[0]) * 2 - 1  # all shift combinations
                if self.freeze is True:
                    note_shifts[:3, :] = self.frozen_shifts[:3, :note_shifts.shape[1]]
                # compute the framewise shifts
                frame_shifts = np.zeros((self.num_shifts, frames))  # this will be truncated later
                for i in range(self.num_shifts):
                    for j in range(len(notes)):
                        # only shift the non-silent frames between the note onset and note offset
                        frame_shifts[i, notes[j][0]:notes[j][1]] = note_shifts[i][j]
                # de-tune the pYIN pitch tracks and STFT of vocals
                shifted_pyin = np.vstack([pyin] * self.num_shifts) * np.power(2, max_semitone * frame_shifts / 12)
                # de-tune the vocals stft and vocals cqt
                stacked_cqt_v = np.zeros((frames, self.num_shifts, cqt_params['total_bins']))
                for i, note in enumerate(notes):
                    note_stft = np.array(stft_v[note[0]:note[1], :]).T
                    note_rt = librosa.istft(note_stft, hop_length=hopSize, center=False)
                    for j in range(self.num_shifts):
                        shifted_note_rt = librosa.effects.pitch_shift(note_rt, sr=global_fs, n_steps=note_shifts[j, i])
                        stacked_cqt_v[note[0]:note[1], j, :] = np.abs(librosa.core.cqt(
                                shifted_note_rt, sr=global_fs, hop_length=hopSize, n_bins=cqt_params['total_bins'],
                                bins_per_octave=cqt_params['bins_per_8va'], fmin=cqt_params['fmin']))[:, 4:-4].T
                # get the data into the proper format and shape for tensors
                cqt_b_binary = np.copy(cqt_b)  # copy single-channel CQT for binarization
                # need to repeat the backing track for the batch
                cqt_b = np.stack([cqt_b] * self.num_shifts, axis=1)
                # third channel
                stacked_cqt_v_binary = np.copy(stacked_cqt_v)
                for i in range(self.num_shifts):
                    thresh = threshold_mean(stacked_cqt_v_binary[:, i, :])
                    stacked_cqt_v_binary[:, i, :] = (stacked_cqt_v_binary[:, i, :] > thresh).astype(np.float)
                thresh = threshold_mean(cqt_b_binary)
                cqt_b_binary = (cqt_b_binary > thresh).astype(np.float)
                stacked_cqt_b_binary = np.stack([cqt_b_binary] * self.num_shifts, axis=1)
                stacked_cqt_combined = np.abs(stacked_cqt_v_binary - stacked_cqt_b_binary)

                data_dict = dict()
                data_dict['notes'] = notes
                data_dict['spect_v'] = stacked_cqt_v
                data_dict['spect_b'] = cqt_b
                data_dict['spect_c'] = stacked_cqt_combined
                data_dict['shifted_pyin'] = shifted_pyin
                data_dict['shifts_gt'] = note_shifts
                data_dict['original_boundaries'] = original_boundaries
                data_dict['perf_id'] = self.performance_list[idx]

                with open(fpath, "wb") as f:
                    pickle.dump(data_dict, f)  # save for future epochs
            except Exception as e:
                logger.info("exception in dataset {0} skipping song {1}".format(e, self.performance_list[idx]))
                return None
        else:
            # pre-processing has already been computed: load from file
            try:
                data_dict = self.loaditem(fpath)
            except Exception as e:
                logger.info("exception in dataset {0} skipping song {1}".format(e, self.performance_list[idx]))
                return None
        try:
            # now format the numpy arrays into torch tensors with note-wise splits
            data_dict['spect_v'] = torch.Tensor(data_dict['spect_v'])
            data_dict['spect_b'] = torch.Tensor(data_dict['spect_b'])
            data_dict['spect_c'] = torch.Tensor(data_dict['spect_c'])
            data_dict['shifted_pyin'] = torch.Tensor(data_dict['shifted_pyin'].T)
            data_dict['shifts_gt'] = torch.Tensor(data_dict['shifts_gt'].T)
            # adjust dimension of note shifts
            data_dict['shifts_gt'].unsqueeze_(1)

            # split full songs into sequences
            split_sizes = tuple(np.append(
                    np.diff(data_dict['notes'][:, 0]), data_dict['notes'][-1, 1] - data_dict['notes'][-1, 0] + 1))
            data_dict['spect_v'] = torch.split(data_dict['spect_v'], split_size_or_sections=split_sizes, dim=0)
            data_dict['spect_b'] = torch.split(data_dict['spect_b'], split_size_or_sections=split_sizes, dim=0)
            data_dict['spect_c'] = torch.split(data_dict['spect_c'], split_size_or_sections=split_sizes, dim=0)
            data_dict['shifted_pyin'] = torch.split(data_dict['shifted_pyin'], split_size_or_sections=split_sizes,
                                                    dim=0)
        except Exception as e:
            logger.info("exception in dataset {0} skipping song {1}".format(e, self.performance_list[idx]))
            return None
        return data_dict
    def __getitem__(self, idx):
        """
        :param keep: number of songs to keep, less or equal to the number of pitch shifts
        :return: a list of length num_performance_shifts.
        Each item is a tuple with the labels <seq len, shifts>, input <seq len, shifts, fdim>, key (str)
        """
        fpath = os.path.join(autotune_preprocessed_directory, self.performance_list[idx] + ".pkl")
        if not os.path.exists(fpath):
            try:
                # pass if acctID is 33128648
                pyin = np.load(os.path.join(pyin_directory, self.performance_list[idx] + ".npy"))
                # load stft of vocals, keep complex values in order to use istft later for pitch shifting
                stft_v = dataset_analysis.get_stft(
                    os.path.join(vocals_directory, self.performance_list[idx] + ".wav")).T
                # load cqt of backing track
                cqt_b = np.abs(dataset_analysis.get_cqt(os.path.join(backing_tracks_directory,
                                                                     self.arr_keys[
                                                                         self.performance_list[idx]] + ".wav"))).T
                # truncate pitch features to same length
                frames = min(cqt_b.shape[0], stft_v.shape[0], len(pyin))
                pyin = pyin[:frames]
                stft_v = stft_v[:frames, :]
                cqt_b = cqt_b[:frames, :]
                original_boundaries = np.arange(frames).astype(np.int64)  # store the original indices of the notes here
                # find locations of note onsets using pYIN
                min_note_frames = 24  # half second
                audio_beginnings = np.array([i for i in range(frames - min_note_frames)  # first nonzero frames
                                             if i == 0 and pyin[i] > 0 or i > 0 and pyin[i] > 0 and pyin[i - 1] == 0])
                if self.plot is True:
                    utils.reset_directory("./plots")
                    bplt.output_file(os.path.join("./plots", "note_parse_" + self.performance_list[idx]) + ".html")
                    s1 = bplt.figure(title="note parse")
                    s1.line(np.arange(len(pyin)), pyin)
                    for i, ab in enumerate(audio_beginnings):
                        loc = Span(location=ab, dimension='height', line_color='green')
                        s1.add_layout(loc)
                # discard silent frames
                silent_frames = np.ones(frames)
                silent_frames[np.where(pyin < 1)[0]] *= 0
                pyin = pyin[silent_frames.astype(bool)]
                stft_v = stft_v[silent_frames.astype(bool), :]
                cqt_b = cqt_b[silent_frames.astype(bool), :]
                original_boundaries = original_boundaries[silent_frames.astype(bool)]
                audio_beginnings = [n - np.sum(silent_frames[:n] == 0) for _, n in enumerate(audio_beginnings)]
                frames = len(pyin)
                audio_endings = np.hstack((audio_beginnings[1:], frames - 1))
                # merge notes that are too short
                note_beginnings = []
                note_endings = []
                min_note_frames = 24
                start_note = end_note = 0
                while start_note < len(audio_beginnings):
                    note_beginnings.append(audio_beginnings[start_note])
                    while (audio_endings[end_note] - audio_beginnings[start_note] < min_note_frames and
                           end_note < len(audio_endings) - 1):
                        end_note += 1
                    note_endings.append(audio_endings[end_note])
                    start_note = end_note + 1
                    end_note = start_note
                # check that the last note is long enough
                while note_endings[-1] - note_beginnings[-1] < min_note_frames:
                    del note_beginnings[-1]
                    del note_endings[-2]
                notes = np.array([note_beginnings, note_endings]).T
                # one minor issue
                if notes[-1, 1] > frames - 1:
                    notes[-1, 1] = frames - 1
                if self.plot is True:
                    s2 = bplt.figure(title="note parse of active frames")
                    s2.line(np.arange(len(pyin)), pyin)
                    for i, ab in enumerate(note_beginnings):
                        loc = Span(location=ab, dimension='height', line_color='green')
                        s2.add_layout(loc)
                    for i, ab in enumerate(note_endings):
                        loc = Span(location=ab + 1, dimension='height', line_color='red', line_dash='dotted')
                        s2.add_layout(loc)
                    bplt.save(bplt.gridplot([[s1, s2]], toolbar_location=None))
                # store the original indices of the notes
                original_boundaries = np.array([original_boundaries[notes[:, 0]], original_boundaries[notes[:, 1]]]).T
                # compute shifts for every note in every version in the batch (num_shifts)
                note_shifts = np.random.rand(self.num_shifts, notes.shape[0]) * 2 - 1  # all shift combinations
                if self.freeze is True:
                    note_shifts[:3, :] = self.frozen_shifts[:3, :note_shifts.shape[1]]
                # compute the framewise shifts
                frame_shifts = np.zeros((self.num_shifts, frames))  # this will be truncated later
                for i in range(self.num_shifts):
                    for j in range(len(notes)):
                        # only shift the non-silent frames between the note onset and note offset
                        frame_shifts[i, notes[j][0]:notes[j][1]] = note_shifts[i][j]
                # de-tune the pYIN pitch tracks and STFT of vocals
                shifted_pyin = np.vstack([pyin] * self.num_shifts) * np.power(2, max_semitone * frame_shifts / 12)
                # de-tune the vocals stft and vocals cqt
                stacked_cqt_v = np.zeros((frames, self.num_shifts, cqt_params['total_bins']))
                for i, note in enumerate(notes):
                    note_stft = np.array(stft_v[note[0]:note[1], :]).T
                    note_rt = librosa.istft(note_stft, hop_length=hopSize, center=False)
                    for j in range(self.num_shifts):
                        shifted_note_rt = librosa.effects.pitch_shift(note_rt, sr=global_fs, n_steps=note_shifts[j, i])
                        stacked_cqt_v[note[0]:note[1], j, :] = np.abs(librosa.core.cqt(
                            shifted_note_rt, sr=global_fs, hop_length=hopSize, n_bins=cqt_params['total_bins'],
                            bins_per_octave=cqt_params['bins_per_8va'], fmin=cqt_params['fmin']))[:, 4:-4].T
                # get the data into the proper format and shape for tensors
                cqt_b_binary = np.copy(cqt_b)  # copy single-channel CQT for binarization
                # need to repeat the backing track for the batch
                cqt_b = np.stack([cqt_b] * self.num_shifts, axis=1)
                # third channel
                stacked_cqt_v_binary = np.copy(stacked_cqt_v)
                for i in range(self.num_shifts):
                    thresh = threshold_mean(stacked_cqt_v_binary[:, i, :])
                    stacked_cqt_v_binary[:, i, :] = (stacked_cqt_v_binary[:, i, :] > thresh).astype(np.float)
                thresh = threshold_mean(cqt_b_binary)
                cqt_b_binary = (cqt_b_binary > thresh).astype(np.float)
                stacked_cqt_b_binary = np.stack([cqt_b_binary] * self.num_shifts, axis=1)
                stacked_cqt_combined = np.abs(stacked_cqt_v_binary - stacked_cqt_b_binary)

                start_f = 0
                end_f = 300
                matplotlib.rcParams.update(matplotlib.rcParamsDefault)
                f, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, sharey=True)
                ax1.imshow(np.log(cqt_b[start_f:end_f, 0, :].T + 1e-10), aspect='auto', origin='lower')
                ax1.set_ylabel("CQT bins")
                ax1.set_title("Backing track CQT")
                ax1.set_xlabel("frames")
                ax4.plot(pyin[start_f:end_f] * 500 / np.max(pyin[start_f:end_f]))
                for _, (beg, end) in enumerate(notes):
                    if beg >= start_f and beg <= end_f:
                        ax4.axvline(x=beg, color="green")
                    if end >= start_f and end <= end_f:
                        ax4.axvline(x=end, ls='dotted', color="red")
                ax4.set_xlabel("frames")
                ax4.set_title("Note parse")
                ax2.imshow(np.log(stacked_cqt_v[start_f:end_f, 0, :].T + 1e-10), aspect='auto', origin='lower')
                ax2.set_title("Vocals CQT (in tune)")
                ax2.set_xlabel("frames")
                ax5.imshow(stacked_cqt_combined[start_f:end_f, 0, :].T, aspect='auto', origin='lower')
                ax5.set_title("Difference (in tune)")
                ax5.set_xlabel("frames")
                ax3.imshow(np.log(stacked_cqt_v[start_f:end_f, 2, :].T + 1e-10), aspect='auto', origin='lower')
                ax3.set_title("Vocals CQT (de-tuned)")
                ax3.set_xlabel("frames")
                ax6.imshow(stacked_cqt_combined[start_f:end_f, 2, :].T, aspect='auto', origin='lower')
                ax6.set_title("Difference (de-tuned)")
                ax6.set_xlabel("frames")
                plt.tight_layout()
                plt.savefig("/Users/scwager/Documents/autotune_fa18_data/plots/cqt_comparison_3_" +
                            self.performance_list[idx] + ".eps", format="eps")
                plt.show()

                matplotlib.rcParams.update({'font.size': 20})
                plt.imshow(np.log(cqt_b[start_f:end_f, 0, :].T + 1e-10), aspect=0.6, origin='lower')
                plt.ylabel("CQT bins")
                plt.xlabel("frames")
                plt.savefig("/Users/scwager/Documents/autotune_fa18_data/plots/cqt_comparison_1.eps", format="eps",
                            bbox_inches='tight')
                plt.clf()
                plt.gca()
                plt.cla()

                plt.imshow(np.log(stacked_cqt_b_binary[start_f:end_f, 0, :].T + 1e-10), aspect=0.6, origin='lower')
                plt.xlabel("frames")
                frame1 = plt.gca()
                frame1.axes.yaxis.set_ticklabels([])
                plt.savefig("/Users/scwager/Documents/autotune_fa18_data/plots/cqt_comparison_2.eps", format="eps",
                            bbox_inches='tight')
                plt.clf()
                plt.gca()
                plt.cla()

                plt.imshow(np.log(stacked_cqt_v[start_f:end_f, 0, :].T + 1e-10), aspect=0.6, origin='lower')
                plt.xlabel("frames")
                frame1 = plt.gca()
                frame1.axes.yaxis.set_ticklabels([])
                plt.savefig("/Users/scwager/Documents/autotune_fa18_data/plots/cqt_comparison_3.eps", format="eps",
                            bbox_inches='tight')
                plt.clf()
                plt.gca()
                plt.cla()

                plt.imshow(stacked_cqt_combined[start_f:end_f, 0, :].T, aspect=0.6, origin='lower')
                plt.xlabel("frames")
                frame1 = plt.gca()
                frame1.axes.yaxis.set_ticklabels([])
                plt.savefig("/Users/scwager/Documents/autotune_fa18_data/plots/cqt_comparison_4.eps", format="eps",
                            bbox_inches='tight')
                plt.clf()
                plt.gca()
                plt.cla()

                plt.imshow(np.log(stacked_cqt_v[start_f:end_f, 2, :].T + 1e-10), aspect=0.6, origin='lower')
                plt.xlabel("frames")
                frame1 = plt.gca()
                frame1.axes.yaxis.set_ticklabels([])
                plt.savefig("/Users/scwager/Documents/autotune_fa18_data/plots/cqt_comparison_5.eps", format="eps",
                            bbox_inches='tight')
                plt.clf()
                plt.gca()
                plt.cla()

                plt.imshow(stacked_cqt_combined[start_f:end_f, 2, :].T, aspect=0.6, origin='lower')
                plt.xlabel("frames")
                frame1 = plt.gca()
                frame1.axes.yaxis.set_ticklabels([])
                plt.savefig("/Users/scwager/Documents/autotune_fa18_data/plots/cqt_comparison_6.eps", format="eps",
                            bbox_inches='tight')
                plt.show()
                plt.clf()
                plt.gca()
                plt.cla()
                # ---------------------------------

                data_dict = dict()
                data_dict['notes'] = notes
                data_dict['spect_v'] = stacked_cqt_v
                data_dict['spect_b'] = cqt_b
                data_dict['spect_c'] = stacked_cqt_combined
                data_dict['shifted_pyin'] = shifted_pyin
                data_dict['shifts_gt'] = note_shifts
                data_dict['original_boundaries'] = original_boundaries
                data_dict['perf_id'] = self.performance_list[idx]

                with open(fpath, "wb") as f:
                    pickle.dump(data_dict, f)  # save for future epochs
            except Exception as e:
                logger.info("exception in dataset {0} skipping song {1}".format(e, self.performance_list[idx]))
                return None