예제 #1
0
 def testBuildEmptyOptimizer(self):
     optimizer_text_proto = """
 """
     global_summaries = set([])
     optimizer_proto = optimizer_pb2.Optimizer()
     text_format.Merge(optimizer_text_proto, optimizer_proto)
     with self.assertRaises(ValueError):
         optimizer_builder.build(optimizer_proto, global_summaries)
예제 #2
0
 def testBuildAdamOptimizer(self):
     optimizer_text_proto = """
   adam_optimizer: {
     learning_rate: {
       constant_learning_rate {
         learning_rate: 0.002
       }
     }
   }
   use_moving_average: false
 """
     global_summaries = set([])
     optimizer_proto = optimizer_pb2.Optimizer()
     text_format.Merge(optimizer_text_proto, optimizer_proto)
     optimizer_object = optimizer_builder.build(optimizer_proto,
                                                global_summaries)
     self.assertTrue(isinstance(optimizer_object, tf.train.AdamOptimizer))
예제 #3
0
 def testBuildMovingAverageOptimizerWithNonDefaultDecay(self):
     optimizer_text_proto = """
   adam_optimizer: {
     learning_rate: {
       constant_learning_rate {
         learning_rate: 0.002
       }
     }
   }
   use_moving_average: True
   moving_average_decay: 0.2
 """
     global_summaries = set([])
     optimizer_proto = optimizer_pb2.Optimizer()
     text_format.Merge(optimizer_text_proto, optimizer_proto)
     optimizer_object = optimizer_builder.build(optimizer_proto,
                                                global_summaries)
     self.assertTrue(
         isinstance(optimizer_object,
                    tf.contrib.opt.MovingAverageOptimizer))
     # TODO: Find a way to not depend on the private members.
     self.assertAlmostEqual(optimizer_object._ema._decay, 0.2)
예제 #4
0
 def testBuildRMSPropOptimizer(self):
     optimizer_text_proto = """
   rms_prop_optimizer: {
     learning_rate: {
       exponential_decay_learning_rate {
         initial_learning_rate: 0.004
         decay_steps: 800720
         decay_factor: 0.95
       }
     }
     momentum_optimizer_value: 0.9
     decay: 0.9
     epsilon: 1.0
   }
   use_moving_average: false
 """
     global_summaries = set([])
     optimizer_proto = optimizer_pb2.Optimizer()
     text_format.Merge(optimizer_text_proto, optimizer_proto)
     optimizer_object = optimizer_builder.build(optimizer_proto,
                                                global_summaries)
     self.assertTrue(isinstance(optimizer_object,
                                tf.train.RMSPropOptimizer))