Esempio n. 1
0
def train():
    x_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    y_transforms = transforms.ToTensor()

    model = Unet(3, 1).to(device)
    model.load_state_dict(torch.load(r"./results/weights.pth"))
    batch_size = 1
    num_epochs = 2
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    liver_dataset = LiverDataset(r'D:\project\data_sets\liver\train',
                                 transform=x_transforms,
                                 target_transform=y_transforms)
    data_loaders = DataLoader(liver_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=0)
    print("Start training at ", strftime("%Y-%m-%d %H:%M:%S", localtime()))
    for epoch in range(num_epochs):
        prev_time = datetime.now()
        print('Epoch{}/{}'.format(epoch, num_epochs))
        print('-' * 10)
        dt_size = len(data_loaders.dataset)
        epoch_loss = 0
        step = 0
        for x, y in data_loaders:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            if (step % 10) == 0:
                print("%d/%d, train_loss:%0.3f" %
                      (step, (dt_size - 1) // data_loaders.batch_size + 1,
                       loss.item()))
        # print the results of the current training
        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = 'Time:{:.0f}:{:.0f}:{:.0f}'.format(h, m, s)
        epoch_str = "epoch {} loss:{:.4f} ".format(epoch, epoch_loss / 400)
        print(epoch_str + time_str)
        res_record("Time:" + strftime("%Y-%m-%d %H:%M:%S  ", localtime()))
        res_record(epoch_str + '\n')
    print("End training at ", strftime("%Y-%m-%d %H:%M:%S", localtime()))
    # 记录数据
    torch.save(
        model.state_dict(),
        './results/weights{}_{}_{}.pth'.format(localtime().tm_mday,
                                               localtime().tm_hour,
                                               localtime().tm_sec))
    net.load_state_dict(torch.load('./models/model_{}.pth'.format(name)))
    try:
        best_val_loss = np.load('./models/best_val_loss_{}.npy'.format(name))
    except:
        best_val_loss = np.finfo(np.float64).max
    print("Model reloaded. Previous lowest validation loss =",
          str(best_val_loss))
else:
    best_val_loss = np.finfo(np.float64).max

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, weight_decay=1e-4)

best_weights = net.state_dict()
num_epochs = 5
train_loss = np.zeros(num_epochs)
validation_loss = np.zeros(num_epochs)

print('\nStart training')
np.savetxt('epochs_completed.txt', np.zeros(1), fmt='%d')
for epoch in range(num_epochs):  #TODO decide epochs
    print('-----------------Epoch = %d-----------------' % (epoch + 1))
    train_loss[epoch], _ = train(train_loader, net, criterion, optimizer,
                                 device, epoch + 1)

    # TODO create your evaluation set, load the evaluation set and test on evaluation set
    val_loss = test(eval_loader, net, criterion, device)
    validation_loss[epoch] = val_loss
    if val_loss < best_val_loss:
Esempio n. 3
0
        with tqdm(train_loader) as it:
            for x, label in it:
                num += 1
                optim.zero_grad()
                loss = diffusion(x.cuda())
                loss.backward()
                optim.step()
                it.set_postfix(ordered_dict={
                    'train loss': loss.item(),
                    'epoch': epoch
                },
                               refresh=False)

            if num % update_ema_every == 0:
                if num < 1000:
                    ema_model.load_state_dict(model.state_dict())
                else:
                    ema.update_model_average(ema_model, model)

    # sample images  100
    shape = (36, 1, 28, 28)
    eta = 0.
    res = diffusion.sample(shape)
    # res = res.detach().cpu().numpy()
    res1 = diffusion.ddim_sample(shape, eta)  #.detach().cpu().numpy()
    # plot_fig(res,path='../figures/diff_1.jpg')
    # plot_fig(res1, path='../figures/diff_2.jpg')
    save_image(res, '../figures/diff_1.jpg', nrow=6)
    save_image(res1, '../figures/diff_2.jpg', nrow=6)
    # plt.show()
            predict_map, predict_area = unet(
                img_batch
            )  # batch_size*num_class*height*width, batch_size*num_class

            loss = criterion(predict_map, predict_area, mask_batch, device)

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        print("epoch %d , average train loss %f" % (epoch + 1, epoch_loss /
                                                    (step + 1)))

        #-----------save model parameter-----
        if (epoch + 1) % 10 == 0:
            torch.save(unet.state_dict(),
                       os.path.join('train_log',
                                    str(epoch + 1) + '.pt'))

        # ----------validation------------

        unet.eval()

        dice_sum = np.zeros(num_class)
        estimate_mae_sum = np.zeros(num_class)
        segment_mae_sum = np.zeros(num_class)

        for step, (img_valid, mask_valid) in enumerate(validate_dataloader):

            img_valid = img_valid.float().to(
                device)  #batch_size*channel*height*width
Esempio n. 5
0
train_data = Dataset_brain_4(file_path)
batch_size = 64
train_loader = data.DataLoader(dataset=train_data,
                               batch_size=batch_size,
                               shuffle=True,
                               drop_last=True,
                               num_workers=4)
unet = Unet(4)
optimizer = torch.optim.Adam(unet.parameters(), lr=0.001)
unet.cuda()
unet.train()
EPOCH = 30
print(EPOCH)
for epoch in range(EPOCH):
    batch_score = 0
    num_batch = 0
    for i, (img, label) in enumerate(train_loader):
        seg = unet(img.float().cuda())
        loss = dice_loss(seg, label.float().cuda())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        seg = seg.cpu()
        seg[seg >= 0.5] = 1.
        seg[seg != 1] = 0.
        batch_score += dice_score(seg, label.float()).data.numpy()
        num_batch += img.size(0)
    batch_score /= num_batch
    print('EPOCH %d : train_score = %.4f' % (epoch, batch_score))
torch.save(unet.state_dict(), model_save)
Esempio n. 6
0
def main(argv):
    """

    IMAGES VALID:
    * 005-TS_13C08351_2-2014-02-12 12.22.44.ndpi | id : 77150767
    * 024-12C07162_2A-2012-08-14-17.21.05.jp2 | id : 77150761
    * 019-CP_12C04234_2-2012-08-10-12.49.26.jp2 | id : 77150809

    IMAGES TEST:
    * 004-PF_08C11886_1-2012-08-09-19.05.53.jp2 | id : 77150623
    * 011-TS_13C10153_3-2014-02-13 15.22.21.ndpi | id : 77150611
    * 018-PF_07C18435_1-2012-08-17-00.55.09.jp2 | id : 77150755

    """
    with Cytomine.connect_from_cli(argv):
        parser = ArgumentParser()
        parser.add_argument("-b",
                            "--batch_size",
                            dest="batch_size",
                            default=4,
                            type=int)
        parser.add_argument("-j",
                            "--n_jobs",
                            dest="n_jobs",
                            default=1,
                            type=int)
        parser.add_argument("-e",
                            "--epochs",
                            dest="epochs",
                            default=1,
                            type=int)
        parser.add_argument("-d", "--device", dest="device", default="cpu")
        parser.add_argument("-o",
                            "--overlap",
                            dest="overlap",
                            default=0,
                            type=int)
        parser.add_argument("-t",
                            "--tile_size",
                            dest="tile_size",
                            default=256,
                            type=int)
        parser.add_argument("-z",
                            "--zoom_level",
                            dest="zoom_level",
                            default=0,
                            type=int)
        parser.add_argument("--lr", dest="lr", default=0.01, type=float)
        parser.add_argument("--init_fmaps",
                            dest="init_fmaps",
                            default=16,
                            type=int)
        parser.add_argument("--data_path",
                            "--dpath",
                            dest="data_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        parser.add_argument("-w",
                            "--working_path",
                            "--wpath",
                            dest="working_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        parser.add_argument("-s",
                            "--save_path",
                            dest="save_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        args, _ = parser.parse_known_args(argv)

        os.makedirs(args.save_path, exist_ok=True)
        os.makedirs(args.data_path, exist_ok=True)
        os.makedirs(args.working_path, exist_ok=True)

        # fetch annotations (filter val/test sets + other annotations)
        all_annotations = AnnotationCollection(project=77150529,
                                               showWKT=True,
                                               showMeta=True,
                                               showTerm=True).fetch()
        val_ids = {77150767, 77150761, 77150809}
        test_ids = {77150623, 77150611, 77150755}
        val_test_ids = val_ids.union(test_ids)
        train_collection = all_annotations.filter(lambda a: (
            a.user in {55502856} and len(a.term) > 0 and a.term[0] in
            {35777351, 35777321, 35777459} and a.image not in val_test_ids))
        val_rois = all_annotations.filter(
            lambda a: (a.user in {142954314} and a.image in val_ids and len(
                a.term) > 0 and a.term[0] in {154890363}))
        val_foreground = all_annotations.filter(
            lambda a: (a.user in {142954314} and a.image in val_ids and len(
                a.term) > 0 and a.term[0] in {154005477}))

        train_wsi_ids = list({an.image
                              for an in all_annotations
                              }.difference(val_test_ids))
        val_wsi_ids = list(val_ids)

        download_path = os.path.join(args.data_path,
                                     "crops-{}".format(args.tile_size))
        images = {
            _id: ImageInstance().fetch(_id)
            for _id in (train_wsi_ids + val_wsi_ids)
        }

        train_crops = [
            AnnotationCrop(images[annot.image],
                           annot,
                           download_path,
                           args.tile_size,
                           zoom_level=args.zoom_level)
            for annot in train_collection
        ]
        val_crops = [
            AnnotationCrop(images[annot.image],
                           annot,
                           download_path,
                           args.tile_size,
                           zoom_level=args.zoom_level) for annot in val_rois
        ]

        for crop in train_crops + val_crops:
            crop.download()

        np.random.seed(42)
        dataset = RemoteAnnotationTrainDataset(
            train_crops, seg_trans=segmentation_transform)
        loader = DataLoader(dataset,
                            shuffle=True,
                            batch_size=args.batch_size,
                            num_workers=args.n_jobs,
                            worker_init_fn=worker_init)

        # network
        device = torch.device(args.device)
        unet = Unet(args.init_fmaps, n_classes=1)
        unet.train()
        unet.to(device)

        optimizer = Adam(unet.parameters(), lr=args.lr)
        loss_fn = BCEWithLogitsLoss(reduction="mean")

        results = {
            "train_losses": [],
            "val_losses": [],
            "val_metrics": [],
            "save_path": []
        }

        for e in range(args.epochs):
            print("########################")
            print("        Epoch {}".format(e))
            print("########################")

            epoch_losses = list()
            unet.train()
            for i, (x, y) in enumerate(loader):
                x, y = (t.to(device) for t in [x, y])
                y_pred = unet.forward(x)
                loss = loss_fn(y_pred, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_losses = [loss.detach().cpu().item()] + epoch_losses[:5]
                print("{} - {:1.5f}".format(i, np.mean(epoch_losses)))
                results["train_losses"].append(epoch_losses[0])

            unet.eval()
            # validation
            val_losses = np.zeros(len(val_rois), dtype=np.float)
            val_roc_auc = np.zeros(len(val_rois), dtype=np.float)
            val_cm = np.zeros([len(val_rois), 2, 2], dtype=np.int)

            for i, roi in enumerate(val_crops):
                foregrounds = find_intersecting_annotations(
                    roi.annotation, val_foreground)
                with torch.no_grad():
                    y_pred, y_true = predict_roi(
                        roi,
                        foregrounds,
                        unet,
                        device,
                        in_trans=transforms.ToTensor(),
                        batch_size=args.batch_size,
                        tile_size=args.tile_size,
                        overlap=args.overlap,
                        n_jobs=args.n_jobs,
                        zoom_level=args.zoom_level)

                val_losses[i] = metrics.log_loss(y_true.flatten(),
                                                 y_pred.flatten())
                val_roc_auc[i] = metrics.roc_auc_score(y_true.flatten(),
                                                       y_pred.flatten())
                val_cm[i] = metrics.confusion_matrix(
                    y_true.flatten().astype(np.uint8),
                    (y_pred.flatten() > 0.5).astype(np.uint8))

            print("------------------------------")
            print("Epoch {}:".format(e))
            val_loss = np.mean(val_losses)
            roc_auc = np.mean(val_roc_auc)
            print("> val_loss: {:1.5f}".format(val_loss))
            print("> roc_auc : {:1.5f}".format(roc_auc))
            cm = np.sum(val_cm, axis=0)
            cnt = np.sum(val_cm)
            print("CM at 0.5 threshold")
            print("> {:3.2f}%  {:3.2f}%".format(100 * cm[0, 0] / cnt,
                                                100 * cm[0, 1] / cnt))
            print("> {:3.2f}%  {:3.2f}%".format(100 * cm[1, 0] / cnt,
                                                100 * cm[1, 1] / cnt))
            print("------------------------------")

            filename = "{}_e_{}_val_{:0.4f}_roc_{:0.4f}_z{}_s{}.pth".format(
                datetime.now().timestamp(), e, val_loss, roc_auc,
                args.zoom_level, args.tile_size)
            torch.save(unet.state_dict(), os.path.join(args.save_path,
                                                       filename))

            results["val_losses"].append(val_loss)
            results["val_metrics"].append(roc_auc)
            results["save_path"].append(filename)

        return results
Esempio n. 7
0
args["save_dir"].mkdir(parents = True, exist_ok=True)

from data import GlacierDataset
from torch.utils.data import DataLoader

paths = {
    "x": sorted(list(args["base_dir"].glob("x*"))),
    "y": sorted(list(args["base_dir"].glob("y*")))
}

ds = GlacierDataset(paths["x"], paths["y"])
loader = DataLoader(ds, batch_size=args["batch_size"], shuffle=False)

import torch.optim
from unet import Unet
from train import train_epoch

model = Unet(9, 3, 4, dropout=0.2).to(args["device"])
optimizer = torch.optim.Adam(model.parameters(), lr=args["lr"])

L=[]
for epoch in range(args["epochs"]):
    l=train_epoch(model, loader, optimizer, args["device"], epoch, args["save_dir"])
    L.append([l[0],l[1]])

torch.save(model.state_dict(), "model.pt")

with open('loss.pkl', 'wb') as f:
    pickle.dump(L, f)
Esempio n. 8
0
                save_image(valOutput[3, :, :, :],
                           '{}/predicted_{}.png'.format(output_path, num))

                del Xval, yval, valOutput
                torch.cuda.empty_cache()

        print("Loss value =", sumloss / len(local_batch))

        avgDice = sumDice / num
        diceList.append(avgDice)
        print("Dice score on validation set =", avgDice)
    print("\nAverage Dice score for the current epoch =", np.mean(avgDice))

# Save the trained UNET model:
unetpath = './model/unet.pth'
torch.save(unet.state_dict(), unetpath)

# free CUDA
#del X, y, Xval, yval, val_batch, local_batch
#torch.cuda.empty_cache()

##---------- Classification Network below -----------------

# Generating the data again with random train-val set
partition, labels = dataProvider(classes)

# Check if the dictionaries were populated correctly
#print(partition)
print("VGG16: Train and validation data count: ")
print(len(partition['train']))
print(len(partition['validation']))
Esempio n. 9
0
    return options


if __name__ == '__main__':
    args = get_args()

    net = Unet(input_channels=3,
               input_width=480,
               input_height=360,
               n_classes=12)

    if args.load:
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from {}'.format(args.load))

    if args.gpu:
        net.cuda()
    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  gpu=args.gpu,
                  img_scale=args.scale)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        print('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)