예제 #1
0
 def test_create_from_params(self):
   params = plan.plan_default_params()
   params.update({
       'mode': plan.Plan.mode_keys.EVAL,
       'truncate_examples': 3})
   p = plan.Plan.create_from_params(_setup_plan(
       compiler=block_compiler.Compiler.create(blocks.Scalar()),
       losses={'foo': tf.constant(42.0)},
       examples=xrange(5)), params)
   self.assertEqual(p.logdir, os.path.join('/tmp/', 'plan', 'run_0', 'eval'))
   self.assertEqual(p.logdir_restore,
                    os.path.join('/tmp/', 'plan', 'run_0', 'train'))
   self.assertEqual(p.rundir, os.path.join('/tmp/', 'plan', 'run_0'))
   self.assertEqual(p.plandir, os.path.join('/tmp/', 'plan'))
   self.assertEqual([0, 1, 2], list(p.examples))
   self.assertEqual(p.compute_summaries, True)
예제 #2
0
    def test_create_from_params(self):
        params = plan.plan_default_params()
        params.update({
            'mode': plan.Plan.mode_keys.TRAIN,
            'truncate_examples': 3,
            'num_multiprocess_processes': 4,
            'master': 'foo',
            'batches_per_epoch': 123
        })
        foo = tf.get_variable('foo', [], tf.float32,
                              tf.constant_initializer(4))
        p = plan.Plan.create_from_params(
            _setup_plan(compiler=block_compiler.Compiler.create(
                blocks.Scalar()),
                        losses={'foo': foo},
                        examples=xrange(5)), params)
        self.assertEqual(p.num_multiprocess_processes, 4)
        self.assertEqual(p.master, 'foo')
        self.assertEqual(p.batches_per_epoch, 123)
        self.assertEqual(p.compute_summaries, True)
        self.assertEqual(p.is_chief_trainer, True)
        self.assertEqual(p.logdir,
                         os.path.join('/tmp/', 'plan', 'run_0', 'train'))
        self.assertEqual(p.rundir, os.path.join('/tmp/', 'plan', 'run_0'))
        self.assertEqual(p.plandir, os.path.join('/tmp/', 'plan'))
        self.assertEqual([0, 1, 2], list(p.examples))
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            self.assertEqual(4, p.loss_total.eval())
            sess.run(p.train_op)  # should make loss smaller
            self.assertLess(p.loss_total.eval(), 4)

        tf.flags.FLAGS.num_multiprocess_processes = None
        tf.flags.FLAGS.task = 42
        train_op = tf.no_op()
        p = plan.Plan.create_from_flags(
            _setup_plan(compiler=block_compiler.Compiler.create(
                blocks.Scalar()),
                        losses={'foo': tf.constant(3.14)},
                        train_op=train_op,
                        examples=xrange(5)))
        self.assertEqual(p.num_multiprocess_processes, None)
        self.assertEqual(p.compute_summaries, False)
        self.assertEqual(p.is_chief_trainer, False)
        self.assertEqual(p.train_op, train_op)
예제 #3
0
  def test_create_from_params(self):
    params = plan.plan_default_params()
    params.update({
        'mode': plan.Plan.mode_keys.TRAIN,
        'truncate_examples': 3,
        'num_multiprocess_processes': 4,
        'master': 'foo',
        'batches_per_epoch': 123})
    foo = tf.get_variable('foo', [], tf.float32, tf.constant_initializer(4))
    p = plan.Plan.create_from_params(_setup_plan(
        compiler=block_compiler.Compiler.create(blocks.Scalar()),
        losses={'foo': foo},
        examples=xrange(5)), params)
    self.assertEqual(p.num_multiprocess_processes, 4)
    self.assertEqual(p.master, 'foo')
    self.assertEqual(p.batches_per_epoch, 123)
    self.assertEqual(p.compute_summaries, True)
    self.assertEqual(p.is_chief_trainer, True)
    self.assertEqual(p.logdir, os.path.join('/tmp/', 'plan', 'run_0', 'train'))
    self.assertEqual(p.rundir, os.path.join('/tmp/', 'plan', 'run_0'))
    self.assertEqual(p.plandir, os.path.join('/tmp/', 'plan'))
    self.assertEqual([0, 1, 2], list(p.examples))
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      self.assertEqual(4, p.loss_total.eval())
      sess.run(p.train_op)  # should make loss smaller
      self.assertLess(p.loss_total.eval(), 4)

    tf.flags.FLAGS.num_multiprocess_processes = None
    tf.flags.FLAGS.task = 42
    train_op = tf.no_op()
    p = plan.Plan.create_from_flags(_setup_plan(
        compiler=block_compiler.Compiler.create(blocks.Scalar()),
        losses={'foo': tf.constant(3.14)},
        train_op=train_op,
        examples=xrange(5)))
    self.assertEqual(p.num_multiprocess_processes, None)
    self.assertEqual(p.compute_summaries, False)
    self.assertEqual(p.is_chief_trainer, False)
    self.assertEqual(p.train_op, train_op)