Пример #1
0
def main():
    # test = Image.open("./data/1.pgm")
    # test.show()
    # print (test)

    print("Training starts...")
    folder_dataset = dset.ImageFolder(root=Config.training_dir)
    siamese_dataset = Contrastive_Dataset(imageFolderDataset=folder_dataset,
                                          transform=transforms.Compose([
                                              transforms.Resize((100, 100)),
                                              transforms.ToTensor()
                                          ]),
                                          should_invert=False)

    train_dataloader = DataLoader(siamese_dataset,
                                  shuffle=True,
                                  num_workers=8,
                                  batch_size=Config.train_batch_size)

    net = SiameseNetwork()
    criterion = ContrastiveLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0005)

    counter = []
    loss_history = []
    iteration_number = 0

    for epoch in range(0, Config.train_number_epochs):
        for i, data in enumerate(train_dataloader, 0):
            img0, img1, label = data
            img0, img1, label = Variable(img0), Variable(img1), Variable(label)
            output1, output2 = net(img0, img1)
            optimizer.zero_grad()
            loss_contrastive = criterion(output1, output2, label)
            loss_contrastive.backward()
            optimizer.step()
            if i % 10 == 0:
                print("Epoch number {}\n Current loss {}\n".format(
                    epoch, loss_contrastive.data[0]))
                iteration_number += 10
                counter.append(iteration_number)
                loss_history.append(loss_contrastive.data[0])
    torch.save(net.state_dict(), "trained_weights.pt")
    # save_checkpoint({
    #     'epoch': epoch + 1,
    #     })

    show_plot(counter, loss_history)
Пример #2
0
    siamese_dataset_test = SiameseNetworkDataset(sample_path=Config.sample_test_path,
                                        transform=transforms.Compose([transforms.ToTensor()])
                                       ,should_invert=False)

    print("len(siamese_dataset_test) = ", siamese_dataset_test.__len__())

    test_dataloader = DataLoader(siamese_dataset_test,
                        shuffle=False,#siamese_dataset 重排后再取数据
                        num_workers=5,
                        batch_size=360)
    

    #模型训练
    net = SiameseNetwork() #网络结构
    criterion = ContrastiveLoss() #损失函数
    optimizer = optim.Adam(net.parameters(),lr = 0.0005) #参数优化函数

    train_loss_history = []
    test_loss_history = []
    for epoch in range(0, Config.train_number_epochs):#整个样本集的迭代
        list_loss_epoch_c = []
        for i, data in enumerate(train_dataloader,0):#batch迭代
            img0, img1, label = data

            optimizer.zero_grad() #模型参数梯度设为0
            output1,output2 = net(img0,img1)
            loss_contrastive = criterion(output1,output2,label)

            loss_contrastive.backward() #反向传播
            optimizer.step() #更新参数空间