from matplotlib import pyplot as plt from os.path import join as pj from util import setup_logging from conv_model import ConvModel from env import current as env setup_logging(logging.getLogger()) data_source = [] for f in sorted(os.listdir(env.dataset())): if f.endswith(".wav"): data_source.append(env.dataset(f)) cm = ConvModel(batch_size=30000, filter_len=150, filters_num=100, target_sr=3000, gamma=1e-03, strides=8, avg_window=5, lrate=1e-04) sess = tf.Session() dataset = cm.form_dataset(data_source, proportion=0.1) cm.train(sess, dataset, 10000) cm.evaluate_and_save(sess, dataset) cm.serialize(sess)
model_fname = env.run("model.ckpt") batch_size = 30000 L = 150 filters_num = 100 target_sr = 3000 gamma = 1e-03 epochs = 2000 lrate = 1e-04 k = 8 # filter strides avg_size = 5 sel = None cm = ConvModel(batch_size, L, filters_num, k, avg_size, lrate, gamma) sess = tf.Session() saver = tf.train.Saver() if os.path.exists(model_fname): print "Restoring from {}".format(model_fname) saver.restore(sess, model_fname) epochs = 0 else: sess.run(tf.initialize_all_variables()) def read_song(source_id): song_data_raw, source_sr = lr.load(data_source[source_id]) song_data = lr.resample(song_data_raw, source_sr, target_sr, scale=True) song_data = song_data[:song_data.shape[0]/10]