Exemple #1
0
def main():
    dataset = EycDataset(train=True)
    net = SiameseNetwork().cuda()
    # net = torch.load('models/model_triplet_pr_po_max_pool_fix.pt')
    print("model loaded")

    train_dataloader = DataLoader(dataset,
                                  shuffle=True,
                                  num_workers=8,
                                  batch_size=train_batch_size)
    # criterion = TGLoss()
    criterion = TripletLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0005)

    for epoch in range(0, epoch_num):
        if epoch % 10 == 0:
            test.test(True)
            test.test(False)

        for i, data in enumerate(train_dataloader):

            # print(data)
            (anchor, positive, negative) = data

            anchor, positive, negative = Variable(anchor).cuda(), Variable(
                positive).cuda(), Variable(negative).cuda()
            (anchor_output, positive_output,
             negative_output) = net(anchor, positive, negative)

            optimizer.zero_grad()
            # loss = criterion(anchor_output, positive_output, negative_output, train_batch_size)
            loss = criterion(anchor_output, positive_output, negative_output)
            loss.backward()
            optimizer.step()

            if i % 10 == 0:
                print("Epoch number {}\n Current loss {}\n".format(
                    epoch, loss.data[0]))

        print("Saving model")
        torch.save(net, 'models/model_triplet_pr_po_max_pool_fix_weighted.pt')
        print("-- Model Checkpoint saved ---")
Exemple #2
0
def main():
    print("Extract data")
    unzip_data()

    print("Split on train and test")
    split_on_train_and_test()

    print("Create datasets")
    train_ds, test_ds = prepare_datasets()

    print("Create data loaders")
    train_sampler = SiameseSampler(train_ds, random_state=RS)
    test_sampler = SiameseSampler(test_ds, random_state=RS)
    train_data_loader = DataLoader(train_ds,
                                   batch_size=BATCH_SIZE,
                                   sampler=train_sampler,
                                   num_workers=4)
    test_data_loader = DataLoader(test_ds,
                                  batch_size=BATCH_SIZE,
                                  sampler=test_sampler,
                                  num_workers=4)

    print("Build computational graph")
    mobilenet = mobilenet_v2(pretrained=True)
    # remove last layer
    mobilenet = torch.nn.Sequential(*(list(mobilenet.children())[:-1]))
    siams = SiameseNetwork(twin_net=TransferTwinNetwork(
        base_model=mobilenet, output_dim=EMBEDDING_DIM))
    siams.to(DEVICE)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(siams.parameters(), lr=LR)

    print("Train model")
    siams = train(siams, criterion, optimizer, train_data_loader,
                  test_data_loader)

    print("Save model")
    torch.save(siams.twin_net.state_dict(), 'models/twin.pt')
Exemple #3
0
    os.makedirs(args.out_path, exist_ok=True)

    # Set device to CUDA if a CUDA device is available, else CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_dataset = Dataset(args.train_path, shuffle_pairs=True, augment=True)
    val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False)

    train_dataloader = DataLoader(train_dataset, batch_size=8, drop_last=True)
    val_dataloader = DataLoader(val_dataset, batch_size=8)

    model = SiameseNetwork(backbone=args.backbone)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    criterion = torch.nn.BCELoss()

    writer = SummaryWriter(os.path.join(args.out_path, "summary"))

    best_val = 10000000000

    for epoch in range(args.epochs):
        print("[{} / {}]".format(epoch, args.epochs))
        model.train()

        losses = []
        correct = 0
        total = 0

        # Training Loop Start
Exemple #4
0
train_loader = CameraDataset(pivot_images,
                             positive_images,
                             batch_size,
                             num_batch,
                             data_transform,
                             is_train=True)
print('Randomly paired data are generated.')

# 2: load network
branch = BranchNetwork()
net = SiameseNetwork(branch)

criterion = ContrastiveLoss(margin=1.0)

optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()),
                       lr=learning_rate,
                       weight_decay=0.000001)

# 3: setup computation device
if resume:
    if os.path.isfile(resume):
        checkpoint = torch.load(resume,
                                map_location=lambda storage, loc: storage)
        net.load_state_dict(checkpoint['state_dict'])
        print('resume from {}.'.format(resume))
    else:
        print('file not found at {}'.format(resume))
else:
    print('Learning from scratch')