def forward_fn(inputs, is_train): """Forward pass function. Args: * inputs: inputs to the network's forward pass * is_train: whether to use the forward pass with training operations inserted Returns: * outputs: outputs from the network's forward pass """ nb_classes = FLAGS.nb_classes depth_mult = FLAGS.mobilenet_depth_mult if FLAGS.mobilenet_version == 1: scope_fn = MobileNetV1.mobilenet_v1_arg_scope with slim.arg_scope(scope_fn(is_training=is_train)): # pylint: disable=not-context-manager outputs, __ = MobileNetV1.mobilenet_v1(inputs, is_training=is_train, num_classes=nb_classes, depth_multiplier=depth_mult) elif FLAGS.mobilenet_version == 2: scope_fn = MobileNetV2.training_scope with slim.arg_scope(scope_fn(is_training=is_train)): # pylint: disable=not-context-manager outputs, __ = MobileNetV2.mobilenet(inputs, num_classes=nb_classes, depth_multiplier=depth_mult) else: raise ValueError('invalid MobileNet version: {}'.format( FLAGS.mobilenet_version)) return outputs
def mobilenetv2_head(inputs, is_training=True): with slim.arg_scope(mobilenetv2_scope(is_training=is_training, trainable=True)): net, _ = mobilenet_v2.mobilenet(input_tensor=inputs, num_classes=None, is_training=False, depth_multiplier=1.0, scope='MobilenetV2', conv_defs=V2_HEAD_DEF, finegrain_classification_mode=False) net = tf.squeeze(net, [1, 2]) return net