def main(): # lsp_extended_dataset = Dataset.LSP_Dataset(path="./lspet_dataset", is_lsp_extended_dataset=True) lsp_dataset = Dataset.LSP_Dataset() dataset = torch.utils.data.ConcatDataset([lsp_dataset ]) #lsp_extended_dataset batch_size = 16 total = len(dataset) train_size, val_size, test_size = int(total * 0.6), int(total * 0.2), int( total * 0.2) lengths = [train_size, val_size, test_size] train_dataset, val_dataset, test_dataset = torch.utils.data.dataset.random_split( dataset, lengths) train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=True) test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) model = Network.DeepPose().float().to(device) criterion = nn.MSELoss(reduction="sum") optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-3) train_loss_lst, val_loss_lst, batch_epoch_loss_lst = train( epochs=100, model=model, train_dl=train_dl, val_dl=val_dl, optimizer=optimizer, criterion=criterion, train_size=train_size, val_size=val_size) # first cascading stage S=2 stage2_dataset = cascade_train.LSP_cascade_Dataset(train_dataset, model)