def do_experiment(model_config): tf.reset_default_graph() experiment_id = ex.current_run._id print('Experiment ID: {eid}'.format(eid=experiment_id)) # Prepare data print('Preparing dataset') train_data, val_data, test_data = dataset.prepare_datasets(model_config) print('Dataset ready') # Start session tf_config = tf.ConfigProto() #tf_config.gpu_options.allow_growth = True tf_config.gpu_options.visible_device_list = str(model_config['GPU']) sess = tf.Session(config=tf_config) #sess = tf.Session() #sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type="readline") print('Session started') # Create iterators handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle(handle, train_data.output_types, train_data.output_shapes) mixed_spec, voice_spec, background_spec, mixed_audio, voice_audio, background_audio = iterator.get_next( ) training_iterator = train_data.make_initializable_iterator() validation_iterator = val_data.make_initializable_iterator() testing_iterator = test_data.make_initializable_iterator() training_handle = sess.run(training_iterator.string_handle()) validation_handle = sess.run(validation_iterator.string_handle()) testing_handle = sess.run(testing_iterator.string_handle()) print('Iterators created') # Create variable placeholders and model is_training = tf.placeholder(shape=(), dtype=bool) mixed_phase = tf.expand_dims(mixed_spec[:, :, :-1, 3], 3) print('Creating model') if model_config['data_type'] == 'mag': mixed_input = tf.expand_dims(mixed_spec[:, :, :-1, 2], 3) voice_input = tf.expand_dims(voice_spec[:, :, :-1, 2], 3) elif model_config['data_type'] in ['mag_phase', 'mag_phase_diff']: mixed_input = mixed_spec[:, :, :-1, 2:4] voice_input = voice_spec[:, :, :-1, 2:4] elif model_config['data_type'] == 'real_imag': mixed_input = mixed_spec[:, :, :-1, 0:2] voice_input = voice_spec[:, :, :-1, 0:2] elif model_config['data_type'] in ['mag_real_imag', 'mag_phase2']: mixed_input = tf.concat([ tf.expand_dims(mixed_spec[:, :, :-1, 2], 3), mixed_spec[:, :, :-1, 0:2] ], 3) voice_input = tf.concat([ tf.expand_dims(voice_spec[:, :, :-1, 2], 3), voice_spec[:, :, :-1, 0:2] ], 3) elif model_config['data_type'] == 'mag_phase_real_imag': mixed_input = mixed_spec[:, :, :-1, :] voice_input = voice_spec[:, :, :-1, :] model = audio_models.MagnitudeModel(mixed_input, voice_input, mixed_phase, mixed_audio, voice_audio, background_audio, model_config['model_variant'], is_training, model_config['learning_rate'], model_config['data_type'], model_config['phase_weight'], name='Magnitude_Model') sess.run(tf.global_variables_initializer()) if model_config['loading']: print('Loading checkpoint') checkpoint = os.path.join(model_config['model_base_dir'], model_config['checkpoint_to_load']) restorer = tf.train.Saver() restorer.restore(sess, checkpoint) # Summaries model_folder = str(experiment_id) writer = tf.summary.FileWriter(os.path.join(model_config["log_dir"], model_folder), graph=sess.graph) # Get baseline metrics at initialisation test_count = 0 if model_config['initialisation_test']: print('Running initialisation test') initial_test_loss, test_count = test(sess, model, model_config, handle, testing_iterator, testing_handle, test_count, experiment_id) # Train the model model = train(sess, model, model_config, model_folder, handle, training_iterator, training_handle, validation_iterator, validation_handle, writer) # Test trained model mean_test_loss, test_count = test(sess, model, model_config, handle, testing_iterator, testing_handle, test_count, experiment_id) print('{ts}:\n\tAll done with experiment {exid}!'.format( ts=datetime.datetime.now(), exid=experiment_id)) if model_config['initialisation_test']: print('\tInitial test loss: {init}'.format(init=initial_test_loss)) print('\tFinal test loss: {final}'.format(final=mean_test_loss))
def do_experiment(model_config): tf.reset_default_graph() experiment_id = ex.current_run._id print('Experiment ID: {eid}'.format(eid=experiment_id)) # Prepare data print('Preparing dataset') train_data, val_data, test_data = dataset.prepare_datasets(model_config) print('Dataset ready') # Start session #tf_config = tf.ConfigProto() #tf_config.gpu_options.allow_growth = True #tf_config.gpu_options.visible_device_list = "0" #sess = tf.Session(config=tf_config) sess = tf.Session() print('Session started') # Create iterators handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle(handle, train_data.output_types, train_data.output_shapes) mixed_spec, voice_spec, mixed_audio, voice_audio = iterator.get_next() training_iterator = train_data.make_initializable_iterator() validation_iterator = val_data.make_initializable_iterator() testing_iterator = test_data.make_initializable_iterator() training_handle = sess.run(training_iterator.string_handle()) validation_handle = sess.run(validation_iterator.string_handle()) testing_handle = sess.run(testing_iterator.string_handle()) print('Iterators created') # Create variable placeholders and model is_training = tf.placeholder(shape=(), dtype=bool) if model_config['mag_phase']: mixed_mag = tf.expand_dims(mixed_spec[:, :, :-1, 0], 3) mixed_phase = tf.expand_dims(mixed_spec[:, :, :-1, 1], 3) voice_mag = tf.expand_dims(voice_spec[:, :, :-1, 0], 3) print('Creating model') model = audio_models.MagnitudeModel(mixed_mag, voice_mag, mixed_phase, mixed_audio, voice_audio, model_config['model_variant'], is_training, name='U_Net_Model') else: mixed_spec_trim = mixed_spec[:, :, :-1, :] voice_spec_trim = voice_spec[:, :, :-1, :] print('Creating model') model = audio_models.ComplexNumberModel(mixed_spec_trim, voice_spec_trim, mixed_audio, voice_audio, model_config['model_variant'], is_training) sess.run(tf.global_variables_initializer()) if model_config['loading']: # TODO - Think this works now but needs proper testing print('Loading checkpoint') checkpoint = os.path.join(model_config['model_base_dir'], model_config['checkpoint_to_load']) restorer = tf.train.Saver() restorer.restore(sess, checkpoint) # Summaries model_folder = str(experiment_id) writer = tf.summary.FileWriter(os.path.join(model_config["log_dir"], model_folder), graph=sess.graph) # Get baseline metrics at initialisation test_count = 0 if model_config['initialisation_test']: print('Running initialisation test') initial_test_loss, test_count = test(sess, model, model_config, handle, testing_iterator, testing_handle, writer, test_count, experiment_id) # Train the model model = train(sess, model, model_config, model_folder, handle, training_iterator, training_handle, validation_iterator, validation_handle, writer) # Test trained model mean_test_loss, test_count = test(sess, model, model_config, handle, testing_iterator, testing_handle, writer, test_count, experiment_id) print('{ts}:\n\tAll done with experiment {exid}!'.format(ts=datetime.datetime.now(), exid=experiment_id)) if model_config['initialisation_test']: print('\tInitial test loss: {init}'.format(init=initial_test_loss)) print('\tFinal test loss: {final}'.format(final=mean_test_loss))