Example #1
0
def build_mnasnet_base(images, model_name, training, override_params=None):
    """A helper functiion to create a MnasNet base model and return global_pool.

  Args:
    images: input images tensor.
    model_name: string, the model name of a pre-defined MnasNet.
    training: boolean, whether the model is constructed for training.
    override_params: A dictionary of params for overriding. Fields must exist in
      mnasnet_model.GlobalParams.

  Returns:
    features: global pool features.
    endpoints: the endpoints for each layer.
  Raises:
    When model_name specified an undefined model, raises NotImplementedError.
    When override_params has invalid fields, raises ValueError.
  """
    assert isinstance(images, tf.Tensor)
    blocks_args, global_params = get_model_params(model_name, override_params)

    with tf.variable_scope(model_name):
        model = mnasnet_model.MnasNetModel(blocks_args, global_params)
        features = model(images, training=training, features_only=True)

    features = tf.identity(features, 'global_pool')
    return features, model.endpoints
def build_mnasnet_model(images, model_name, training, override_params=None):
    """A helper functiion to create a MnasNet model and return predicted logits.

  Args:
    images: input images tensor.
    model_name: string, the model name of a pre-defined MnasNet.
    training: boolean, whether the model is constructed for training.
    override_params: A dictionary of params for overriding. Fields must exist in
      mnasnet_model.GlobalParams.

  Returns:
    logits: the logits tensor of classes.
    endpoints: the endpoints for each layer.
  Raises:
    When model_name specified an undefined model, raises NotImplementedError.
    When override_params has invalid fields, raises ValueError.
  """
    assert isinstance(images, tf.Tensor)
    blocks_args, global_params = get_model_params(model_name, override_params)
    with tf.variable_scope(model_name):
        model = mnasnet_model.MnasNetModel(blocks_args, global_params)
        logits = model(images, training=training)

    logits = tf.squeeze(tf.expand_dims(logits, 0), 0)
    logits = tf.identity(logits, 'logits')
    return logits, model.endpoints
Example #3
0
def build_mnasnet_model(images, model_name, training, override_params=None):
    """A helper functiion to creates a MnasNet model and returns predicted logits.

  Args:
    images: input images tensor.
    model_name: string, the model name of a pre-defined MnasNet.
    training: boolean, whether the model is constructed for training.
    override_params: A dictionary of params for overriding. Fields must exist in
      mnasnet_model.GlobalParams.

  Returns:
    logits: the logits tensor of classes.
    endpoints: the endpoints for each layer.
  Raises:
    When model_name specified an undefined model, raises NotImplementedError.
    When override_params has invalid fields, raises ValueError.
  """
    assert isinstance(images, tf.Tensor)
    if model_name == 'mnasnet-a1':
        blocks_args, global_params = mnasnet_a1()
    elif model_name == 'mnasnet-b1':
        blocks_args, global_params = mnasnet_b1()
    elif 'legrnet' in model_name:
        blocks_args, global_params = eval(model_name)()
    else:
        raise NotImplementedError('model name is not pre-defined: %s' %
                                  model_name)

    if override_params:
        # ValueError will be raised here if override_params has fields not included
        # in global_params.
        global_params = global_params._replace(**override_params)

    with tf.variable_scope(model_name):
        if 'legrnet' in model_name:
            model = legrnet_model.LeGRNetModel(blocks_args, global_params)
        else:
            model = mnasnet_model.MnasNetModel(blocks_args, global_params)
        logits = model(images, training=training)

    logits = tf.identity(logits, 'logits')
    return logits, model.endpoints