def train(self): val_hist = {} pr = self.pr i = 0 while True: step, lr = self.get_step() if i > 0 and step % pr.check_iters == 0: self.checkpoint_fast() if i > 0 and step % pr.slow_check_iters == 0: self.checkpoint_slow() if step >= pr.train_iters: break start = ut.now_sec() if step % 20 == 0: ret = self.sess.run([self.train_op, self.merged_summary] + self.loss.get_losses()) self.sum_writer.add_summary(ret[1], step) loss_vals = ret[2:] else: loss_vals = self.sess.run([self.train_op] + self.loss.get_losses())[1:] ts = moving_avg('time', ut.now_sec() - start, val_hist) out = [] for name, val in zip(self.loss.get_loss_names(), loss_vals): out.append('%s: %.3f' % (name, moving_avg(name, val, val_hist))) out = ' '.join(out) if step < 10 or step % pr.print_iters == 0: print 'Iteration %d, lr = %.0e, %s, time: %.3f' % (step, lr, out, ts) i += 1
def train(self): val_hist = {} pr = self.pr num_steps = 0 while True: step, lr = self.get_step() first = (num_steps == 0) if not first and step % pr.check_iters == 0: self.checkpoint_fast() if not first and step % pr.slow_check_iters == 0: self.checkpoint_slow() if step >= pr.train_iters: break if pr.show_iters is not None and (first or step % pr.show_iters == 0): self.sess.run(self.show_train) loss_ops = self.gen_loss.get_losses( ) + self.discrim_loss.get_losses() loss_names = self.gen_loss.get_loss_names( ) + self.discrim_loss.get_loss_names() start = ut.now_sec() if pr.summary_iters is not None and step % pr.summary_iters == 0: ret = self.sess.run([self.train_op, self.merged_summary] + loss_ops) self.sum_writer.add_summary(ret[1], step) loss_vals = ret[2:] elif self.profile and (pr.profile_iters is not None and not first and step % pr.profile_iters == 0): run_meta = tf.RunMetadata() loss_vals = self.sess.run( [self.train_op] + loss_ops, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_meta)[1:] opts = tf.profiler.ProfileOptionBuilder.time_and_memory() self.profiler.add_step(step, run_meta) self.profiler.profile_operations(options=opts) self.profiler.profile_graph(options=opts) self.profiler.advise(self.sess.graph) else: loss_vals = self.sess.run([self.train_op] + loss_ops)[1:] if step % 100 == 0: gc.collect() ts = moving_avg('time', ut.now_sec() - start, val_hist) out = [] for name, val in zip(loss_names, loss_vals): out.append('%s: %.3f' % (name, moving_avg(name, val, val_hist))) out = ' '.join(out) if step < 10 or step % pr.print_iters == 0: print 'Iteration %d, lr = %.0e, %s, time: %.3f' % (step, lr, out, ts) num_steps += 1