def build_model(images, training, override_params=None, arch=None): """A helper functiion to creates a ConvNet model and returns predicted logits. Args: images: input images tensor. training: boolean, whether the model is constructed for training. override_params: A dictionary of params for overriding. Fields must exist in model_def.GlobalParams. Returns: logits: the logits tensor of classes. endpoints: the endpoints for each layer. Raises: When override_params has invalid fields, raises ValueError. """ assert isinstance(images, tf.Tensor) assert os.path.isfile(arch) with open(arch, 'r') as f: lines = f.readlines() lines = [line.strip() for line in lines] blocks_args, global_params = parse_netarch_string(lines) if override_params: global_params = global_params._replace(**override_params) with tf.variable_scope('single-path'): model = model_def.MnasNetModel(blocks_args, global_params) logits, macs = model(images, training=training) macs /= 1e6 # macs to M logits = tf.identity(logits, 'logits') return logits, model.endpoints, macs
def build_model(images, model_name, training, override_params=None, parse_search_dir=None): """A helper functiion to creates a ConvNet model and returns predicted logits. Args: images: input images tensor. model_name: string, the model name of a pre-defined MnasNet. training: boolean, whether the model is constructed for training. override_params: A dictionary of params for overriding. Fields must exist in model_def.GlobalParams. Returns: logits: the logits tensor of classes. endpoints: the endpoints for each layer. Raises: When model_name specified an undefined model, raises NotImplementedError. When override_params has invalid fields, raises ValueError. """ assert isinstance(images, tf.Tensor) if model_name == 'single-path': assert parse_search_dir is not None blocks_args, global_params = parse_netarch_model(parse_search_dir) else: raise NotImplementedError('model name is not pre-defined: %s' % model_name) if override_params: # ValueError will be raised here if override_params has fields not included # in global_params. global_params = global_params._replace(**override_params) with tf.variable_scope(model_name): model = model_def.MnasNetModel(blocks_args, global_params) logits = model(images, training=training) logits = tf.identity(logits, 'logits') return logits, model.endpoints