Beispiel #1
0
 def testFullyConvolutionalEndpointShapes(self):
   num_classes = 10
   inputs = create_test_input(2, 321, 321, 3)
   with slim.arg_scope(nas_network.nas_arg_scope()):
     _, end_points = self._pnasnet_small(inputs,
                                         num_classes)
     endpoint_to_shape = {
         'Stem': [2, 81, 81, 128],
         'Cell_0': [2, 41, 41, 100],
         'Cell_1': [2, 21, 21, 200],
         'Cell_2': [2, 21, 21, 200]}
     for endpoint, shape in endpoint_to_shape.iteritems():
       self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
Beispiel #2
0
 def testFullyConvolutionalEndpointShapes(self):
     num_classes = 10
     inputs = create_test_input(2, 321, 321, 3)
     with slim.arg_scope(nas_network.nas_arg_scope()):
         _, end_points = self._pnasnet_small(inputs, num_classes)
         endpoint_to_shape = {
             'Stem': [2, 81, 81, 128],
             'Cell_0': [2, 41, 41, 100],
             'Cell_1': [2, 21, 21, 200],
             'Cell_2': [2, 21, 21, 200]
         }
         for endpoint, shape in endpoint_to_shape.iteritems():
             self.assertListEqual(
                 end_points[endpoint].get_shape().as_list(), shape)
Beispiel #3
0
 def testFullyConvolutionalEndpointShapes(self):
   num_classes = 10
   backbone = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1]
   inputs = create_test_input(None, 321, 321, 3)
   with slim.arg_scope(nas_network.nas_arg_scope()):
     _, end_points = self._pnasnet(inputs, backbone, num_classes)
     endpoint_to_shape = {
         'Stem': [None, 81, 81, 128],
         'Cell_0': [None, 81, 81, 50],
         'Cell_1': [None, 81, 81, 50],
         'Cell_2': [None, 81, 81, 50],
         'Cell_3': [None, 41, 41, 100],
         'Cell_4': [None, 21, 21, 200],
         'Cell_5': [None, 41, 41, 100],
         'Cell_6': [None, 21, 21, 200],
         'Cell_7': [None, 21, 21, 200],
         'Cell_8': [None, 11, 11, 400],
         'Cell_9': [None, 11, 11, 400],
         'Cell_10': [None, 21, 21, 200],
         'Cell_11': [None, 41, 41, 100]
     }
     for endpoint, shape in endpoint_to_shape.items():
       self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)