def train_step(
    input,
    ground_truth,
    model,
    train_fout=None,
    val_fout=None
):  # training input and ground truth should already be synchronized and encoded
    # fout is there for if we want to record the loss data
    # get a validation input
    validation_input_name = np.random.choice(validation_files)
    # parse the validation set
    validation_input = loader.parse_input(validation_input_name, input_path)
    validation_label = sync.trim_front(
        loader.parse_label(validation_input_name, ground_truth_data_path))
    val_input, val_label = sync.sync_data(validation_input, validation_label,
                                          len(validation_label))
    val_label = np.array(
        loader.encode_multihot(val_label))  # encode to multi-hot

    val_input = np.array(val_input)
    val_input = np.reshape(
        val_input, (val_input.shape[0], 1, val_input.shape[1])
    )  # reshape to a tensor which the neural net can use (batch_size, 1, samples_per_batch)

    with tf.GradientTape(
    ) as tape:  # Start calculating the gradient and applying it
        # generate the predictions
        # it is garenteed the ground truth and prediction will have the same shape
        training_prediction = model(input)
        validation_prediction = model(val_input)

        training_losses = [
            x(ground_truth, training_prediction) for x in loss_list
        ]  # store all training losses
        validation_losses = [
            x(val_label, validation_prediction) for x in loss_list
        ]  # store all validation losses
        applicable_loss = training_losses[default_loss_index]
        visible_loss = validation_losses[default_loss_index]

        # store loss
        if (train_fout != None):
            write_loss(training_losses, loss_label, train_fout)
        if (val_fout != None):
            write_loss(validation_losses, loss_label, val_fout)

        # calculate and apply gradient
        grad = tape.gradient(applicable_loss, model.trainable_variables)
        # THIS SHIT DOESNT WORK AND IDK WHY
        default_opt.apply_gradients(zip(grad, model.trainable_variables))

        overall_train_loss = np.mean(applicable_loss)
        overall_val_loss = np.mean(visible_loss)
        # CLI debug messages
        # print(colored('>>> Overall Training Loss: ', 'green') + colored(str(overall_train_loss), 'green', attrs=['bold', 'reverse']))
        # print(colored('>>> Overall Validation Loss: ', 'green') + colored(str(overall_val_loss), 'green', attrs=['bold', 'reverse']))

        return overall_train_loss, overall_val_loss
def train_step(input, label, model, train_fout=None, val_fout=None):
    # {X}_fout is there for if we want to record the loss data

    # get a validation input
    validation_input_name = np.random.choice(validation_files)
    # parse the validation set
    validation_input, val_sr = loader.parse_input(
        validation_input_name, input_path)  # get input data and sample rate
    validation_label, val_bpm = loader.parse_label(
        validation_input_name, label_data_path)  # get label data and bpm

    # generate mel spectrogram
    val_ml = loader.get_mel_spec(validation_input, mel_res, val_sr,
                                 window_size, hop_len)

    # trim the label data
    validation_label = sync.trim_front(validation_label)

    val_input, val_label = sync.sync_data(val_ml, validation_label, val_bpm,
                                          hop_len)
    val_label = np.array(
        loader.encode_multihot(val_label))  # encode to multi-hot

    val_input = np.reshape(
        val_input,
        (val_input.shape[0], val_input.shape[1], val_input.shape[2], 1))
    # reshape to a tensor which the neural net can use (mini_batch_size, window_size, mel_resolution, channel)

    with tf.GradientTape(
    ) as tape:  # Start calculating the gradient and applying it
        # generate the predictions
        # it is garenteed the ground truth and prediction will have the same shape
        training_prediction = [
        ]  # crate temporary array of predictions so we can concat them later for mini-batch processing
        for sample in input:
            temp_pred = model(sample)  # get a training pred
            training_prediction.append(temp_pred)
        # concatinate the perdictions
        training_prediction = tf.concat(training_prediction, 0)
        validation_prediction = model(val_input)  # get a validation pred

        training_losses = [x(label, training_prediction)
                           for x in loss_list]  # store all training losses
        validation_losses = [
            x(val_label, validation_prediction) for x in loss_list
        ]  # store all validation losses
        applicable_loss = training_losses[
            default_loss_index]  # idk why i named it this but this is the training loss
        visible_loss = validation_losses[default_loss_index]  # validation loss

        # store loss
        if (train_fout != None):
            write_loss(training_losses, loss_label, train_fout)
        if (val_fout != None):
            write_loss(validation_losses, loss_label, val_fout)

        # calculate and apply gradient
        grad = tape.gradient(applicable_loss, model.trainable_variables)
        # THIS SHIT DOESNT WORK AND IDK WHY
        default_opt.apply_gradients(zip(grad, model.trainable_variables))

        overall_train_loss = np.mean(applicable_loss)
        overall_val_loss = np.mean(visible_loss)
        # CLI debug messages
        # print(colored('>>> Overall Training Loss: ', 'green') + colored(str(overall_train_loss), 'green', attrs=['bold', 'reverse']))
        # print(colored('>>> Overall Validation Loss: ', 'green') + colored(str(overall_val_loss), 'green', attrs=['bold', 'reverse']))

        return overall_train_loss, overall_val_loss
        # ===== READING IN TRAINING FILES =====
        for file in (
                training_files[j:j +
                               n_batch]):  # loop through current mini-batch
            # set loop header to reading file state
            header = colored(
                'Reading and processing file: [{file}]...'.format(**locals()),
                'grey', 'on_yellow') + '          | ' + colored(
                    'Last Trn Loss: ', 'green') + colored(
                        str(train_loss), 'green', attrs=['bold', 'reverse']
                    ) + '; ' + colored('Last Val Loss: ', 'green') + colored(
                        str(val_loss), 'green', attrs=['bold', 'reverse'])
            file_range.set_description(header)
            file_range.refresh()

            unpaired_input, sr = loader.parse_input(file,
                                                    input_path)  # parse input
            if (unpaired_input.size == 0):
                print(colored('skipped {file}'.format(**locals()), 'red'))
                continue
            # create mel spectrogram
            unpaired_input_ml = loader.get_mel_spec(unpaired_input, mel_res,
                                                    sr, window_size, hop_len)
            # get label for input
            unpaired_label, bpm = loader.parse_label(file, label_data_path)
            unpaired_label = sync.trim_front(
                unpaired_label)  # trimming the MIDI and syncying the data

            input, label = sync.sync_data(unpaired_input_ml, unpaired_label,
                                          bpm, hop_len)  # pair IO
            # 7480_4 cocks it up
            label = np.array(loader.encode_multihot(label))  # encode label
    os.makedirs(output_dir_name)
    print(
        colored(
            "Successfully opened output directory. {output_dir_name} was created"
            .format(**locals()), 'green'))
except FileExistsError:
    print(
        colored(
            "Write path {output_dir_name} already exists. Opening path.".
            format(**locals()), 'yellow'))
print('Attempting to convert {parse_size} rWav files'.format(**locals()))

files = os.listdir(input_dir_name)
total_size = 0
for file in tqdm(files):
    prefix = file.split('.')[0]
    write_path = os.path.join(output_dir_name,
                              '{prefix}.rwav'.format(**locals()))
    fout = open(write_path, 'w')  # open the writing path
    data = loader.parse_input(file, input_dir_name, norm=False)
    fout.write('/'.join([str(i) for i in data.tolist()]))
    fout.close()
    total_size += os.path.getsize(os.path.join(write_path))

total_size_mb = int(total_size / 10000) / 100
files_length = len(files)
print(
    colored(
        'Program Finished. Synthesized {files_length} rWavs, totalling {total_size_mb}MB.'
        .format(**locals()), 'green'))
import matplotlib.pyplot as plt
from functools import reduce
from time import sleep
import time
import librosa

from tensorflow.python.client import device_lib

# load in testing data
absolute_path = os.path.join('/home/lemonorange/catRemixV2')
data_root_path = os.path.join(absolute_path, 'data')
input_path = os.path.join(data_root_path, 'wav')
label_path = os.path.join(data_root_path, 'rawMid')
ground_truth_data_path = os.path.join(data_root_path, 'rawMid')

data, sr = loader.parse_input('7480_6.wav', input_path,
                              norm=False)  # NORMALIZATION MUST BE DISABLED
label, bpm = loader.parse_label('7480_6.wav', label_path)

data2, sr2 = loader.parse_input('1_0.wav', input_path,
                                norm=False)  # NORMALIZATION MUST BE DISABLED
label2, bpm2 = loader.parse_label('1_0.wav', label_path)

ML = loader.get_mel_spec(data, 512, sr, 4096, 2048)
ML2 = loader.get_mel_spec(data2, 128, sr2, 2048, 512)

ML.shape

plt.plot([sum(i) for i in ML])
plt.plot(np.concatenate([[sum(i)] * 28 for i in sync.trim_front(label)]),
         color='green')  # this is the midi sound wave (sorta)
plt.imshow(ML2)
window_size = 10

chunk_length_seconds = (hop_length * window_size)/sample_rate
sample_per_chunk = hop_length * window_size

absolute_path = os.path.join('/home/lemonorange/catRemixV2')
data_root_path = os.path.join(absolute_path, 'data')
input_path = os.path.join(data_root_path, 'wav')
label_data_path = os.path.join(data_root_path, 'rawMid')
storage_path = os.path.join(absolute_path, 'network')

that_file = "dancebg.wav"

that_file

data, sr = loader.parse_input(that_file, input_path, norm=False)

Audio(data, rate=sample_rate) # preview the audio

# read_path = os.path.join(data_root_path, 'lol-phoenix')
read_path = os.path.join(data_root_path, 'dancemonkey')

files = sorted(os.listdir(read_path), key = lambda x : int(x.split('_')[0]))
max_length = chunk_length_seconds * (int(files[-1].split('_')[0])+10)
max_samples = int(max_length * sample_rate)
max_length = max_samples / sample_rate
max_note_index = int(files[-1].split('_')[0])

# start assembly
overall_audio = np.zeros(max_samples).tolist()