def verify(args): dataset_dir = get_dataset_dir(args) log_dir = get_log_dir(args) model_class = get_model_class(args) model = model_class(False).to(device) checkpoint = torch.load(args.verify_model) # model.load_state_dict(checkpoint["state_dict"], strict=False) model.load_state_dict(checkpoint, strict=False) # print(model) model.eval() image_a, image_b = args.verify_images.split(",") image_a = transform_for_infer(model_class.IMAGE_SHAPE)( image_loader(image_a)) image_a = image_a.unsqueeze(0) image_a = image_a.to(device) # print(f"Image A shape : {image_a.shape}") # image_a = torch.randn(3, 96, 128) image_b = transform_for_infer(model_class.IMAGE_SHAPE)( image_loader(image_b)) # images = torch.stack([image_a, image_b]).to(device) # print(f"Image stack : {images.shape}") with torch.no_grad(): out = model(image_a)
def train(args): dataset_dir = get_dataset_dir(args) log_dir = get_log_dir(args) model_class = get_model_class(args) training_set, validation_set, num_classes = create_datasets(dataset_dir) training_dataset = Dataset(training_set, transform_for_training(model_class.IMAGE_SHAPE)) validation_dataset = Dataset(validation_set, transform_for_infer(model_class.IMAGE_SHAPE)) training_dataloader = torch.utils.data.DataLoader( training_dataset, batch_size=args.batch_size, num_workers=6, shuffle=True) validation_dataloader = torch.utils.data.DataLoader( validation_dataset, batch_size=args.batch_size, num_workers=6, shuffle=False) model = model_class(num_classes).to(device) trainables_wo_bn = [ param for name, param in model.named_parameters() if param.requires_grad and "bn" not in name ] trainables_only_bn = [ param for name, param in model.named_parameters() if param.requires_grad and "bn" in name ] optimizer = torch.optim.SGD( [ { "params": trainables_wo_bn, "weight_decay": 0.0001 }, { "params": trainables_only_bn }, ], lr=args.lr, momentum=0.9, ) trainer = Trainer( optimizer, model, training_dataloader, validation_dataloader, max_epoch=args.epochs, resume=args.resume, log_dir=log_dir, ) trainer.train()
def verify(args): dataset_dir = get_dataset_dir(args) log_dir = get_log_dir(args) model_class = get_model_class(args) model = model_class(False).to(device) checkpoint = torch.load(args.verify_model) model.load_state_dict(checkpoint['state_dict'], strict=False) model.eval() image_a, image_b = args.verify_images.split(',') image_a = transform_for_infer( model_class.IMAGE_SHAPE)(image_loader(image_a)) image_b = transform_for_infer( model_class.IMAGE_SHAPE)(image_loader(image_b)) images = torch.stack([image_a, image_b]).to(device) _, (embedings_a, embedings_b) = model(images) distance = torch.sum(torch.pow(embedings_a - embedings_b, 2)).item() print("distance: {}".format(distance))
def evaluate(args): dataset_dir = get_dataset_dir(args) log_dir = get_log_dir(args) model_class = get_model_class(args) pairs_path = args.pairs if args.pairs else \ os.path.join(dataset_dir, 'pairs.txt') if not os.path.isfile(pairs_path): download(dataset_dir, 'http://vis-www.cs.umass.edu/lfw/pairs.txt') dataset = LFWPairedDataset( dataset_dir, pairs_path, transform_for_infer(model_class.IMAGE_SHAPE)) dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=4) model = model_class(False).to(device) checkpoint = torch.load(args.evaluate) model.load_state_dict(checkpoint['state_dict'], strict=False) model.eval() embedings_a = torch.zeros(len(dataset), model.FEATURE_DIM) embedings_b = torch.zeros(len(dataset), model.FEATURE_DIM) matches = torch.zeros(len(dataset), dtype=torch.uint8) for iteration, (images_a, images_b, batched_matches) \ in enumerate(dataloader): current_batch_size = len(batched_matches) images_a = images_a.to(device) images_b = images_b.to(device) _, batched_embedings_a = model(images_a) _, batched_embedings_b = model(images_b) start = args.batch_size * iteration end = start + current_batch_size embedings_a[start:end, :] = batched_embedings_a.data embedings_b[start:end, :] = batched_embedings_b.data matches[start:end] = batched_matches.data thresholds = np.arange(0, 4, 0.1) distances = torch.sum(torch.pow(embedings_a - embedings_b, 2), dim=1) tpr, fpr, accuracy, best_thresholds = compute_roc( distances, matches, thresholds ) roc_file = args.roc if args.roc else os.path.join(log_dir, 'roc.png') generate_roc_curve(fpr, tpr, roc_file) print('Model accuracy is {}'.format(accuracy)) print('ROC curve generated at {}'.format(roc_file))
def train(args): dataset_dir = get_dataset_dir(args) log_dir = get_log_dir(args) model_class = get_model_class(args) # orgainzesz dataset go more into training_set, validation_set, num_classes = create_datasets(dataset_dir) #look more into training_dataset = Dataset(training_set, transform_for_training(model_class.IMAGE_SHAPE)) validation_dataset = Dataset(validation_set, transform_for_infer(model_class.IMAGE_SHAPE)) training_dataloader = torch.utils.data.DataLoader( training_dataset, batch_size=args.batch_size, num_workers=6, shuffle=True) validation_dataloader = torch.utils.data.DataLoader( validation_dataset, batch_size=args.batch_size, num_workers=6, shuffle=False) # what is Numclases and devie model = model_class(num_classes).to(device) # tain trainables_wo_bn and trainables_only_bn trainables_wo_bn = [ param for name, param in model.named_parameters() if param.requires_grad and 'bn' not in name ] trainables_only_bn = [ param for name, param in model.named_parameters() if param.requires_grad and 'bn' in name ] optimizer = torch.optim.SGD([{ 'params': trainables_wo_bn, 'weight_decay': 0.0001 }, { 'params': trainables_only_bn }], lr=args.lr, momentum=0.9) trainer = Trainer(optimizer, model, training_dataloader, validation_dataloader, max_epoch=args.epochs, resume=args.resume, log_dir=log_dir) trainer.train()
def train(args): t_training_set = [] t_validation_set = [] t_num_classes = [] dataset_dir = get_dataset_dir(args) log_dir = get_log_dir(args) model_class = get_model_class(args) if args.w != 0: w_training_set, w_validation_set, num_classes_w = create_datasetsW( dataset_dir) w_training_set = w_training_set[0:(int(args.w / 2))] w_validation_set = w_validation_set[0:(int(args.w / 2))] t_training_set = + [w_training_set] t_validation_set = + [w_validation_set] t_num_classes = +num_classes_w if args.sa != 0: sa_training_set, sa_validation_set, num_classes_sa = create_datasetsSA( dataset_dir) sa_training_set = sa_training_set[0:int(args.sa / 2)] sa_validation_set = sa_validation_set[0:int(args.sa / 2)] t_training_set = +sa_training_set t_validation_set = +sa_validation_set t_num_classes = +num_classes_sa if args.ai != 0: as_training_set, as_validation_set, num_classes_as = create_datasetsAs( dataset_dir) as_training_set = as_training_set[0:(int(args.ai / 2)) + 1] as_validation_set = as_validation_set[0:(int(args.ai / 2)) + 1] t_training_set = +as_training_set t_validation_set = +as_validation_set t_num_classes = +num_classes_as if args.af != 0: af_training_set, af_validation_set, classes_af = create_datasetsAF( dataset_dir) af_training_set = af_training_set[0:(int(args.ai / 2)) + 1] af_validation_set = af_validation_set[0:(int(args.ai / 2)) + 1] t_training_set.append(af_training_set) t_validation_set.append(af_validation_set) t_num_classes += classes_af training_set = t_training_set validation_set = t_validation_set num_classes = t_num_classes training_dataset = Dataset(training_set, transform_for_training(model_class.IMAGE_SHAPE)) validation_dataset = Dataset(validation_set, transform_for_infer(model_class.IMAGE_SHAPE)) training_dataloader = torch.utils.data.DataLoader( training_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) validation_dataloader = torch.utils.data.DataLoader( validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) model = model_class(num_classes).to(device) trainables_wo_bn = [ param for name, param in model.named_parameters() if param.requires_grad and 'bn' not in name ] trainables_only_bn = [ param for name, param in model.named_parameters() if param.requires_grad and 'bn' in name ] optimizer = torch.optim.SGD([{ 'params': trainables_wo_bn, 'weight_decay': 0.0001 }, { 'params': trainables_only_bn }], lr=args.lr, momentum=0.9) trainer = Trainer(optimizer, model, training_dataloader, validation_dataloader, max_epoch=args.epochs, resume=args.resume, log_dir=log_dir) trainer.train()