def main(architecture: str, tag: str): architecture: Architecture = Architecture[architecture.upper()] triplet_embedding_model = architecture.get_triplet_embedding_model() tag = Tag(tag) # Determine under which tag to save the fine-tuned weights if none was # explicitly specified. if not tag.version: try: tag.version = architecture.get_latest_version(tag) + 1 except ValueError: tag.version = 1 try: triplet_embedding_model.train( DATASET.triplets, BATCH_SIZE, NUM_EPOCHS, OPTIMIZER, LOSS, # Apply augmenter to anchors only. augmenter=(augmenter, None, None)) # Allow user to manually interrupt training while still saving weights. except KeyboardInterrupt: pass triplet_embedding_model.save_weights(tag)
def init_scorer(architecture: Architecture, tag: Optional[Union[str, Tag]]) -> ScorerModel: if isinstance(tag, str): tag = Tag(tag) # If no version is specified, use latest version. if tag and not tag.version: tag.version = architecture.get_latest_version(tag) return architecture.get_scorer_model(tag)
def get_embedding_model(self, tag: Optional[Union[str, Tag]] = None, use_triplets: bool = False) -> EmbeddingModel: if isinstance(tag, str): tag = Tag(tag) base_model = self.get_model() cls = TripletEmbeddingModel if use_triplets else EmbeddingModel return cls(base_model, tag, self.resolution, self.model_dir, name=self.value)
def get_latest_version(self, tag: Union[str, Tag]) -> int: if isinstance(tag, str): tag = Tag(tag) try: def filter_func(filename): return bool(re.search(rf'{tag.name}-\d+\.\w+$', filename)) model_files = list(filter(filter_func, os.listdir(self.model_dir))) except FileNotFoundError: model_files = [] if not model_files: raise ValueError(f'No {self.value} weights have been saved yet') return max(map(Tag.get_version_from_filename, model_files))
def finetune_and_embed(triplet_embedding_model: TripletEmbeddingModel, triplets: List[FaceTriplet]): dummy_image = triplets[0].anchor y_original = triplet_embedding_model.embed(dummy_image) batch_size = 1 num_epochs = 1 optimizer = Adam(learning_rate=3e-4) loss = TripletLoss(alpha=0.5, force_normalization=True) triplet_embedding_model.train(triplets, batch_size, num_epochs, optimizer, loss) tag = Tag('tag:1') triplet_embedding_model.save_weights(tag) y_trained = triplet_embedding_model.embed(dummy_image) triplet_embedding_model.load_weights(tag) y_restored = triplet_embedding_model.embed(dummy_image) return y_original, y_trained, y_restored
def get_weights_path(self, tag: Tag): filename = tag.append_to_filename('weights.h5') return os.path.join(self.model_dir, filename)
def test_create_tag_from_string(): tag = Tag('tag:1') assert tag.name == 'tag' assert tag.version == 1
def test_append_to_filename(): filename = 'myfile.txt' tag = Tag('mytag_simple', 2) tagged_filename = tag.append_to_filename(filename) assert tagged_filename == 'myfile-mytag_simple-2.txt'
def test_get_tag_from_filename_without_version(): filename = 'myfile-mytag_simple.txt' tag = Tag.from_filename(filename) assert tag.name == 'mytag_simple' assert tag.version is None
def test_get_tag_from_filename(): filename = 'myfile-mytag_simple-2.txt' tag = Tag.from_filename(filename) assert tag.name == 'mytag_simple' assert tag.version == 2
def test_get_version_from_filename(): filename = 'myfile-mytag-2.txt' assert Tag.get_version_from_filename(filename) == 2
def test_create_tag_from_string_double_colon(): with pytest.raises(ValueError): Tag('tag:2:1')
def test_create_tag_from_string_without_version(): tag = Tag('tag') assert tag.name == 'tag' assert tag.version is None