def train_step(
    input,
    ground_truth,
    model,
    train_fout=None,
    val_fout=None
):  # training input and ground truth should already be synchronized and encoded
    # fout is there for if we want to record the loss data
    # get a validation input
    validation_input_name = np.random.choice(validation_files)
    # parse the validation set
    validation_input = loader.parse_input(validation_input_name, input_path)
    validation_label = sync.trim_front(
        loader.parse_label(validation_input_name, ground_truth_data_path))
    val_input, val_label = sync.sync_data(validation_input, validation_label,
                                          len(validation_label))
    val_label = np.array(
        loader.encode_multihot(val_label))  # encode to multi-hot

    val_input = np.array(val_input)
    val_input = np.reshape(
        val_input, (val_input.shape[0], 1, val_input.shape[1])
    )  # reshape to a tensor which the neural net can use (batch_size, 1, samples_per_batch)

    with tf.GradientTape(
    ) as tape:  # Start calculating the gradient and applying it
        # generate the predictions
        # it is garenteed the ground truth and prediction will have the same shape
        training_prediction = model(input)
        validation_prediction = model(val_input)

        training_losses = [
            x(ground_truth, training_prediction) for x in loss_list
        ]  # store all training losses
        validation_losses = [
            x(val_label, validation_prediction) for x in loss_list
        ]  # store all validation losses
        applicable_loss = training_losses[default_loss_index]
        visible_loss = validation_losses[default_loss_index]

        # store loss
        if (train_fout != None):
            write_loss(training_losses, loss_label, train_fout)
        if (val_fout != None):
            write_loss(validation_losses, loss_label, val_fout)

        # calculate and apply gradient
        grad = tape.gradient(applicable_loss, model.trainable_variables)
        # THIS SHIT DOESNT WORK AND IDK WHY
        default_opt.apply_gradients(zip(grad, model.trainable_variables))

        overall_train_loss = np.mean(applicable_loss)
        overall_val_loss = np.mean(visible_loss)
        # CLI debug messages
        # print(colored('>>> Overall Training Loss: ', 'green') + colored(str(overall_train_loss), 'green', attrs=['bold', 'reverse']))
        # print(colored('>>> Overall Validation Loss: ', 'green') + colored(str(overall_val_loss), 'green', attrs=['bold', 'reverse']))

        return overall_train_loss, overall_val_loss
def train_step(input, label, model, train_fout=None, val_fout=None):
    # {X}_fout is there for if we want to record the loss data

    # get a validation input
    validation_input_name = np.random.choice(validation_files)
    # parse the validation set
    validation_input, val_sr = loader.parse_input(
        validation_input_name, input_path)  # get input data and sample rate
    validation_label, val_bpm = loader.parse_label(
        validation_input_name, label_data_path)  # get label data and bpm

    # generate mel spectrogram
    val_ml = loader.get_mel_spec(validation_input, mel_res, val_sr,
                                 window_size, hop_len)

    # trim the label data
    validation_label = sync.trim_front(validation_label)

    val_input, val_label = sync.sync_data(val_ml, validation_label, val_bpm,
                                          hop_len)
    val_label = np.array(
        loader.encode_multihot(val_label))  # encode to multi-hot

    val_input = np.reshape(
        val_input,
        (val_input.shape[0], val_input.shape[1], val_input.shape[2], 1))
    # reshape to a tensor which the neural net can use (mini_batch_size, window_size, mel_resolution, channel)

    with tf.GradientTape(
    ) as tape:  # Start calculating the gradient and applying it
        # generate the predictions
        # it is garenteed the ground truth and prediction will have the same shape
        training_prediction = [
        ]  # crate temporary array of predictions so we can concat them later for mini-batch processing
        for sample in input:
            temp_pred = model(sample)  # get a training pred
            training_prediction.append(temp_pred)
        # concatinate the perdictions
        training_prediction = tf.concat(training_prediction, 0)
        validation_prediction = model(val_input)  # get a validation pred

        training_losses = [x(label, training_prediction)
                           for x in loss_list]  # store all training losses
        validation_losses = [
            x(val_label, validation_prediction) for x in loss_list
        ]  # store all validation losses
        applicable_loss = training_losses[
            default_loss_index]  # idk why i named it this but this is the training loss
        visible_loss = validation_losses[default_loss_index]  # validation loss

        # store loss
        if (train_fout != None):
            write_loss(training_losses, loss_label, train_fout)
        if (val_fout != None):
            write_loss(validation_losses, loss_label, val_fout)

        # calculate and apply gradient
        grad = tape.gradient(applicable_loss, model.trainable_variables)
        # THIS SHIT DOESNT WORK AND IDK WHY
        default_opt.apply_gradients(zip(grad, model.trainable_variables))

        overall_train_loss = np.mean(applicable_loss)
        overall_val_loss = np.mean(visible_loss)
        # CLI debug messages
        # print(colored('>>> Overall Training Loss: ', 'green') + colored(str(overall_train_loss), 'green', attrs=['bold', 'reverse']))
        # print(colored('>>> Overall Validation Loss: ', 'green') + colored(str(overall_val_loss), 'green', attrs=['bold', 'reverse']))

        return overall_train_loss, overall_val_loss
            file_range.refresh()

            unpaired_input, sr = loader.parse_input(file,
                                                    input_path)  # parse input
            if (unpaired_input.size == 0):
                print(colored('skipped {file}'.format(**locals()), 'red'))
                continue
            # create mel spectrogram
            unpaired_input_ml = loader.get_mel_spec(unpaired_input, mel_res,
                                                    sr, window_size, hop_len)
            # get label for input
            unpaired_label, bpm = loader.parse_label(file, label_data_path)
            unpaired_label = sync.trim_front(
                unpaired_label)  # trimming the MIDI and syncying the data

            input, label = sync.sync_data(unpaired_input_ml, unpaired_label,
                                          bpm, hop_len)  # pair IO
            # 7480_4 cocks it up
            label = np.array(loader.encode_multihot(label))  # encode label

            input = np.reshape(
                input, (input.shape[0], input.shape[1], input.shape[2],
                        1))  # reshape to a tensor which the neural net can use

            X.append(input)  # add to stash
            temp_out.append(label)  # add to stash

        y = np.concatenate(temp_out)

        # TODO: implement make shift early stopping system

        # actual training part
# if the midi_size covers 700 samples in sr, and the hop length of Mel is 512
# we will take 2x 512 slices, which adds up to 1024 samples in total matched to the note
# the math is as the follwing:
# length < ceil(midi_size / hop_length)
# start_position < floor(note_location / hop_length)

librosa.display.specshow(ML.transpose(),
                         sr=sr,
                         hop_length=512,
                         x_axis='time',
                         y_axis='mel')
plt.colorbar(format='%+2.0f dB')

label = sync.trim_front(label)

sync_in, sync_la = sync.sync_data(ML, label, bpm, 2048)
sync_in.shape

112 / 14

sync_in2, sync_la2 = sync.sync_data(ML2, label2, bpm2, 512)

np.concatenate(sync_in).shape
plt.plot([sum(i) for i in np.concatenate(sync_in)])
plt.plot(np.concatenate([[sum(i)] * 8 for i in sync_la]),
         color='green')  # this is the midi sound

sync_in2.shape
np.array(loader.encode_multihot(sync_la2)).shape

for i in sync_in:
        temp_input = []
        temp_out = []
        # ===== READING IN TRAINING FILES =====
        for file in (
                training_files[j:j +
                               n_batch]):  # loop through current mini-batch
            unpaired_input = loader.parse_input(file,
                                                input_path)  # parse input
            if (unpaired_input.size == 0):
                print(colored('skipped {file}'.format(**locals()), 'red'))
                continue
            unpaired_label = sync.trim_front(
                loader.parse_label(file, ground_truth_data_path)
            )  # trimming the MIDI and syncying the data
            input, label = sync.sync_data(
                unpaired_input, unpaired_label,
                len(unpaired_label))  # pair IO + trim
            label = np.array(loader.encode_multihot(label))  # encode label

            input = np.array(input)
            input = np.reshape(
                input, (input.shape[0], 1, input.shape[1]
                        ))  # reshape to a tensor which the neural net can use

            temp_input.append(input)  # add to stash
            temp_out.append(label)  # add to stash

        X = np.concatenate(temp_input)
        y = np.concatenate(temp_out)

        # TODO: implement make shift early stopping system