def train(self, sess): # tf.initialize_all_variables().run() print('initializing...opt') d_opt = self.d_opt g_opt = self.g_opt try: init = tf.global_variables_initializer() sess.run(init) except AttributeError: init = tf.intializer_all_varialble() sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) sample_spec, sample_rat, sample_sid = self.sess.run( [self.spec, self.rat, self.sid], feed_dict={self.is_valid: False}) print('sample spec shape: ', sample_spec.shape) print('sample rat shape: ', sample_rat.shape) counter = 0 # count of num of samples num_examples = 0 for record in tf.python_io.tf_record_iterator(self.tfrecords): num_examples += 1 print("total num of patches in tfrecords", self.tfrecords, ": ", num_examples) num_batches = num_examples / self.batch_size print('batches per epoch: ', num_batches) batch_idx = 0 current_epoch = 0 batch_timings = [] d_losses = [] d_fake_losses = [] d_real_losses = [] g_losses = [] g_adv_losses = [] g_l1_losses = [] v_l1_losses = [] try: while not coord.should_stop(): start = timeit.default_timer() for i in range(2): _d_opt, d_fake_loss, d_real_loss = self.sess.run( [d_opt, self.d_fake_losses[0], self.d_real_losses[0]], feed_dict={self.is_valid: False}) _g_opt, g_adv_loss, g_l1_loss = self.sess.run( [g_opt, self.g_adv_losses[0], self.g_l1_losses[0]], feed_dict={self.is_valid: False}) v_l1_loss = self.sess.run(self.g_l1_losses[0], feed_dict={self.is_valid: True}) end = timeit.default_timer() batch_timings.append(end - start) d_fake_losses.append(d_fake_loss) d_real_losses.append(d_real_loss) g_adv_losses.append(g_adv_loss) g_l1_losses.append(g_l1_loss) v_l1_losses.append(v_l1_loss) print('{}/{} (epoch {}), ' 'd_rl_loss = {:.5f}, ' 'd_fk_loss = {:.5f}, ' 'g_adv_loss = {:.5f}, ' 'g_l1_loss = {:.5f}, ' 'valid_l1_loss = {:.5f}, ' ' time/batch = {:.5f}, ' 'mtime/batch = {:.5f}'.format(counter, self.epoch * num_batches, current_epoch, d_real_loss, d_fake_loss, g_adv_loss, g_l1_loss, v_l1_loss, end - start, np.mean(batch_timings))) if counter % 10 == 0: pred, y = self.sess.run([self.fake_input, self.rat], feed_dict={self.is_valid: False}) print('pred', pred) print('y', y) batch_idx += 1 counter += 1 if (counter) % 2000 == 0 and (counter) > 0: self.saver.save(self.sess, "./only_spec_gan_" + str(counter) + "_th") if batch_idx >= num_batches: current_epoch += 1 #reset batch idx batch_idx = 0 if current_epoch >= self.epoch: print(str(self.epoch), ': epoch limit') print('saving last model at iteration', str(counter)) self.saver.save(self.sess, "./only_spec_gan_final") break except tf.errors.OutOfRangeError: print('done training') pass finally: coord.request_stop() coord.join(threads)
def train(self, sess, config): """ Training the GAN """ print('initializing...opt') d_opt = self.d_opt g_opt = self.g_opt try: init = tf.global_variables_initializer() sess.run(init) except AttributeError: init = tf.intializer_all_varialble() sess.run(init) print('initializing...var') # g_summaries = [self.d_fake_summary, # self.d_fake_loss_summary, # self.g_loss_summary, # self.g_l2_loss_summary, # self.g_loss_adv_summary, # self.generated_wav_summary] # d_summaries = [self.d_loss_summary, self.d_real_summary, self.d_real_loss_summary, self.high_wav_summary] # if hasattr(self, 'alpha_summ'): # g_summaries += self.alpha_summ # self.g_sum = tf.summary.merge(g_summaries) # self.d_sum = tf.summary.merge(d_summaries) if not os.path.exists(os.path.join(config.save_path, 'train')): os.makedirs(os.path.join(config.save_path, 'train')) self.writer = tf.summary.FileWriter( os.path.join(config.save_path, 'train'), self.sess.graph) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) sample_low, sample_high, sample_z = self.sess.run( [self.gt_low[0], self.gt_high[0], self.zz_A[0]], feed_dict={ self.is_valid: False, self.is_train: True, self.is_mismatch: False }) v_sample_low, v_sample_high, v_sample_z = self.sess.run( [self.gt_low[0], self.gt_high[0], self.zz_A[0]], feed_dict={ self.is_valid: True, self.is_train: False, self.is_mismatch: False }) print('sample low shape: ', sample_low.shape) print('sample high shape: ', sample_high.shape) print('sample z shape: ', sample_z.shape) save_path = config.save_path counter = 0 # count of num of samples num_examples = 0 for record in tf.python_io.tf_record_iterator(self.tfrecords): num_examples += 1 print("total num of patches in tfrecords", self.tfrecords, ": ", num_examples) # last samples # batch num num_batches = num_examples / self.batch_size print('batches per epoch: ', num_batches) if self.load(self.save_path): print('load success') else: print('load failed') batch_idx = 0 current_epoch = 0 batch_timings = [] g_losses = [] d_A_losses = [] d_B_losses = [] g_adv_losses = [] g_l1_losses_BAB = [] g_l1_losses_AB = [] g_l1_losses_ABA = [] g_l1_losses_BA = [] try: while not coord.should_stop(): start = timeit.default_timer() if counter % config.save_freq == 0: for d_iter in range(self.disc_updates): _d_opt, d_A_loss, d_B_loss = self.sess.run( [d_opt, self.d_A_losses[0], self.d_B_losses[0]], feed_dict={ self.is_valid: False, self.is_train: True, self.is_mismatch: True }) _d_opt, d_A_loss, d_B_loss = self.sess.run( [d_opt, self.d_A_losses[0], self.d_B_losses[0]], feed_dict={ self.is_valid: False, self.is_train: True, self.is_mismatch: False }) #_d_sum, d_fake_loss, d_real_loss = self.sess.run( # [self.d_sum, self.d_fake_losses[0], self.d_real_losses[0]], feed_dict={self.is_valid: False}) if self.d_clip_weights: self.sess.run(self.d_clip, feed_dict={ self.is_valid: False, self.is_train: True }) #_g_opt, _g_sum, g_adv_loss, g_l2_loss = self.sess.run([g_opt, self.g_sum, self.g_adv_losses[0], self.g_l2_losses[0]], feed_dict={self.is_valid:False}) _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run( [ g_opt, self.g_adv_losses[0], self.g_losses_AB[0], self.g_losses_BA[0], self.g_l1_losses_ABA[0], self.g_l1_losses_BAB[0] ], feed_dict={ self.is_valid: False, self.is_train: True, self.is_mismatch: True }) _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run( [ g_opt, self.g_adv_losses[0], self.g_losses_AB[0], self.g_losses_BA[0], self.g_l1_losses_ABA[0], self.g_l1_losses_BAB[0] ], feed_dict={ self.is_valid: False, self.is_train: True, self.is_mismatch: False }) # _phase_opt, phase_loss = self.sess.run([phase_opt, self.phase_losses[0]], feed_dict={self.is_valid:False,self.is_train: True}) else: for d_iter in range(self.disc_updates): _d_opt, d_A_loss, d_B_loss = self.sess.run( [d_opt, self.d_A_losses[0], self.d_B_losses[0]], feed_dict={ self.is_valid: False, self.is_train: True, self.is_mismatch: True }) _d_opt, d_A_loss, d_B_loss = self.sess.run( [d_opt, self.d_A_losses[0], self.d_B_losses[0]], feed_dict={ self.is_valid: False, self.is_train: True, self.is_mismatch: False }) #d_fake_loss, d_real_loss = self.sess.run( # [self.d_fake_losses[0], self.d_real_losses[0]], feed_dict={self.is_valid: False}) if self.d_clip_weights: self.sess.run(self.d_clip, feed_dict={ self.is_valid: False, self.is_train: True }) _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run( [ g_opt, self.g_adv_losses[0], self.g_losses_AB[0], self.g_losses_BA[0], self.g_l1_losses_ABA[0], self.g_l1_losses_BAB[0] ], feed_dict={ self.is_valid: False, self.is_train: True, self.is_mismatch: True }) _g_opt, g_adv_loss, g_AB_loss, g_BA_loss, g_ABA_loss, g_BAB_loss = self.sess.run( [ g_opt, self.g_adv_losses[0], self.g_losses_AB[0], self.g_losses_BA[0], self.g_l1_losses_ABA[0], self.g_l1_losses_BAB[0] ], feed_dict={ self.is_valid: False, self.is_train: True, self.is_mismatch: False }) # _phase_opt, phase_loss = self.sess.run([phase_opt, self.phase_losses[0]], feed_dict={self.is_valid:False,self.is_train: True}) end = timeit.default_timer() batch_timings.append(end - start) d_A_losses.append(d_A_loss) d_B_losses.append(d_B_loss) g_adv_losses.append(g_adv_loss) g_l1_losses_BAB.append(g_BAB_loss) # clean - reverb - clean g_l1_losses_AB.append(g_AB_loss) # reverb - clean g_l1_losses_ABA.append(g_ABA_loss) # reverb - clean - reverb g_l1_losses_BA.append(g_BA_loss) # clean - reverb print( '{}/{} (epoch {}), d_A_loss = {:.5f}, ' 'd_B_loss = {:.5f}, ' #d_nfk_loss = {:.5f}, ' 'g_adv_loss = {:.5f}, g_AB_loss = {:.5f}, g_BAB_loss = {:.5f}, ' 'g_BA_loss = {:.5f}, g_ABA_loss = {:.5f}, ' ' time/batch = {:.5f}, ' 'mtime/batch = {:.5f}'.format( counter, config.epoch * num_batches, current_epoch, d_A_loss, d_B_loss, g_adv_loss, g_AB_loss, g_BAB_loss, g_BA_loss, g_ABA_loss, end - start, np.mean(batch_timings))) batch_idx += 1 counter += 1 if (counter) % 2000 == 0 and (counter) > 0: self.save(config.save_path, counter) if (counter % config.save_freq == 0) or (counter == 1): # self.writer.add_summary(_g_sum, counter) # self.writer.add_summary(_d_sum, counter) #feed_dict = {self.gt_high[0]:v_sample_high, self.gt_low[0]:v_sample_low, self.zz[0]:v_sample_z, self.is_valid:True} s_A, s_B, s_reverb, s_gt, r_phase, f_phase = self.sess.run( [ self.GG_A[0][0, :, :, :], self.GG_B[0][0, :, :, :], self.gt_low[0][0, :, :, :], self.gt_high[0][0, :, :, :], self.ori_phase_[0][0, :, :, :], self.rev_phase_[0][0, :, :, :] ], feed_dict={ self.is_valid: True, self.is_train: False, self.is_mismatch: False }) if not os.path.exists(save_path + '/wav'): os.makedirs(save_path + '/wav') if not os.path.exists(save_path + '/txt'): os.makedirs(save_path + '/txt') if not os.path.exists(save_path + '/spec'): os.makedirs(save_path + '/spec') print(str(counter) + 'th finished') x_AB = s_A x_BA = s_B x_reverb = s_reverb x_gt = s_gt Sre = self.get_spectrum(x_reverb).reshape(512, 128) Sgt = self.get_spectrum(x_gt).reshape(512, 128) SAB = self.get_spectrum(x_AB).reshape(512, 128) SBA = self.get_spectrum(x_BA).reshape(512, 128) S = np.concatenate((Sre, Sgt, SAB, SBA), axis=1) fig = Figure(figsize=S.shape[::-1], dpi=1, frameon=False) canvas = FigureCanvas(fig) fig.figimage(S, cmap='jet') fig.savefig(save_path + '/spec/' + 'valid_batch_index' + str(counter) + '-th_pr.png') x_pr = librosa.istft(self.inv_magphase(s_A, f_phase)) librosa.output.write_wav( save_path + '/wav/' + str(counter) + '_AB(dereverb).wav', x_pr, 16000) x_pr = librosa.istft(self.inv_magphase(s_B, r_phase)) librosa.output.write_wav( save_path + '/wav/' + str(counter) + '_BA(reverb).wav', x_pr, 16000) x_lr = librosa.istft(self.inv_magphase(s_reverb, f_phase)) librosa.output.write_wav( save_path + '/wav/' + str(counter) + '_reverb.wav', x_lr, 16000) x_hr = librosa.istft(self.inv_magphase(s_gt, r_phase)) librosa.output.write_wav( save_path + '/wav/' + str(counter) + '_orig.wav', x_hr, 16000) s_AB, s_BA, s_reverb, s_gt = self.sess.run( [ self.GG_A[0][0, :, :, :], self.GG_B[0][0, :, :, :], self.gt_low[0][0, :, :, :], self.gt_high[0][0, :, :, :] ], feed_dict={ self.is_valid: False, self.is_train: True, self.is_mismatch: False }) x_AB = s_AB x_BA = s_BA x_reverb = s_reverb x_gt = s_gt Sre = self.get_spectrum(x_reverb).reshape(512, 128) Sgt = self.get_spectrum(x_gt).reshape(512, 128) SAB = self.get_spectrum(x_AB).reshape(512, 128) SBA = self.get_spectrum(x_BA).reshape(512, 128) S = np.concatenate((Sre, Sgt, SAB, SBA), axis=1) fig = Figure(figsize=S.shape[::-1], dpi=1, frameon=False) canvas = FigureCanvas(fig) fig.figimage(S, cmap='jet') fig.savefig(save_path + '/spec/' + 'train_batch_index' + str(counter) + '-th_pr.png') #np.savetxt(os.path.join(save_path, '/txt/d_real_losses.txt'), d_real_losses) #np.savetxt(os.path.join(save_path, '/txt/d_fake_losses.txt'), d_fake_losses) #np.savetxt(os.path.join(save_path, '/txt/g_adv_losses.txt'), g_adv_losses) #np.savetxt(os.path.join(save_path, '/txt/g_l2_losses.txt'), g_l2_losses) if batch_idx >= num_batches: current_epoch += 1 #reset batch idx batch_idx = 0 if current_epoch >= config.epoch: print(str(self.epoch), ': epoch limit') print('saving last model at iteration', str(counter)) self.save(config.save_path, counter) # self.writer.add_summary(_g_sum, counter) # self.writer.add_summary(_d_sum, counter) break except tf.errors.InternalError: print('InternalError') pass except tf.errors.OutOfRangeError: print('done training') pass finally: coord.request_stop() coord.join(threads)
def infer(self, sess): # tf.initialize_all_variables().run() print('initializing...opt') opt = self.opt try: init = tf.global_variables_initializer() sess.run(init) except AttributeError: init = tf.intializer_all_varialble() sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) counter = 0 # count of num of samples num_examples = 0 for record in tf.python_io.tf_record_iterator(self.tfrecords_test): num_examples += 1 print("total num of patches in tfrecords", self.tfrecords, ": ", num_examples) num_batches = num_examples / self.batch_size print('batches per epoch: ', num_batches) batch_idx = 0 current_epoch = 0 batch_timings = [] test_pred_rat = [] test_gt_rat = [] test_sid = [] if self.load(self.save_path): print('load success') else: print('load failed') try: while not coord.should_stop(): output_, y_, sid = self.sess.run( [self.pred_only_spec, self.rat, self.sid], feed_dict={self.is_valid: True}) test_pred_rat.append(output_) test_gt_rat.append(y_) test_sid.append(sid) if counter % 10 == 0: print(counter) counter += 1 # if current_epoch >= self.epoch: # print (str(self.epoch),': epoch limit') # print ('saving last model at iteration',str(counter)) # self.saver.save(self.sess, "./only_vggish_final") # break except tf.errors.OutOfRangeError: np.save('./test_pred_rat_spec.npy', test_pred_rat) np.save('./test_gt_rat_spec.npy', test_gt_rat) np.save('./test_sid.npy', test_sid) print('done training') pass finally: coord.request_stop() coord.join(threads)
def train(self, sess, config): """ Training the GAN """ print('initializing...opt') d_opt = self.d_opt g_opt = self.g_opt try: init = tf.global_variables_initializer() sess.run(init) except AttributeError: init = tf.intializer_all_varialble() sess.run(init) print('initializing...var') g_summaries = [ self.d_fake_summary, self.d_fake_loss_summary, self.g_loss_summary, self.g_l2_loss_summary, self.g_loss_adv_summary, self.generated_wav_summary, self.generated_audio_summary ] d_summaries = [ self.d_loss_summary, self.d_real_summary, self.d_real_loss_summary, self.nonreverb_audio_summary, self.nonreverb_wav_summary ] if hasattr(self, 'alpha_summ'): g_summaries += self.alpha_summ self.g_sum = tf.summary.merge(g_summaries) self.d_sum = tf.summary.merge(d_summaries) if not os.path.exists(os.path.join(config.save_path, 'train')): os.makedirs(os.path.join(config.save_path, 'train')) self.writer = tf.summary.FileWriter( os.path.join(config.save_path, 'train'), self.sess.graph) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) sample_reverb, sample_nonreverb, sample_z = self.sess.run( [self.gt_reverb[0], self.gt_nonreverb[0], self.zz[0]], feed_dict={self.is_valid: False}) v_sample_reverb, v_sample_nonreverb, v_sample_z = self.sess.run( [self.gt_reverb[0], self.gt_nonreverb[0], self.zz[0]], feed_dict={ self.is_valid: True, self.is_train: False }) print('sample reverb shape: ', sample_reverb.shape) print('sample nonreverb shape: ', sample_nonreverb.shape) print('sample z shape: ', sample_z.shape) save_path = config.save_path counter = 0 # count of num of samples num_examples = 0 for record in tf.python_io.tf_record_iterator(self.tfrecords): num_examples += 1 print("total num of patches in tfrecords", self.tfrecords, ": ", num_examples) # last samples # batch num num_batches = num_examples / self.batch_size print('batches per epoch: ', num_batches) if self.load(self.save_path): print('load success') else: print('load failed') batch_idx = 0 current_epoch = 0 batch_timings = [] g_losses = [] d_fake_losses = [] d_real_losses = [] g_adv_losses = [] g_l2_losses = [] try: while not coord.should_stop(): start = timeit.default_timer() if counter % config.save_freq == 0: for d_iter in range(self.disc_updates): _d_opt, _d_sum, d_fake_loss, d_real_loss = self.sess.run( [ d_opt, self.d_sum, self.d_fake_losses[0], self.d_real_losses[0] ], feed_dict={ self.is_valid: False, self.is_train: True }) #_d_sum, d_fake_loss, d_real_loss = self.sess.run( # [self.d_sum, self.d_fake_losses[0], self.d_real_losses[0]], feed_dict={self.is_valid: False}) if self.d_clip_weights: self.sess.run(self.d_clip, feed_dict={ self.is_valid: False, self.is_train: True }) #_g_opt, _g_sum, g_adv_loss, g_l2_loss = self.sess.run([g_opt, self.g_sum, self.g_adv_losses[0], self.g_l2_losses[0]], feed_dict={self.is_valid:False}) _g_opt, _g_sum, g_adv_loss, g_l2_loss = self.sess.run( [ g_opt, self.g_sum, self.g_adv_losses[0], self.g_l2_losses[0] ], feed_dict={ self.is_valid: False, self.is_train: True }) else: for d_iter in range(self.disc_updates): _d_opt, d_fake_loss, d_real_loss = self.sess.run( [ d_opt, self.d_fake_losses[0], self.d_real_losses[0] ], feed_dict={ self.is_valid: False, self.is_train: True }) #d_fake_loss, d_real_loss = self.sess.run( # [self.d_fake_losses[0], self.d_real_losses[0]], feed_dict={self.is_valid: False}) if self.d_clip_weights: self.sess.run(self.d_clip, feed_dict={ self.is_valid: False, self.is_train: True }) #_g_opt, g_adv_loss, g_l2_loss = self.sess.run([g_opt, self.g_adv_losses[0], self.g_l2_losses[0]], feed_dict={self.is_valid:False}) _g_opt, g_adv_loss, g_l2_loss = self.sess.run( [g_opt, self.g_adv_losses[0], self.g_l2_losses[0]], feed_dict={ self.is_valid: False, self.is_train: True }) end = timeit.default_timer() batch_timings.append(end - start) d_fake_losses.append(d_fake_loss) d_real_losses.append(d_real_loss) g_adv_losses.append(g_adv_loss) g_l2_losses.append(g_l2_loss) print('{}/{} (epoch {}), d_rl_loss = {:.5f}, ' 'd_fk_loss = {:.5f}, ' #d_nfk_loss = {:.5f}, ' 'g_adv_loss = {:.5f}, g_l1_loss = {:.5f},' ' time/batch = {:.5f}, ' 'mtime/batch = {:.5f}'.format(counter, config.epoch * num_batches, current_epoch, d_real_loss, d_fake_loss, g_adv_loss, g_l2_loss, end - start, np.mean(batch_timings))) batch_idx += 1 counter += 1 if (counter) % 2000 == 0 and (counter) > 0: self.save(config.save_path, counter) if (counter) % config.save_freq == 0: self.writer.add_summary(_g_sum, counter) self.writer.add_summary(_d_sum, counter) #feed_dict = {self.gt_nonreverb[0]:v_sample_nonreverb, self.gt_reverb[0]:v_sample_reverb, self.zz[0]:v_sample_z, self.is_valid:True} canvas_w, s_reverb, s_nonreverb = self.sess.run( [self.GG[0], self.gt_reverb[0], self.gt_nonreverb[0]], feed_dict={ self.is_valid: True, self.is_train: False }) if not os.path.exists(save_path + '/wav'): os.makedirs(save_path + '/wav') if not os.path.exists(save_path + '/txt'): os.makedirs(save_path + '/txt') if not os.path.exists(save_path + '/spec'): os.makedirs(save_path + '/spec') print('max :', np.max(canvas_w[0]), 'min :', np.min(canvas_w[0])) if self.pre_emphasis > 0: canvas_w = self.de_emphasis(canvas_w, self.pre_emphasis) s_reverb = self.de_emphasis(s_reverb, self.pre_emphasis) s_nonreverb = self.de_emphasis(s_nonreverb, self.pre_emphasis) x_pr = canvas_w.flatten() x_pr = x_pr[:int(len(x_pr) / 8)] x_lr = s_reverb.flatten()[:len(x_pr)] x_hr = s_nonreverb.flatten()[:len(x_pr)] Sl = self.get_spectrum(x_lr, n_fft=2048) Sh = self.get_spectrum(x_hr, n_fft=2048) Sp = self.get_spectrum(x_pr, n_fft=2048) S = np.concatenate( (Sl.reshape(Sh.shape[0], Sh.shape[1]), Sh, Sp), axis=1) fig = Figure(figsize=S.shape[::-1], dpi=1, frameon=False) canvas = FigureCanvas(fig) fig.figimage(S, cmap='jet') fig.savefig(save_path + '/spec/' + 'valid_batch_index' + str(counter) + '-th_pr.png') librosa.output.write_wav( save_path + '/wav/' + str(counter) + '_dereverb.wav', x_pr, 16000) librosa.output.write_wav( save_path + '/wav/' + str(counter) + '_reverb.wav', x_lr, 16000) librosa.output.write_wav( save_path + '/wav/' + str(counter) + '_orig.wav', x_hr, 16000) canvas_w, s_reverb, s_nonreverb = self.sess.run( [self.GG[0], self.gt_reverb[0], self.gt_nonreverb[0]], feed_dict={ self.is_valid: False, self.is_train: True }) print('max :', np.max(canvas_w[0]), 'min :', np.min(canvas_w[0])) x_pr = canvas_w.flatten() x_pr = x_pr[:int(len(x_pr) / 8)] x_lr = s_reverb.flatten()[:len(x_pr)] x_hr = s_nonreverb.flatten()[:len(x_pr)] Sl = self.get_spectrum(x_lr, n_fft=2048) Sh = self.get_spectrum(x_hr, n_fft=2048) Sp = self.get_spectrum(x_pr, n_fft=2048) S = np.concatenate( (Sl.reshape(Sh.shape[0], Sh.shape[1]), Sh, Sp), axis=1) fig = Figure(figsize=S.shape[::-1], dpi=1, frameon=False) canvas = FigureCanvas(fig) fig.figimage(S, cmap='jet') fig.savefig(save_path + '/spec/' + 'train_batch_index' + str(counter) + '-th_pr.png') #np.savetxt(os.path.join(save_path, '/txt/d_real_losses.txt'), d_real_losses) #np.savetxt(os.path.join(save_path, '/txt/d_fake_losses.txt'), d_fake_losses) #np.savetxt(os.path.join(save_path, '/txt/g_adv_losses.txt'), g_adv_losses) #np.savetxt(os.path.join(save_path, '/txt/g_l2_losses.txt'), g_l2_losses) if batch_idx >= num_batches: current_epoch += 1 #reset batch idx batch_idx = 0 if current_epoch >= config.epoch: print(str(self.epoch), ': epoch limit') print('saving last model at iteration', str(counter)) self.save(config.save_path, counter) self.writer.add_summary(_g_sum, counter) self.writer.add_summary(_d_sum, counter) break except tf.errors.OutOfRangeError: print('done training') pass finally: coord.request_stop() coord.join(threads)