def build_mnasnet_base(images, model_name, training, override_params=None): """A helper functiion to create a MnasNet base model and return global_pool. Args: images: input images tensor. model_name: string, the model name of a pre-defined MnasNet. training: boolean, whether the model is constructed for training. override_params: A dictionary of params for overriding. Fields must exist in mnasnet_model.GlobalParams. Returns: features: global pool features. endpoints: the endpoints for each layer. Raises: When model_name specified an undefined model, raises NotImplementedError. When override_params has invalid fields, raises ValueError. """ assert isinstance(images, tf.Tensor) blocks_args, global_params = get_model_params(model_name, override_params) with tf.variable_scope(model_name): model = mnasnet_model.MnasNetModel(blocks_args, global_params) features = model(images, training=training, features_only=True) features = tf.identity(features, 'global_pool') return features, model.endpoints
def build_mnasnet_model(images, model_name, training, override_params=None): """A helper functiion to create a MnasNet model and return predicted logits. Args: images: input images tensor. model_name: string, the model name of a pre-defined MnasNet. training: boolean, whether the model is constructed for training. override_params: A dictionary of params for overriding. Fields must exist in mnasnet_model.GlobalParams. Returns: logits: the logits tensor of classes. endpoints: the endpoints for each layer. Raises: When model_name specified an undefined model, raises NotImplementedError. When override_params has invalid fields, raises ValueError. """ assert isinstance(images, tf.Tensor) blocks_args, global_params = get_model_params(model_name, override_params) with tf.variable_scope(model_name): model = mnasnet_model.MnasNetModel(blocks_args, global_params) logits = model(images, training=training) logits = tf.squeeze(tf.expand_dims(logits, 0), 0) logits = tf.identity(logits, 'logits') return logits, model.endpoints
def build_mnasnet_model(images, model_name, training, override_params=None): """A helper functiion to creates a MnasNet model and returns predicted logits. Args: images: input images tensor. model_name: string, the model name of a pre-defined MnasNet. training: boolean, whether the model is constructed for training. override_params: A dictionary of params for overriding. Fields must exist in mnasnet_model.GlobalParams. Returns: logits: the logits tensor of classes. endpoints: the endpoints for each layer. Raises: When model_name specified an undefined model, raises NotImplementedError. When override_params has invalid fields, raises ValueError. """ assert isinstance(images, tf.Tensor) if model_name == 'mnasnet-a1': blocks_args, global_params = mnasnet_a1() elif model_name == 'mnasnet-b1': blocks_args, global_params = mnasnet_b1() elif 'legrnet' in model_name: blocks_args, global_params = eval(model_name)() else: raise NotImplementedError('model name is not pre-defined: %s' % model_name) if override_params: # ValueError will be raised here if override_params has fields not included # in global_params. global_params = global_params._replace(**override_params) with tf.variable_scope(model_name): if 'legrnet' in model_name: model = legrnet_model.LeGRNetModel(blocks_args, global_params) else: model = mnasnet_model.MnasNetModel(blocks_args, global_params) logits = model(images, training=training) logits = tf.identity(logits, 'logits') return logits, model.endpoints