Esempio n. 1
0
import matplotlib.pyplot as plt
import torch
import argparse
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
from reversible import ReversibleModel
from audio_midi_dataset import get_dataset_individually, Spec2MidiDataset, SqueezingDataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SequentialSampler
from train_loop import normal_noise_like
import os
import re
import mpl_rc
import utils
import pretty_midi as pm
RCPARAMS = mpl_rc.default()


def collect_samples(device, model, loader, n_samples):
    model.eval()
    samples_x_true = []
    samples_x_pred = []
    for si in range(n_samples):
        x_true = []
        x_pred = []
        for batch in loader:
            x = batch['x']
            y = batch['y']

            y = y + normal_noise_like(y,
                                      model.y_noise_scale)  # tiny exaggeration
import matplotlib.pyplot as plt
import torch
import argparse
import numpy as np
from plot_input_output import plot_input_output
from reversible import ReversibleModel
from audio_midi_dataset import get_dataset_individually, Spec2MidiDataset, SqueezingDataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SequentialSampler
from train_loop import normal_noise_like
import os
import mpl_rc
import utils
rcParams = mpl_rc.default()

START = 100
END = 150


def collect_input_output(device, model, loader, n_samples):
    model.eval()
    samples_x_true = []
    samples_x_invs = []
    samples_x_edit = []
    samples_x_zepa = []
    samples_x_samp = []

    samples_y_true = []
    samples_y_pred = []
    samples_y_edit = []