Exemple #1
0
 def test_mobilenet_weights(self):
     model = 'mobilenet_v1_0.25_128'
     with tempfile.TemporaryDirectory() as target_path:
         s3.download_folder(f"slim/{model}", target_path)
         downloaded_files = os.listdir(target_path)
         expected_suffixes = [
             '.ckpt.data-00000-of-00001', '.ckpt.index', '.ckpt.meta',
             '.tflite', '_eval.pbtxt', '_frozen.pb', '_info.txt'
         ]
         assert set(downloaded_files) == set(
             [model + suffix for suffix in expected_suffixes])
 def _find_model_weights(model_name):
     _logger = logging.getLogger(fullname(TFSlimModel._find_model_weights))
     framework_home = os.path.expanduser(os.getenv('CM_HOME', '~/.candidate_models'))
     weights_path = os.getenv('CM_TSLIM_WEIGHTS_DIR', os.path.join(framework_home, 'model-weights', 'slim'))
     model_path = os.path.join(weights_path, model_name)
     if not os.path.isdir(model_path):
         _logger.debug(f"Downloading weights for {model_name} to {model_path}")
         os.makedirs(model_path)
         s3.download_folder(f"slim/{model_name}", model_path)
     fnames = glob.glob(os.path.join(model_path, '*.ckpt*'))
     assert len(fnames) > 0, f"no checkpoint found in {model_path}"
     restore_path = fnames[0].split('.ckpt')[0] + '.ckpt'
     return restore_path
 def _find_model_json(model_name):
     _logger = logging.getLogger(fullname(TFUtilsModel._find_model_json))
     framework_home = os.path.expanduser(os.getenv('CM_HOME', '~/.candidate_models'))
     json_path = os.getenv('CM_TFUTILS_JSON_DIR', os.path.join(framework_home, 'model-jsons', 'tfutils'))
     model_path = os.path.join(json_path, model_name)
     if not os.path.isdir(model_path):
         _logger.debug(f"Downloading json for {model_name} to {model_path}")
         os.makedirs(model_path)
         s3.download_folder(f"model-jsons/{model_name}", model_path,
                            bucket='brain-score-tfutils-models', region='us-west-1')
     fnames = glob.glob(os.path.join(model_path, '*.json*'))
     assert len(fnames) > 0, f"no json found in {model_path}"
     tnn_json = fnames[0].split('.json')[0] + '.json'
     return tnn_json