Beispiel #1
0
    model = torch.nn.DataParallel(embedding_net).cuda()
    model.load_state_dict(checkpoint['state_dict'])

    # ---------------------------------------------------------------------- #
    # CREATE QUEUE OF IMAGES TO LABEL
    # ---------------------------------------------------------------------- #
    dataset_query = (Detection.select(
        Detection.id, Detection.category_id, Detection.kind,
        Detection.category_confidence, Detection.bbox_confidence,
        Image.file_name,
        Image.grayscale).join(Image,
                              on=(Image.id == Detection.image)).order_by(
                                  fn.Random()).limit(args.db_query_limit))
    dataset = SQLDataLoader(args.crop_dir,
                            query=dataset_query,
                            is_training=False,
                            kind=DetectionKind.ModelDetection.value,
                            num_workers=8)

    grayscale_values = [rec[6] for rec in dataset.samples]
    grayscale_indices = list(
        itertools.compress(range(len(grayscale_values)),
                           grayscale_values))  # records with grayscale images
    color_indices = list(
        set(range(len(dataset.samples))) -
        set(grayscale_indices))  # records with color images
    detection_conf_values = [rec[4] for rec in dataset.samples]
    dataset.updateEmbedding(model)
    dataset.embedding_mode()
    dataset.train()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--db_name',
                        default='missouricameratraps',
                        type=str,
                        help='Name of the training (target) data Postgres DB.')
    parser.add_argument('--db_user',
                        default='user',
                        type=str,
                        help='Name of the user accessing the Postgres DB.')
    parser.add_argument('--db_password',
                        default='password',
                        type=str,
                        help='Password of the user accessing the Postgres DB.')
    parser.add_argument(
        '--num',
        default=1000,
        type=int,
        help='Number of samples to draw from dataset to get embedding features.'
    )
    parser.add_argument(
        '--crop_dir',
        type=str,
        help=
        'Path to directory with cropped images to get embedding features for.')
    parser.add_argument('--base_model',
                        type=str,
                        help='Path to latest embedding model checkpoint.')
    parser.add_argument('--random_seed',
                        default=1234,
                        type=int,
                        help='Random seed to get same samples from database.')
    args = parser.parse_args()

    random.seed(args.random_seed)
    np.random.seed(args.random_seed)

    BASE_MODEL = args.base_model

    # Load the saved embedding model from the checkpoint
    checkpoint = load_checkpoint(BASE_MODEL)
    if checkpoint['loss_type'].lower(
    ) == 'center' or checkpoint['loss_type'].lower() == 'softmax':
        embedding_net = SoftmaxNet(checkpoint['arch'], checkpoint['feat_dim'],
                                   checkpoint['num_classes'], False)
    else:
        embedding_net = NormalizedEmbeddingNet(checkpoint['arch'],
                                               checkpoint['feat_dim'], False)
    model = torch.nn.DataParallel(embedding_net).cuda()
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    # Get a sample from the database, with eval transforms applied, etc.
    DB_NAME = args.db_name
    USER = args.db_user
    PASSWORD = args.db_password

    # Connect to database and sample a dataset
    target_db = PostgresqlDatabase(DB_NAME,
                                   user=USER,
                                   password=PASSWORD,
                                   host='localhost')
    target_db.connect(reuse_if_open=True)
    db_proxy.initialize(target_db)
    dataset_query = Detection.select(Detection.image_id, Oracle.label,
                                     Detection.kind).join(Oracle).limit(
                                         args.num)
    dataset = SQLDataLoader(args.crop_dir,
                            query=dataset_query,
                            is_training=False,
                            kind=DetectionKind.ModelDetection.value,
                            num_workers=8,
                            limit=args.num)
    imagepaths = dataset.getallpaths()

    sample_image_path = '0ca68a6f-6348-4456-8fb5-c067e2cbfe14'  #'0a170ee9-166d-45df-8f45-b14550fc124e'#'43e3d2a6-38ea-4d17-a712-1b0feab92d58'#'0ef0f79f-7b58-473d-abbf-75bba59e834d'
    dataset.image_mode()
    sample_image = dataset.loader(sample_image_path)
    sample_image.save('sample_image_for_activations.png')
    print(sample_image.size)
    sample_image = dataset.eval_transform(sample_image)
    print(sample_image.shape)

    # output = model.forward(sample_image.unsqueeze(0))
    # print(output)

    model_inner_resnet = list(model.children())[0].inner_model
    model_inner_resnet.eval()
    model_inner_resnet.layer1[0].conv2.register_forward_hook(hook)

    output = model.forward(sample_image.unsqueeze(0))
    intermediate_output = outputs[0].cpu().detach().numpy()
    print(intermediate_output.shape)

    for i in range(intermediate_output.shape[1]):
        plt.subplot(8, 8, i + 1)
        plt.imshow(intermediate_output[0, i, :, :], cmap='viridis')
        plt.axis('off')
    plt.suptitle('ResNet Layer1 Conv2 Activations')
    plt.savefig('temp.png')
Beispiel #3
0
def main():
    args = parser.parse_args()

    # Initialize Database
    ## database connection credentials
    DB_NAME = args.db_name
    USER = args.db_user
    PASSWORD = args.db_password
    print("DB Connect")
    ## try to connect as USER to database DB_NAME through peewee
    target_db = PostgresqlDatabase(DB_NAME,
                                   user=USER,
                                   password=PASSWORD,
                                   host='localhost')
    target_db.connect(reuse_if_open=True)
    db_proxy.initialize(target_db)
    print("connected")

    # Load the saved embedding model
    checkpoint = load_checkpoint(args.base_model)
    if args.experiment_name == '':
        args.experiment_name = "experiment_%s_%s" % (checkpoint['loss_type'],
                                                     args.strategy)
    if not os.path.exists(args.experiment_name):
        os.mkdir(args.experiment_name)

    if checkpoint['loss_type'].lower(
    ) == 'center' or checkpoint['loss_type'].lower() == 'softmax':
        embedding_net = SoftmaxNet(checkpoint['arch'], checkpoint['feat_dim'],
                                   checkpoint['num_classes'], False)
    else:
        embedding_net = NormalizedEmbeddingNet(checkpoint['arch'],
                                               checkpoint['feat_dim'], False)

    model = torch.nn.DataParallel(embedding_net).cuda()
    model.load_state_dict(checkpoint['state_dict'])

    # dataset_query = Detection.select().limit(5)
    dataset_query = Detection.select(
        Detection.image_id, Oracle.label,
        Detection.kind).join(Oracle).order_by(fn.random()).limit(
            args.db_query_limit
        )  ## TODO: should this really be order_by random?
    dataset = SQLDataLoader(args.crop_dir,
                            query=dataset_query,
                            is_training=False,
                            kind=DetectionKind.ModelDetection.value,
                            num_workers=8,
                            limit=args.db_query_limit)
    dataset.updateEmbedding(model)
    # plot_embedding_images(dataset.em[:], np.asarray(dataset.getlabels()) , dataset.getpaths(), {})
    # plot_embedding_images(dataset.em[:], np.asarray(dataset.getalllabels()) , dataset.getallpaths(), {})

    # Random examples to start
    #random_ids = np.random.choice(dataset.current_set, 1000, replace=False).tolist()
    #random_ids = selectSamples(dataset.em[dataset.current_set], dataset.current_set, 2000)
    #print(random_ids)
    # Move Records
    #moveRecords(dataset, DetectionKind.ModelDetection.value, DetectionKind.UserDetection.value, random_ids)

    # #print([len(x) for x in dataset.set_indices])
    # # Finetune the embedding model
    # #dataset.set_kind(DetectionKind.UserDetection.value)
    # #dataset.train()
    # #train_dataset = SQLDataLoader(trainset_query, os.path.join(args.run_data, 'crops'), is_training= True)
    # #finetune_embedding(model, checkpoint['loss_type'], dataset, 32, 4, 100)
    # #save_checkpoint({
    # #        'arch': model.arch,
    # #        'state_dict': model.state_dict(),
    # #        'optimizer' : optimizer.state_dict(),
    # #        'loss_type' : loss_type,
    # #        }, False, "%s%s_%s_%04d.tar"%('finetuned', loss_type, model.arch, len(dataset.set_indices[DetectionKind.UserDetection.value])))

    dataset.embedding_mode()
    dataset.train()
    sampler = get_AL_sampler(args.strategy)(dataset.em, dataset.getalllabels(),
                                            12)

    kwargs = {}
    kwargs["N"] = args.active_batch
    kwargs["already_selected"] = dataset.set_indices[
        DetectionKind.UserDetection.value]
    kwargs["model"] = MLPClassifier(alpha=0.0001)

    print("Start the active learning loop")
    sys.stdout.flush()
    numLabeled = len(dataset.set_indices[DetectionKind.UserDetection.value])
    while numLabeled <= args.active_budget:
        print([len(x) for x in dataset.set_indices])
        sys.stdout.flush()

        # Get indices of samples to get user to label
        if numLabeled == 0:
            indices = np.random.choice(dataset.current_set,
                                       kwargs["N"],
                                       replace=False).tolist()
        else:
            indices = sampler.select_batch(**kwargs)
        # numLabeled = len(dataset.set_indices[DetectionKind.UserDetection.value])
        #kwargs["already_selected"].extend(indices)
        moveRecords(dataset, DetectionKind.ModelDetection.value,
                    DetectionKind.UserDetection.value, indices)
        numLabeled = len(
            dataset.set_indices[DetectionKind.UserDetection.value])

        # Train on samples that have been labeled so far
        dataset.set_kind(DetectionKind.UserDetection.value)
        X_train = dataset.em[dataset.current_set]
        y_train = np.asarray(dataset.getlabels())

        kwargs["model"].fit(X_train, y_train)
        joblib.dump(
            kwargs["model"], "%s/%s_%04d.skmodel" %
            (args.experiment_name, 'classifier', numLabeled))

        # Test on the samples that have not been labeled
        dataset.set_kind(DetectionKind.ModelDetection.value)
        dataset.embedding_mode()
        X_test = dataset.em[dataset.current_set]
        y_test = np.asarray(dataset.getlabels())
        print("Accuracy", kwargs["model"].score(X_test, y_test))

        sys.stdout.flush()
        if numLabeled % 2000 == 1000:
            dataset.set_kind(DetectionKind.UserDetection.value)
            finetune_embedding(model, checkpoint['loss_type'], dataset, 10, 4,
                               100 if numLabeled == 1000 else 50)
            save_checkpoint(
                {
                    'arch': checkpoint['arch'],
                    'state_dict': model.state_dict(),
                    #'optimizer' : optimizer.state_dict(),
                    'loss_type': checkpoint['loss_type'],
                    'feat_dim': checkpoint['feat_dim'],
                    'num_classes': args.num_classes
                },
                False,
                "%s/%s%s_%s_%04d.tar" %
                (args.experiment_name, 'finetuned', checkpoint['loss_type'],
                 checkpoint['arch'], numLabeled))

            dataset.set_kind(DetectionKind.ModelDetection.value)
            dataset.updateEmbedding(model)
            dataset.embedding_mode()
Beispiel #4
0
def main():
    args = parser.parse_args()
    print("DB Connect")
    db_path = os.path.join(args.run_data, os.path.basename(
        args.run_data)) + ".db"
    print(db_path)
    db = SqliteDatabase(db_path)
    proxy.initialize(db)
    db.connect()
    print("connected")
    print("CompleteLoop")

    checkpoint = load_checkpoint(args.base_model)
    embedding_net = EmbeddingNet(checkpoint['arch'], checkpoint['feat_dim'],
                                 False)
    #embedding_net = EmbeddingNet('resnet50', 256, True)
    model = torch.nn.DataParallel(embedding_net).cuda()
    model.load_state_dict(checkpoint['state_dict'])
    #unlabeledset_query= Detection.select(Detection.id,Oracle.label).join(Oracle).where(Detection.kind==DetectionKind.ModelDetection.value).order_by(fn.random()).limit(150000)
    #unlabeled_dataset = SQLDataLoader(unlabeledset_query, os.path.join(args.run_data, "crops"), is_training= False, num_workers= 8)
    dataset = SQLDataLoader(os.path.join(args.run_data, "crops"),
                            is_training=False,
                            kind=DetectionKind.ModelDetection.value,
                            num_workers=8)
    dataset.updateEmbedding(model)
    #print('Embedding Done')
    #sys.stdout.flush()
    #plot_embedding(dataset.em[dataset.current_set], np.asarray(dataset.getlabels()) , dataset.getpaths(), {})
    # Random examples to start
    random_ids = np.random.choice(dataset.current_set, 5000,
                                  replace=False).tolist()
    #random_ids = selectSamples(dataset.em[dataset.current_set], dataset.current_set, 2000)
    #print(random_ids)
    # Move Records
    moveRecords(dataset, DetectionKind.ModelDetection.value,
                DetectionKind.UserDetection.value, random_ids)

    print([len(x) for x in dataset.set_indices])
    # Finetune the embedding model
    dataset.setKind(DetectionKind.UserDetection.value)
    dataset.train()
    #train_dataset = SQLDataLoader(trainset_query, os.path.join(args.run_data, 'crops'), is_training= True)
    finetune_embedding(model, dataset, 32, 4, 0)
    #unlabeled_dataset.updateEmbedding(model)
    dataset.updateEmbedding(model)
    dataset.setKind(DetectionKind.UserDetection.value)
    #print(dataset.em[dataset.current_set].shape, np.asarray(dataset.getlabels()).shape, len(dataset.getpaths()))
    #plot_embedding( dataset.em[dataset.current_set], np.asarray(dataset.getlabels()) , dataset.getpaths(), {})
    #plot_embedding( unlabeled_dataset.em, np.asarray(unlabeled_dataset.getlabels()) , unlabeled_dataset.getIDs(), {})
    dataset.embedding_mode()
    dataset.train()
    clf_model = ClassificationNet(256, 48).cuda()
    #train_eval_classifier()
    #clf_model = ClassificationNet(checkpoint['feat_dim'], 48).cuda()
    clf_criterion = FocalLoss(gamma=2)  #nn.CrossEntropyLoss()
    clf_optimizer = torch.optim.Adam(clf_model.parameters(),
                                     lr=0.001,
                                     weight_decay=0.0005)
    clf_e = Engine(clf_model,
                   clf_criterion,
                   clf_optimizer,
                   verbose=True,
                   print_freq=10)
    #names = ["Linear SVM", "RBF SVM", "Random Forest", "Neural Net", "Naive Bayes"]
    #classifiers = [SVC(kernel="linear", C=0.025, probability= True, class_weight='balanced'),
    #    SVC(gamma=2, C=1, probability= True, class_weight='balanced'),
    #    RandomForestClassifier(max_depth=None, n_estimators=100, class_weight='balanced'),
    #    MLPClassifier(alpha=1),
    #    GaussianNB()]
    #estimators= []
    #for name, clf in zip(names, classifiers):
    #    estimators.append((name, clf))
    #eclf1 = VotingClassifier(estimators= estimators, voting='hard')
    #eclf2 = VotingClassifier(estimators= estimators, voting='soft')
    #names.append("ensemble hard")
    #classifiers.append(eclf1)
    #names.append("ensemble soft")
    #classifiers.append(eclf2)
    names = ["Neural Net"]
    classifiers = [MLPClassifier(alpha=1)]
    """dataset.setKind(DetectionKind.UserDetection.value)

    learner = ActiveLearner(
            estimator=MLPClassifier(),
            query_strategy=uncertainty_sampling,
            X_training = dataset.em[dataset.current_set], y_training = np.asarray(dataset.getlabels()))

    for step in range(91):
        dataset.setKind(DetectionKind.ModelDetection.value)
        query_idx, query_inst = learner.query(dataset.em[dataset.current_set], n_instances=100)
        moveRecords(dataset, DetectionKind.ModelDetection.value, DetectionKind.UserDetection.value, [dataset.current_set[i] for i in query_idx])
        dataset.setKind(DetectionKind.UserDetection.value)
        learner.teach(dataset.em[dataset.current_set], np.asarray(dataset.getlabels()))
        if step in [11, 31, 51, 71, 91, 101]:
            finetune_embedding(model, dataset, 32, 4, 100)
            dataset.updateEmbedding(model)
            dataset.embedding_mode()
        dataset.setKind(DetectionKind.ModelDetection.value)
        print(learner.score(dataset.em[dataset.current_set], np.asarray(dataset.getlabels())))
        print([len(x) for x in dataset.set_indices])
        sys.stdout.flush()"""
    sampler = get_AL_sampler('uniform')(dataset.em[dataset.current_set],
                                        dataset.getlabels(), 12)
    print(sampler, type(sampler), dir(sampler))
    kwargs = {}
    kwargs["N"] = 100
    kwargs["already_selected"] = []
    kwargs["model"] = SVC(kernel="linear",
                          C=0.025,
                          probability=True,
                          class_weight='balanced')
    kwargs["model"].fit(dataset.em[dataset.current_set], dataset.getlabels())
    batch_AL = sampler.select_batch(**kwargs)
    print(batch_AL)
    for step in range(101):
        dataset.setKind(DetectionKind.UserDetection.value)
        clf_model.train()
        clf_train_loader = dataset.getSingleLoader(batch_size=64)
        for i in range(15):
            clf_e.train_one_epoch(clf_train_loader, i, True)
        clf_model.eval()
        X_train = dataset.em[dataset.current_set]
        y_train = np.asarray(dataset.getlabels())
        for name, clf in zip(names, classifiers):
            clf.fit(X_train, y_train)
            print(name)

        dataset.setKind(DetectionKind.ModelDetection.value)
        #dataset.image_mode()
        #dataset.updateEmbedding(model)
        dataset.embedding_mode()
        dataset.eval()
        eval_loader = dataset.getSingleLoader(batch_size=1024)
        clf_e.validate(eval_loader, True)
        X_test = dataset.em[dataset.current_set]
        y_test = np.asarray(dataset.getlabels())
        prob_list = []
        for name, clf in zip(names, classifiers):
            #y_pred= clf.predict(X_test)
            #print(confusion_matrix(y_test, y_pred))
            #paths= dataset.getpaths()
            #for i, (yp, yt) in enumerate(zip(y_pred, y_test)):
            #    if yp != yt:
            #copy(paths[i],"mistakes")
            #print(yt, yp, paths[i],i)
            if not name.startswith("ensemble"):
                prob_list.append(clf.predict_proba(X_test))
            score = clf.score(X_test, y_test)
            print(name, score)
        #clf_output= clf_e.embedding(eval_loader, dim=48)
        if step % 10 == 1 and step > 10:
            dataset.setKind(DetectionKind.UserDetection.value)
            finetune_embedding(model, dataset, 32, 4, 50)
            dataset.setKind(DetectionKind.ModelDetection.value)
            dataset.updateEmbedding(model)
            dataset.embedding_mode()
        indices = activeLearning(prob_list, dataset)
        moveRecords(dataset, DetectionKind.ModelDetection.value,
                    DetectionKind.UserDetection.value,
                    [dataset.current_set[i] for i in indices])
        print([len(x) for x in dataset.set_indices])
Beispiel #5
0
def main():
    args = parser.parse_args()
    print("DB Connect")
    db_path = os.path.join(args.run_data, os.path.basename(
        args.run_data)) + ".db"
    print(db_path)
    db = SqliteDatabase(db_path)
    proxy.initialize(db)
    db.connect()
    print("connected")
    print("CompleteLoop")

    checkpoint = load_checkpoint(args.base_model)
    embedding_net = EmbeddingNet(checkpoint['arch'], checkpoint['feat_dim'],
                                 False)
    #embedding_net = EmbeddingNet('resnet50', 256, True)
    model = torch.nn.DataParallel(embedding_net).cuda()
    model.load_state_dict(checkpoint['state_dict'])
    #unlabeledset_query= Detection.select(Detection.id,Oracle.label).join(Oracle).where(Detection.kind==DetectionKind.ModelDetection.value).order_by(fn.random()).limit(150000)
    #unlabeled_dataset = SQLDataLoader(unlabeledset_query, os.path.join(args.run_data, "crops"), is_training= False, num_workers= 8)
    unlabeled_dataset = SQLDataLoader(os.path.join(args.run_data, "crops"),
                                      is_training=False,
                                      kind=DetectionKind.ModelDetection.value,
                                      num_workers=8)
    unlabeled_dataset.updateEmbedding(model)
    #print('Embedding Done')
    #sys.stdout.flush()
    plot_embedding(unlabeled_dataset.em,
                   np.asarray(unlabeled_dataset.getlabels()),
                   unlabeled_dataset.getpaths(), {})
    # Random examples to start
    random_ids = np.random.choice(unlabeled_dataset.getIDs(),
                                  5000,
                                  replace=False).tolist()
    #random_ids = noveltySamples(unlabeled_dataset.em, unlabeled_dataset.getIDs(), 1000)
    #print(random_ids)
    # Move Records
    moveRecords(DetectionKind.ModelDetection, DetectionKind.UserDetection,
                random_ids)

    # Finetune the embedding model
    print(len(unlabeled_dataset))
    unlabeled_dataset.setKind(DetectionKind.UserDetection.value)
    unlabeled_dataset.train()
    print(len(unlabeled_dataset))
    #train_dataset = SQLDataLoader(trainset_query, os.path.join(args.run_data, 'crops'), is_training= True)
    finetune_embedding(model, unlabeled_dataset, 32, 4, 100)
    #unlabeled_dataset.updateEmbedding(model)
    train_dataset.updateEmbedding(model)
    plot_embedding(train_dataset.em, np.asarray(train_dataset.getlabels()),
                   train_dataset.getpaths(), {})
    #plot_embedding( unlabeled_dataset.em, np.asarray(unlabeled_dataset.getlabels()) , unlabeled_dataset.getIDs(), {})
    train_dataset.embedding_mode()
    train_dataset.train()
    clf_model = ClassificationNet(256, 48).cuda()
    train_eval_classifier()
    #clf_model = ClassificationNet(checkpoint['feat_dim'], 48).cuda()
    clf_criterion = nn.CrossEntropyLoss()
    clf_optimizer = torch.optim.Adam(clf_model.parameters(),
                                     lr=0.001,
                                     weight_decay=0.0005)
    clf_e = Engine(clf_model,
                   clf_criterion,
                   clf_optimizer,
                   verbose=True,
                   print_freq=1)

    clf_model.train()
    clf_train_loader = train_dataset.getSingleLoader(batch_size=64)
    for i in range(15):
        clf_e.train_one_epoch(clf_train_loader, i, True)
    clf_model.eval()
    unlabeledset_query = Detection.select(
        Detection.id, Oracle.label).join(Oracle).where(
            Detection.kind == DetectionKind.ModelDetection.value).order_by(
                fn.random()).limit(20000)
    unlabeled_dataset = SQLDataLoader(unlabeledset_query,
                                      os.path.join(args.run_data, 'crops'),
                                      is_training=False,
                                      num_workers=4)
    unlabeled_dataset.refresh(unlabeledset_query)
    unlabeled_dataset.updateEmbedding(model)
    unlabeled_dataset.embedding_mode()
    unlabeled_dataset.eval()
    eval_loader = unlabeled_dataset.getSingleLoader(batch_size=1024)
    clf_e.validate(eval_loader, True)
    clf_output = clf_e.embedding(eval_loader, dim=48)
    indices = activeLearning(clf_output, unlabeled_dataset.em)
    moveRecords(DetectionKind.ModelDetection, DetectionKind.UserDetection,
                [unlabeled_dataset.getIDs()[i] for i in indices])
Beispiel #6
0
def main():
    global args, best_acc1
    args = parser.parse_args()
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    checkpoint = {}
    if args.resume != '':
        checkpoint = load_checkpoint(args.resume)
        args.loss_type = checkpoint['loss_type']
        args.feat_dim = checkpoint['feat_dim']
        best_accl = checkpoint['best_acc1']

    db_path = os.path.join(args.train_data, os.path.basename(
        args.train_data)) + ".db"
    print(db_path)
    db = SqliteDatabase(db_path)
    proxy.initialize(db)
    db.connect()
    """
    to use full images
    train_query =  Detection.select(Detection.image_id,Oracle.label,Detection.kind).join(Oracle).order_by(fn.random()).limit(limit)
    
    train_dataset = SQLDataLoader('/lscratch/datasets/serengeti', is_training= True, num_workers= args.workers, 
            raw_size= args.raw_size, processed_size= args.processed_size)
    """
    train_dataset = SQLDataLoader(os.path.join(args.train_data, 'crops'),
                                  is_training=True,
                                  num_workers=args.workers,
                                  raw_size=args.raw_size,
                                  processed_size=args.processed_size)
    train_dataset.setKind(DetectionKind.UserDetection.value)
    if args.val_data is not None:
        val_dataset = SQLDataLoader(os.path.join(args.val_data, 'crops'),
                                    is_training=False,
                                    num_workers=args.workers)
    #num_classes= len(train_dataset.getClassesInfo()[0])
    num_classes = args.num_classes
    if args.balanced_P == -1:
        args.balanced_P = num_classes
    #print("Num Classes= "+str(num_classes))
    if args.loss_type.lower() == 'center' or args.loss_type.lower(
    ) == 'softmax':
        train_loader = train_dataset.getSingleLoader(
            batch_size=args.batch_size)
        train_embd_loader = train_loader
        if args.val_data is not None:
            val_loader = val_dataset.getSingleLoader(
                batch_size=args.batch_size)
            val_embd_loader = val_loader
    else:
        train_loader = train_dataset.getBalancedLoader(P=args.balanced_P,
                                                       K=args.balanced_K)
        train_embd_loader = train_dataset.getSingleLoader(
            batch_size=args.batch_size)
        if args.val_data is not None:
            val_loader = val_dataset.getBalancedLoader(P=args.balanced_P,
                                                       K=args.balanced_K)
            val_embd_loader = val_dataset.getSingleLoader(
                batch_size=args.batch_size)

    center_loss = None
    if args.loss_type.lower() == 'center' or args.loss_type.lower(
    ) == 'softmax':
        model = torch.nn.DataParallel(
            SoftmaxNet(args.arch,
                       args.feat_dim,
                       num_classes,
                       use_pretrained=args.pretrained)).cuda()
        if args.loss_type.lower() == 'center':
            criterion = CenterLoss(num_classes=num_classes,
                                   feat_dim=args.feat_dim)
            params = list(model.parameters()) + list(criterion.parameters())
        else:
            criterion = nn.CrossEntropyLoss().cuda()
            params = model.parameters()
    else:
        model = torch.nn.DataParallel(
            NormalizedEmbeddingNet(args.arch,
                                   args.feat_dim,
                                   use_pretrained=args.pretrained)).cuda()
        if args.loss_type.lower() == 'siamese':
            criterion = OnlineContrastiveLoss(args.margin,
                                              HardNegativePairSelector())
        else:
            criterion = OnlineTripletLoss(
                args.margin, RandomNegativeTripletSelector(args.margin))
        params = model.parameters()

    # define loss function (criterion) and optimizer
    optimizer = torch.optim.Adam(params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    #optimizer = torch.optim.SGD(params, momentum = 0.9, lr = args.lr, weight_decay = args.weight_decay)
    start_epoch = 0

    if checkpoint:
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        #optimizer.load_state_dict(checkpoint['optimizer'])
        if args.loss_type.lower() == 'center':
            criterion.load_state_dict(checkpoint['centers'])

    e = Engine(model,
               criterion,
               optimizer,
               verbose=True,
               print_freq=args.print_freq)
    for epoch in range(start_epoch, args.epochs):
        # train for one epoch
        #adjust_lr(optimizer,epoch)
        e.train_one_epoch(
            train_loader, epoch, True if args.loss_type.lower() == 'center'
            or args.loss_type.lower() == 'softmax' else False)
        #if epoch % 1 == 0 and epoch > 0:
        #    a, b, c = e.predict(train_embd_loader, load_info = True, dim = args.feat_dim)
        #    plot_embedding(reduce_dimensionality(a), b, c, {})
        # evaluate on validation set
        if args.val_data is not None:
            e.validate(val_loader,
                       True if args.loss_type.lower() == 'center' else False)
        save_checkpoint(
            {
                'epoch':
                epoch + 1,
                'arch':
                args.arch,
                'state_dict':
                model.state_dict(),
                'best_acc1':
                best_acc1,
                'optimizer':
                optimizer.state_dict(),
                'loss_type':
                args.loss_type,
                'num_classes':
                args.num_classes,
                'feat_dim':
                args.feat_dim,
                'centers':
                criterion.state_dict()
                if args.loss_type.lower() == 'center' else None
            }, False, "%s%s_%s_%04d.tar" %
            (args.checkpoint_prefix, args.loss_type, args.arch, epoch))