Ejemplo n.º 1
0
 def testOverrideHParamsCifarModel(self):
   batch_size = 5
   height, width = 32, 32
   num_classes = 10
   inputs = tf.random_uniform((batch_size, height, width, 3))
   tf.train.create_global_step()
   config = nasnet.cifar_config()
   config.set_hparam('data_format', 'NCHW')
   with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
     _, end_points = nasnet.build_nasnet_cifar(
         inputs, num_classes, config=config)
   self.assertListEqual(
       end_points['Stem'].shape.as_list(), [batch_size, 96, 32, 32])
Ejemplo n.º 2
0
 def testNoAuxHeadCifarModel(self):
   batch_size = 5
   height, width = 32, 32
   num_classes = 10
   for use_aux_head in (True, False):
     tf.reset_default_graph()
     inputs = tf.random_uniform((batch_size, height, width, 3))
     tf.train.create_global_step()
     config = nasnet.cifar_config()
     config.set_hparam('use_aux_head', int(use_aux_head))
     with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
       _, end_points = nasnet.build_nasnet_cifar(inputs, num_classes,
                                                 config=config)
     self.assertEqual('AuxLogits' in end_points, use_aux_head)
Ejemplo n.º 3
0
 def testUseBoundedAcitvationCifarModel(self):
   batch_size = 1
   height, width = 32, 32
   num_classes = 10
   for use_bounded_activation in (True, False):
     tf.reset_default_graph()
     inputs = tf.random_uniform((batch_size, height, width, 3))
     config = nasnet.cifar_config()
     config.set_hparam('use_bounded_activation', use_bounded_activation)
     with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
       _, _ = nasnet.build_nasnet_cifar(
           inputs, num_classes, config=config)
     for node in tf.get_default_graph().as_graph_def().node:
       if node.op.startswith('Relu'):
         self.assertEqual(node.op == 'Relu6', use_bounded_activation)