Beispiel #1
0
def verify(args):
    dataset_dir = get_dataset_dir(args)
    log_dir = get_log_dir(args)
    model_class = get_model_class(args)

    model = model_class(False).to(device)
    checkpoint = torch.load(args.verify_model)
    # model.load_state_dict(checkpoint["state_dict"], strict=False)
    model.load_state_dict(checkpoint, strict=False)
    # print(model)

    model.eval()

    image_a, image_b = args.verify_images.split(",")
    image_a = transform_for_infer(model_class.IMAGE_SHAPE)(
        image_loader(image_a))
    image_a = image_a.unsqueeze(0)
    image_a = image_a.to(device)
    # print(f"Image A shape : {image_a.shape}")
    # image_a = torch.randn(3, 96, 128)
    image_b = transform_for_infer(model_class.IMAGE_SHAPE)(
        image_loader(image_b))
    # images = torch.stack([image_a, image_b]).to(device)
    # print(f"Image stack : {images.shape}")

    with torch.no_grad():
        out = model(image_a)
Beispiel #2
0
def train(args):
    dataset_dir = get_dataset_dir(args)
    log_dir = get_log_dir(args)
    model_class = get_model_class(args)

    training_set, validation_set, num_classes = create_datasets(dataset_dir)

    training_dataset = Dataset(training_set,
                               transform_for_training(model_class.IMAGE_SHAPE))
    validation_dataset = Dataset(validation_set,
                                 transform_for_infer(model_class.IMAGE_SHAPE))

    training_dataloader = torch.utils.data.DataLoader(
        training_dataset,
        batch_size=args.batch_size,
        num_workers=6,
        shuffle=True)

    validation_dataloader = torch.utils.data.DataLoader(
        validation_dataset,
        batch_size=args.batch_size,
        num_workers=6,
        shuffle=False)

    model = model_class(num_classes).to(device)

    trainables_wo_bn = [
        param for name, param in model.named_parameters()
        if param.requires_grad and "bn" not in name
    ]
    trainables_only_bn = [
        param for name, param in model.named_parameters()
        if param.requires_grad and "bn" in name
    ]

    optimizer = torch.optim.SGD(
        [
            {
                "params": trainables_wo_bn,
                "weight_decay": 0.0001
            },
            {
                "params": trainables_only_bn
            },
        ],
        lr=args.lr,
        momentum=0.9,
    )

    trainer = Trainer(
        optimizer,
        model,
        training_dataloader,
        validation_dataloader,
        max_epoch=args.epochs,
        resume=args.resume,
        log_dir=log_dir,
    )
    trainer.train()
Beispiel #3
0
def verify(args):
    dataset_dir = get_dataset_dir(args)
    log_dir = get_log_dir(args)
    model_class = get_model_class(args)

    model = model_class(False).to(device)
    checkpoint = torch.load(args.verify_model)
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.eval()

    image_a, image_b = args.verify_images.split(',')
    image_a = transform_for_infer(
        model_class.IMAGE_SHAPE)(image_loader(image_a))
    image_b = transform_for_infer(
        model_class.IMAGE_SHAPE)(image_loader(image_b))
    images = torch.stack([image_a, image_b]).to(device)

    _, (embedings_a, embedings_b) = model(images)

    distance = torch.sum(torch.pow(embedings_a - embedings_b, 2)).item()
    print("distance: {}".format(distance))
Beispiel #4
0
def evaluate(args):
    dataset_dir = get_dataset_dir(args)
    log_dir = get_log_dir(args)
    model_class = get_model_class(args)

    pairs_path = args.pairs if args.pairs else \
        os.path.join(dataset_dir, 'pairs.txt')

    if not os.path.isfile(pairs_path):
        download(dataset_dir, 'http://vis-www.cs.umass.edu/lfw/pairs.txt')

    dataset = LFWPairedDataset(
        dataset_dir, pairs_path, transform_for_infer(model_class.IMAGE_SHAPE))
    dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=4)
    model = model_class(False).to(device)

    checkpoint = torch.load(args.evaluate)
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.eval()

    embedings_a = torch.zeros(len(dataset), model.FEATURE_DIM)
    embedings_b = torch.zeros(len(dataset), model.FEATURE_DIM)
    matches = torch.zeros(len(dataset), dtype=torch.uint8)

    for iteration, (images_a, images_b, batched_matches) \
            in enumerate(dataloader):
        current_batch_size = len(batched_matches)
        images_a = images_a.to(device)
        images_b = images_b.to(device)

        _, batched_embedings_a = model(images_a)
        _, batched_embedings_b = model(images_b)

        start = args.batch_size * iteration
        end = start + current_batch_size

        embedings_a[start:end, :] = batched_embedings_a.data
        embedings_b[start:end, :] = batched_embedings_b.data
        matches[start:end] = batched_matches.data

    thresholds = np.arange(0, 4, 0.1)
    distances = torch.sum(torch.pow(embedings_a - embedings_b, 2), dim=1)

    tpr, fpr, accuracy, best_thresholds = compute_roc(
        distances,
        matches,
        thresholds
    )

    roc_file = args.roc if args.roc else os.path.join(log_dir, 'roc.png')
    generate_roc_curve(fpr, tpr, roc_file)
    print('Model accuracy is {}'.format(accuracy))
    print('ROC curve generated at {}'.format(roc_file))
Beispiel #5
0
def train(args):
    dataset_dir = get_dataset_dir(args)
    log_dir = get_log_dir(args)
    model_class = get_model_class(args)

    # orgainzesz dataset go more into
    training_set, validation_set, num_classes = create_datasets(dataset_dir)
    #look more into
    training_dataset = Dataset(training_set,
                               transform_for_training(model_class.IMAGE_SHAPE))
    validation_dataset = Dataset(validation_set,
                                 transform_for_infer(model_class.IMAGE_SHAPE))

    training_dataloader = torch.utils.data.DataLoader(
        training_dataset,
        batch_size=args.batch_size,
        num_workers=6,
        shuffle=True)

    validation_dataloader = torch.utils.data.DataLoader(
        validation_dataset,
        batch_size=args.batch_size,
        num_workers=6,
        shuffle=False)
    # what is Numclases and devie
    model = model_class(num_classes).to(device)
    # tain  trainables_wo_bn and   trainables_only_bn
    trainables_wo_bn = [
        param for name, param in model.named_parameters()
        if param.requires_grad and 'bn' not in name
    ]
    trainables_only_bn = [
        param for name, param in model.named_parameters()
        if param.requires_grad and 'bn' in name
    ]

    optimizer = torch.optim.SGD([{
        'params': trainables_wo_bn,
        'weight_decay': 0.0001
    }, {
        'params': trainables_only_bn
    }],
                                lr=args.lr,
                                momentum=0.9)

    trainer = Trainer(optimizer,
                      model,
                      training_dataloader,
                      validation_dataloader,
                      max_epoch=args.epochs,
                      resume=args.resume,
                      log_dir=log_dir)
    trainer.train()
Beispiel #6
0
def train(args):
    t_training_set = []
    t_validation_set = []
    t_num_classes = []
    dataset_dir = get_dataset_dir(args)
    log_dir = get_log_dir(args)
    model_class = get_model_class(args)

    if args.w != 0:
        w_training_set, w_validation_set, num_classes_w = create_datasetsW(
            dataset_dir)
        w_training_set = w_training_set[0:(int(args.w / 2))]
        w_validation_set = w_validation_set[0:(int(args.w / 2))]
        t_training_set = + [w_training_set]
        t_validation_set = + [w_validation_set]
        t_num_classes = +num_classes_w

    if args.sa != 0:
        sa_training_set, sa_validation_set, num_classes_sa = create_datasetsSA(
            dataset_dir)
        sa_training_set = sa_training_set[0:int(args.sa / 2)]
        sa_validation_set = sa_validation_set[0:int(args.sa / 2)]
        t_training_set = +sa_training_set
        t_validation_set = +sa_validation_set
        t_num_classes = +num_classes_sa

    if args.ai != 0:
        as_training_set, as_validation_set, num_classes_as = create_datasetsAs(
            dataset_dir)
        as_training_set = as_training_set[0:(int(args.ai / 2)) + 1]
        as_validation_set = as_validation_set[0:(int(args.ai / 2)) + 1]
        t_training_set = +as_training_set
        t_validation_set = +as_validation_set
        t_num_classes = +num_classes_as

    if args.af != 0:
        af_training_set, af_validation_set, classes_af = create_datasetsAF(
            dataset_dir)
        af_training_set = af_training_set[0:(int(args.ai / 2)) + 1]
        af_validation_set = af_validation_set[0:(int(args.ai / 2)) + 1]
        t_training_set.append(af_training_set)
        t_validation_set.append(af_validation_set)
        t_num_classes += classes_af

    training_set = t_training_set
    validation_set = t_validation_set
    num_classes = t_num_classes

    training_dataset = Dataset(training_set,
                               transform_for_training(model_class.IMAGE_SHAPE))
    validation_dataset = Dataset(validation_set,
                                 transform_for_infer(model_class.IMAGE_SHAPE))

    training_dataloader = torch.utils.data.DataLoader(
        training_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True)

    validation_dataloader = torch.utils.data.DataLoader(
        validation_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False)

    model = model_class(num_classes).to(device)

    trainables_wo_bn = [
        param for name, param in model.named_parameters()
        if param.requires_grad and 'bn' not in name
    ]
    trainables_only_bn = [
        param for name, param in model.named_parameters()
        if param.requires_grad and 'bn' in name
    ]

    optimizer = torch.optim.SGD([{
        'params': trainables_wo_bn,
        'weight_decay': 0.0001
    }, {
        'params': trainables_only_bn
    }],
                                lr=args.lr,
                                momentum=0.9)

    trainer = Trainer(optimizer,
                      model,
                      training_dataloader,
                      validation_dataloader,
                      max_epoch=args.epochs,
                      resume=args.resume,
                      log_dir=log_dir)
    trainer.train()