def main(architecture: str, tag: str):
    architecture: Architecture = Architecture[architecture.upper()]
    triplet_embedding_model = architecture.get_triplet_embedding_model()
    tag = Tag(tag)

    # Determine under which tag to save the fine-tuned weights if none was
    # explicitly specified.
    if not tag.version:
        try:
            tag.version = architecture.get_latest_version(tag) + 1
        except ValueError:
            tag.version = 1

    try:
        triplet_embedding_model.train(
            DATASET.triplets,
            BATCH_SIZE,
            NUM_EPOCHS,
            OPTIMIZER,
            LOSS,
            # Apply augmenter to anchors only.
            augmenter=(augmenter, None, None))
    # Allow user to manually interrupt training while still saving weights.
    except KeyboardInterrupt:
        pass
    triplet_embedding_model.save_weights(tag)
 def init_scorer(architecture: Architecture,
                 tag: Optional[Union[str, Tag]]) -> ScorerModel:
     if isinstance(tag, str):
         tag = Tag(tag)
     # If no version is specified, use latest version.
     if tag and not tag.version:
         tag.version = architecture.get_latest_version(tag)
     return architecture.get_scorer_model(tag)
Exemplo n.º 3
0
 def get_embedding_model(self,
                         tag: Optional[Union[str, Tag]] = None,
                         use_triplets: bool = False) -> EmbeddingModel:
     if isinstance(tag, str):
         tag = Tag(tag)
     base_model = self.get_model()
     cls = TripletEmbeddingModel if use_triplets else EmbeddingModel
     return cls(base_model,
                tag,
                self.resolution,
                self.model_dir,
                name=self.value)
Exemplo n.º 4
0
    def get_latest_version(self, tag: Union[str, Tag]) -> int:
        if isinstance(tag, str):
            tag = Tag(tag)
        try:

            def filter_func(filename):
                return bool(re.search(rf'{tag.name}-\d+\.\w+$', filename))

            model_files = list(filter(filter_func, os.listdir(self.model_dir)))
        except FileNotFoundError:
            model_files = []
        if not model_files:
            raise ValueError(f'No {self.value} weights have been saved yet')
        return max(map(Tag.get_version_from_filename, model_files))
def finetune_and_embed(triplet_embedding_model: TripletEmbeddingModel,
                       triplets: List[FaceTriplet]):
    dummy_image = triplets[0].anchor
    y_original = triplet_embedding_model.embed(dummy_image)

    batch_size = 1
    num_epochs = 1
    optimizer = Adam(learning_rate=3e-4)
    loss = TripletLoss(alpha=0.5, force_normalization=True)
    triplet_embedding_model.train(triplets, batch_size, num_epochs, optimizer,
                                  loss)

    tag = Tag('tag:1')
    triplet_embedding_model.save_weights(tag)
    y_trained = triplet_embedding_model.embed(dummy_image)

    triplet_embedding_model.load_weights(tag)
    y_restored = triplet_embedding_model.embed(dummy_image)
    return y_original, y_trained, y_restored
Exemplo n.º 6
0
 def get_weights_path(self, tag: Tag):
     filename = tag.append_to_filename('weights.h5')
     return os.path.join(self.model_dir, filename)
def test_create_tag_from_string():
    tag = Tag('tag:1')
    assert tag.name == 'tag'
    assert tag.version == 1
def test_append_to_filename():
    filename = 'myfile.txt'
    tag = Tag('mytag_simple', 2)
    tagged_filename = tag.append_to_filename(filename)
    assert tagged_filename == 'myfile-mytag_simple-2.txt'
def test_get_tag_from_filename_without_version():
    filename = 'myfile-mytag_simple.txt'
    tag = Tag.from_filename(filename)
    assert tag.name == 'mytag_simple'
    assert tag.version is None
def test_get_tag_from_filename():
    filename = 'myfile-mytag_simple-2.txt'
    tag = Tag.from_filename(filename)
    assert tag.name == 'mytag_simple'
    assert tag.version == 2
def test_get_version_from_filename():
    filename = 'myfile-mytag-2.txt'
    assert Tag.get_version_from_filename(filename) == 2
def test_create_tag_from_string_double_colon():
    with pytest.raises(ValueError):
        Tag('tag:2:1')
def test_create_tag_from_string_without_version():
    tag = Tag('tag')
    assert tag.name == 'tag'
    assert tag.version is None