示例#1
0
  def train(self):
    val_hist = {}
    pr = self.pr
    i = 0
    while True:
      step, lr = self.get_step()

      if i > 0 and step % pr.check_iters == 0:
        self.checkpoint_fast()
      if i > 0 and step % pr.slow_check_iters == 0:
        self.checkpoint_slow()

      if step >= pr.train_iters:
        break

      start = ut.now_sec()
      if step % 20 == 0:
        ret = self.sess.run([self.train_op, self.merged_summary] + self.loss.get_losses())
        self.sum_writer.add_summary(ret[1], step)
        loss_vals = ret[2:]
      else:
        loss_vals = self.sess.run([self.train_op] + self.loss.get_losses())[1:]
      ts = moving_avg('time', ut.now_sec() - start, val_hist)

      out = []
      for name, val in zip(self.loss.get_loss_names(), loss_vals):
        out.append('%s: %.3f' % (name, moving_avg(name, val, val_hist)))
      out = ' '.join(out)

      if step < 10 or step % pr.print_iters == 0:
        print 'Iteration %d, lr = %.0e, %s, time: %.3f' % (step, lr, out, ts)
      i += 1
示例#2
0
    def train(self):
        val_hist = {}
        pr = self.pr
        num_steps = 0
        while True:
            step, lr = self.get_step()
            first = (num_steps == 0)
            if not first and step % pr.check_iters == 0:
                self.checkpoint_fast()
            if not first and step % pr.slow_check_iters == 0:
                self.checkpoint_slow()

            if step >= pr.train_iters:
                break

            if pr.show_iters is not None and (first
                                              or step % pr.show_iters == 0):
                self.sess.run(self.show_train)

            loss_ops = self.gen_loss.get_losses(
            ) + self.discrim_loss.get_losses()
            loss_names = self.gen_loss.get_loss_names(
            ) + self.discrim_loss.get_loss_names()
            start = ut.now_sec()
            if pr.summary_iters is not None and step % pr.summary_iters == 0:
                ret = self.sess.run([self.train_op, self.merged_summary] +
                                    loss_ops)
                self.sum_writer.add_summary(ret[1], step)
                loss_vals = ret[2:]
            elif self.profile and (pr.profile_iters is not None and not first
                                   and step % pr.profile_iters == 0):
                run_meta = tf.RunMetadata()
                loss_vals = self.sess.run(
                    [self.train_op] + loss_ops,
                    options=tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE),
                    run_metadata=run_meta)[1:]
                opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
                self.profiler.add_step(step, run_meta)
                self.profiler.profile_operations(options=opts)
                self.profiler.profile_graph(options=opts)
                self.profiler.advise(self.sess.graph)
            else:
                loss_vals = self.sess.run([self.train_op] + loss_ops)[1:]

            if step % 100 == 0:
                gc.collect()

            ts = moving_avg('time', ut.now_sec() - start, val_hist)
            out = []
            for name, val in zip(loss_names, loss_vals):
                out.append('%s: %.3f' %
                           (name, moving_avg(name, val, val_hist)))
            out = ' '.join(out)

            if step < 10 or step % pr.print_iters == 0:
                print 'Iteration %d, lr = %.0e, %s, time: %.3f' % (step, lr,
                                                                   out, ts)

            num_steps += 1