Esempio n. 1
0
def completeLoop(X, y, base_ind):
    embedding_net = EmbeddingNet('densenet161', 256, True)
    center_loss = None
    model = torch.nn.DataParallel(embedding_net).cuda()
    criterion = OnlineTripletLoss(1.0, SemihardNegativeTripletSelector(1.0))
    params = model.parameters()

    # define loss function (criterion) and optimizer
    optimizer = torch.optim.Adam(params, lr=0.0001)
    start_epoch = 0
    checkpoint = load_checkpoint('triplet_model_0086.tar')
    if checkpoint:
        model.load_state_dict(checkpoint['state_dict'])

    e = Engine(model, criterion, optimizer, verbose=True, print_freq=10)
    trainset_query = Detection.select(
        Detection.id, Oracle.label, Detection.embedding).join(Oracle).where(
            Oracle.status == 0,
            Detection.kind == DetectionKind.UserDetection.value)
    embedding_dataset = SQLDataLoader(trainset_query,
                                      "all_crops/SS_full_crops",
                                      True,
                                      num_workers=4,
                                      batch_size=64)
    print(len(embedding_dataset))
    num_classes = 48  #len(run_dataset.getClassesInfo()[0])
    print("Num Classes= " + str(num_classes))
    embedding_loader = embedding_dataset.getBalancedLoader(16, 4)
    for i in range(200):
        e.train_one_epoch(embedding_loader, i, False)
    embedding_dataset2 = SQLDataLoader(trainset_query,
                                       "all_crops/SS_full_crops",
                                       False,
                                       num_workers=4,
                                       batch_size=512)
    em = e.embedding(embedding_dataset2.getSingleLoader())
    lb = np.asarray([x[1] for x in embedding_dataset2.samples])
    pt = [x[0] for x in embedding_dataset2.samples
          ]  #e.predict(run_loader, load_info=True)

    co = 0
    for r, s, t in zip(em, lb, pt):
        co += 1
        smp = Detection.get(id=t)
        smp.embedding = r
        smp.save()
        if co % 100 == 0:
            print(co)
    print("Loop Started")
    train_dataset = SQLDataLoader(trainset_query,
                                  "all_crops/SS_full_crops",
                                  False,
                                  datatype='embedding',
                                  num_workers=4,
                                  batch_size=64)
    print(len(train_dataset))
    num_classes = 48  #len(run_dataset.getClassesInfo()[0])
    print("Num Classes= " + str(num_classes))
    run_loader = train_dataset.getSingleLoader()
    clf_model = ClassificationNet(256, 48).cuda()
    clf_criterion = nn.CrossEntropyLoss()
    clf_optimizer = torch.optim.Adam(clf_model.parameters(),
                                     lr=0.001,
                                     weight_decay=0.0005)
    clf_e = Engine(clf_model,
                   clf_criterion,
                   clf_optimizer,
                   verbose=True,
                   print_freq=10)
    for i in range(250):
        clf_e.train_one_epoch(run_loader, i, True)
    testset_query = Detection.select(
        Detection.id, Oracle.label, Detection.embedding).join(Oracle).where(
            Oracle.status == 0,
            Detection.kind == DetectionKind.ModelDetection.value)
    test_dataset = SQLDataLoader(testset_query,
                                 "all_crops/SS_full_crops",
                                 False,
                                 datatype='image',
                                 num_workers=4,
                                 batch_size=512)
    print(len(test_dataset))
    num_classes = 48  #len(run_dataset.getClassesInfo()[0])
    print("Num Classes= " + str(num_classes))
    test_loader = test_dataset.getSingleLoader()
    test_em = e.embedding(test_loader)
    test_lb = np.asarray([x[1] for x in test_dataset.samples])
    print(test_lb.shape, test_em.shape)
    testset = TensorDataset(torch.from_numpy(test_em),
                            torch.from_numpy(test_lb))
    clf_e.validate(DataLoader(testset, batch_size=512, shuffle=False), True)
    """X_train= X[list(base_ind)]
Esempio n. 2
0
def main():
    args = parser.parse_args()
    print("DB Connect")
    db_path = os.path.join(args.run_data, os.path.basename(
        args.run_data)) + ".db"
    print(db_path)
    db = SqliteDatabase(db_path)
    proxy.initialize(db)
    db.connect()
    print("connected")
    print("CompleteLoop")

    checkpoint = load_checkpoint(args.base_model)
    embedding_net = EmbeddingNet(checkpoint['arch'], checkpoint['feat_dim'],
                                 False)
    #embedding_net = EmbeddingNet('resnet50', 256, True)
    model = torch.nn.DataParallel(embedding_net).cuda()
    model.load_state_dict(checkpoint['state_dict'])
    #unlabeledset_query= Detection.select(Detection.id,Oracle.label).join(Oracle).where(Detection.kind==DetectionKind.ModelDetection.value).order_by(fn.random()).limit(150000)
    #unlabeled_dataset = SQLDataLoader(unlabeledset_query, os.path.join(args.run_data, "crops"), is_training= False, num_workers= 8)
    unlabeled_dataset = SQLDataLoader(os.path.join(args.run_data, "crops"),
                                      is_training=False,
                                      kind=DetectionKind.ModelDetection.value,
                                      num_workers=8)
    unlabeled_dataset.updateEmbedding(model)
    #print('Embedding Done')
    #sys.stdout.flush()
    plot_embedding(unlabeled_dataset.em,
                   np.asarray(unlabeled_dataset.getlabels()),
                   unlabeled_dataset.getpaths(), {})
    # Random examples to start
    random_ids = np.random.choice(unlabeled_dataset.getIDs(),
                                  5000,
                                  replace=False).tolist()
    #random_ids = noveltySamples(unlabeled_dataset.em, unlabeled_dataset.getIDs(), 1000)
    #print(random_ids)
    # Move Records
    moveRecords(DetectionKind.ModelDetection, DetectionKind.UserDetection,
                random_ids)

    # Finetune the embedding model
    print(len(unlabeled_dataset))
    unlabeled_dataset.setKind(DetectionKind.UserDetection.value)
    unlabeled_dataset.train()
    print(len(unlabeled_dataset))
    #train_dataset = SQLDataLoader(trainset_query, os.path.join(args.run_data, 'crops'), is_training= True)
    finetune_embedding(model, unlabeled_dataset, 32, 4, 100)
    #unlabeled_dataset.updateEmbedding(model)
    train_dataset.updateEmbedding(model)
    plot_embedding(train_dataset.em, np.asarray(train_dataset.getlabels()),
                   train_dataset.getpaths(), {})
    #plot_embedding( unlabeled_dataset.em, np.asarray(unlabeled_dataset.getlabels()) , unlabeled_dataset.getIDs(), {})
    train_dataset.embedding_mode()
    train_dataset.train()
    clf_model = ClassificationNet(256, 48).cuda()
    train_eval_classifier()
    #clf_model = ClassificationNet(checkpoint['feat_dim'], 48).cuda()
    clf_criterion = nn.CrossEntropyLoss()
    clf_optimizer = torch.optim.Adam(clf_model.parameters(),
                                     lr=0.001,
                                     weight_decay=0.0005)
    clf_e = Engine(clf_model,
                   clf_criterion,
                   clf_optimizer,
                   verbose=True,
                   print_freq=1)

    clf_model.train()
    clf_train_loader = train_dataset.getSingleLoader(batch_size=64)
    for i in range(15):
        clf_e.train_one_epoch(clf_train_loader, i, True)
    clf_model.eval()
    unlabeledset_query = Detection.select(
        Detection.id, Oracle.label).join(Oracle).where(
            Detection.kind == DetectionKind.ModelDetection.value).order_by(
                fn.random()).limit(20000)
    unlabeled_dataset = SQLDataLoader(unlabeledset_query,
                                      os.path.join(args.run_data, 'crops'),
                                      is_training=False,
                                      num_workers=4)
    unlabeled_dataset.refresh(unlabeledset_query)
    unlabeled_dataset.updateEmbedding(model)
    unlabeled_dataset.embedding_mode()
    unlabeled_dataset.eval()
    eval_loader = unlabeled_dataset.getSingleLoader(batch_size=1024)
    clf_e.validate(eval_loader, True)
    clf_output = clf_e.embedding(eval_loader, dim=48)
    indices = activeLearning(clf_output, unlabeled_dataset.em)
    moveRecords(DetectionKind.ModelDetection, DetectionKind.UserDetection,
                [unlabeled_dataset.getIDs()[i] for i in indices])