Ejemplo n.º 1
0
def completeLoop(X, y, base_ind):
    embedding_net = EmbeddingNet('densenet161', 256, True)
    center_loss = None
    model = torch.nn.DataParallel(embedding_net).cuda()
    criterion = OnlineTripletLoss(1.0, SemihardNegativeTripletSelector(1.0))
    params = model.parameters()

    # define loss function (criterion) and optimizer
    optimizer = torch.optim.Adam(params, lr=0.0001)
    start_epoch = 0
    checkpoint = load_checkpoint('triplet_model_0086.tar')
    if checkpoint:
        model.load_state_dict(checkpoint['state_dict'])

    e = Engine(model, criterion, optimizer, verbose=True, print_freq=10)
    trainset_query = Detection.select(
        Detection.id, Oracle.label, Detection.embedding).join(Oracle).where(
            Oracle.status == 0,
            Detection.kind == DetectionKind.UserDetection.value)
    embedding_dataset = SQLDataLoader(trainset_query,
                                      "all_crops/SS_full_crops",
                                      True,
                                      num_workers=4,
                                      batch_size=64)
    print(len(embedding_dataset))
    num_classes = 48  #len(run_dataset.getClassesInfo()[0])
    print("Num Classes= " + str(num_classes))
    embedding_loader = embedding_dataset.getBalancedLoader(16, 4)
    for i in range(200):
        e.train_one_epoch(embedding_loader, i, False)
    embedding_dataset2 = SQLDataLoader(trainset_query,
                                       "all_crops/SS_full_crops",
                                       False,
                                       num_workers=4,
                                       batch_size=512)
    em = e.embedding(embedding_dataset2.getSingleLoader())
    lb = np.asarray([x[1] for x in embedding_dataset2.samples])
    pt = [x[0] for x in embedding_dataset2.samples
          ]  #e.predict(run_loader, load_info=True)

    co = 0
    for r, s, t in zip(em, lb, pt):
        co += 1
        smp = Detection.get(id=t)
        smp.embedding = r
        smp.save()
        if co % 100 == 0:
            print(co)
    print("Loop Started")
    train_dataset = SQLDataLoader(trainset_query,
                                  "all_crops/SS_full_crops",
                                  False,
                                  datatype='embedding',
                                  num_workers=4,
                                  batch_size=64)
    print(len(train_dataset))
    num_classes = 48  #len(run_dataset.getClassesInfo()[0])
    print("Num Classes= " + str(num_classes))
    run_loader = train_dataset.getSingleLoader()
    clf_model = ClassificationNet(256, 48).cuda()
    clf_criterion = nn.CrossEntropyLoss()
    clf_optimizer = torch.optim.Adam(clf_model.parameters(),
                                     lr=0.001,
                                     weight_decay=0.0005)
    clf_e = Engine(clf_model,
                   clf_criterion,
                   clf_optimizer,
                   verbose=True,
                   print_freq=10)
    for i in range(250):
        clf_e.train_one_epoch(run_loader, i, True)
    testset_query = Detection.select(
        Detection.id, Oracle.label, Detection.embedding).join(Oracle).where(
            Oracle.status == 0,
            Detection.kind == DetectionKind.ModelDetection.value)
    test_dataset = SQLDataLoader(testset_query,
                                 "all_crops/SS_full_crops",
                                 False,
                                 datatype='image',
                                 num_workers=4,
                                 batch_size=512)
    print(len(test_dataset))
    num_classes = 48  #len(run_dataset.getClassesInfo()[0])
    print("Num Classes= " + str(num_classes))
    test_loader = test_dataset.getSingleLoader()
    test_em = e.embedding(test_loader)
    test_lb = np.asarray([x[1] for x in test_dataset.samples])
    print(test_lb.shape, test_em.shape)
    testset = TensorDataset(torch.from_numpy(test_em),
                            torch.from_numpy(test_lb))
    clf_e.validate(DataLoader(testset, batch_size=512, shuffle=False), True)
    """X_train= X[list(base_ind)]
Ejemplo n.º 2
0
def main():
    global args, best_acc1
    args = parser.parse_args()
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    checkpoint = {}
    if args.resume != '':
        checkpoint = load_checkpoint(args.resume)
        args.loss_type = checkpoint['loss_type']
        args.feat_dim = checkpoint['feat_dim']
        best_accl = checkpoint['best_acc1']

    db_path = os.path.join(args.train_data, os.path.basename(
        args.train_data)) + ".db"
    print(db_path)
    db = SqliteDatabase(db_path)
    proxy.initialize(db)
    db.connect()
    """
    to use full images
    train_query =  Detection.select(Detection.image_id,Oracle.label,Detection.kind).join(Oracle).order_by(fn.random()).limit(limit)
    
    train_dataset = SQLDataLoader('/lscratch/datasets/serengeti', is_training= True, num_workers= args.workers, 
            raw_size= args.raw_size, processed_size= args.processed_size)
    """
    train_dataset = SQLDataLoader(os.path.join(args.train_data, 'crops'),
                                  is_training=True,
                                  num_workers=args.workers,
                                  raw_size=args.raw_size,
                                  processed_size=args.processed_size)
    train_dataset.setKind(DetectionKind.UserDetection.value)
    if args.val_data is not None:
        val_dataset = SQLDataLoader(os.path.join(args.val_data, 'crops'),
                                    is_training=False,
                                    num_workers=args.workers)
    #num_classes= len(train_dataset.getClassesInfo()[0])
    num_classes = args.num_classes
    if args.balanced_P == -1:
        args.balanced_P = num_classes
    #print("Num Classes= "+str(num_classes))
    if args.loss_type.lower() == 'center' or args.loss_type.lower(
    ) == 'softmax':
        train_loader = train_dataset.getSingleLoader(
            batch_size=args.batch_size)
        train_embd_loader = train_loader
        if args.val_data is not None:
            val_loader = val_dataset.getSingleLoader(
                batch_size=args.batch_size)
            val_embd_loader = val_loader
    else:
        train_loader = train_dataset.getBalancedLoader(P=args.balanced_P,
                                                       K=args.balanced_K)
        train_embd_loader = train_dataset.getSingleLoader(
            batch_size=args.batch_size)
        if args.val_data is not None:
            val_loader = val_dataset.getBalancedLoader(P=args.balanced_P,
                                                       K=args.balanced_K)
            val_embd_loader = val_dataset.getSingleLoader(
                batch_size=args.batch_size)

    center_loss = None
    if args.loss_type.lower() == 'center' or args.loss_type.lower(
    ) == 'softmax':
        model = torch.nn.DataParallel(
            SoftmaxNet(args.arch,
                       args.feat_dim,
                       num_classes,
                       use_pretrained=args.pretrained)).cuda()
        if args.loss_type.lower() == 'center':
            criterion = CenterLoss(num_classes=num_classes,
                                   feat_dim=args.feat_dim)
            params = list(model.parameters()) + list(criterion.parameters())
        else:
            criterion = nn.CrossEntropyLoss().cuda()
            params = model.parameters()
    else:
        model = torch.nn.DataParallel(
            NormalizedEmbeddingNet(args.arch,
                                   args.feat_dim,
                                   use_pretrained=args.pretrained)).cuda()
        if args.loss_type.lower() == 'siamese':
            criterion = OnlineContrastiveLoss(args.margin,
                                              HardNegativePairSelector())
        else:
            criterion = OnlineTripletLoss(
                args.margin, RandomNegativeTripletSelector(args.margin))
        params = model.parameters()

    # define loss function (criterion) and optimizer
    optimizer = torch.optim.Adam(params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    #optimizer = torch.optim.SGD(params, momentum = 0.9, lr = args.lr, weight_decay = args.weight_decay)
    start_epoch = 0

    if checkpoint:
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        #optimizer.load_state_dict(checkpoint['optimizer'])
        if args.loss_type.lower() == 'center':
            criterion.load_state_dict(checkpoint['centers'])

    e = Engine(model,
               criterion,
               optimizer,
               verbose=True,
               print_freq=args.print_freq)
    for epoch in range(start_epoch, args.epochs):
        # train for one epoch
        #adjust_lr(optimizer,epoch)
        e.train_one_epoch(
            train_loader, epoch, True if args.loss_type.lower() == 'center'
            or args.loss_type.lower() == 'softmax' else False)
        #if epoch % 1 == 0 and epoch > 0:
        #    a, b, c = e.predict(train_embd_loader, load_info = True, dim = args.feat_dim)
        #    plot_embedding(reduce_dimensionality(a), b, c, {})
        # evaluate on validation set
        if args.val_data is not None:
            e.validate(val_loader,
                       True if args.loss_type.lower() == 'center' else False)
        save_checkpoint(
            {
                'epoch':
                epoch + 1,
                'arch':
                args.arch,
                'state_dict':
                model.state_dict(),
                'best_acc1':
                best_acc1,
                'optimizer':
                optimizer.state_dict(),
                'loss_type':
                args.loss_type,
                'num_classes':
                args.num_classes,
                'feat_dim':
                args.feat_dim,
                'centers':
                criterion.state_dict()
                if args.loss_type.lower() == 'center' else None
            }, False, "%s%s_%s_%04d.tar" %
            (args.checkpoint_prefix, args.loss_type, args.arch, epoch))