Пример #1
0
def finetune_embedding(model, loss_type, train_dataset, P, K, epochs):
    """
    Fine tune the embedding model.

    Arguments:
        model: Model to fine tune.
        loss_type: The loss function to minimize while fine tuning.
        train_dataset: Dataset object to use to train the embedding.
        P: Number of classes to sample from the dataset if using a balanced loader.
        K: Number of samples from each class to sample from the dataset if using a balanced loader.
        epochs: Number of epochs to train the embedding for.
    """
    train_dataset.image_mode()

    if loss_type.lower() == 'softmax':
        criterion = nn.CrossEntropyLoss().cuda()
        train_loader = train_dataset.getSingleLoader()
    elif loss_type.lower() == 'siamese':
        criterion = OnlineContrastiveLoss(1, HardNegativePairSelector())
        train_loader = train_dataset.getBalancedLoader(P=P, K=K)
    else:
        criterion = OnlineTripletLoss(1, RandomNegativeTripletSelector(1))
        train_loader = train_dataset.getBalancedLoader(P=P, K=K)

    params = model.parameters()
    optimizer = torch.optim.Adam(params, lr=0.0001)  #, weight_decay = 0.0005)
    e = Engine(model, criterion, optimizer, verbose=True, print_freq=10)

    for epoch in range(epochs):
        e.train_one_epoch(train_loader, epoch, False)
Пример #2
0
 def active(self, event):
     self.parentWidget().statusBar().showMessage("Start Learning")
     checkpoint = load_checkpoint('../merge/triplet_model_0054.tar')
     run_dataset = SQLDataLoader(DetectionKind.ModelDetection,
                                 "/home/pangolin/all_crops/SS_full_crops",
                                 False,
                                 num_workers=8,
                                 batch_size=2048)
     #run_dataset.setup(Detection.select(Detection.id,Category.id).join(Category).where(Detection.kind==DetectionKind.ModelDetection.value).limit(250000))
     num_classes = len(run_dataset.getClassesInfo())
     print("Num Classes= " + str(num_classes))
     run_loader = run_dataset.getSingleLoader()
     embedding_net = EmbeddingNet(checkpoint['arch'],
                                  checkpoint['feat_dim'])
     if checkpoint['loss_type'].lower() == 'center':
         model = torch.nn.DataParallel(
             ClassificationNet(embedding_net, n_classes=14)).cuda()
     else:
         model = torch.nn.DataParallel(embedding_net).cuda()
     model.load_state_dict(checkpoint['state_dict'])
     self.parentWidget().progressBar.setMaximum(len(run_dataset) // 2048)
     e = Engine(model,
                None,
                None,
                verbose=True,
                progressBar=self.parentWidget().progressBar)
     self.parentWidget().statusBar().showMessage("Extract Embeddings")
     embd, label, paths = e.predict(run_loader, load_info=True)
     self.parentWidget().statusBar().showMessage("Clustring Images")
     self.parentWidget().progressBar.setMaximum(0)
     new_selected = self.selectSamples(embd, paths, 300)
     self.tab1.update()
     self.tab1.showCurrentPage(force=True)
     self.parentWidget().statusBar().showMessage("Clustring Finished")
Пример #3
0
def finetune_embedding(model, train_dataset, P, K, epochs):
    train_dataset.image_mode()
    train_loader = train_dataset.getBalancedLoader(P=P, K=K)
    criterion = OnlineTripletLoss(1, RandomNegativeTripletSelector(1))
    params = model.parameters()
    optimizer = torch.optim.Adam(params, lr=0.0001)  #, weight_decay = 0.0005)
    e = Engine(model, criterion, optimizer, verbose=True, print_freq=10)
    for epoch in range(epochs):
        e.train_one_epoch(train_loader, epoch, False)
Пример #4
0
def train_eval_classifier(clf_model, unlabeled_dataset, model, pth, epochs=15):
    trainset_query = Detection.select(
        Detection.id, Oracle.label).join(Oracle).where(
            Detection.kind == DetectionKind.UserDetection.value)
    train_dataset = SQLDataLoader(trainset_query,
                                  os.path.join(args.run_data, 'crops'),
                                  is_training=True)
    train_dataset.updateEmbedding(model)
    train_dataset.embedding_mode()
    train_dataset.train()
    clf_criterion = nn.CrossEntropyLoss()
    clf_optimizer = torch.optim.Adam(clf_model.parameters(),
                                     lr=0.001,
                                     weight_decay=0.0005)
    clf_e = Engine(clf_model,
                   clf_criterion,
                   clf_optimizer,
                   verbose=True,
                   print_freq=1)

    clf_model.train()
    clf_train_loader = train_dataset.getSingleLoader(batch_size=64)
    for i in range(epochs):
        clf_e.train_one_epoch(clf_train_loader, i, True)
    clf_model.eval()
    unlabeled_dataset.embedding_mode()
    unlabeled_dataset.eval()
    eval_loader = unlabeled_dataset.getSingleLoader(batch_size=1024)
    clf_e.validate(eval_loader, True)
Пример #5
0
def main():
    global args, best_acc1
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    checkpoint = {}
    if args.resume != '':
        checkpoint = load_checkpoint(args.resume)
        args.loss_type = checkpoint['loss_type']
        args.feat_dim = checkpoint['feat_dim']
        best_accl = checkpoint['best_acc1']

    train_dataset = BaseDataLoader(args.train_data,
                                   True,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   raw_size=args.raw_size,
                                   processed_size=args.processed_size)
    if args.val_data is not None:
        val_dataset = BaseDataLoader(args.val_data,
                                     False,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers)
    num_classes = len(train_dataset.getClassesInfo()[0])

    if args.balanced_P == -1:
        args.balanced_P = num_classes
    print("Num Classes= " + str(num_classes))

    if args.loss_type.lower() == 'center':
        train_loader = train_dataset.getSingleLoader()
        train_embd_loader = train_loader
        if args.val_data is not None:
            val_loader = val_dataset.getSingleLoader()
            val_embd_loader = val_loader
    else:
        train_loader = train_dataset.getBalancedLoader(P=args.balanced_P,
                                                       K=args.balanced_K)
        train_embd_loader = train_dataset.getSingleLoader()
        if args.val_data is not None:
            val_loader = val_dataset.getBalancedLoader(P=args.balanced_P,
                                                       K=args.balanced_K)
            val_embd_loader = val_dataset.getSingleLoader()

    embedding_net = EmbeddingNet(args.arch, args.feat_dim, args.pretrained)
    center_loss = None
    if args.loss_type.lower() == 'center':
        model = torch.nn.DataParallel(
            ClassificationNet(embedding_net, n_classes=num_classes)).cuda()
        criterion = CenterLoss(num_classes=num_classes, feat_dim=args.feat_dim)
        params = list(model.parameters()) + list(criterion.parameters())
    else:
        model = torch.nn.DataParallel(embedding_net).cuda()
        if args.loss_type.lower() == 'siamese':
            criterion = OnlineContrastiveLoss(args.margin,
                                              RandomNegativePairSelector())
        else:
            criterion = OnlineTripletLoss(
                args.margin, SemihardNegativeTripletSelector(args.margin))
        params = model.parameters()

    # define loss function (criterion) and optimizer
    optimizer = torch.optim.Adam(params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    start_epoch = 0

    if checkpoint:
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        #optimizer.load_state_dict(checkpoint['optimizer'])
        if args.loss_type.lower() == 'center':
            criterion.load_state_dict(checkpoint['centers'])

    e = Engine(model,
               criterion,
               optimizer,
               verbose=True,
               print_freq=args.print_freq)
    for epoch in range(start_epoch, args.epochs):
        # train for one epoch

        e.train_one_epoch(
            train_loader, epoch,
            True if args.loss_type.lower() == 'center' else False)
        # evaluate on validation set
        if args.val_data is not None:
            e.validate(val_loader,
                       True if args.loss_type.lower() == 'center' else False)
        save_checkpoint(
            {
                'epoch':
                epoch + 1,
                'arch':
                args.arch,
                'state_dict':
                model.state_dict(),
                'best_acc1':
                best_acc1,
                'optimizer':
                optimizer.state_dict(),
                'loss_type':
                args.loss_type,
                'feat_dim':
                args.feat_dim,
                'centers':
                criterion.state_dict()
                if args.loss_type.lower() == 'center' else None
            }, False, "%s%s_model_%04d.tar" %
            (args.checkpoint_prefix, args.loss_type, epoch))
Пример #6
0
def completeLoop(X, y, base_ind):
    embedding_net = EmbeddingNet('densenet161', 256, True)
    center_loss = None
    model = torch.nn.DataParallel(embedding_net).cuda()
    criterion = OnlineTripletLoss(1.0, SemihardNegativeTripletSelector(1.0))
    params = model.parameters()

    # define loss function (criterion) and optimizer
    optimizer = torch.optim.Adam(params, lr=0.0001)
    start_epoch = 0
    checkpoint = load_checkpoint('triplet_model_0086.tar')
    if checkpoint:
        model.load_state_dict(checkpoint['state_dict'])

    e = Engine(model, criterion, optimizer, verbose=True, print_freq=10)
    trainset_query = Detection.select(
        Detection.id, Oracle.label, Detection.embedding).join(Oracle).where(
            Oracle.status == 0,
            Detection.kind == DetectionKind.UserDetection.value)
    embedding_dataset = SQLDataLoader(trainset_query,
                                      "all_crops/SS_full_crops",
                                      True,
                                      num_workers=4,
                                      batch_size=64)
    print(len(embedding_dataset))
    num_classes = 48  #len(run_dataset.getClassesInfo()[0])
    print("Num Classes= " + str(num_classes))
    embedding_loader = embedding_dataset.getBalancedLoader(16, 4)
    for i in range(200):
        e.train_one_epoch(embedding_loader, i, False)
    embedding_dataset2 = SQLDataLoader(trainset_query,
                                       "all_crops/SS_full_crops",
                                       False,
                                       num_workers=4,
                                       batch_size=512)
    em = e.embedding(embedding_dataset2.getSingleLoader())
    lb = np.asarray([x[1] for x in embedding_dataset2.samples])
    pt = [x[0] for x in embedding_dataset2.samples
          ]  #e.predict(run_loader, load_info=True)

    co = 0
    for r, s, t in zip(em, lb, pt):
        co += 1
        smp = Detection.get(id=t)
        smp.embedding = r
        smp.save()
        if co % 100 == 0:
            print(co)
    print("Loop Started")
    train_dataset = SQLDataLoader(trainset_query,
                                  "all_crops/SS_full_crops",
                                  False,
                                  datatype='embedding',
                                  num_workers=4,
                                  batch_size=64)
    print(len(train_dataset))
    num_classes = 48  #len(run_dataset.getClassesInfo()[0])
    print("Num Classes= " + str(num_classes))
    run_loader = train_dataset.getSingleLoader()
    clf_model = ClassificationNet(256, 48).cuda()
    clf_criterion = nn.CrossEntropyLoss()
    clf_optimizer = torch.optim.Adam(clf_model.parameters(),
                                     lr=0.001,
                                     weight_decay=0.0005)
    clf_e = Engine(clf_model,
                   clf_criterion,
                   clf_optimizer,
                   verbose=True,
                   print_freq=10)
    for i in range(250):
        clf_e.train_one_epoch(run_loader, i, True)
    testset_query = Detection.select(
        Detection.id, Oracle.label, Detection.embedding).join(Oracle).where(
            Oracle.status == 0,
            Detection.kind == DetectionKind.ModelDetection.value)
    test_dataset = SQLDataLoader(testset_query,
                                 "all_crops/SS_full_crops",
                                 False,
                                 datatype='image',
                                 num_workers=4,
                                 batch_size=512)
    print(len(test_dataset))
    num_classes = 48  #len(run_dataset.getClassesInfo()[0])
    print("Num Classes= " + str(num_classes))
    test_loader = test_dataset.getSingleLoader()
    test_em = e.embedding(test_loader)
    test_lb = np.asarray([x[1] for x in test_dataset.samples])
    print(test_lb.shape, test_em.shape)
    testset = TensorDataset(torch.from_numpy(test_em),
                            torch.from_numpy(test_lb))
    clf_e.validate(DataLoader(testset, batch_size=512, shuffle=False), True)
    """X_train= X[list(base_ind)]
Пример #7
0
def main():
    args = parser.parse_args()
    print("DB Connect")
    db_path = os.path.join(args.run_data, os.path.basename(
        args.run_data)) + ".db"
    print(db_path)
    db = SqliteDatabase(db_path)
    proxy.initialize(db)
    db.connect()
    print("connected")
    print("CompleteLoop")

    checkpoint = load_checkpoint(args.base_model)
    embedding_net = EmbeddingNet(checkpoint['arch'], checkpoint['feat_dim'],
                                 False)
    #embedding_net = EmbeddingNet('resnet50', 256, True)
    model = torch.nn.DataParallel(embedding_net).cuda()
    model.load_state_dict(checkpoint['state_dict'])
    #unlabeledset_query= Detection.select(Detection.id,Oracle.label).join(Oracle).where(Detection.kind==DetectionKind.ModelDetection.value).order_by(fn.random()).limit(150000)
    #unlabeled_dataset = SQLDataLoader(unlabeledset_query, os.path.join(args.run_data, "crops"), is_training= False, num_workers= 8)
    dataset = SQLDataLoader(os.path.join(args.run_data, "crops"),
                            is_training=False,
                            kind=DetectionKind.ModelDetection.value,
                            num_workers=8)
    dataset.updateEmbedding(model)
    #print('Embedding Done')
    #sys.stdout.flush()
    #plot_embedding(dataset.em[dataset.current_set], np.asarray(dataset.getlabels()) , dataset.getpaths(), {})
    # Random examples to start
    random_ids = np.random.choice(dataset.current_set, 5000,
                                  replace=False).tolist()
    #random_ids = selectSamples(dataset.em[dataset.current_set], dataset.current_set, 2000)
    #print(random_ids)
    # Move Records
    moveRecords(dataset, DetectionKind.ModelDetection.value,
                DetectionKind.UserDetection.value, random_ids)

    print([len(x) for x in dataset.set_indices])
    # Finetune the embedding model
    dataset.setKind(DetectionKind.UserDetection.value)
    dataset.train()
    #train_dataset = SQLDataLoader(trainset_query, os.path.join(args.run_data, 'crops'), is_training= True)
    finetune_embedding(model, dataset, 32, 4, 0)
    #unlabeled_dataset.updateEmbedding(model)
    dataset.updateEmbedding(model)
    dataset.setKind(DetectionKind.UserDetection.value)
    #print(dataset.em[dataset.current_set].shape, np.asarray(dataset.getlabels()).shape, len(dataset.getpaths()))
    #plot_embedding( dataset.em[dataset.current_set], np.asarray(dataset.getlabels()) , dataset.getpaths(), {})
    #plot_embedding( unlabeled_dataset.em, np.asarray(unlabeled_dataset.getlabels()) , unlabeled_dataset.getIDs(), {})
    dataset.embedding_mode()
    dataset.train()
    clf_model = ClassificationNet(256, 48).cuda()
    #train_eval_classifier()
    #clf_model = ClassificationNet(checkpoint['feat_dim'], 48).cuda()
    clf_criterion = FocalLoss(gamma=2)  #nn.CrossEntropyLoss()
    clf_optimizer = torch.optim.Adam(clf_model.parameters(),
                                     lr=0.001,
                                     weight_decay=0.0005)
    clf_e = Engine(clf_model,
                   clf_criterion,
                   clf_optimizer,
                   verbose=True,
                   print_freq=10)
    #names = ["Linear SVM", "RBF SVM", "Random Forest", "Neural Net", "Naive Bayes"]
    #classifiers = [SVC(kernel="linear", C=0.025, probability= True, class_weight='balanced'),
    #    SVC(gamma=2, C=1, probability= True, class_weight='balanced'),
    #    RandomForestClassifier(max_depth=None, n_estimators=100, class_weight='balanced'),
    #    MLPClassifier(alpha=1),
    #    GaussianNB()]
    #estimators= []
    #for name, clf in zip(names, classifiers):
    #    estimators.append((name, clf))
    #eclf1 = VotingClassifier(estimators= estimators, voting='hard')
    #eclf2 = VotingClassifier(estimators= estimators, voting='soft')
    #names.append("ensemble hard")
    #classifiers.append(eclf1)
    #names.append("ensemble soft")
    #classifiers.append(eclf2)
    names = ["Neural Net"]
    classifiers = [MLPClassifier(alpha=1)]
    """dataset.setKind(DetectionKind.UserDetection.value)

    learner = ActiveLearner(
            estimator=MLPClassifier(),
            query_strategy=uncertainty_sampling,
            X_training = dataset.em[dataset.current_set], y_training = np.asarray(dataset.getlabels()))

    for step in range(91):
        dataset.setKind(DetectionKind.ModelDetection.value)
        query_idx, query_inst = learner.query(dataset.em[dataset.current_set], n_instances=100)
        moveRecords(dataset, DetectionKind.ModelDetection.value, DetectionKind.UserDetection.value, [dataset.current_set[i] for i in query_idx])
        dataset.setKind(DetectionKind.UserDetection.value)
        learner.teach(dataset.em[dataset.current_set], np.asarray(dataset.getlabels()))
        if step in [11, 31, 51, 71, 91, 101]:
            finetune_embedding(model, dataset, 32, 4, 100)
            dataset.updateEmbedding(model)
            dataset.embedding_mode()
        dataset.setKind(DetectionKind.ModelDetection.value)
        print(learner.score(dataset.em[dataset.current_set], np.asarray(dataset.getlabels())))
        print([len(x) for x in dataset.set_indices])
        sys.stdout.flush()"""
    sampler = get_AL_sampler('uniform')(dataset.em[dataset.current_set],
                                        dataset.getlabels(), 12)
    print(sampler, type(sampler), dir(sampler))
    kwargs = {}
    kwargs["N"] = 100
    kwargs["already_selected"] = []
    kwargs["model"] = SVC(kernel="linear",
                          C=0.025,
                          probability=True,
                          class_weight='balanced')
    kwargs["model"].fit(dataset.em[dataset.current_set], dataset.getlabels())
    batch_AL = sampler.select_batch(**kwargs)
    print(batch_AL)
    for step in range(101):
        dataset.setKind(DetectionKind.UserDetection.value)
        clf_model.train()
        clf_train_loader = dataset.getSingleLoader(batch_size=64)
        for i in range(15):
            clf_e.train_one_epoch(clf_train_loader, i, True)
        clf_model.eval()
        X_train = dataset.em[dataset.current_set]
        y_train = np.asarray(dataset.getlabels())
        for name, clf in zip(names, classifiers):
            clf.fit(X_train, y_train)
            print(name)

        dataset.setKind(DetectionKind.ModelDetection.value)
        #dataset.image_mode()
        #dataset.updateEmbedding(model)
        dataset.embedding_mode()
        dataset.eval()
        eval_loader = dataset.getSingleLoader(batch_size=1024)
        clf_e.validate(eval_loader, True)
        X_test = dataset.em[dataset.current_set]
        y_test = np.asarray(dataset.getlabels())
        prob_list = []
        for name, clf in zip(names, classifiers):
            #y_pred= clf.predict(X_test)
            #print(confusion_matrix(y_test, y_pred))
            #paths= dataset.getpaths()
            #for i, (yp, yt) in enumerate(zip(y_pred, y_test)):
            #    if yp != yt:
            #copy(paths[i],"mistakes")
            #print(yt, yp, paths[i],i)
            if not name.startswith("ensemble"):
                prob_list.append(clf.predict_proba(X_test))
            score = clf.score(X_test, y_test)
            print(name, score)
        #clf_output= clf_e.embedding(eval_loader, dim=48)
        if step % 10 == 1 and step > 10:
            dataset.setKind(DetectionKind.UserDetection.value)
            finetune_embedding(model, dataset, 32, 4, 50)
            dataset.setKind(DetectionKind.ModelDetection.value)
            dataset.updateEmbedding(model)
            dataset.embedding_mode()
        indices = activeLearning(prob_list, dataset)
        moveRecords(dataset, DetectionKind.ModelDetection.value,
                    DetectionKind.UserDetection.value,
                    [dataset.current_set[i] for i in indices])
        print([len(x) for x in dataset.set_indices])
Пример #8
0
def main():
    args = parser.parse_args()
    print("DB Connect")
    db_path = os.path.join(args.run_data, os.path.basename(
        args.run_data)) + ".db"
    print(db_path)
    db = SqliteDatabase(db_path)
    proxy.initialize(db)
    db.connect()
    print("connected")
    print("CompleteLoop")

    checkpoint = load_checkpoint(args.base_model)
    embedding_net = EmbeddingNet(checkpoint['arch'], checkpoint['feat_dim'],
                                 False)
    #embedding_net = EmbeddingNet('resnet50', 256, True)
    model = torch.nn.DataParallel(embedding_net).cuda()
    model.load_state_dict(checkpoint['state_dict'])
    #unlabeledset_query= Detection.select(Detection.id,Oracle.label).join(Oracle).where(Detection.kind==DetectionKind.ModelDetection.value).order_by(fn.random()).limit(150000)
    #unlabeled_dataset = SQLDataLoader(unlabeledset_query, os.path.join(args.run_data, "crops"), is_training= False, num_workers= 8)
    unlabeled_dataset = SQLDataLoader(os.path.join(args.run_data, "crops"),
                                      is_training=False,
                                      kind=DetectionKind.ModelDetection.value,
                                      num_workers=8)
    unlabeled_dataset.updateEmbedding(model)
    #print('Embedding Done')
    #sys.stdout.flush()
    plot_embedding(unlabeled_dataset.em,
                   np.asarray(unlabeled_dataset.getlabels()),
                   unlabeled_dataset.getpaths(), {})
    # Random examples to start
    random_ids = np.random.choice(unlabeled_dataset.getIDs(),
                                  5000,
                                  replace=False).tolist()
    #random_ids = noveltySamples(unlabeled_dataset.em, unlabeled_dataset.getIDs(), 1000)
    #print(random_ids)
    # Move Records
    moveRecords(DetectionKind.ModelDetection, DetectionKind.UserDetection,
                random_ids)

    # Finetune the embedding model
    print(len(unlabeled_dataset))
    unlabeled_dataset.setKind(DetectionKind.UserDetection.value)
    unlabeled_dataset.train()
    print(len(unlabeled_dataset))
    #train_dataset = SQLDataLoader(trainset_query, os.path.join(args.run_data, 'crops'), is_training= True)
    finetune_embedding(model, unlabeled_dataset, 32, 4, 100)
    #unlabeled_dataset.updateEmbedding(model)
    train_dataset.updateEmbedding(model)
    plot_embedding(train_dataset.em, np.asarray(train_dataset.getlabels()),
                   train_dataset.getpaths(), {})
    #plot_embedding( unlabeled_dataset.em, np.asarray(unlabeled_dataset.getlabels()) , unlabeled_dataset.getIDs(), {})
    train_dataset.embedding_mode()
    train_dataset.train()
    clf_model = ClassificationNet(256, 48).cuda()
    train_eval_classifier()
    #clf_model = ClassificationNet(checkpoint['feat_dim'], 48).cuda()
    clf_criterion = nn.CrossEntropyLoss()
    clf_optimizer = torch.optim.Adam(clf_model.parameters(),
                                     lr=0.001,
                                     weight_decay=0.0005)
    clf_e = Engine(clf_model,
                   clf_criterion,
                   clf_optimizer,
                   verbose=True,
                   print_freq=1)

    clf_model.train()
    clf_train_loader = train_dataset.getSingleLoader(batch_size=64)
    for i in range(15):
        clf_e.train_one_epoch(clf_train_loader, i, True)
    clf_model.eval()
    unlabeledset_query = Detection.select(
        Detection.id, Oracle.label).join(Oracle).where(
            Detection.kind == DetectionKind.ModelDetection.value).order_by(
                fn.random()).limit(20000)
    unlabeled_dataset = SQLDataLoader(unlabeledset_query,
                                      os.path.join(args.run_data, 'crops'),
                                      is_training=False,
                                      num_workers=4)
    unlabeled_dataset.refresh(unlabeledset_query)
    unlabeled_dataset.updateEmbedding(model)
    unlabeled_dataset.embedding_mode()
    unlabeled_dataset.eval()
    eval_loader = unlabeled_dataset.getSingleLoader(batch_size=1024)
    clf_e.validate(eval_loader, True)
    clf_output = clf_e.embedding(eval_loader, dim=48)
    indices = activeLearning(clf_output, unlabeled_dataset.em)
    moveRecords(DetectionKind.ModelDetection, DetectionKind.UserDetection,
                [unlabeled_dataset.getIDs()[i] for i in indices])
Пример #9
0
def main():
    global args, best_acc1
    args = parser.parse_args()
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    checkpoint = {}
    if args.resume != '':
        checkpoint = load_checkpoint(args.resume)
        args.loss_type = checkpoint['loss_type']
        args.feat_dim = checkpoint['feat_dim']
        best_accl = checkpoint['best_acc1']

    db_path = os.path.join(args.train_data, os.path.basename(
        args.train_data)) + ".db"
    print(db_path)
    db = SqliteDatabase(db_path)
    proxy.initialize(db)
    db.connect()
    """
    to use full images
    train_query =  Detection.select(Detection.image_id,Oracle.label,Detection.kind).join(Oracle).order_by(fn.random()).limit(limit)
    
    train_dataset = SQLDataLoader('/lscratch/datasets/serengeti', is_training= True, num_workers= args.workers, 
            raw_size= args.raw_size, processed_size= args.processed_size)
    """
    train_dataset = SQLDataLoader(os.path.join(args.train_data, 'crops'),
                                  is_training=True,
                                  num_workers=args.workers,
                                  raw_size=args.raw_size,
                                  processed_size=args.processed_size)
    train_dataset.setKind(DetectionKind.UserDetection.value)
    if args.val_data is not None:
        val_dataset = SQLDataLoader(os.path.join(args.val_data, 'crops'),
                                    is_training=False,
                                    num_workers=args.workers)
    #num_classes= len(train_dataset.getClassesInfo()[0])
    num_classes = args.num_classes
    if args.balanced_P == -1:
        args.balanced_P = num_classes
    #print("Num Classes= "+str(num_classes))
    if args.loss_type.lower() == 'center' or args.loss_type.lower(
    ) == 'softmax':
        train_loader = train_dataset.getSingleLoader(
            batch_size=args.batch_size)
        train_embd_loader = train_loader
        if args.val_data is not None:
            val_loader = val_dataset.getSingleLoader(
                batch_size=args.batch_size)
            val_embd_loader = val_loader
    else:
        train_loader = train_dataset.getBalancedLoader(P=args.balanced_P,
                                                       K=args.balanced_K)
        train_embd_loader = train_dataset.getSingleLoader(
            batch_size=args.batch_size)
        if args.val_data is not None:
            val_loader = val_dataset.getBalancedLoader(P=args.balanced_P,
                                                       K=args.balanced_K)
            val_embd_loader = val_dataset.getSingleLoader(
                batch_size=args.batch_size)

    center_loss = None
    if args.loss_type.lower() == 'center' or args.loss_type.lower(
    ) == 'softmax':
        model = torch.nn.DataParallel(
            SoftmaxNet(args.arch,
                       args.feat_dim,
                       num_classes,
                       use_pretrained=args.pretrained)).cuda()
        if args.loss_type.lower() == 'center':
            criterion = CenterLoss(num_classes=num_classes,
                                   feat_dim=args.feat_dim)
            params = list(model.parameters()) + list(criterion.parameters())
        else:
            criterion = nn.CrossEntropyLoss().cuda()
            params = model.parameters()
    else:
        model = torch.nn.DataParallel(
            NormalizedEmbeddingNet(args.arch,
                                   args.feat_dim,
                                   use_pretrained=args.pretrained)).cuda()
        if args.loss_type.lower() == 'siamese':
            criterion = OnlineContrastiveLoss(args.margin,
                                              HardNegativePairSelector())
        else:
            criterion = OnlineTripletLoss(
                args.margin, RandomNegativeTripletSelector(args.margin))
        params = model.parameters()

    # define loss function (criterion) and optimizer
    optimizer = torch.optim.Adam(params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    #optimizer = torch.optim.SGD(params, momentum = 0.9, lr = args.lr, weight_decay = args.weight_decay)
    start_epoch = 0

    if checkpoint:
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        #optimizer.load_state_dict(checkpoint['optimizer'])
        if args.loss_type.lower() == 'center':
            criterion.load_state_dict(checkpoint['centers'])

    e = Engine(model,
               criterion,
               optimizer,
               verbose=True,
               print_freq=args.print_freq)
    for epoch in range(start_epoch, args.epochs):
        # train for one epoch
        #adjust_lr(optimizer,epoch)
        e.train_one_epoch(
            train_loader, epoch, True if args.loss_type.lower() == 'center'
            or args.loss_type.lower() == 'softmax' else False)
        #if epoch % 1 == 0 and epoch > 0:
        #    a, b, c = e.predict(train_embd_loader, load_info = True, dim = args.feat_dim)
        #    plot_embedding(reduce_dimensionality(a), b, c, {})
        # evaluate on validation set
        if args.val_data is not None:
            e.validate(val_loader,
                       True if args.loss_type.lower() == 'center' else False)
        save_checkpoint(
            {
                'epoch':
                epoch + 1,
                'arch':
                args.arch,
                'state_dict':
                model.state_dict(),
                'best_acc1':
                best_acc1,
                'optimizer':
                optimizer.state_dict(),
                'loss_type':
                args.loss_type,
                'num_classes':
                args.num_classes,
                'feat_dim':
                args.feat_dim,
                'centers':
                criterion.state_dict()
                if args.loss_type.lower() == 'center' else None
            }, False, "%s%s_%s_%04d.tar" %
            (args.checkpoint_prefix, args.loss_type, args.arch, epoch))