Example #1
0
def get_num_batch(file_list, infer=False):
    """ Get number of bacthes. """
    data_list = read_list(file_list)
    counter = 0

    with tf.Graph().as_default():
        if not infer:
            inputs, labels = get_batch(data_list,
                                       FLAGS.batch_size,
                                       FLAGS.input_dim,
                                       FLAGS.output_dim,
                                       0,
                                       0,
                                       FLAGS.num_threads * 2,
                                       1,
                                       infer=infer)
        else:
            utt_id, inputs, _ = get_batch(data_list,
                                          FLAGS.batch_size,
                                          FLAGS.input_dim,
                                          FLAGS.output_dim,
                                          0,
                                          0,
                                          FLAGS.num_threads * 2,
                                          1,
                                          infer=infer)

        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        sess = tf.Session()
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            start = datetime.datetime.now()
            while not coord.should_stop():
                sess.run([inputs])
                counter += 1
        except tf.errors.OutOfRangeError:
            end = datetime.datetime.now()
            duration = (end - start).total_seconds()
            print('Number of batches is %d. Reading time is %.0fs.' %
                  (counter, duration))
        finally:
            # When done, ask the threads to stop.
            coord.request_stop()

        # Wait for threads to finish.
        coord.join(threads)
        sess.close()

    return counter
Example #2
0
def train(cv_num_batch, tr_num_batch):
    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                tr_data_list = read_list(FLAGS.tr_list_file)
                tr_inputs, tr_labels = get_batch(tr_data_list,
                                                 FLAGS.batch_size,
                                                 FLAGS.input_dim,
                                                 FLAGS.output_dim,
                                                 FLAGS.left_context,
                                                 FLAGS.right_context,
                                                 FLAGS.num_threads,
                                                 FLAGS.max_epoches)

                cv_data_list = read_list(FLAGS.cv_list_file)
                cv_inputs, cv_labels = get_batch(cv_data_list,
                                                 FLAGS.batch_size,
                                                 FLAGS.input_dim,
                                                 FLAGS.output_dim,
                                                 FLAGS.left_context,
                                                 FLAGS.right_context,
                                                 FLAGS.num_threads,
                                                 FLAGS.max_epoches)

        devices = []
        for i in xrange(FLAGS.num_gpu):
            device_name = ("/gpu:%d" % i)
            print('Using device: ', device_name)
            devices.append(device_name)
        # Prevent exhausting all the gpu memories.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        # execute the session
        with tf.Session(config=config) as sess:
            # Create two models with tr_inputs and cv_inputs individually.
            with tf.name_scope('model'):
                print("=======================================================")
                print("                  Build Train model                    ")
                print("=======================================================")
                tr_model = SEGAN(sess, FLAGS, devices, tr_inputs, tr_labels,
                                 cross_validation=False)
                # tr_model and val_model should share variables
                print("=======================================================")
                print("            Build Cross-Validation model               ")
                print("=======================================================")
                tf.get_variable_scope().reuse_variables()
                cv_model = SEGAN(sess, FLAGS, devices, cv_inputs, cv_labels,
                                 cross_validation=True)

            show_all_variables()

            init = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())
            print("Initializing variables ...")
            sess.run(init)

            if tr_model.load(tr_model.save_dir):
                print("[*] Load SUCCESS")
            else:
                print("[!] Load failed, maybe begin a new model.")

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            try:
                cv_g_loss, cv_d_loss = eval_one_epoch(sess, coord, cv_model,
                                                      cv_num_batch)
                print(("CROSSVAL PRERUN AVG.LOSS "
                    "%.4F(G_loss), %.4F(D_loss)") % (cv_g_loss, cv_d_loss))
                sys.stdout.flush()

                for epoch in xrange(FLAGS.max_epoches):
                    start = datetime.datetime.now()
                    tr_g_loss, tr_d_loss = train_one_epoch(sess, coord,
                            tr_model, tr_num_batch, epoch+1)
                    cv_g_loss, cv_d_loss = eval_one_epoch(sess, coord,
                            cv_model, cv_num_batch)
                    end = datetime.datetime.now()
                    print(("Epoch %02d: TRAIN AVG.LOSS "
                        "%.5F(G_loss, lrate(%e)), %.5F(D_loss, lrate(%e)), "
                        "CROSSVAL AVG.LOSS "
                        "%.5F(G_loss), %.5F(D_loss), TIME USED: %.2fmin") % (
                            epoch+1, tr_g_loss, tr_model.g_learning_rate,
                            tr_d_loss, tr_model.d_learning_rate,
                            cv_g_loss, cv_d_loss, (end-start).seconds/60.0))
                    sys.stdout.flush()
                    FLAGS.d_learning_rate *= FLAGS.halving_factor
                    FLAGS.g_learning_rate *= FLAGS.halving_factor
                    tr_model.d_learning_rate, tr_model.g_learning_rate = \
                        FLAGS.d_learning_rate, FLAGS.g_learning_rate
                    tr_model.save(tr_model.save_dir, epoch+1)
            except Exception, e:
                # Report exceptions to the coordinator.
                coord.request_stop(e)
            finally:
Example #3
0
def train(cv_num_batch, tr_num_batch):
    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                tr_data_list = read_list(FLAGS.tr_list_file)
                tr_inputs, tr_labels = get_batch(
                    tr_data_list, FLAGS.batch_size, FLAGS.input_dim,
                    FLAGS.output_dim, FLAGS.left_context, FLAGS.right_context,
                    FLAGS.num_threads, FLAGS.max_epoches)

                cv_data_list = read_list(FLAGS.cv_list_file)
                cv_inputs, cv_labels = get_batch(
                    cv_data_list, FLAGS.batch_size, FLAGS.input_dim,
                    FLAGS.output_dim, FLAGS.left_context, FLAGS.right_context,
                    FLAGS.num_threads, FLAGS.max_epoches)

        devices = []
        for i in xrange(FLAGS.num_gpu):
            device_name = ("/gpu:%d" % i)
            print('Using device: ', device_name)
            devices.append(device_name)
        # Prevent exhausting all the gpu memories.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        # execute the session
        with tf.Session(config=config) as sess:
            # Create two models with tr_inputs and cv_inputs individually.
            with tf.name_scope('model'):
                print(
                    "=======================================================")
                print(
                    "|                Build Train model                    |")
                print(
                    "=======================================================")
                tr_model = GAN(sess,
                               FLAGS,
                               devices,
                               tr_inputs,
                               tr_labels,
                               cross_validation=False)
                # tr_model and val_model should share variables
                print(
                    "=======================================================")
                print(
                    "|           Build Cross-Validation model              |")
                print(
                    "=======================================================")
                tf.get_variable_scope().reuse_variables()
                cv_model = GAN(sess,
                               FLAGS,
                               devices,
                               cv_inputs,
                               cv_labels,
                               cross_validation=True)

            show_all_variables()

            init = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())
            print("Initializing variables ...")
            sess.run(init)

            if tr_model.load(tr_model.save_dir):
                print("[*] Load SUCCESS")
            else:
                print("[!] Begin a new model.")

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                cv_d_rl_loss, cv_d_fk_loss, \
                cv_d_loss, cv_g_adv_loss, \
                cv_g_mse_loss, cv_g_l2_loss, \
                cv_g_loss = eval_one_epoch(sess, coord, cv_model, cv_num_batch, 0)
                print("CROSSVAL.LOSS PRERUN: "
                      "d_rl_loss = {:.5f}, d_fk_loss = {:.5f}, "
                      "d_loss = {:.5f}, g_adv_loss = {:.5f}, "
                      "g_mse_loss = {:.5f}, g_l2_loss = {:.5f}, "
                      "g_loss = {:.5f}".format(cv_d_rl_loss, cv_d_fk_loss,
                                               cv_d_loss, cv_g_adv_loss,
                                               cv_g_mse_loss, cv_g_l2_loss,
                                               cv_g_loss))
                sys.stdout.flush()

                g_loss_prev = cv_g_loss
                decay_steps = 1

                for epoch in range(FLAGS.max_epoches):
                    start = datetime.datetime.now()
                    tr_d_rl_loss, tr_d_fk_loss, \
                    tr_d_loss, tr_g_adv_loss, \
                    tr_g_mse_loss, tr_g_l2_loss, \
                    tr_g_loss = train_one_epoch(sess, coord,
                                                tr_model, tr_num_batch, epoch+1)
                    cv_d_rl_loss, cv_d_fk_loss, \
                    cv_d_loss, cv_g_adv_loss, \
                    cv_g_mse_loss, cv_g_l2_loss, \
                    cv_g_loss = eval_one_epoch(sess, coord,
                                               cv_model, cv_num_batch, epoch+1)
                    d_lr, g_lr = sess.run(
                        [tr_model.d_learning_rate, tr_model.g_learning_rate])

                    end = datetime.datetime.now()
                    print("Epoch {} (TRAIN AVG.LOSS): "
                          "d_rl_loss = {:.5f}, d_fk_loss = {:.5f}, "
                          "d_loss = {:.5f}, g_adv_loss = {:.5f}, "
                          "g_mse_loss = {:.5f}, g_l2_loss = {:.5f}, "
                          "g_loss = {:.5f}, "
                          "d_lr = {:.3e}, g_lr = {:.3e}\n"
                          "Epoch {} (CROSS AVG.LOSS): "
                          "d_rl_loss = {:.5f}, d_fk_loss = {:.5f}, "
                          "d_loss = {:.5f}, g_adv_loss = {:.5f}, "
                          "g_mse_loss = {:.5f}, g_l2_loss = {:.5f}, "
                          "g_loss = {:.5f}, "
                          "time = {:.2f} h".format(
                              epoch + 1, tr_d_rl_loss, tr_d_fk_loss, tr_d_loss,
                              tr_g_adv_loss, tr_g_mse_loss, tr_g_l2_loss,
                              tr_g_loss, d_lr, g_lr, epoch + 1, cv_d_rl_loss,
                              cv_d_fk_loss, cv_d_loss, cv_g_adv_loss,
                              cv_g_mse_loss, cv_g_l2_loss, cv_g_loss,
                              (end - start).seconds / 3600.0))
                    sys.stdout.flush()

                    g_loss_new = cv_g_loss
                    # Accept or reject new parameters
                    if g_loss_new < g_loss_prev:
                        tr_model.save(tr_model.save_dir, epoch + 1)
                        print("Epoch {}: Nnet Accepted. "
                              "Save model SUCCESS.".format(epoch + 1))
                        # Relative loss between previous and current val_loss
                        g_rel_impr = (g_loss_prev - g_loss_new) / g_loss_prev
                        g_loss_prev = g_loss_new
                    else:
                        print("Epoch {}: Nnet Rejected.".format(epoch + 1))
                        if tr_model.load(tr_model.save_dir):
                            print("[*] Load previous model SUCCESS.")
                            sys.stdout.flush()
                        else:
                            print("[!] Load failed. No checkpoint from {} to "
                                  "restore previous model. Exit now.".format(
                                      tr_model.save_dir))
                            sys.stdout.flush()
                            sys.exit(1)
                        # Relative loss between previous and current val_loss
                        g_rel_impr = (g_loss_prev - g_loss_new) / g_loss_prev

                    # Start decay when improvement is low (Exponential decay)
                    if g_rel_impr < FLAGS.start_decay_impr and \
                            epoch+1 >= FLAGS.keep_lr:
                        g_learning_rate = \
                                FLAGS.g_learning_rate * \
                                FLAGS.decay_factor ** (decay_steps)
                        d_learning_rate = \
                                FLAGS.d_learning_rate * \
                                FLAGS.decay_factor ** (decay_steps)
                        disc_noise_std = \
                                FLAGS.init_disc_noise_std * \
                                FLAGS.decay_factor ** (decay_steps)
                        sess.run(
                            tf.assign(tr_model.g_learning_rate,
                                      g_learning_rate))
                        sess.run(
                            tf.assign(tr_model.d_learning_rate,
                                      d_learning_rate))
                        sess.run(
                            tf.assign(tr_model.disc_noise_std, disc_noise_std))
                        decay_steps += 1

                    # Stopping criterion
                    if g_rel_impr < FLAGS.end_decay_impr:
                        if epoch < FLAGS.min_epoches:
                            print("Epoch %d: We were supposed to finish, "
                                  "but we continue as min_epoches %d" %
                                  (epoch + 1, FLAGS.min_epoches))
                            continue
                        else:
                            print("Epoch %d: Finished, too small relative "
                                  "G improvement %g" % (epoch + 1, g_rel_impr))
                            break

            except Exception, e:
                # Report exceptions to the coordinator.
                coord.request_stop(e)
            finally:
Example #4
0
def decode():
    """Decoding the inputs using current model."""
    tf.logging.info("Get TEST sets number.")
    num_batch = get_num_batch(FLAGS.test_list_file, infer=True)
    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                data_list = read_list(FLAGS.test_list_file)
                test_utt_id, test_inputs, _ = get_batch(
                    data_list,
                    batch_size=1,
                    input_size=FLAGS.input_dim,
                    output_size=FLAGS.output_dim,
                    left=FLAGS.left_context,
                    right=FLAGS.right_context,
                    num_enqueuing_threads=FLAGS.num_threads,
                    num_epochs=1,
                    infer=True)

        devices = []
        for i in xrange(FLAGS.num_gpu):
            device_name = ("/gpu:%d" % i)
            print('Using device: ', device_name)
            devices.append(device_name)

        # Prevent exhausting all the gpu memories.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        # execute the session
        with tf.Session(config=config) as sess:
            # Create two models with tr_inputs and cv_inputs individually.
            with tf.name_scope('model'):
                model = GAN(sess,
                            FLAGS,
                            devices,
                            test_inputs,
                            labels=None,
                            cross_validation=True)

            show_all_variables()

            init = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())
            print("Initializing variables ...")
            sess.run(init)

            if model.load(model.save_dir, moving_average=True):
                print("[*] Load SUCCESS")
            else:
                print("[!] Load failed. Checkpoint not found. Exit now.")
                sys.exit(1)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            cmvn_filename = os.path.join(FLAGS.data_dir, "train_cmvn.npz")
            if os.path.isfile(cmvn_filename):
                cmvn = np.load(cmvn_filename)
            else:
                tf.logging.fatal("%s not exist, exit now." % cmvn_filename)
                sys.exit(1)

            out_dir_name = os.path.join(FLAGS.save_dir, 'test')
            if not os.path.exists(out_dir_name):
                os.makedirs(out_dir_name)

            write_scp_path = os.path.join(out_dir_name, 'feats.scp')
            write_ark_path = os.path.join(out_dir_name, 'feats.ark')
            writer = ArkWriter(write_scp_path)

            try:
                for batch in range(num_batch):
                    if coord.should_stop():
                        break
                    outputs = model.generator(test_inputs, None, reuse=True)
                    outputs = tf.reshape(outputs, [-1, model.output_dim])
                    utt_id, activations = sess.run([test_utt_id, outputs])
                    sequence = activations * cmvn['stddev_labels'] + \
                            cmvn['mean_labels']
                    save_result = np.vstack(sequence)
                    writer.write_next_utt(write_ark_path, utt_id[0],
                                          save_result)
                    tf.logging.info("Write inferred %s to %s" %
                                    (utt_id[0], write_ark_path))
            except Exception, e:
                # Report exceptions to the coordinator.
                coord.request_stop(e)
            finally:
def decode():
    """Decoding the inputs using current model."""
    tf.logging.info("Get TEST sets number.")
    num_batch = get_num_batch(FLAGS.test_list_file, infer=True)
    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                data_list = read_list(FLAGS.test_list_file)
                test_utt_id, test_inputs, _ = get_batch(
                    data_list,
                    batch_size=1,
                    input_size=FLAGS.input_dim,
                    output_size=FLAGS.output_dim,
                    left=FLAGS.left_context,
                    right=FLAGS.right_context,
                    num_enqueuing_threads=FLAGS.num_threads,
                    num_epochs=1,
                    infer=True)
                # test_inputs = tf.squeeze(test_inputs, axis=[0])
        devices = []
        for i in xrange(FLAGS.num_gpu):
            device_name = ("/gpu:%d" % i)
            print('Using device: ', device_name)
            devices.append(device_name)

        # Prevent exhausting all the gpu memories.
        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = 0.4
        #config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        set_session(tf.Session(config=config))
        # execute the session
        with tf.Session(config=config) as sess:
            # Create two models with tr_inputs and cv_inputs individually.
            with tf.name_scope('model'):
                model = DNNTrainer(sess,
                                   FLAGS,
                                   devices,
                                   test_inputs,
                                   labels=None,
                                   cross_validation=True)

            show_all_variables()

            init = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())
            print("Initializing variables ...")
            sess.run(init)

            if model.load(model.save_dir, moving_average=False):
                print("[*] Load Moving Average model SUCCESS")
            else:
                print("[!] Load failed. Checkpoint not found. Exit now.")
                sys.exit(1)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            cmvn_filename = os.path.join(FLAGS.data_dir, "train_cmvn.npz")
            if os.path.isfile(cmvn_filename):
                cmvn = np.load(cmvn_filename)
            else:
                tf.logging.fatal("%s not exist, exit now." % cmvn_filename)
                sys.exit(1)
            out_dir_name = os.path.join('/Work18/2017/linan/SE/my_enh',
                                        FLAGS.save_dir, FLAGS.savetestdir)
            # out_dir_name = os.path.join(FLAGS.save_dir, 'test')
            if not os.path.exists(out_dir_name):
                os.makedirs(out_dir_name)

            write_scp_path = os.path.join(out_dir_name, 'feats.scp')
            write_ark_path = os.path.join(out_dir_name, 'feats.ark')
            writer = ArkWriter(write_scp_path)

            outputs = model.generator(test_inputs, None, reuse=True)
            outputs = tf.reshape(outputs, [-1, model.output_dim])
            print('shape is', np.shape(outputs))
            try:
                for batch in range(num_batch):
                    if coord.should_stop():
                        break
                    # outputs = model.generator(test_inputs, None, reuse=True)
                    # outputs = tf.reshape(outputs, [-1, model.output_dim])
                    utt_id, activations = sess.run([test_utt_id, outputs])
                    # sequence = activations * cmvn['stddev_labels'] + \
                    # cmvn['mean_labels']
                    sequence = activations
                    save_result = np.vstack(sequence)
                    dir_load = FLAGS.savetestdir
                    dir_load = dir_load.split('/')[-1]
                    mode = FLAGS.mode
                    if mode == 'use_org':
                        inputs_path = os.path.join(
                            'workspace/features/spectrogram/test', dir_load,
                            '%s.wav.p' % utt_id[0])
                        data = cPickle.load(open(inputs_path, 'rb'))
                        [mixed_complx_x] = data
                        #tf.logging.info("Write inferred %s to %s" %(utt_id[0], np.shape(save_result)))
                        save_result = np.exp(save_result)
                        n_window = cfg.n_window
                        s = recover_wav(save_result, mixed_complx_x,
                                        cfg.n_overlap, np.hamming)
                        s *= np.sqrt(
                            (np.hamming(n_window)**2
                             ).sum())  # Scaler for compensate the amplitude
                        # change after spectrogram and IFFT.
                        print("start enhance wav file")
                        # Write out enhanced wav.
                        out_path = os.path.join("workspace", "enh_wavs",
                                                "test", dir_load,
                                                "%s.enh.wav" % utt_id[0])
                        print("have enhanced all  the wav")
                        pp_data.create_folder(os.path.dirname(out_path))
                        pp_data.write_audio(out_path, s, 16000)
                    elif mode == 'g_l':
                        inputs_path = os.path.join(
                            'workspace/features/spectrogram/test', dir_load,
                            '%s.wav.p' % utt_id[0])
                        data = cPickle.load(open(inputs_path, 'rb'))
                        [mixed_complx_x] = data
                        save_result = np.exp(save_result)
                        s = save_result
                        s = audio_utilities.reconstruct_signal_griffin_lim(
                            s, mixed_complx_x, 512, 256, 15)
                        #s = recover_wav(save_result,mixed_complx_x,cfg.n_overlap, np.hamming)
                        s *= np.sqrt((np.hamming(cfg.n_window)**2).sum())
                        #s = audio._griffin_lim(s)
                        out_path = os.path.join("workspace", "enh_wavs",
                                                "test2", dir_load,
                                                "%s.enh.wav" % utt_id[0])
                        pp_data.create_folder(os.path.dirname(out_path))
                        pp_data.write_audio(out_path, s, 16000)
                        tf.logging.info("Write inferred%s" % (np.shape(s)))
                    #writer.write_next_utt(write_ark_path, utt_id[0], save_result)
                    tf.logging.info("Write inferred %s to %s" %
                                    (utt_id[0], out_path))

            except Exception, e:
                # Report exceptions to the coordinator.
                coord.request_stop(e)
            finally:
Example #6
0
def train(valdi_batch_per_iter, train_batch_per_iter, min_iters, max_iters):
    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                tr_data_list = read_list(FLAGS.tr_list_file)
                tr_inputs, tr_labels = get_batch(
                    tr_data_list, FLAGS.batch_size, FLAGS.input_dim,
                    FLAGS.output_dim, FLAGS.left_context, FLAGS.right_context,
                    FLAGS.num_threads, FLAGS.max_epoches)

                cv_data_list = read_list(FLAGS.cv_list_file)
                cv_inputs, cv_labels = get_batch(
                    cv_data_list, FLAGS.batch_size, FLAGS.input_dim,
                    FLAGS.output_dim, FLAGS.left_context, FLAGS.right_context,
                    FLAGS.num_threads, None)

        devices = []
        for i in xrange(FLAGS.num_gpu):
            device_name = ("/gpu:%d" % i)
            print('Using device: ', device_name)
            devices.append(device_name)
        # Prevent exhausting all the gpu memories.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        # execute the session
        with tf.Session(config=config) as sess:
            # Create two models with tr_inputs and cv_inputs individually.
            with tf.name_scope('model'):
                print(
                    "=======================================================")
                print(
                    "|                Build Train model                    |")
                print(
                    "=======================================================")
                tr_model = GAN(sess,
                               FLAGS,
                               devices,
                               tr_inputs,
                               tr_labels,
                               cross_validation=False)
                # tr_model and val_model should share variables
                print(
                    "=======================================================")
                print(
                    "|           Build Cross-Validation model              |")
                print(
                    "=======================================================")
                tf.get_variable_scope().reuse_variables()
                cv_model = GAN(sess,
                               FLAGS,
                               devices,
                               cv_inputs,
                               cv_labels,
                               cross_validation=True)

            show_all_variables()

            init = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())
            print("Initializing variables ...")
            sess.run(init)

            if tr_model.load(tr_model.save_dir, moving_average=False):
                print("[*] Load SUCCESS")
            else:
                print("[!] Begin a new model.")

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                # Early stop counter
                g_loss_prev = 10000.0
                g_rel_impr = 1.0
                check_interval = 3
                windows_g_loss = []

                g_learning_rate = FLAGS.num_gpu * FLAGS.g_learning_rate
                d_learning_rate = FLAGS.num_gpu * FLAGS.d_learning_rate
                sess.run(tf.assign(tr_model.g_learning_rate, g_learning_rate))
                sess.run(tf.assign(tr_model.d_learning_rate, d_learning_rate))

                for iteration in range(max_iters):
                    start = datetime.datetime.now()
                    tr_d_rl_loss, tr_d_fk_loss, \
                    tr_d_loss, tr_g_adv_loss, \
                    tr_g_mse_loss, tr_g_l2_loss, \
                    tr_g_loss = train_one_iteration(sess, coord,
                        tr_model, train_batch_per_iter, iteration+1)
                    cv_d_rl_loss, cv_d_fk_loss, \
                    cv_d_loss, cv_g_adv_loss, \
                    cv_g_mse_loss, cv_g_l2_loss, \
                    cv_g_loss = eval_one_iteration(sess, coord,
                        cv_model, valdi_batch_per_iter, iteration+1)
                    d_learning_rate, \
                    g_learning_rate = sess.run([tr_model.d_learning_rate,
                                                tr_model.g_learning_rate])

                    end = datetime.datetime.now()
                    print("{}/{} (INFO): d_learning_rate = {:.5e}, "
                          "g_learning_rate = {:.5e}, time = {:.3f} min\n"
                          "{}/{} (TRAIN AVG.LOSS): "
                          "d_rl_loss = {:.5f}, d_fk_loss = {:.5f}, "
                          "d_loss = {:.5f}, g_adv_loss = {:.5f}, "
                          "g_mse_loss = {:.5f}, g_l2_loss = {:.5f}, "
                          "g_loss = {:.5f}\n"
                          "{}/{} (CROSS AVG.LOSS): "
                          "d_rl_loss = {:.5f}, d_fk_loss = {:.5f}, "
                          "d_loss = {:.5f}, g_adv_loss = {:.5f}, "
                          "g_mse_loss = {:.5f}, g_l2_loss = {:.5f}, "
                          "g_loss = {:.5f}".format(
                              iteration + 1, max_iters, d_learning_rate,
                              g_learning_rate, (end - start).seconds / 60.0,
                              iteration + 1, max_iters, tr_d_rl_loss,
                              tr_d_fk_loss, tr_d_loss, tr_g_adv_loss,
                              tr_g_mse_loss, tr_g_l2_loss, tr_g_loss,
                              iteration + 1, max_iters, cv_d_rl_loss,
                              cv_d_fk_loss, cv_d_loss, cv_g_adv_loss,
                              cv_g_mse_loss, cv_g_l2_loss, cv_g_loss))
                    sys.stdout.flush()

                    # Start decay learning rate
                    g_learning_rate = exponential_decay(
                        iteration + 1, FLAGS.num_gpu, min_iters,
                        FLAGS.g_learning_rate)
                    d_learning_rate = exponential_decay(
                        iteration + 1, FLAGS.num_gpu, min_iters,
                        FLAGS.d_learning_rate)
                    disc_noise_std = exponential_decay(
                        iteration + 1,
                        FLAGS.num_gpu,
                        min_iters,
                        FLAGS.init_disc_noise_std,
                        multiply_jobs=False)
                    sess.run(
                        tf.assign(tr_model.g_learning_rate, g_learning_rate))
                    sess.run(
                        tf.assign(tr_model.d_learning_rate, d_learning_rate))
                    sess.run(tf.assign(tr_model.disc_noise_std,
                                       disc_noise_std))

                    windows_g_loss.append(cv_g_loss)

                    # Accept or reject new parameters.
                    if (iteration + 1) % check_interval == 0:
                        g_loss_new = np.mean(windows_g_loss)
                        g_rel_impr = (g_loss_prev - g_loss_new) / g_loss_prev
                        if g_rel_impr > 0.0:
                            tr_model.save(tr_model.save_dir, iteration + 1)
                            print("Iteration {}: Nnet Accepted. "
                                  "Save model SUCCESS. g_loss_prev = {:.5f}, "
                                  "g_loss_new = {:.5f}".format(
                                      iteration + 1, g_loss_prev, g_loss_new))
                            g_loss_prev = g_loss_new
                        else:
                            print("Iteration {}: Nnet Rejected. "
                                  "g_loss_prev = {:.5f}, "
                                  "g_loss_new = {:.5f}".format(
                                      iteration + 1, g_loss_prev, g_loss_new))
                            # tr_model.load(tr_model.save_dir, moving_average=False)
                        windows_g_loss = []

                    # Stopping criterion.
                    if iteration + 1 > min_iters and \
                            (iteration + 1) % check_interval == 0:
                        if g_rel_impr < FLAGS.end_improve:
                            print("Iteration %d: Finished, too small relative "
                                  "G improvement %g" %
                                  (iteration + 1, g_rel_impr))
                            break
                    sys.stdout.flush()

                if windows_g_loss:
                    g_loss_new = np.mean(windows_g_loss)
                    g_rel_impr = (g_loss_prev - g_loss_new) / g_loss_prev
                    if g_rel_impr > 0.0:
                        tr_model.save(tr_model.save_dir, iteration + 1)
                        print("Iteration {}: Nnet Accepted. "
                              "Save model SUCCESS. g_loss_prev = {:.5f}, "
                              "g_loss_new = {:.5f}".format(
                                  iteration + 1, g_loss_prev, g_loss_new))
                        g_loss_prev = g_loss_new
                        sys.stdout.flush()
                    windows_g_loss = []

            except Exception, e:
                # Report exceptions to the coordinator.
                coord.request_stop(e)
            finally: