def completeLoop(X, y, base_ind): embedding_net = EmbeddingNet('densenet161', 256, True) center_loss = None model = torch.nn.DataParallel(embedding_net).cuda() criterion = OnlineTripletLoss(1.0, SemihardNegativeTripletSelector(1.0)) params = model.parameters() # define loss function (criterion) and optimizer optimizer = torch.optim.Adam(params, lr=0.0001) start_epoch = 0 checkpoint = load_checkpoint('triplet_model_0086.tar') if checkpoint: model.load_state_dict(checkpoint['state_dict']) e = Engine(model, criterion, optimizer, verbose=True, print_freq=10) trainset_query = Detection.select( Detection.id, Oracle.label, Detection.embedding).join(Oracle).where( Oracle.status == 0, Detection.kind == DetectionKind.UserDetection.value) embedding_dataset = SQLDataLoader(trainset_query, "all_crops/SS_full_crops", True, num_workers=4, batch_size=64) print(len(embedding_dataset)) num_classes = 48 #len(run_dataset.getClassesInfo()[0]) print("Num Classes= " + str(num_classes)) embedding_loader = embedding_dataset.getBalancedLoader(16, 4) for i in range(200): e.train_one_epoch(embedding_loader, i, False) embedding_dataset2 = SQLDataLoader(trainset_query, "all_crops/SS_full_crops", False, num_workers=4, batch_size=512) em = e.embedding(embedding_dataset2.getSingleLoader()) lb = np.asarray([x[1] for x in embedding_dataset2.samples]) pt = [x[0] for x in embedding_dataset2.samples ] #e.predict(run_loader, load_info=True) co = 0 for r, s, t in zip(em, lb, pt): co += 1 smp = Detection.get(id=t) smp.embedding = r smp.save() if co % 100 == 0: print(co) print("Loop Started") train_dataset = SQLDataLoader(trainset_query, "all_crops/SS_full_crops", False, datatype='embedding', num_workers=4, batch_size=64) print(len(train_dataset)) num_classes = 48 #len(run_dataset.getClassesInfo()[0]) print("Num Classes= " + str(num_classes)) run_loader = train_dataset.getSingleLoader() clf_model = ClassificationNet(256, 48).cuda() clf_criterion = nn.CrossEntropyLoss() clf_optimizer = torch.optim.Adam(clf_model.parameters(), lr=0.001, weight_decay=0.0005) clf_e = Engine(clf_model, clf_criterion, clf_optimizer, verbose=True, print_freq=10) for i in range(250): clf_e.train_one_epoch(run_loader, i, True) testset_query = Detection.select( Detection.id, Oracle.label, Detection.embedding).join(Oracle).where( Oracle.status == 0, Detection.kind == DetectionKind.ModelDetection.value) test_dataset = SQLDataLoader(testset_query, "all_crops/SS_full_crops", False, datatype='image', num_workers=4, batch_size=512) print(len(test_dataset)) num_classes = 48 #len(run_dataset.getClassesInfo()[0]) print("Num Classes= " + str(num_classes)) test_loader = test_dataset.getSingleLoader() test_em = e.embedding(test_loader) test_lb = np.asarray([x[1] for x in test_dataset.samples]) print(test_lb.shape, test_em.shape) testset = TensorDataset(torch.from_numpy(test_em), torch.from_numpy(test_lb)) clf_e.validate(DataLoader(testset, batch_size=512, shuffle=False), True) """X_train= X[list(base_ind)]
def main(): args = parser.parse_args() print("DB Connect") db_path = os.path.join(args.run_data, os.path.basename( args.run_data)) + ".db" print(db_path) db = SqliteDatabase(db_path) proxy.initialize(db) db.connect() print("connected") print("CompleteLoop") checkpoint = load_checkpoint(args.base_model) embedding_net = EmbeddingNet(checkpoint['arch'], checkpoint['feat_dim'], False) #embedding_net = EmbeddingNet('resnet50', 256, True) model = torch.nn.DataParallel(embedding_net).cuda() model.load_state_dict(checkpoint['state_dict']) #unlabeledset_query= Detection.select(Detection.id,Oracle.label).join(Oracle).where(Detection.kind==DetectionKind.ModelDetection.value).order_by(fn.random()).limit(150000) #unlabeled_dataset = SQLDataLoader(unlabeledset_query, os.path.join(args.run_data, "crops"), is_training= False, num_workers= 8) unlabeled_dataset = SQLDataLoader(os.path.join(args.run_data, "crops"), is_training=False, kind=DetectionKind.ModelDetection.value, num_workers=8) unlabeled_dataset.updateEmbedding(model) #print('Embedding Done') #sys.stdout.flush() plot_embedding(unlabeled_dataset.em, np.asarray(unlabeled_dataset.getlabels()), unlabeled_dataset.getpaths(), {}) # Random examples to start random_ids = np.random.choice(unlabeled_dataset.getIDs(), 5000, replace=False).tolist() #random_ids = noveltySamples(unlabeled_dataset.em, unlabeled_dataset.getIDs(), 1000) #print(random_ids) # Move Records moveRecords(DetectionKind.ModelDetection, DetectionKind.UserDetection, random_ids) # Finetune the embedding model print(len(unlabeled_dataset)) unlabeled_dataset.setKind(DetectionKind.UserDetection.value) unlabeled_dataset.train() print(len(unlabeled_dataset)) #train_dataset = SQLDataLoader(trainset_query, os.path.join(args.run_data, 'crops'), is_training= True) finetune_embedding(model, unlabeled_dataset, 32, 4, 100) #unlabeled_dataset.updateEmbedding(model) train_dataset.updateEmbedding(model) plot_embedding(train_dataset.em, np.asarray(train_dataset.getlabels()), train_dataset.getpaths(), {}) #plot_embedding( unlabeled_dataset.em, np.asarray(unlabeled_dataset.getlabels()) , unlabeled_dataset.getIDs(), {}) train_dataset.embedding_mode() train_dataset.train() clf_model = ClassificationNet(256, 48).cuda() train_eval_classifier() #clf_model = ClassificationNet(checkpoint['feat_dim'], 48).cuda() clf_criterion = nn.CrossEntropyLoss() clf_optimizer = torch.optim.Adam(clf_model.parameters(), lr=0.001, weight_decay=0.0005) clf_e = Engine(clf_model, clf_criterion, clf_optimizer, verbose=True, print_freq=1) clf_model.train() clf_train_loader = train_dataset.getSingleLoader(batch_size=64) for i in range(15): clf_e.train_one_epoch(clf_train_loader, i, True) clf_model.eval() unlabeledset_query = Detection.select( Detection.id, Oracle.label).join(Oracle).where( Detection.kind == DetectionKind.ModelDetection.value).order_by( fn.random()).limit(20000) unlabeled_dataset = SQLDataLoader(unlabeledset_query, os.path.join(args.run_data, 'crops'), is_training=False, num_workers=4) unlabeled_dataset.refresh(unlabeledset_query) unlabeled_dataset.updateEmbedding(model) unlabeled_dataset.embedding_mode() unlabeled_dataset.eval() eval_loader = unlabeled_dataset.getSingleLoader(batch_size=1024) clf_e.validate(eval_loader, True) clf_output = clf_e.embedding(eval_loader, dim=48) indices = activeLearning(clf_output, unlabeled_dataset.em) moveRecords(DetectionKind.ModelDetection, DetectionKind.UserDetection, [unlabeled_dataset.getIDs()[i] for i in indices])