示例#1
0
def train_lopq(start, args):
    args = copy.deepcopy(args)
    dt = TrainingSet.objects.get(**args['training_set_selector'])
    m = TrainedModel()
    dirname = "{}/models/{}".format(settings.MEDIA_ROOT, m.uuid)
    m.create_directory()
    l = lopq_trainer.LOPQTrainer(name=args["name"],
                                 dirname=dirname,
                                 components=args['components'],
                                 m=args['m'],
                                 v=args['v'],
                                 sub=args['sub'],
                                 source_indexer_shasum=args['indexer_shasum'])
    index_list = []
    for f in dt.files:
        di = IndexEntries.objects.get(pk=f['pk'])
        vecs, _ = di.load_index()
        if di.count:
            index_list.append(np.atleast_2d(vecs))
            logging.info("loaded {}".format(index_list[-1].shape))
        else:
            logging.info("Ignoring {}".format(di.pk))
    data = np.concatenate(vecs).squeeze()
    logging.info("Final shape {}".format(data.shape))
    l.train(data, lopq_train_opts=args["lopq_train_opts"])
    j = l.save()
    m.name = j["name"]
    m.algorithm = j["algorithm"]
    m.model_type = j["model_type"]
    m.arguments = j["arguments"]
    m.shasum = j["shasum"]
    m.files = j["files"]
    m.event = start
    m.training_set = dt
    m.save()
    m.upload()
    _ = Retriever.objects.create(name="Retriever for approximator {}".format(
        m.pk),
                                 source_filters={},
                                 algorithm=Retriever.LOPQ,
                                 approximator_shasum=m.shasum,
                                 indexer_shasum=args['indexer_shasum'])
from django.conf import settings
from dvaapp.models import TrainedModel, Retriever
from dvalib.trainers import lopq_trainer
import numpy as np

if __name__ == '__main__':
    l = lopq_trainer.LOPQTrainer(
        name="Facenet_LOPQ_on_LFW",
        dirname=os.path.join(os.path.dirname('__file__'),
                             "../../shared/facenet_lopq/"),
        components=64,
        m=32,
        v=32,
        sub=256,
        source_indexer_shasum="9f99caccbc75dcee8cb0a55a0551d7c5cb8a6836")
    data = np.load('facenet.npy')
    l.train(data)
    j = l.save()
    with open("lopq_facenet_approximator.json", 'w') as out:
        json.dump(j, out)
    m = TrainedModel(**j)
    m.save()
    m.create_directory()
    for f in m.files:
        shutil.copy(
            f['url'], '{}/models/{}/{}'.format(settings.MEDIA_ROOT, m.pk,
                                               f['filename']))
    dr = Retriever.objects.create(name="lopq retriever",
                                  source_filters={},
                                  algorithm=Retriever.LOPQ,
                                  approximator_shasum=m.shasum)
示例#3
0
def train_faiss(start, args):
    args = copy.deepcopy(args)
    dt = TrainingSet.objects.get(**args['training_set_selector'])
    m = TrainedModel()
    m.create_directory()
    index_list = []
    vecs = None
    for f in dt.files:
        di = IndexEntries.objects.get(pk=f['pk'])
        vecs, _ = di.load_index()
        if di.count:
            index_list.append(np.atleast_2d(vecs))
            logging.info("loaded {}".format(index_list[-1].shape))
        else:
            logging.info("Ignoring {}".format(di.pk))
    data = np.concatenate(vecs).squeeze()
    logging.info("Final shape {}".format(data.shape))
    output_file = "{}/models/{}/faiss.index".format(settings.MEDIA_ROOT,
                                                    m.uuid)
    index_factory = args['index_factory']
    shasum = faiss_trainer.train_index(data, index_factory, output_file)
    m.name = args['name']
    m.algorithm = "FAISS"
    m.model_type = m.APPROXIMATOR
    m.arguments = {'index_factory': args['index_factory']}
    m.shasum = shasum
    m.files = [{"filename": "faiss.index", "url": output_file}]
    m.event = start
    m.training_set = dt
    m.save()
    m.upload()
    _ = Retriever.objects.create(name="Retriever for approximator {}".format(
        m.pk),
                                 source_filters={},
                                 algorithm=Retriever.FAISS,
                                 approximator_shasum=m.shasum,
                                 indexer_shasum=args['indexer_shasum'])