Exemple #1
0
# Data loader
train_set = build_dataset(cfg.data.train)
query_set = build_dataset(cfg.data.query)
print('datasets loaded')

train_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size=cfg.data.imgs_per_gpu,
                                           shuffle=True)
query_loader = torch.utils.data.DataLoader(query_set,
                                           batch_size=cfg.data.imgs_per_gpu,
                                           shuffle=True)
print('dataloader created')

# Build model
backbone_model = Vgg16L2(num_dim=128)
model = TripletNet(backbone_model)
model.cuda()
print('model built')

margin = 1.
loss_fn = TripletLoss(margin)
lr = 1e-4
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 20
log_interval = 500

print('start training')
fit(train_loader, query_loader, model, loss_fn, optimizer, scheduler, n_epochs,
    cuda, log_interval)
                                           n_samples=12)
test_batch_sampler = BalancedBatchSampler(torch.tensor(query_set.train_labels),
                                          n_classes=8,
                                          n_samples=12)

# Dataloaders
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
online_train_loader = torch.utils.data.DataLoader(
    train_set, batch_sampler=train_batch_sampler, **kwargs)
online_test_loader = torch.utils.data.DataLoader(
    query_set, batch_sampler=test_batch_sampler, **kwargs)
print('dataloaders built')

# Build model and load checkpoint
# model = build_retriever(cfg.model)
model = Vgg16L2(num_dim=128)
model.cuda()
print('model built')

# Set up the network and training parameters
from losses import OnlineTripletLoss
from utils import AllTripletSelector, HardestNegativeTripletSelector, RandomNegativeTripletSelector, SemihardNegativeTripletSelector  # Strategies for selecting triplets within a minibatch
from metrics import AverageNonzeroTripletsMetric

margin = 1.
loss_fn = OnlineTripletLoss(margin, SemihardNegativeTripletSelector(margin))
lr = 1e-4
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 100
log_interval = 50