示例#1
0
    def train(self, sess):
        # tf.initialize_all_variables().run()
        print('initializing...opt')
        d_opt = self.d_opt
        g_opt = self.g_opt

        try:
            init = tf.global_variables_initializer()
            sess.run(init)
        except AttributeError:
            init = tf.intializer_all_varialble()
            sess.run(init)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        sample_spec, sample_rat, sample_sid = self.sess.run(
            [self.spec, self.rat, self.sid], feed_dict={self.is_valid: False})
        print('sample spec shape: ', sample_spec.shape)
        print('sample rat shape: ', sample_rat.shape)
        counter = 0
        # count of num of samples
        num_examples = 0
        for record in tf.python_io.tf_record_iterator(self.tfrecords):
            num_examples += 1
        print("total num of patches in tfrecords", self.tfrecords, ":  ",
              num_examples)

        num_batches = num_examples / self.batch_size
        print('batches per epoch: ', num_batches)
        batch_idx = 0
        current_epoch = 0
        batch_timings = []

        d_losses = []
        d_fake_losses = []
        d_real_losses = []
        g_losses = []
        g_adv_losses = []
        g_l1_losses = []
        v_l1_losses = []

        try:
            while not coord.should_stop():
                start = timeit.default_timer()

                for i in range(2):
                    _d_opt, d_fake_loss, d_real_loss = self.sess.run(
                        [d_opt, self.d_fake_losses[0], self.d_real_losses[0]],
                        feed_dict={self.is_valid: False})
                _g_opt, g_adv_loss, g_l1_loss = self.sess.run(
                    [g_opt, self.g_adv_losses[0], self.g_l1_losses[0]],
                    feed_dict={self.is_valid: False})
                v_l1_loss = self.sess.run(self.g_l1_losses[0],
                                          feed_dict={self.is_valid: True})

                end = timeit.default_timer()
                batch_timings.append(end - start)

                d_fake_losses.append(d_fake_loss)
                d_real_losses.append(d_real_loss)
                g_adv_losses.append(g_adv_loss)
                g_l1_losses.append(g_l1_loss)
                v_l1_losses.append(v_l1_loss)

                print('{}/{} (epoch {}), '
                      'd_rl_loss = {:.5f}, '
                      'd_fk_loss = {:.5f}, '
                      'g_adv_loss = {:.5f}, '
                      'g_l1_loss = {:.5f}, '
                      'valid_l1_loss = {:.5f}, '
                      ' time/batch = {:.5f}, '
                      'mtime/batch = {:.5f}'.format(counter,
                                                    self.epoch * num_batches,
                                                    current_epoch, d_real_loss,
                                                    d_fake_loss, g_adv_loss,
                                                    g_l1_loss, v_l1_loss,
                                                    end - start,
                                                    np.mean(batch_timings)))

                if counter % 10 == 0:
                    pred, y = self.sess.run([self.fake_input, self.rat],
                                            feed_dict={self.is_valid: False})
                    print('pred', pred)
                    print('y', y)

                batch_idx += 1
                counter += 1
                if (counter) % 2000 == 0 and (counter) > 0:
                    self.saver.save(self.sess,
                                    "./only_spec_gan_" + str(counter) + "_th")
                if batch_idx >= num_batches:
                    current_epoch += 1
                    #reset batch idx
                    batch_idx = 0
                if current_epoch >= self.epoch:
                    print(str(self.epoch), ': epoch limit')
                    print('saving last model at iteration', str(counter))
                    self.saver.save(self.sess, "./only_spec_gan_final")
                    break

        except tf.errors.OutOfRangeError:
            print('done training')
            pass
        finally:
            coord.request_stop()
        coord.join(threads)
    def train(self, sess, config):
        """ Training the GAN """
        print('initializing...opt')
        d_opt = self.d_opt
        g_opt = self.g_opt

        try:
            init = tf.global_variables_initializer()
            sess.run(init)
        except AttributeError:
            init = tf.intializer_all_varialble()
            sess.run(init)

        print('initializing...var')
        # g_summaries = [self.d_fake_summary,
        #                 self.d_fake_loss_summary,
        #                 self.g_loss_summary,
        #                 self.g_l2_loss_summary,
        #                 self.g_loss_adv_summary,
        #                 self.generated_wav_summary]
        # d_summaries = [self.d_loss_summary, self.d_real_summary, self.d_real_loss_summary, self.high_wav_summary]

        # if hasattr(self, 'alpha_summ'):
        #     g_summaries += self.alpha_summ
        # self.g_sum = tf.summary.merge(g_summaries)
        # self.d_sum = tf.summary.merge(d_summaries)

        if not os.path.exists(os.path.join(config.save_path, 'train')):
            os.makedirs(os.path.join(config.save_path, 'train'))

        self.writer = tf.summary.FileWriter(
            os.path.join(config.save_path, 'train'), self.sess.graph)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        sample_low, sample_high, sample_z = self.sess.run(
            [self.gt_low[0], self.gt_high[0], self.zz_A[0]],
            feed_dict={
                self.is_valid: False,
                self.is_train: True,
                self.is_mismatch: False
            })
        v_sample_low, v_sample_high, v_sample_z = self.sess.run(
            [self.gt_low[0], self.gt_high[0], self.zz_A[0]],
            feed_dict={
                self.is_valid: True,
                self.is_train: False,
                self.is_mismatch: False
            })

        print('sample low shape: ', sample_low.shape)
        print('sample high shape: ', sample_high.shape)
        print('sample z shape: ', sample_z.shape)

        save_path = config.save_path
        counter = 0
        # count of num of samples
        num_examples = 0
        for record in tf.python_io.tf_record_iterator(self.tfrecords):
            num_examples += 1
        print("total num of patches in tfrecords", self.tfrecords, ":  ",
              num_examples)

        # last samples
        # batch num
        num_batches = num_examples / self.batch_size
        print('batches per epoch: ', num_batches)

        if self.load(self.save_path):
            print('load success')
        else:
            print('load failed')
        batch_idx = 0
        current_epoch = 0
        batch_timings = []
        g_losses = []
        d_A_losses = []
        d_B_losses = []
        g_adv_losses = []
        g_l1_losses_BAB = []
        g_l1_losses_AB = []
        g_l1_losses_ABA = []
        g_l1_losses_BA = []

        try:
            while not coord.should_stop():
                start = timeit.default_timer()
                if counter % config.save_freq == 0:

                    for d_iter in range(self.disc_updates):
                        _d_opt, d_A_loss, d_B_loss = self.sess.run(
                            [d_opt, self.d_A_losses[0], self.d_B_losses[0]],
                            feed_dict={
                                self.is_valid: False,
                                self.is_train: True,
                                self.is_mismatch: True
                            })
                        _d_opt, d_A_loss, d_B_loss = self.sess.run(
                            [d_opt, self.d_A_losses[0], self.d_B_losses[0]],
                            feed_dict={
                                self.is_valid: False,
                                self.is_train: True,
                                self.is_mismatch: False
                            })

                        #_d_sum, d_fake_loss, d_real_loss = self.sess.run(
                        #   [self.d_sum, self.d_fake_losses[0], self.d_real_losses[0]], feed_dict={self.is_valid: False})

                        if self.d_clip_weights:
                            self.sess.run(self.d_clip,
                                          feed_dict={
                                              self.is_valid: False,
                                              self.is_train: True
                                          })

                    #_g_opt, _g_sum, g_adv_loss, g_l2_loss = self.sess.run([g_opt, self.g_sum, self.g_adv_losses[0], self.g_l2_losses[0]], feed_dict={self.is_valid:False})
                    _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run(
                        [
                            g_opt, self.g_adv_losses[0], self.g_losses_AB[0],
                            self.g_losses_BA[0], self.g_l1_losses_ABA[0],
                            self.g_l1_losses_BAB[0]
                        ],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True,
                            self.is_mismatch: True
                        })
                    _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run(
                        [
                            g_opt, self.g_adv_losses[0], self.g_losses_AB[0],
                            self.g_losses_BA[0], self.g_l1_losses_ABA[0],
                            self.g_l1_losses_BAB[0]
                        ],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True,
                            self.is_mismatch: False
                        })
                    # _phase_opt, phase_loss = self.sess.run([phase_opt, self.phase_losses[0]], feed_dict={self.is_valid:False,self.is_train: True})

                else:
                    for d_iter in range(self.disc_updates):
                        _d_opt, d_A_loss, d_B_loss = self.sess.run(
                            [d_opt, self.d_A_losses[0], self.d_B_losses[0]],
                            feed_dict={
                                self.is_valid: False,
                                self.is_train: True,
                                self.is_mismatch: True
                            })
                        _d_opt, d_A_loss, d_B_loss = self.sess.run(
                            [d_opt, self.d_A_losses[0], self.d_B_losses[0]],
                            feed_dict={
                                self.is_valid: False,
                                self.is_train: True,
                                self.is_mismatch: False
                            })
                        #d_fake_loss, d_real_loss = self.sess.run(
                        #    [self.d_fake_losses[0], self.d_real_losses[0]], feed_dict={self.is_valid: False})
                        if self.d_clip_weights:
                            self.sess.run(self.d_clip,
                                          feed_dict={
                                              self.is_valid: False,
                                              self.is_train: True
                                          })
                    _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run(
                        [
                            g_opt, self.g_adv_losses[0], self.g_losses_AB[0],
                            self.g_losses_BA[0], self.g_l1_losses_ABA[0],
                            self.g_l1_losses_BAB[0]
                        ],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True,
                            self.is_mismatch: True
                        })
                    _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run(
                        [
                            g_opt, self.g_adv_losses[0], self.g_losses_AB[0],
                            self.g_losses_BA[0], self.g_l1_losses_ABA[0],
                            self.g_l1_losses_BAB[0]
                        ],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True,
                            self.is_mismatch: False
                        })
                    # _phase_opt, phase_loss = self.sess.run([phase_opt, self.phase_losses[0]], feed_dict={self.is_valid:False,self.is_train: True})

                end = timeit.default_timer()
                batch_timings.append(end - start)
                d_A_losses.append(d_A_loss)
                d_B_losses.append(d_B_loss)
                g_adv_losses.append(g_adv_loss)
                g_l1_losses_BAB.append(g_BAB_loss)  # clean - reverb - clean
                g_l1_losses_AB.append(g_AB_loss)  # reverb - clean
                g_l1_losses_ABA.append(g_ABA_loss)  # reverb - clean  - reverb
                g_l1_losses_BA.append(g_BA_loss)  # clean - reverb

                print(
                    '{}/{} (epoch {}), d_A_loss = {:.5f}, '
                    'd_B_loss = {:.5f}, '  #d_nfk_loss = {:.5f}, '
                    'g_adv_loss = {:.5f}, g_AB_loss = {:.5f}, g_BAB_loss = {:.5f}, '
                    'g_BA_loss = {:.5f}, g_ABA_loss = {:.5f}, '
                    ' time/batch = {:.5f}, '
                    'mtime/batch = {:.5f}'.format(
                        counter, config.epoch * num_batches, current_epoch,
                        d_A_loss, d_B_loss, g_adv_loss, g_AB_loss, g_BAB_loss,
                        g_BA_loss, g_ABA_loss, end - start,
                        np.mean(batch_timings)))
                batch_idx += 1
                counter += 1

                if (counter) % 2000 == 0 and (counter) > 0:
                    self.save(config.save_path, counter)

                if (counter % config.save_freq == 0) or (counter == 1):
                    # self.writer.add_summary(_g_sum, counter)
                    # self.writer.add_summary(_d_sum, counter)
                    #feed_dict = {self.gt_high[0]:v_sample_high, self.gt_low[0]:v_sample_low, self.zz[0]:v_sample_z, self.is_valid:True}

                    s_A, s_B, s_reverb, s_gt, r_phase, f_phase = self.sess.run(
                        [
                            self.GG_A[0][0, :, :, :], self.GG_B[0][0, :, :, :],
                            self.gt_low[0][0, :, :, :],
                            self.gt_high[0][0, :, :, :],
                            self.ori_phase_[0][0, :, :, :],
                            self.rev_phase_[0][0, :, :, :]
                        ],
                        feed_dict={
                            self.is_valid: True,
                            self.is_train: False,
                            self.is_mismatch: False
                        })

                    if not os.path.exists(save_path + '/wav'):
                        os.makedirs(save_path + '/wav')
                    if not os.path.exists(save_path + '/txt'):
                        os.makedirs(save_path + '/txt')
                    if not os.path.exists(save_path + '/spec'):
                        os.makedirs(save_path + '/spec')

                    print(str(counter) + 'th finished')

                    x_AB = s_A
                    x_BA = s_B
                    x_reverb = s_reverb
                    x_gt = s_gt

                    Sre = self.get_spectrum(x_reverb).reshape(512, 128)
                    Sgt = self.get_spectrum(x_gt).reshape(512, 128)
                    SAB = self.get_spectrum(x_AB).reshape(512, 128)
                    SBA = self.get_spectrum(x_BA).reshape(512, 128)
                    S = np.concatenate((Sre, Sgt, SAB, SBA), axis=1)
                    fig = Figure(figsize=S.shape[::-1], dpi=1, frameon=False)
                    canvas = FigureCanvas(fig)
                    fig.figimage(S, cmap='jet')
                    fig.savefig(save_path + '/spec/' + 'valid_batch_index' +
                                str(counter) + '-th_pr.png')

                    x_pr = librosa.istft(self.inv_magphase(s_A, f_phase))
                    librosa.output.write_wav(
                        save_path + '/wav/' + str(counter) +
                        '_AB(dereverb).wav', x_pr, 16000)
                    x_pr = librosa.istft(self.inv_magphase(s_B, r_phase))
                    librosa.output.write_wav(
                        save_path + '/wav/' + str(counter) + '_BA(reverb).wav',
                        x_pr, 16000)
                    x_lr = librosa.istft(self.inv_magphase(s_reverb, f_phase))
                    librosa.output.write_wav(
                        save_path + '/wav/' + str(counter) + '_reverb.wav',
                        x_lr, 16000)
                    x_hr = librosa.istft(self.inv_magphase(s_gt, r_phase))
                    librosa.output.write_wav(
                        save_path + '/wav/' + str(counter) + '_orig.wav', x_hr,
                        16000)

                    s_AB, s_BA, s_reverb, s_gt = self.sess.run(
                        [
                            self.GG_A[0][0, :, :, :], self.GG_B[0][0, :, :, :],
                            self.gt_low[0][0, :, :, :],
                            self.gt_high[0][0, :, :, :]
                        ],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True,
                            self.is_mismatch: False
                        })

                    x_AB = s_AB
                    x_BA = s_BA
                    x_reverb = s_reverb
                    x_gt = s_gt

                    Sre = self.get_spectrum(x_reverb).reshape(512, 128)
                    Sgt = self.get_spectrum(x_gt).reshape(512, 128)
                    SAB = self.get_spectrum(x_AB).reshape(512, 128)
                    SBA = self.get_spectrum(x_BA).reshape(512, 128)

                    S = np.concatenate((Sre, Sgt, SAB, SBA), axis=1)
                    fig = Figure(figsize=S.shape[::-1], dpi=1, frameon=False)
                    canvas = FigureCanvas(fig)
                    fig.figimage(S, cmap='jet')
                    fig.savefig(save_path + '/spec/' + 'train_batch_index' +
                                str(counter) + '-th_pr.png')

                    #np.savetxt(os.path.join(save_path, '/txt/d_real_losses.txt'), d_real_losses)
                    #np.savetxt(os.path.join(save_path, '/txt/d_fake_losses.txt'), d_fake_losses)
                    #np.savetxt(os.path.join(save_path, '/txt/g_adv_losses.txt'), g_adv_losses)
                    #np.savetxt(os.path.join(save_path, '/txt/g_l2_losses.txt'), g_l2_losses)

                if batch_idx >= num_batches:
                    current_epoch += 1
                    #reset batch idx
                    batch_idx = 0

                if current_epoch >= config.epoch:
                    print(str(self.epoch), ': epoch limit')
                    print('saving last model at iteration', str(counter))
                    self.save(config.save_path, counter)
                    # self.writer.add_summary(_g_sum, counter)
                    # self.writer.add_summary(_d_sum, counter)
                    break

        except tf.errors.InternalError:
            print('InternalError')
            pass

        except tf.errors.OutOfRangeError:
            print('done training')
            pass
        finally:
            coord.request_stop()
        coord.join(threads)
示例#3
0
文件: infer_spec.py 项目: sjlee7/CBMC
    def infer(self, sess):
        # tf.initialize_all_variables().run()
        print('initializing...opt')
        opt = self.opt
        try:
            init = tf.global_variables_initializer()
            sess.run(init)
        except AttributeError:
            init = tf.intializer_all_varialble()
            sess.run(init)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        counter = 0
        # count of num of samples
        num_examples = 0
        for record in tf.python_io.tf_record_iterator(self.tfrecords_test):
            num_examples += 1
        print("total num of patches in tfrecords", self.tfrecords, ":  ",
              num_examples)

        num_batches = num_examples / self.batch_size
        print('batches per epoch: ', num_batches)
        batch_idx = 0
        current_epoch = 0
        batch_timings = []
        test_pred_rat = []
        test_gt_rat = []
        test_sid = []

        if self.load(self.save_path):
            print('load success')
        else:
            print('load failed')

        try:
            while not coord.should_stop():

                output_, y_, sid = self.sess.run(
                    [self.pred_only_spec, self.rat, self.sid],
                    feed_dict={self.is_valid: True})
                test_pred_rat.append(output_)
                test_gt_rat.append(y_)
                test_sid.append(sid)

                if counter % 10 == 0:
                    print(counter)

                counter += 1

                # if current_epoch >= self.epoch:
                #   print (str(self.epoch),': epoch limit')
                #   print ('saving last model at iteration',str(counter))
                #   self.saver.save(self.sess, "./only_vggish_final")
                #   break

        except tf.errors.OutOfRangeError:
            np.save('./test_pred_rat_spec.npy', test_pred_rat)
            np.save('./test_gt_rat_spec.npy', test_gt_rat)
            np.save('./test_sid.npy', test_sid)

            print('done training')
            pass
        finally:
            coord.request_stop()
        coord.join(threads)
示例#4
0
    def train(self, sess, config):
        """ Training the GAN """
        print('initializing...opt')
        d_opt = self.d_opt
        g_opt = self.g_opt

        try:
            init = tf.global_variables_initializer()
            sess.run(init)
        except AttributeError:
            init = tf.intializer_all_varialble()
            sess.run(init)

        print('initializing...var')
        g_summaries = [
            self.d_fake_summary, self.d_fake_loss_summary, self.g_loss_summary,
            self.g_l2_loss_summary, self.g_loss_adv_summary,
            self.generated_wav_summary, self.generated_audio_summary
        ]
        d_summaries = [
            self.d_loss_summary, self.d_real_summary, self.d_real_loss_summary,
            self.nonreverb_audio_summary, self.nonreverb_wav_summary
        ]

        if hasattr(self, 'alpha_summ'):
            g_summaries += self.alpha_summ
        self.g_sum = tf.summary.merge(g_summaries)
        self.d_sum = tf.summary.merge(d_summaries)

        if not os.path.exists(os.path.join(config.save_path, 'train')):
            os.makedirs(os.path.join(config.save_path, 'train'))

        self.writer = tf.summary.FileWriter(
            os.path.join(config.save_path, 'train'), self.sess.graph)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        sample_reverb, sample_nonreverb, sample_z = self.sess.run(
            [self.gt_reverb[0], self.gt_nonreverb[0], self.zz[0]],
            feed_dict={self.is_valid: False})
        v_sample_reverb, v_sample_nonreverb, v_sample_z = self.sess.run(
            [self.gt_reverb[0], self.gt_nonreverb[0], self.zz[0]],
            feed_dict={
                self.is_valid: True,
                self.is_train: False
            })

        print('sample reverb shape: ', sample_reverb.shape)
        print('sample nonreverb shape: ', sample_nonreverb.shape)
        print('sample z shape: ', sample_z.shape)

        save_path = config.save_path
        counter = 0
        # count of num of samples
        num_examples = 0
        for record in tf.python_io.tf_record_iterator(self.tfrecords):
            num_examples += 1
        print("total num of patches in tfrecords", self.tfrecords, ":  ",
              num_examples)

        # last samples
        # batch num
        num_batches = num_examples / self.batch_size
        print('batches per epoch: ', num_batches)

        if self.load(self.save_path):
            print('load success')
        else:
            print('load failed')
        batch_idx = 0
        current_epoch = 0
        batch_timings = []
        g_losses = []
        d_fake_losses = []
        d_real_losses = []
        g_adv_losses = []
        g_l2_losses = []

        try:
            while not coord.should_stop():
                start = timeit.default_timer()
                if counter % config.save_freq == 0:

                    for d_iter in range(self.disc_updates):
                        _d_opt, _d_sum, d_fake_loss, d_real_loss = self.sess.run(
                            [
                                d_opt, self.d_sum, self.d_fake_losses[0],
                                self.d_real_losses[0]
                            ],
                            feed_dict={
                                self.is_valid: False,
                                self.is_train: True
                            })
                        #_d_sum, d_fake_loss, d_real_loss = self.sess.run(
                        #   [self.d_sum, self.d_fake_losses[0], self.d_real_losses[0]], feed_dict={self.is_valid: False})

                        if self.d_clip_weights:
                            self.sess.run(self.d_clip,
                                          feed_dict={
                                              self.is_valid: False,
                                              self.is_train: True
                                          })

                    #_g_opt, _g_sum, g_adv_loss, g_l2_loss = self.sess.run([g_opt, self.g_sum, self.g_adv_losses[0], self.g_l2_losses[0]], feed_dict={self.is_valid:False})
                    _g_opt, _g_sum, g_adv_loss, g_l2_loss = self.sess.run(
                        [
                            g_opt, self.g_sum, self.g_adv_losses[0],
                            self.g_l2_losses[0]
                        ],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True
                        })

                else:
                    for d_iter in range(self.disc_updates):
                        _d_opt, d_fake_loss, d_real_loss = self.sess.run(
                            [
                                d_opt, self.d_fake_losses[0],
                                self.d_real_losses[0]
                            ],
                            feed_dict={
                                self.is_valid: False,
                                self.is_train: True
                            })
                        #d_fake_loss, d_real_loss = self.sess.run(
                        #    [self.d_fake_losses[0], self.d_real_losses[0]], feed_dict={self.is_valid: False})
                        if self.d_clip_weights:
                            self.sess.run(self.d_clip,
                                          feed_dict={
                                              self.is_valid: False,
                                              self.is_train: True
                                          })
                    #_g_opt, g_adv_loss, g_l2_loss = self.sess.run([g_opt, self.g_adv_losses[0], self.g_l2_losses[0]], feed_dict={self.is_valid:False})
                    _g_opt, g_adv_loss, g_l2_loss = self.sess.run(
                        [g_opt, self.g_adv_losses[0], self.g_l2_losses[0]],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True
                        })

                end = timeit.default_timer()
                batch_timings.append(end - start)
                d_fake_losses.append(d_fake_loss)
                d_real_losses.append(d_real_loss)
                g_adv_losses.append(g_adv_loss)
                g_l2_losses.append(g_l2_loss)
                print('{}/{} (epoch {}), d_rl_loss = {:.5f}, '
                      'd_fk_loss = {:.5f}, '  #d_nfk_loss = {:.5f}, '
                      'g_adv_loss = {:.5f}, g_l1_loss = {:.5f},'
                      ' time/batch = {:.5f}, '
                      'mtime/batch = {:.5f}'.format(counter,
                                                    config.epoch * num_batches,
                                                    current_epoch, d_real_loss,
                                                    d_fake_loss, g_adv_loss,
                                                    g_l2_loss, end - start,
                                                    np.mean(batch_timings)))
                batch_idx += 1
                counter += 1

                if (counter) % 2000 == 0 and (counter) > 0:
                    self.save(config.save_path, counter)
                if (counter) % config.save_freq == 0:
                    self.writer.add_summary(_g_sum, counter)
                    self.writer.add_summary(_d_sum, counter)
                    #feed_dict = {self.gt_nonreverb[0]:v_sample_nonreverb, self.gt_reverb[0]:v_sample_reverb, self.zz[0]:v_sample_z, self.is_valid:True}

                    canvas_w, s_reverb, s_nonreverb = self.sess.run(
                        [self.GG[0], self.gt_reverb[0], self.gt_nonreverb[0]],
                        feed_dict={
                            self.is_valid: True,
                            self.is_train: False
                        })

                    if not os.path.exists(save_path + '/wav'):
                        os.makedirs(save_path + '/wav')
                    if not os.path.exists(save_path + '/txt'):
                        os.makedirs(save_path + '/txt')
                    if not os.path.exists(save_path + '/spec'):
                        os.makedirs(save_path + '/spec')

                    print('max :', np.max(canvas_w[0]), 'min :',
                          np.min(canvas_w[0]))

                    if self.pre_emphasis > 0:
                        canvas_w = self.de_emphasis(canvas_w,
                                                    self.pre_emphasis)
                        s_reverb = self.de_emphasis(s_reverb,
                                                    self.pre_emphasis)
                        s_nonreverb = self.de_emphasis(s_nonreverb,
                                                       self.pre_emphasis)

                    x_pr = canvas_w.flatten()
                    x_pr = x_pr[:int(len(x_pr) / 8)]
                    x_lr = s_reverb.flatten()[:len(x_pr)]
                    x_hr = s_nonreverb.flatten()[:len(x_pr)]

                    Sl = self.get_spectrum(x_lr, n_fft=2048)
                    Sh = self.get_spectrum(x_hr, n_fft=2048)
                    Sp = self.get_spectrum(x_pr, n_fft=2048)

                    S = np.concatenate(
                        (Sl.reshape(Sh.shape[0], Sh.shape[1]), Sh, Sp), axis=1)
                    fig = Figure(figsize=S.shape[::-1], dpi=1, frameon=False)
                    canvas = FigureCanvas(fig)
                    fig.figimage(S, cmap='jet')
                    fig.savefig(save_path + '/spec/' + 'valid_batch_index' +
                                str(counter) + '-th_pr.png')

                    librosa.output.write_wav(
                        save_path + '/wav/' + str(counter) + '_dereverb.wav',
                        x_pr, 16000)

                    librosa.output.write_wav(
                        save_path + '/wav/' + str(counter) + '_reverb.wav',
                        x_lr, 16000)

                    librosa.output.write_wav(
                        save_path + '/wav/' + str(counter) + '_orig.wav', x_hr,
                        16000)

                    canvas_w, s_reverb, s_nonreverb = self.sess.run(
                        [self.GG[0], self.gt_reverb[0], self.gt_nonreverb[0]],
                        feed_dict={
                            self.is_valid: False,
                            self.is_train: True
                        })

                    print('max :', np.max(canvas_w[0]), 'min :',
                          np.min(canvas_w[0]))

                    x_pr = canvas_w.flatten()
                    x_pr = x_pr[:int(len(x_pr) / 8)]
                    x_lr = s_reverb.flatten()[:len(x_pr)]
                    x_hr = s_nonreverb.flatten()[:len(x_pr)]

                    Sl = self.get_spectrum(x_lr, n_fft=2048)
                    Sh = self.get_spectrum(x_hr, n_fft=2048)
                    Sp = self.get_spectrum(x_pr, n_fft=2048)

                    S = np.concatenate(
                        (Sl.reshape(Sh.shape[0], Sh.shape[1]), Sh, Sp), axis=1)
                    fig = Figure(figsize=S.shape[::-1], dpi=1, frameon=False)
                    canvas = FigureCanvas(fig)
                    fig.figimage(S, cmap='jet')
                    fig.savefig(save_path + '/spec/' + 'train_batch_index' +
                                str(counter) + '-th_pr.png')

                    #np.savetxt(os.path.join(save_path, '/txt/d_real_losses.txt'), d_real_losses)
                    #np.savetxt(os.path.join(save_path, '/txt/d_fake_losses.txt'), d_fake_losses)
                    #np.savetxt(os.path.join(save_path, '/txt/g_adv_losses.txt'), g_adv_losses)
                    #np.savetxt(os.path.join(save_path, '/txt/g_l2_losses.txt'), g_l2_losses)

                if batch_idx >= num_batches:
                    current_epoch += 1
                    #reset batch idx
                    batch_idx = 0

                if current_epoch >= config.epoch:
                    print(str(self.epoch), ': epoch limit')
                    print('saving last model at iteration', str(counter))
                    self.save(config.save_path, counter)
                    self.writer.add_summary(_g_sum, counter)
                    self.writer.add_summary(_d_sum, counter)
                    break

        except tf.errors.OutOfRangeError:
            print('done training')
            pass
        finally:
            coord.request_stop()
        coord.join(threads)