def learn(data, K, resolution, epoch, device, checkpoint=None): """ Data is a tensor ith size (len(dataset), number of colours,resolution[0], resolution[1]) """ model = Segmenter(resolution, K, device) optimizer = optim.Adam(model.parameters(), lr=LR) if checkpoint != None: model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) dataloader = torch.utils.data.DataLoader(data, BATCH, shuffle=True, drop_last=True) for i in tqdm(range(epoch), desc="Epochs"): for batch in tqdm(dataloader, desc="Batchs"): model.zero_grad() Uenc = model(batch, True) Nloss = Ncutsoft(Uenc, batch).to(device) print("Loss on this batch: {}".format(float(Nloss))) print("Accuracy on this batch : {}%".format( round(100 * (1 - float(Nloss) / K), 2))) Nloss.backward() optimizer.step() model.zero_grad() Udec = model(batch, False)[1] Jloss = Jrecons(Udec, batch).to(device) Jloss.backward() optimizer.step() torch.save( { "model": model.state_dict(), "optimizer": optimizer.state_dict() }, path_saved_models + '{}epoch'.format(i)) return (model, optimizer)