def test_mobilenet_weights(self): model = 'mobilenet_v1_0.25_128' with tempfile.TemporaryDirectory() as target_path: s3.download_folder(f"slim/{model}", target_path) downloaded_files = os.listdir(target_path) expected_suffixes = [ '.ckpt.data-00000-of-00001', '.ckpt.index', '.ckpt.meta', '.tflite', '_eval.pbtxt', '_frozen.pb', '_info.txt' ] assert set(downloaded_files) == set( [model + suffix for suffix in expected_suffixes])
def _find_model_weights(model_name): _logger = logging.getLogger(fullname(TFSlimModel._find_model_weights)) framework_home = os.path.expanduser(os.getenv('CM_HOME', '~/.candidate_models')) weights_path = os.getenv('CM_TSLIM_WEIGHTS_DIR', os.path.join(framework_home, 'model-weights', 'slim')) model_path = os.path.join(weights_path, model_name) if not os.path.isdir(model_path): _logger.debug(f"Downloading weights for {model_name} to {model_path}") os.makedirs(model_path) s3.download_folder(f"slim/{model_name}", model_path) fnames = glob.glob(os.path.join(model_path, '*.ckpt*')) assert len(fnames) > 0, f"no checkpoint found in {model_path}" restore_path = fnames[0].split('.ckpt')[0] + '.ckpt' return restore_path
def _find_model_json(model_name): _logger = logging.getLogger(fullname(TFUtilsModel._find_model_json)) framework_home = os.path.expanduser(os.getenv('CM_HOME', '~/.candidate_models')) json_path = os.getenv('CM_TFUTILS_JSON_DIR', os.path.join(framework_home, 'model-jsons', 'tfutils')) model_path = os.path.join(json_path, model_name) if not os.path.isdir(model_path): _logger.debug(f"Downloading json for {model_name} to {model_path}") os.makedirs(model_path) s3.download_folder(f"model-jsons/{model_name}", model_path, bucket='brain-score-tfutils-models', region='us-west-1') fnames = glob.glob(os.path.join(model_path, '*.json*')) assert len(fnames) > 0, f"no json found in {model_path}" tnn_json = fnames[0].split('.json')[0] + '.json' return tnn_json