def main(): # test = Image.open("./data/1.pgm") # test.show() # print (test) print("Training starts...") folder_dataset = dset.ImageFolder(root=Config.training_dir) siamese_dataset = Contrastive_Dataset(imageFolderDataset=folder_dataset, transform=transforms.Compose([ transforms.Resize((100, 100)), transforms.ToTensor() ]), should_invert=False) train_dataloader = DataLoader(siamese_dataset, shuffle=True, num_workers=8, batch_size=Config.train_batch_size) net = SiameseNetwork() criterion = ContrastiveLoss() optimizer = optim.Adam(net.parameters(), lr=0.0005) counter = [] loss_history = [] iteration_number = 0 for epoch in range(0, Config.train_number_epochs): for i, data in enumerate(train_dataloader, 0): img0, img1, label = data img0, img1, label = Variable(img0), Variable(img1), Variable(label) output1, output2 = net(img0, img1) optimizer.zero_grad() loss_contrastive = criterion(output1, output2, label) loss_contrastive.backward() optimizer.step() if i % 10 == 0: print("Epoch number {}\n Current loss {}\n".format( epoch, loss_contrastive.data[0])) iteration_number += 10 counter.append(iteration_number) loss_history.append(loss_contrastive.data[0]) torch.save(net.state_dict(), "trained_weights.pt") # save_checkpoint({ # 'epoch': epoch + 1, # }) show_plot(counter, loss_history)
siamese_dataset_test = SiameseNetworkDataset(sample_path=Config.sample_test_path, transform=transforms.Compose([transforms.ToTensor()]) ,should_invert=False) print("len(siamese_dataset_test) = ", siamese_dataset_test.__len__()) test_dataloader = DataLoader(siamese_dataset_test, shuffle=False,#siamese_dataset 重排后再取数据 num_workers=5, batch_size=360) #模型训练 net = SiameseNetwork() #网络结构 criterion = ContrastiveLoss() #损失函数 optimizer = optim.Adam(net.parameters(),lr = 0.0005) #参数优化函数 train_loss_history = [] test_loss_history = [] for epoch in range(0, Config.train_number_epochs):#整个样本集的迭代 list_loss_epoch_c = [] for i, data in enumerate(train_dataloader,0):#batch迭代 img0, img1, label = data optimizer.zero_grad() #模型参数梯度设为0 output1,output2 = net(img0,img1) loss_contrastive = criterion(output1,output2,label) loss_contrastive.backward() #反向传播 optimizer.step() #更新参数空间