Esempio n. 1
0
def train():
    tr_tfrecords_lst, tr_num_batches = read_list_file("tr_tf",
                                                      FLAGS.batch_size)

    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                tr_mixed, tr_labels, tr_genders, tr_lengths = get_padded_batch(
                    tr_tfrecords_lst,
                    FLAGS.batch_size,
                    FLAGS.input_size * 2,
                    FLAGS.output_size * 2,
                    num_enqueuing_threads=FLAGS.num_threads,
                    num_epochs=FLAGS.max_epochs)

                tr_inputs = tf.slice(tr_mixed, [0, 0, 0],
                                     [-1, -1, FLAGS.input_size])
                tr_inputs = tf.reshape(tr_inputs, [-1, FLAGS.input_size])

        with tf.name_scope('model'):
            tr_model = LSTM(tr_inputs, tr_lengths)
            # tr_model = LSTM(FLAGS, tr_inputs, tr_labels,tr_lengths,tr_genders)
            # tr_model and val_model should share variables
            tf.get_variable_scope().reuse_variables
        show_all_variables()
        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        # Prevent exhausting all the gpu memories.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        # config.gpu_options.per_process_gpu_memory_fraction = 0.5
        config.allow_soft_placement = True
        #sess = tf.InteractiveSession(config=config)
        sess = tf.Session(config=config)
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            # Cross validation before training.
            # lengths = sess.run(tr_lengths)
            # print('lengths shape: ')
            # print(np.shape(lengths))
            # print(lengths)
            # tr_mix = sess.run(tr_labels)
            # print('tr_mix shape: ')
            # print(np.shape(tr_mix))
            X_mean, X_std = train_one_epoch(sess, coord, tr_model,
                                            tr_num_batches)
            # print(np.shape(X_mean))
            np.savetxt('tr_mean_good1', X_mean, fmt="%s")
            np.savetxt('tr_std_good1', X_std, fmt="%s")

        except Exception, e:
            # Report exceptions to the coordinator.
            coord.request_stop(e)
        finally:
Esempio n. 2
0
 def _build(self):
     if not self._infer:
         input_sequence, target_sequence, length = get_padded_batch(
             file_list=self._tfrecords_lst,
             batch_size=self._batch_size,
             input_size=self._input_size,
             output_size=self._output_size,
             num_enqueuing_threads=self._num_enqueuing_threads,
             num_epochs=self._num_epochs,
             infer=self._infer)
         return input_sequence, target_sequence, length
     else:
         input_sequence, length = get_padded_batch(
             file_list=self._tfrecords_lst,
             batch_size=self._batch_size,
             input_size=self._input_size,
             output_size=self._output_size,
             num_enqueuing_threads=self._num_enqueuing_threads,
             num_epochs=self._num_epochs,
             infer=self._infer)
         return input_sequence, length
Esempio n. 3
0
def decode():
    """Decoding the inputs using current model."""
    tfrecords_lst, num_batches = read_list_file('tt_tf', FLAGS.batch_size)

    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                tt_mixed, tt_labels, tt_genders, tt_lengths = get_padded_batch(
                    tfrecords_lst,
                    FLAGS.batch_size,
                    FLAGS.input_size * 2,
                    FLAGS.output_size * 2,
                    num_enqueuing_threads=1,
                    num_epochs=1,
                    shuffle=False)
                tt_inputs = tf.slice(tt_mixed, [0, 0, 0],
                                     [-1, -1, FLAGS.input_size])
                tt_angles = tf.slice(tt_mixed, [0, 0, FLAGS.input_size],
                                     [-1, -1, -1])
        # Create two models with train_input and val_input individually.
        with tf.name_scope('model'):
            model = LSTM(FLAGS,
                         tt_inputs,
                         tt_labels,
                         tt_lengths,
                         tt_genders,
                         infer=True)

        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())

        sess = tf.Session()

        sess.run(init)

        ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir + '/nnet')
        if ckpt and ckpt.model_checkpoint_path:
            tf.logging.info("Restore from " + ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            tf.logging.fatal("checkpoint not fou1nd.")
            sys.exit(-1)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    #cmvn_filename = os.path.join(FLAGS.date_dir, "/train_cmvn.npz")
    #if os.path.isfile(cmvn_filename):
    #    cmvn = np.load(cmvn_filename)
    #else:
    #    tf.logging.fatal("%s not exist, exit now." % cmvn_filename)
    #    sys.exit(-1)

    data_dir = FLAGS.data_dir
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    processed = 0
    try:
        for batch in xrange(num_batches):
            if coord.should_stop():
                break
            if FLAGS.assign == 'def':
                cleaned1, cleaned2, angles, lengths = sess.run(
                    [model._cleaned1, model._cleaned2, tt_angles, tt_lengths])
            else:
                x1, x2 = model.get_opt_output()
                cleaned1, cleaned2 = sess.run([x1, x2])
            spec1 = cleaned1 * np.exp(angles * 1j)
            spec2 = cleaned2 * np.exp(angles * 1j)
            #sequence = activations * cmvn['stddev_labels'] + \
            #    cmvn['mean_labels']
            for i in range(0, FLAGS.batch_size):
                tffilename = tfrecords_lst[i + processed]
                (_, name) = os.path.split(tffilename)
                (partname, _) = os.path.splitext(name)
                wav_name1 = data_dir + '/' + partname + '_1.wav'
                wav_name2 = data_dir + '/' + partname + '_2.wav'
                wav1 = istft(spec1[i, 0:lengths[i], :], size=256, shift=128)
                wav2 = istft(spec2[i, 0:lengths[i], :], size=256, shift=128)
                audiowrite(wav1, wav_name1, 8000, True, True)
                audiowrite(wav2, wav_name2, 8000, True, True)
            processed = processed + FLAGS.batch_size

            if batch % 50 == 0:
                print(batch)

    except Exception, e:
        # Report exceptions to the coordinator.
        coord.request_stop(e)
Esempio n. 4
0
def train():
    tr_tfrecords_lst, tr_num_batches = read_list_file("tr_tf",
                                                      FLAGS.batch_size)
    val_tfrecords_lst, val_num_batches = read_list_file(
        "cv_tf", FLAGS.batch_size)

    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                tr_mixed, tr_labels, tr_genders, tr_lengths = get_padded_batch(
                    tr_tfrecords_lst,
                    FLAGS.batch_size,
                    FLAGS.input_size * 2,
                    FLAGS.output_size * 2,
                    num_enqueuing_threads=FLAGS.num_threads,
                    num_epochs=FLAGS.max_epochs)

                val_mixed, val_labels, val_genders, val_lengths = get_padded_batch(
                    val_tfrecords_lst,
                    FLAGS.batch_size,
                    FLAGS.input_size * 2,
                    FLAGS.output_size * 2,
                    num_enqueuing_threads=FLAGS.num_threads,
                    num_epochs=FLAGS.max_epochs + 1)
                tr_inputs = tf.slice(tr_mixed, [0, 0, 0],
                                     [-1, -1, FLAGS.input_size])
                val_inputs = tf.slice(val_mixed, [0, 0, 0],
                                      [-1, -1, FLAGS.input_size])

        with tf.name_scope('model'):
            tr_model = LSTM(FLAGS, tr_inputs, tr_labels, tr_lengths,
                            tr_genders)
            # tr_model and val_model should share variables
            tf.get_variable_scope().reuse_variables()
            val_model = LSTM(FLAGS, val_inputs, val_labels, val_lengths,
                             val_genders)
        show_all_variables()
        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        # Prevent exhausting all the gpu memories.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        # config.gpu_options.per_process_gpu_memory_fraction = 0.5
        config.allow_soft_placement = True
        #sess = tf.InteractiveSession(config=config)
        sess = tf.Session(config=config)
        sess.run(init)
        if FLAGS.resume_training.lower() == 'true':
            ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir + '/nnet')
            if ckpt and ckpt.model_checkpoint_path:
                tf.logging.info("restore from" + ckpt.model_checkpoint_path)
                tr_model.saver.restore(sess, ckpt.model_checkpoint_path)
                best_path = ckpt.model_checkpoint_path
            else:
                tf.logging.fatal("checkpoint not found")
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            # Cross validation before training.
            loss_prev = eval_one_epoch(sess, coord, val_model, val_num_batches)
            tf.logging.info("CROSSVAL PRERUN AVG.LOSS %.4F" % loss_prev)

            sess.run(tf.assign(tr_model.lr, FLAGS.learning_rate))
            for epoch in xrange(FLAGS.max_epochs):
                start_time = time.time()

                # Training
                tr_loss = train_one_epoch(sess, coord, tr_model,
                                          tr_num_batches)

                # Validation
                val_loss = eval_one_epoch(sess, coord, val_model,
                                          val_num_batches)

                end_time = time.time()
                # Determine checkpoint path
                ckpt_name = "nnet_iter%d_lrate%e_tr%.4f_cv%.4f" % (
                    epoch + 1, FLAGS.learning_rate, tr_loss, val_loss)
                ckpt_dir = FLAGS.save_dir + '/nnet'
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                ckpt_path = os.path.join(ckpt_dir, ckpt_name)
                # Relative loss between previous and current val_loss
                rel_impr = (loss_prev - val_loss) / loss_prev
                # Accept or reject new parameters
                if val_loss < loss_prev:
                    tr_model.saver.save(sess, ckpt_path)
                    # Logging train loss along with validation loss
                    loss_prev = val_loss
                    best_path = ckpt_path
                    tf.logging.info(
                        "ITERATION %d: TRAIN AVG.LOSS %.4f, (lrate%e) CROSSVAL"
                        " AVG.LOSS %.4f, %s (%s), TIME USED: %.2fs" %
                        (epoch + 1, tr_loss, FLAGS.learning_rate, val_loss,
                         "nnet accepted", ckpt_name,
                         (end_time - start_time) / 1))
                else:
                    tr_model.saver.restore(sess, best_path)
                    tf.logging.info(
                        "ITERATION %d: TRAIN AVG.LOSS %.4f, (lrate%e) CROSSVAL"
                        " AVG.LOSS %.4f, %s, (%s), TIME USED: %.2fs" %
                        (epoch + 1, tr_loss, FLAGS.learning_rate, val_loss,
                         "nnet rejected", ckpt_name,
                         (end_time - start_time) / 1))

                # Start halving when improvement is low
                if rel_impr < FLAGS.start_halving_impr:
                    FLAGS.learning_rate *= FLAGS.halving_factor
                    sess.run(tf.assign(tr_model.lr, FLAGS.learning_rate))

                # Stopping criterion
                if rel_impr < FLAGS.end_halving_impr:
                    if epoch < FLAGS.min_epochs:
                        tf.logging.info(
                            "we were supposed to finish, but we continue as "
                            "min_epochs : %s" % FLAGS.min_epochs)
                        continue
                    else:
                        tf.logging.info(
                            "finished, too small rel. improvement %g" %
                            rel_impr)
                        break
        except Exception, e:
            # Report exceptions to the coordinator.
            coord.request_stop(e)
        finally:
Esempio n. 5
0
def decode():
    """Decoding the inputs using current model."""
    tfrecords_lst, num_batches = read_config_file('test_8k', FLAGS.batch_size)

    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                tt_mixed, tt_inputs, tt_labels1, tt_labels2, tt_lengths = get_padded_batch(
                    tfrecords_lst,
                    FLAGS.batch_size,
                    FLAGS.input_dim,
                    FLAGS.output_dim,
                    num_enqueuing_threads=1,
                    num_epochs=1,
                    shuffle=False)

        # Create two models with train_input and val_input individually.
        with tf.name_scope('model'):
            model = LSTM(FLAGS,
                         tt_inputs,
                         tt_mixed,
                         tt_labels1,
                         tt_labels2,
                         tt_lengths,
                         infer=True)

        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())

        sess = tf.Session()

        sess.run(init)

        ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir + '/nnet')
        if ckpt and ckpt.model_checkpoint_path:
            tf.logging.info("Restore from " + ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            tf.logging.fatal("checkpoint not fou1nd.")
            sys.exit(-1)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        #cmvn_filename = os.path.join(FLAGS.date_dir, "/train_cmvn.npz")
        #if os.path.isfile(cmvn_filename):
        #    cmvn = np.load(cmvn_filename)
        #else:
        #    tf.logging.fatal("%s not exist, exit now." % cmvn_filename)
        #    sys.exit(-1)

    data_dir = FLAGS.data_dir
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if (FLAGS.save_mask).lower() == 'true':
        mask_dir = data_dir + '/mask/'
        os.makedirs(mask_dir)
    try:
        for batch in xrange(num_batches):
            if coord.should_stop():
                break
            if FLAGS.assign == 'def':
                cleaned1, cleaned2, m1, m2 = sess.run([
                    model._cleaned1, model._cleaned2, model._activations1,
                    model._activations2
                ])
            else:
                x1, x2 = model.get_opt_output()
                cleaned1, cleaned2, m1, m2 = sess.run(
                    [x1, x2, model._activations1, model._activations2])

                #sequence = activations * cmvn['stddev_labels'] + \
                #    cmvn['mean_labels']

            tffilename = tfrecords_lst[batch]
            (_, name) = os.path.split(tffilename)
            (uttid, _) = os.path.splitext(name)
            (partname, _) = os.path.splitext(uttid)
            if (FLAGS.save_mask).lower() == 'true':
                np.savetxt(mask_dir + partname + '_1.mask', m1[0, :, :])
                np.savetxt(mask_dir + partname + '_2.mask', m2[0, :, :])
            np.save(data_dir + partname + '_1.wav', cleaned1[0, :, :])
            np.save(data_dir + partname + '_2.wav', cleaned2[0, :, :])
            if batch % 500 == 0:
                print(batch)

    except Exception, e:
        # Report exceptions to the coordinator.
        coord.request_stop(e)