def split_train_valid_and_test_files_fn(args):
    """
    generate non-overlapped training-test set partition

    After downloading and unzipping the MAPS dataset,
    1. define an environment variable called maps to point to the directory of the MAPS dataset,
    2. populate test_dirs with the actual directories of the close and the ambient setting generated by
       the Disklavier piano,
    3. and populate train_dirs with the actual directoreis of the other 7 settings generated by the synthesizer.
    test_dirs_close = ['ENSTDkCl/MUS']
    train_dirs = ['AkPnBcht/MUS', 'AkPnBsdf/MUS', 'AkPnCGdD/MUS', 'AkPnStgb/MUS',
                  'SptkBGAm/MUS', 'SptkBGCl/MUS', 'StbgTGd2/MUS']
    """
    '''
    put "close" test files in a directory, ".../maps/ENSTDkCl"
    put all training files and validation in another directory, ".../maps/train"
    '''
    test_dirs = ['ENSTDkCl/MUS']
    train_dirs = ['train']
    maps_dir = args.data_root

    test_files = []
    for directory in test_dirs:
        path = os.path.join(maps_dir, directory)
        path = os.path.join(path, '*.wav')
        wav_files = glob.glob(path)
        test_files += wav_files

    test_ids = set(
        [MiscFns.filename_to_id(wav_file) for wav_file in test_files])
    print('test_ids={}'.format(len(test_ids)))
    #assert len(test_ids) == 53

    training_files = []
    validation_files = []
    for directory in train_dirs:
        path = os.path.join(maps_dir, directory)
        path = os.path.join(path, '*.wav')
        wav_files = glob.glob(path)
        for wav_file in wav_files:
            me_id = MiscFns.filename_to_id(wav_file)
            if me_id not in test_ids:
                training_files.append(wav_file)
            else:
                validation_files.append(wav_file)

    print('train={0}, test={1}, valid={2}'.format(len(training_files),
                                                  len(test_files),
                                                  len(validation_files)))

    return dict(training=training_files,
                test=test_files,
                validation=validation_files)
def _dataset_iter_fn(name, args):
    """dataset generator"""
    assert name in ('validation', 'training', 'test')
    file_names = split_train_valid_and_test_files_fn(args)
    logging.debug('{} - enter generator'.format(name))

    if name == 'test' and args.test_with_30_secs:
        _duration = 30
    else:
        _duration = None
    logging.debug('{} - generate spectrograms and labels'.format(name))
    dataset = []
    for file_idx, wav_file_name in enumerate(file_names[name]):
        print('file_idx={}'.format(file_idx))
        logging.debug('{}/{} - {}'.format(
            file_idx + 1, len(file_names),
            os.path.basename(wav_file_name))
        )
        samples, unused_sr = librosa.load(
            mono=False, path=wav_file_name, sr=16000, duration=_duration, dtype=np.float32)

        assert unused_sr == 16000
        assert samples.shape[0] == 2
        spectrogram = []
        for ch in range(2):
            sg = MiscFns.spectrogram_fn(
                samples=samples[ch],
                log_filter_bank_basis=MiscFns.log_filter_bank_fn(),
                spec_stride=512
            )

            spectrogram.append(sg)

        spectrogram = np.stack(spectrogram, axis=-1)
        assert spectrogram.shape[1:] == (229, 2)
        mid_file_name = wav_file_name[:-4] + '.mid'
        # spectrogram.shape[0]: the number of frames
        label = MiscFns.label_fn(mid_file_name=mid_file_name, num_frames=spectrogram.shape[0], spec_stride=512)
        # print('label.shape={}'.format(label.shape))
        dataset.append([spectrogram, label])

        logging.debug('number of frames - {}'.format(spectrogram.shape[0]))

    rec_start_end_for_shuffle = []
    for rec_idx, rec_dict in enumerate(dataset):
        num_frames = len(rec_dict[0])
        split_frames = list(range(0, num_frames + 1, 900))
        if split_frames[-1] != num_frames:
            split_frames.append(num_frames)
        start_end_frame_pairs = zip(split_frames[:-1], split_frames[1:])
        rec_start_end_idx_list = [[rec_idx] + list(start_end_pair) for start_end_pair in start_end_frame_pairs]
        rec_start_end_for_shuffle += rec_start_end_idx_list

    if name == 'training':
        np.random.shuffle(rec_start_end_for_shuffle)

    new_dataset = []
    for rec_idx, start_frame, end_frame in rec_start_end_for_shuffle:
        rec_dict = dataset[rec_idx]
        new_spectrogram = rec_dict[0][start_frame:end_frame]
        new_label = rec_dict[1][start_frame:end_frame]
        new_dataset.append([new_spectrogram, new_label, rec_idx])

    return new_dataset
示例#3
0
def main():
    warnings.simplefilter("ignore", ResourceWarning)
    MODEL_DICT = {}
    MODEL_DICT['config'] = Config()  # generate configurations

    # generate models
    #for name in ('training', 'validation', 'test'):
    for name in ('training', 'test'):
        MODEL_DICT[name] = Model(config=MODEL_DICT['config'], name=name)

    # placeholder for auxiliary information
    aug_info_pl = tf.placeholder(dtype=tf.string, name='aug_info_pl')
    aug_info_summary = tf.summary.text('aug_info_summary', aug_info_pl)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(MODEL_DICT['config'].gpu_id)

    with tf.Session(config=MODEL_DICT['config'].config) as sess:
        coord = tf.train.Coordinator()
        thread = tf.train.start_queue_runners(sess, coord)
        # define model saver
        if MODEL_DICT['config'].train_or_inference.inference is not None or \
                MODEL_DICT['config'].train_or_inference.from_saved is not None or \
                MODEL_DICT['config'].train_or_inference.model_prefix is not None:
            MODEL_DICT['model_saver'] = tf.train.Saver(max_to_keep=200)

            logging.info('saved/restored variables:')
            for idx, var in enumerate(MODEL_DICT['model_saver']._var_list):
                logging.info('{}\t{}'.format(idx, var.op.name))

        # define summary writers
        summary_writer_dict = {}
        #for training_valid_or_test in ('training', 'validation', 'test'):
        for training_valid_or_test in ('training', 'test'):
            if training_valid_or_test == 'training':
                summary_writer_dict[training_valid_or_test] = tf.summary.FileWriter(
                    os.path.join(MODEL_DICT['config'].tb_dir, training_valid_or_test),
                    sess.graph
                )
            else:
                summary_writer_dict[training_valid_or_test] = tf.summary.FileWriter(
                    os.path.join(MODEL_DICT['config'].tb_dir, training_valid_or_test)
                )

        aug_info = []
        if MODEL_DICT['config'].train_or_inference.inference is not None:
            aug_info.append('inference with {}'.format(MODEL_DICT['config'].train_or_inference.inference))
            aug_info.append('inference with only the first 30 secs - {}'.format(MODEL_DICT['config'].test_with_30_secs))
        elif MODEL_DICT['config'].train_or_inference.from_saved is not None:
            aug_info.append('continue training from {}'.format(MODEL_DICT['config'].train_or_inference.from_saved))
        aug_info.append('learning rate - {}'.format(MODEL_DICT['config'].learning_rate))
        aug_info.append('tb dir - {}'.format(MODEL_DICT['config'].tb_dir))
        aug_info.append('debug mode - {}'.format(MODEL_DICT['config'].debug_mode))
        aug_info.append('batch size - {}'.format(MODEL_DICT['config'].batch_size))
        aug_info.append('num of batches per epoch - {}'.format(MODEL_DICT['config'].batches_per_epoch))
        aug_info.append('num of epochs - {}'.format(MODEL_DICT['config'].num_epochs))
        aug_info.append('training start time - {}'.format(datetime.datetime.now()))
        aug_info = '\n\n'.join(aug_info)
        logging.info(aug_info)
        summary_writer_dict['training'].add_summary(sess.run(aug_info_summary, feed_dict={aug_info_pl: aug_info}))

        logging.info('global vars -')
        for idx, var in enumerate(tf.global_variables()):
            logging.info("{}\t{}\t{}".format(idx, var.name, var.shape))

        logging.info('local vars -')
        for idx, var in enumerate(tf.local_variables()):
            logging.info('{}\t{}'.format(idx, var.name))

        #extract tf operations
        op_stat_summary_dict = {}
        #for training_valid_or_test in ('training', 'validation', 'test'):
        for training_valid_or_test in ('training', 'test'):
            op_list = []
            if training_valid_or_test == 'training':
                op_list.append(MODEL_DICT[training_valid_or_test].op_dict['training_op'])
                #op_list.append(MODEL_DICT[training_valid_or_test].op_dict['update_op_after_each_batch'])
                op_stat_summary_dict[training_valid_or_test] = dict(
                    op_list=op_list
                )

            else:
                op_list.append(MODEL_DICT[training_valid_or_test].op_dict['update_op_after_each_batch'])
                stat_op_dict = MODEL_DICT[training_valid_or_test].op_dict['statistics_after_each_epoch']
                tb_summary_dict = MODEL_DICT[training_valid_or_test].op_dict['tb_summary']
                op_stat_summary_dict[training_valid_or_test] = dict(
                    op_list=op_list,
                    stat_op_dict=stat_op_dict,
                    tb_summary_dict=tb_summary_dict
                )


        if MODEL_DICT['config'].train_or_inference.inference is not None:  # inference
            save_path = os.path.join('saved_model', MODEL_DICT['config'].train_or_inference.inference)
            print('save_path:{}'.format(save_path))
            MODEL_DICT['model_saver'].restore(sess, save_path)

            logging.info('do inference ...')
            # initialize local variables for storing statistics
            sess.run(tf.initializers.variables(tf.local_variables()))
            # initialize dataset iterator
            sess.run(MODEL_DICT['test'].reinitializable_iter_for_dataset.initializer)

            op_list = op_stat_summary_dict['test']['op_list']
            stat_op_dict = op_stat_summary_dict['test']['stat_op_dict']
            tb_summary_image = op_stat_summary_dict['test']['tb_summary_dict']['image']
            tb_summary_stats = op_stat_summary_dict['test']['tb_summary_dict']['statistics']

            batch_idx = 0
            op_list_with_image_summary = [tb_summary_image] + op_list
            logging.info('batch - {}'.format(batch_idx + 1))
            tmp = sess.run(op_list_with_image_summary)
            images = tmp[0]
            summary_writer_dict['test'].add_summary(images, 0)

            while True:
                try:
                    sess.run(op_list)
                except tf.errors.OutOfRangeError:
                    break
                else:
                    batch_idx += 1
                    logging.info('batch - {}'.format(batch_idx + 1))
            # write summary data
            summary_writer_dict[training_valid_or_test].add_summary(sess.run(tb_summary_stats), 0)

            # generate statistics
            stat_dict = sess.run(stat_op_dict)

            # display statistics
            MiscFns.display_stat_dict_fn(stat_dict)

        elif MODEL_DICT['config'].train_or_inference.from_saved is not None:  # restore saved model for training
            save_path = os.path.join('saved_model', MODEL_DICT['config'].train_or_inference.from_saved)
            MODEL_DICT['model_saver'].restore(sess, save_path)

            # reproduce statistics
            logging.info('reproduce results ...')
            sess.run(tf.initializers.variables(tf.local_variables()))
            #for valid_or_test in ('validation', 'test'):
            for valid_or_test in (['test']):
                sess.run(MODEL_DICT[valid_or_test].reinitializable_iter_for_dataset.initializer)
            #for valid_or_test in ('validation', 'test'):
            for valid_or_test in (['test']):
                logging.info(valid_or_test)

                op_list = op_stat_summary_dict[valid_or_test]['op_list']
                stat_op_dict = op_stat_summary_dict[valid_or_test]['stat_op_dict']
                statistical_summary = op_stat_summary_dict[valid_or_test]['tb_summary_dict']['statistics']
                image_summary = op_stat_summary_dict[valid_or_test]['tb_summary_dict']['image']

                batch_idx = 0
                op_list_with_image_summary = [image_summary] + op_list
                logging.info('batch - {}'.format(batch_idx + 1))
                tmp = sess.run(op_list_with_image_summary)
                images = tmp[0]
                summary_writer_dict[valid_or_test].add_summary(images, 0)

                while True:
                    try:
                        sess.run(op_list)
                    except tf.errors.OutOfRangeError:
                        break
                    else:
                        batch_idx += 1
                        logging.info('batch - {}'.format(batch_idx + 1))

                summary_writer_dict[valid_or_test].add_summary(sess.run(statistical_summary), 0)

                stat_dict = sess.run(stat_op_dict)

                MiscFns.display_stat_dict_fn(stat_dict)
        else:  # train from scratch and need to initialize global variables
            sess.run(tf.initializers.variables(tf.global_variables()))

        if MODEL_DICT['config'].train_or_inference.inference is None:
            for training_valid_test_epoch_idx in range(MODEL_DICT['config'].num_epochs):
                logging.info('\n\nepoch - {}/{}'.format(training_valid_test_epoch_idx + 1, MODEL_DICT['config'].num_epochs))

                sess.run(tf.initializers.variables(tf.local_variables()))

                #for valid_or_test in ('validation', 'test'):
                for valid_or_test in (['test']):
                    sess.run(MODEL_DICT[valid_or_test].reinitializable_iter_for_dataset.initializer)

                #for training_valid_or_test in ('training', 'validation', 'test'):
                for training_valid_or_test in ('training', 'test'):
                    logging.info(training_valid_or_test)

                    op_list = op_stat_summary_dict[training_valid_or_test]['op_list']
                    if training_valid_or_test == 'test':
                        stat_op_dict = op_stat_summary_dict[training_valid_or_test]['stat_op_dict']
                        statistical_summary = op_stat_summary_dict[training_valid_or_test]['tb_summary_dict']['statistics']
                        image_summary = op_stat_summary_dict[training_valid_or_test]['tb_summary_dict']['image']


                    if training_valid_or_test == 'training':
                        for batch_idx in range(MODEL_DICT['config'].batches_per_epoch):
                            if batch_idx % 1000 == 0:
                                print('batch_idx={}'.format(batch_idx))

                            sess.run(op_list)

                            #print('batch_idx={}'.format(batch_idx))
                            logging.debug('batch - {}/{}'.format(batch_idx + 1, MODEL_DICT['config'].batches_per_epoch))

                        '''
                        summary_writer_dict[training_valid_or_test].add_summary(
                            sess.run(image_summary), training_valid_test_epoch_idx + 1)
                        summary_writer_dict[training_valid_or_test].add_summary(
                            sess.run(statistical_summary), training_valid_test_epoch_idx + 1)
                        param_summary = MODEL_DICT[training_valid_or_test].op_dict['tb_summary']['parameter']
                        summary_writer_dict[training_valid_or_test].add_summary(
                            sess.run(param_summary), training_valid_test_epoch_idx + 1)

                        stat_dict = sess.run(stat_op_dict)
                        '''

                        if MODEL_DICT['config'].train_or_inference.model_prefix is not None:
                            save_path = MODEL_DICT['config'].train_or_inference.model_prefix + \
                                        '_' + 'epoch_{}_of_{}'.format(training_valid_test_epoch_idx + 1,
                                                                      MODEL_DICT['config'].num_epochs)
                            save_path = os.path.join('saved_model', save_path)
                            save_path = MODEL_DICT['model_saver'].save(
                                sess=sess,
                                save_path=save_path,
                                global_step=None,
                                write_meta_graph=False
                            )
                            logging.info('model saved to {}'.format(save_path))
                    else:
                        batch_idx = 0
                        op_list_with_image_summary = [image_summary] + op_list
                        logging.debug('batch - {}'.format(batch_idx + 1))
                        tmp = sess.run(op_list_with_image_summary)
                        images = tmp[0]
                        summary_writer_dict[training_valid_or_test].add_summary(
                            images,
                            training_valid_test_epoch_idx + 1
                        )

                        while True:
                            try:
                                sess.run(op_list)
                            except tf.errors.OutOfRangeError:
                                break
                            else:
                                batch_idx += 1
                                logging.debug('batch - {}'.format(batch_idx + 1))

                        summary_writer_dict[training_valid_or_test].add_summary(
                            sess.run(statistical_summary),
                            training_valid_test_epoch_idx + 1
                        )

                        stat_dict = sess.run(stat_op_dict)

                        MiscFns.display_stat_dict_fn(stat_dict)


        msg = 'training end time - {}'.format(datetime.datetime.now())
        logging.info(msg)
        summary_writer_dict['training'].add_summary(sess.run(aug_info_summary, feed_dict={aug_info_pl: msg}))

        #for training_valid_or_test in ('training', 'validation', 'test'):
        for training_valid_or_test in ('training', 'test'):
            summary_writer_dict[training_valid_or_test].close()
示例#4
0
    def __init__(self):
        self.debug_mode = DEBUG
        self.test_with_30_secs = False
        self.gpu_id = GPU_ID

        self.num_epochs = 50
        self.batches_per_epoch = 5000
        self.batch_size = 4
        self.learning_rate = 1e-4

        self.train_or_inference = Namespace(
            inference=None,
            from_saved=None,
            model_prefix='net'
        )
        # inference: point to the saved model for inference
        # from_saved: point to the saved model from which the training continues
        # model_prefix: the prefix used when saving the model
        # order: If inference is not None, then do inference; elif from_saved is not None, then continue training
        #        from the saved model; elif train from scratch.
        #        If model_prefix is None, the model will not be saved.

        self.tb_dir = 'tb_inf'
        # the directory for saving tensorboard data including performance measures, model parameters, and the model itself

        # check if tb_dir exists
        #assert self.tb_dir is not None
        tmp_dirs = glob.glob('./*/')
        tmp_dirs = [s[2:-1] for s in tmp_dirs]
        if self.tb_dir in tmp_dirs:
            raise EnvironmentError('\n'
                                   'directory {} for storing tensorboard data already exists!\n'
                                   'Cannot proceed.\n'
                                   'Please specify a different directory.'.format(self.tb_dir)
                                   )

        # check if model exists
        if self.train_or_inference.inference is None and self.train_or_inference.model_prefix is not None:
            if os.path.isdir('./saved_model'):
                tmp_prefixes = glob.glob('./saved_model/*')
                prog = re.compile(r'./saved_model/(.+?)_')
                tmp = []
                for file_name in tmp_prefixes:
                    try:
                        prefix = prog.match(file_name).group(1)
                    except AttributeError:
                        pass
                    else:
                        tmp.append(prefix)
                tmp_prefixes = set(tmp)
                if self.train_or_inference.model_prefix in tmp_prefixes:
                    raise EnvironmentError('\n'
                                           'models with prefix {} already exists.\n'
                                           'Please specify a different prefix.'.format(self.train_or_inference.model_prefix)
                                           )

        config = tf.ConfigProto(allow_soft_placement=False, inter_op_parallelism_threads=1,
                   intra_op_parallelism_threads=1)

        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.8
        self.config = config

        self.file_names = MiscFns.split_train_valid_and_test_files_fn()

        # in debug mode the numbers of recordings for training, test and validation are minimized for a debugging purpose
        if self.debug_mode:
            # for name in ('training', 'validation', 'test'):
            #     if name == 'training':
            #         del self.file_names[name][2:]
            #     else:
            #         del self.file_names[name][1:]
            self.file_names['training'] = self.file_names['training'][:2]
            self.file_names['validation'] = self.file_names['validation'][:1]
            self.file_names['test'] = self.file_names['test'][0:2]

            self.num_epochs = 3
            self.batches_per_epoch = 5
            self.gpu_id = 0

        # in inference mode, the numbers of recordings for training and validation are minimized
        if self.train_or_inference.inference is not None:
            for name in ('training', 'validation'):
                del self.file_names[name][1:]

        # the logarithmic filterbank
        self.log_filter_bank = MiscFns.log_filter_bank_fn()