Exemplo n.º 1
0
def summary(args):
    if args.model == "unet":
        model = unet_model.UNet(args)
        model.summary()
    elif args.model == "resunet":
        model = resunet_model.build_res_unet(args)
        model.summary()
    elif args.model == "segnet":
        model = segnet_model.create_segnet(args)
        model.summary()
    elif args.model == "linknet":
        # pretrained_encoder = 'True',
        # weights_path = './checkpoints/linknet_encoder_weights.h5'
        model = LinkNet(1, input_shape=(256, 256, 3))
        model = model.get_model()
        model.summary()
    elif args.model == "DLinkNet":
        model = segnet_model.create_segnet(args)
        model.summary()
    else:
        print("The model name should be from the unet, resunet, linknet or segnet")
Exemplo n.º 2
0
def train(args):
    train_csv = args.train_csv
    valid_csv = args.valid_csv
    image_paths = []
    label_paths = []
    valid_image_paths = []
    valid_label_paths = []

    with open(train_csv, 'r', newline='\n') as csvfile:
        plots = csv.reader(csvfile, delimiter=',')
        for row in plots:
            # print(row)
            image_paths.append(row[0])
            label_paths.append(row[1])

    with open(valid_csv, 'r', newline='\n') as csvfile:
        plots = csv.reader(csvfile, delimiter=',')
        for row in plots:
            valid_image_paths.append(row[0])
            valid_label_paths.append(row[1])

    if args.model == "unet":
        model = unet_model.UNet(args)
    elif args.model == "resunet":
        model = resunet_model.build_res_unet(args)
    elif args.model == "segnet":
        model = segnet_model.create_segnet(args)
    else:
        print("The model name should be from the unet, resunet or segnet")

    model.compile(optimizer="adam",
                  loss="binary_crossentropy",
                  metrics=["acc"])
    input_shape = args.input_shape
    train_gen = datagen.DataGenerator(image_paths,
                                      label_paths,
                                      batch_size=args.batch_size,
                                      n_channels=input_shape[2],
                                      patch_size=input_shape[1],
                                      shuffle=True)
    valid_gen = datagen.DataGenerator(valid_image_paths,
                                      valid_label_paths,
                                      batch_size=args.batch_size,
                                      n_channels=input_shape[2],
                                      patch_size=input_shape[1],
                                      shuffle=True)
    train_steps = len(image_paths) // args.batch_size
    valid_steps = len(valid_image_paths) // args.batch_size

    model_name = args.model_name
    model_file = model_name + str(args.epochs) + datetime.datetime.today(
    ).strftime("_%d_%m_%y") + ".hdf5"
    log_file = model_name + str(
        args.epochs) + datetime.datetime.today().strftime("_%d_%m_%y") + ".log"
    # Training the model
    model_checkpoint = ModelCheckpoint(model_file,
                                       monitor='val_loss',
                                       verbose=1,
                                       save_best_only=True)
    csv_logger = CSVLogger(log_file, separator=',', append=False)
    model.fit_generator(train_gen,
                        validation_data=valid_gen,
                        steps_per_epoch=train_steps,
                        validation_steps=valid_steps,
                        epochs=args.epochs,
                        callbacks=[model_checkpoint, csv_logger])

    # Save the model
    print("Model successfully trained")
Exemplo n.º 3
0
data.shuffledata(imagepaths, maskpaths)
if len(imagepaths) != len(maskpaths) :
    print('dataset error!')
    exit(0)
img_num = len(imagepaths)
print('find images:',img_num)
imagepaths_train = (imagepaths[0:int(img_num*0.8)]).copy()
maskpaths_train = (maskpaths[0:int(img_num*0.8)]).copy()
imagepaths_eval = (imagepaths[int(img_num*0.8):]).copy()
maskpaths_eval = (maskpaths[int(img_num*0.8):]).copy()

'''
--------------------------def network--------------------------
'''
if opt.model =='UNet':
    net = unet_model.UNet(n_channels = 3, n_classes = 1)
elif opt.model =='BiSeNet':
    net = BiSeNet_model.BiSeNet(num_classes=1, context_path='resnet18')

if opt.continuetrain:
    if not os.path.isfile(os.path.join(dir_checkpoint,'last.pth')):
        opt.continuetrain = False
        print('can not load last.pth, training on init weight.')
if opt.continuetrain:
    net.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last.pth')))
    f = open(os.path.join(dir_checkpoint,'epoch_log.txt'),'r')
    opt.startepoch = int(f.read())
    f.close()
if opt.use_gpu:
    net.cuda()
    cudnn.benchmark = True
Exemplo n.º 4
0
if __name__ == "__main__":

    if args.mode == 'debug':
        args.ngf = 8
        # df = df[:args.batch_size]
        args.display_freq = 5
        args.summary_freq = 2

    output_dir = os.path.join(args.output_dir, 'samples_images')
    checkpoint_dir = os.path.join(args.output_dir, 'checkpoints')
    summary_dir = os.path.join(args.output_dir, 'summary')
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(summary_dir, exist_ok=True)

    model = unet_model.UNet(args.nClass, args.height, args.width, ngf=args.ngf)
    loss_loc = losses.WeightedHausdorffDistance(args.height,
                                                args.width,
                                                p=args.p,
                                                return_2_terms=True)
    # dataset
    df = pandas.read_json('./datasets/train/train.json')

    print('---- hyper parameters>>>>>>>')

    for k, v in args.__dict__.items():
        print('{}: {}'.format(k, v))

    print('<<<<<< hyper parameters')

    dataset = data.create_dataset(df, args.batch_size, args.height, args.width)
Exemplo n.º 5
0
                               collate_fn=csv_collator,
                               height=args.height,
                               width=args.width,
                               seed=args.seed,
                               batch_size=args.batch_size,
                               drop_last_batch=args.drop_last_batch,
                               num_workers=args.nThreads,
                               val_dir=args.val_dir,
                               max_valset_size=args.max_valset_size)

# Model
with peter('Building network'):
    model = unet_model.UNet(3,
                            1,
                            height=args.height,
                            width=args.width,
                            known_n_points=args.n_points,
                            device=device,
                            ultrasmall=args.ultrasmallnet)
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f" with {ballpark(num_params)} trainable parameters. ", end='')
model = nn.DataParallel(model)
model.to(device)

# Loss functions
loss_regress = nn.SmoothL1Loss()
loss_loc = losses.WeightedHausdorffDistance(resized_height=args.height,
                                            resized_width=args.width,
                                            p=args.p,
                                            return_2_terms=True,
                                            device=device)
Exemplo n.º 6
0
# Restore saved checkpoint (model weights)
with peter("Loading checkpoint"):

    if os.path.isfile(args.model):
        if args.cuda:
            checkpoint = torch.load(args.model)
        else:
            checkpoint = torch.load(
                args.model, map_location=lambda storage, loc: storage)
        # Model
        if args.n_points is None:
            if 'n_points' not in checkpoint:
                # Model will also estimate # of points
                model = unet_model.UNet(3, 1,
                                        known_n_points=None,
                                        height=args.height,
                                        width=args.width,
                                        ultrasmall=args.ultrasmallnet)

            else:
                # The checkpoint tells us the # of points to estimate
                model = unet_model.UNet(3, 1,
                                        known_n_points=checkpoint['n_points'],
                                        height=args.height,
                                        width=args.width,
                                        ultrasmall=args.ultrasmallnet)
        else:
            # The user tells us the # of points to estimate
            model = unet_model.UNet(3, 1,
                                    known_n_points=args.n_points,
                                    height=args.height,
Exemplo n.º 7
0
def accuracy(args):
    if args.onehot == "yes":
        if args.model == "resunet":
            model = resunet_novel2.build_res_unet(args)
            model.load_weights(args.weights)
    else:
        if args.model == "unet":
            model = unet_model.UNet(args)
            model.load_weights(args.weights)
        elif args.model == "resunet":
            # model = load_model(args.weights)
            model = resunet_model.build_res_unet(args)
            model.load_weights(args.weights)
        elif args.model == "segnet":
            model = segnet_model.create_segnet(args)
            model.load_weights(args.weights)
        else:
            print("The model name should be from the unet, resunet or segnet")
    # print(model)
    paths_file = args.csv_paths
    test_image_paths = []
    test_label_paths = []
    test_pred_paths = []
    with open(paths_file, 'r', newline='\n') as csvfile:
        plots = csv.reader(csvfile, delimiter=',')
        for row in plots:
            test_image_paths.append(row[0])
            test_label_paths.append(row[1])
            if len(row) > 2:
                test_pred_paths.append((row[2]))
            # print(row[0], row[1])
    print(len(test_image_paths), len(test_label_paths), len(test_pred_paths))
    # print(test_image_paths, test_label_paths)

    tn, fp, fn, tp = 0, 0, 0, 0
    rows = []
    for i in range(len(test_image_paths)):
        image = gdal.Open(test_image_paths[i])
        image_array = np.array(image.ReadAsArray()) / 255
        image_array = image_array.transpose(1, 2, 0)
        label = gdal.Open(test_label_paths[i])
        label_array = np.array(label.ReadAsArray()) / 255
        label_array = np.expand_dims(label_array, axis=-1)
        # print(len(test_pred_paths))
        if len(test_pred_paths) > 0:
            pred = gdal.Open(test_pred_paths[i])
            pred_array = np.array(pred.ReadAsArray())
            pred_array = np.expand_dims(pred_array, axis=-1)
            image_array = np.concatenate((image_array, pred_array), axis=2)
        fm = np.expand_dims(image_array, axis=0)
        result_array = model.predict(fm)
        result_array = np.squeeze(result_array)  # .transpose(2, 0, 1)
        # print(result_array.shape)
        # result_array = result_array[1:, :, :]
        # print(result_array.shape)
        A = np.around(label_array.flatten())
        B = np.around(result_array.flatten())
        cm = confusion_matrix(A, B)
        if len(cm) == 1:
            rows.append(
                [test_image_paths[i], test_label_paths[i], cm[0][0], 0, 0, 0])
            tn += cm[0][0]
        else:
            rows.append([
                test_image_paths[i], test_label_paths[i], cm[0][0], cm[0][1],
                cm[1][0], cm[1][1]
            ])
            tn += cm[0][0]
            fp += cm[0][1]
            fn += cm[1][0]
            tp += cm[1][1]
        print("Predicted " + str(i + 1) + " Images")

    iou = tp / (tp + fp + fn)
    f_score = (2 * tp) / (2 * tp + fp + fn)

    print("IOU Score: " + str(iou))
    print("F-Score: " + str(f_score))
training_transforms += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
trainset = CSVDataset(args.train_dir,
                      transforms=transforms.Compose(training_transforms),
                      max_dataset_size=args.max_trainset_size)
trainset_loader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             drop_last=args.drop_last_batch,
                             shuffle=True,
                             num_workers=args.nThreads,
                             collate_fn=csv_collator)

# Model
with peter('Building network'):
    model = unet_model.UNet(3,
                            1,
                            height=args.height,
                            width=args.width,
                            known_n_points=args.n_points)
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f" with {ballpark(num_params)} trainable parameters. ", end='')
model = nn.DataParallel(model)
model.to(device)

# Loss function
loss_regress = nn.SmoothL1Loss()
loss_loc = losses.WeightedHausdorffDistance(resized_height=args.height,
                                            resized_width=args.width,
                                            p=args.p,
                                            return_2_terms=True,
                                            device=device)
l1_loss = nn.L1Loss(size_average=False)
Exemplo n.º 9
0
def accuracy(args):
    if args.model == "unet":
        model = unet_model.UNet(args)
        model.load_weights(args.weights)
    elif args.model == "resunet":
        # model = load_model(args.weights)
        model = resunet_model.build_res_unet(args)
        model.load_weights(args.weights)
    elif args.model == "segnet":
        model = segnet_model.create_segnet(args)
        model.load_weights(args.weights)
    else:
        print("The model name should be from the unet, resunet or segnet")
    # print(model)
    paths_file = args.csv_paths
    test_image_paths = []
    test_label_paths = []
    with open(paths_file, 'r', newline='\n') as csvfile:
        plots = csv.reader(csvfile, delimiter=',')
        for row in plots:
            test_image_paths.append(row[0])
            test_label_paths.append(row[1])
    y = []
    y_pred = []
    for i in range(len(test_image_paths)):
        image = gdal.Open(test_image_paths[i])
        image_array = np.array(image.ReadAsArray()) / 255
        image_array = image_array.transpose(1, 2, 0)
        label = gdal.Open(test_label_paths[i])
        label_array = np.array(label.ReadAsArray())
        label_array = np.expand_dims(label_array, axis=-1)
        fm = np.expand_dims(image_array, axis=0)
        result_array = model.predict(fm)
        result_array = np.argmax(result_array[0], axis=2)
        result_array = np.squeeze(result_array)
        y.append(np.around(label_array))
        y_pred.append(result_array)
        print("Predicted " + str(i + 1) + " Images")
    # print(len(np.array(y).flatten()), len(np.array(y_pred).flatten()))
    print("\n")
    cm = confusion_matrix(np.array(y).flatten(), np.array(y_pred).flatten())
    cm_multi = multilabel_confusion_matrix(np.array(y).flatten(), np.array(y_pred).flatten())
    print("Confusion Matrix " + "\n")
    print(cm, "\n")
    accuracy = np.trace(cm/np.sum(cm))
    print("Overal Accuracy: ", round(accuracy, 3), "\n")

    mean_iou = 0
    mean_f1 = 0
    for j in range(len(cm_multi)):
        print("Class: " + str(j))
        iou = cm_multi[j][1][1] / (cm_multi[j][1][1] + cm_multi[j][0][1] + cm_multi[j][1][0])
        f1 = (2 * cm_multi[j][1][1]) / (2 * cm_multi[j][1][1] + cm_multi[j][0][1] + cm_multi[j][1][0])
        precision = cm_multi[j][1][1] / (cm_multi[j][1][1] + cm_multi[j][0][1])
        recall = cm_multi[j][1][1] / (cm_multi[j][1][1] + cm_multi[j][1][0])
        mean_iou  += iou
        mean_f1 += f1
        print("IoU Score: ", round(iou, 3))
        print("F1-Measure: ", round(f1, 3))
        print("Precision: ", round(precision, 3))
        print("Recall: ", round(recall, 3), "\n")
    print("Mean IoU Score: ", round(mean_iou/len(cm_multi), 3))
    print("Mean F1-Measure: ", round(mean_f1/len(cm_multi), 3))
Exemplo n.º 10
0
def eval(opt):
    import matplotlib
    matplotlib.use('qt5agg')
    import matplotlib.pyplot as plt
    plt.ion()
    import torch.nn.functional as F

    # files = sorted(glob.glob(f'{opt.dataDirectory}/*'))

    print('===> Loading Data ... ', end='')
    opt.dataDirectory = f'{opt.dataDirectory}/skull{opt.skull}_as_test_ed/'

    inputs, mask, label = load_test_data(opt)

    inputs = inputs.squeeze().unfold(-1, 3, 1).permute((0, 4, 1, 2, 3)).contiguous()
    inputs = inputs.view(-1, 512, 512, 400).permute(3, 0, 1, 2).split(64, 0)

    masks = mask.squeeze().unfold(-1, 3, 1).permute((3, 0, 1, 2)).contiguous()
    masks = masks.permute(3, 0, 1, 2)[:, 1].split(64, 0)

    label = label.squeeze().unfold(-1, 3, 1).permute((3, 0, 1, 2)).contiguous()
    label = label.permute(3, 0, 1, 2)[:, 1]
    label = ((label + 1000) / 4000).split(64, 0)
    print('done')

    print('===> Loading Model ... ', end='')
    if not opt.ckpt:
        timestamps = sorted(glob.glob(f'{opt.model_dir}/*'))
        if not timestamps:
            raise Exception(f'No save directories found in {opt.model_dir}')
        lasttime = timestamps[-1].split('/')[-1]
        models = sorted(glob.glob(f'{opt.model_dir}/{lasttime}/*'))
        if not models:
            raise Exception(f'No models found in the last run ({opt.model_dir}{lasttime}/')
        model_file = models[-1].split('/')[-1]
        opt.ckpt = f'{lasttime}/{model_file}'

    model = unet_model.UNet(6, 1)
    saved_dict = SimpleNamespace(**torch.load(f'{opt.model_dir}{opt.ckpt}'))
    model, device = _check_branch(opt, saved_dict, model)
    print('done')

    crit = nn.MSELoss()
    e_loss = 0.0
    preds = []
    input_vols = []
    label_vol = []
    mask_vol = []
    b_losses = []
    n_samps = []

    print('===> Evaluating Model')
    with torch.no_grad():
        model.eval()
        for i, f, m, l in zip(range(1, len(inputs) + 1), inputs, masks, label):
            inputs, mask, label = f.to(device=device), m.to(device=device).bool(), l.to(device=device)

            n_samps.append(mask.sum().item())
            pred = model(inputs).squeeze()
            loss = crit(pred[mask], label[mask])
            b_loss = loss.item()
            b_losses.append(loss.item())

            preds.append(pred.clone().cpu())
            input_vols.append(inputs.clone().cpu())
            label_vol.append(label.clone().cpu())
            mask_vol.append(mask.clone().cpu())

            print(f"=> Done with {i} / {len(inputs)}  Batch Loss: {b_loss:.6f}")

        e_loss = (torch.tensor(n_samps) * torch.tensor(b_losses)).sum() / torch.tensor(n_samps).sum()
        print(f"===> Avg. Loss: {e_loss:.6f}")

        # torch.backends.cudnn.enabled = False

        # model = torch.jit.trace(model, inputs[:, :, 0:128, 0:128, 0:128].clone())

        #
        # inputs = inputs.to(device=device)
        # mask = mask.to(device=device)
        # label = label.to(device=device)

        # for param in model.parameters():
        #     param.requires_grad = False

        # pred = model(inputs).squeeze()
        # preds.append(pred.clone())
        # input_vols.append(inputs.clone())
        # label_vol.append(label.clone())
        # mask_vol.append(mask.clone())
        #
        # loss = crit(pred[mask], label[mask])
        # e_loss += loss.item()
        # b_loss = loss.item()
        #
        # print(f"===> Avg. MSE Loss: {e_loss / len(infer_loader):.6f}")

    pred_vol = F.pad(torch.cat(preds, dim=0), [0, 0, 0, 0, 1, 1])
    # mask_vol = F.pad(torch.cat(mask_vol, dim=0), [0, 0, 0, 0, 1, 1])
    label_vol = F.pad(torch.cat(label_vol, dim=0), [0, 0, 0, 0, 1, 1])
    # input_vols = torch.cat(input_vols, dim=0)
    # label_vol = torch.cat(label_vol, dim=0)

    # pred_vol = pred_vol * mask_vol
    # label_vol = label_vol * mask_vol
    pred_vol = (pred_vol * 4000.0) - 1000.0
    pred_vol[pred_vol < -1000.0] = -1000
    pred_vol[pred_vol > 3000.0] = 3000.0
    # pred_vol = pred_vol.permute(1, 2, 0)

    label_vol = (label_vol * 4000.0) - 1000.0
    label_vol[label_vol < -1000.0] = -1000
    label_vol[label_vol > 3000.0] = 3000.0
    # label_vol = label_vol.permute(1, 2, 0)

    raw_file = sorted(glob.glob(f'{opt.rawDir}skull{opt.skull}*.mat'))[-1]
    raw_dict = loadmat(raw_file)
    ct_mask = torch.tensor(raw_dict['boneMask2']).permute(2, 0, 1)
    ct_mask = ct_mask >= 0.5
    ct_mask = ct_mask.to(dtype=torch.float32)

    # label_vol *= ct_mask
    # pred_vol *= ct_mask

    import CAMP.Core as core
    import CAMP.FileIO as io
    import CAMP.StructuredGridTools as st
    import CAMP.UnstructuredGridOperators as uo
    import CAMP.StructuredGridOperators as so

    out_pred = core.StructuredGrid(pred_vol.shape, tensor=pred_vol.unsqueeze(0), spacing=[0.48, 0.48, 0.48])
    out_label = core.StructuredGrid(label_vol.shape, tensor=label_vol.unsqueeze(0), spacing=[0.48, 0.48, 0.48])
    out_mask = core.StructuredGrid(ct_mask.shape, tensor=ct_mask.unsqueeze(0), spacing=[0.48, 0.48, 0.48])

    a = -0.35
    rot_mat = torch.tensor([[np.cos(a), -np.sin(a), 0, 0],
                            [np.sin(a), np.cos(a), 0, 0],
                            [0, 0, 1, 0],
                            [0, 0, 0, 1]])
    aff_tf = so.AffineTransform.Create(affine=rot_mat)
    out_pred = aff_tf(out_pred)
    out_label = aff_tf(out_label)
    out_mask = aff_tf(out_mask)

    out_mask.data[out_mask.data >= 0.5] = 1.0

    out_pred = out_mask * out_pred
    out_label = out_mask * out_label

    print('Saving ... ', end='')

    WriteDICOM(out_pred, f'{opt.outDirectory}/skull{opt.skull}/skull{opt.skull}_prediction_ed/')
    WriteDICOM(out_label, f'{opt.outDirectory}/skull{opt.skull}/skull{opt.skull}_label/')

    # io.SaveITKFile(out_pred, f'{opt.outDirectory}/skull{opt.skull}/prediction_volume.nii.gz')
    # io.SaveITKFile(out_label, f'{opt.outDirectory}/skull{opt.skull}/label_volume.nii.gz')

    print('done')
Exemplo n.º 11
0
def learn(opt):
    import matplotlib
    matplotlib.use('agg')
    import matplotlib.pyplot as plt
    plt.ion()

    def checkpoint(state, opt, epoch):
        path = f'{opt.outDirectory}/saves/{opt.timestr}/epoch_{epoch:05d}_model.pth'
        torch.save(state, path)
        print(f"===> Checkpoint saved for Epoch {epoch}")

    def train(epoch, scheduler):
        model.train()
        n_samps = []
        b_losses = []
        # crit = nn.MSELoss(reduction='none')
        crit = nn.MSELoss()

        for iteration, batch in enumerate(training_data_loader, 1):
            inputs, mask, label = batch[0].to(device=device), batch[1].to(device=device), batch[2].to(device=device)

            n_samps.append(mask.sum().item())
            optimizer.zero_grad()
            pred = model(inputs).squeeze()

            loss = crit(pred[mask], label[mask])
            loss.backward()

            b_loss = loss.item()
            b_losses.append(loss.item())
            optimizer.step()

            if iteration == len(training_data_loader) // 2 and epoch % 10 == 0:
                with torch.no_grad():
                    l1Loss = nn.MSELoss()
                    im = len(inputs) // 2

                    mask_slice = mask[im, :, :]
                    label_slice = label[im, :, :] * mask_slice
                    pred_slice = pred[im, :, :] * mask_slice
                    input1_slice = inputs[im, 0, :, :] * mask_slice
                    input2_slice = inputs[im, 1, :, :] * mask_slice

                    add_figure(input1_slice, writer, title='Input 1', label='Train/Input1', cmap='viridis', epoch=epoch)
                    add_figure(input2_slice, writer, title='Input 2', label='Train/Input2', cmap='viridis', epoch=epoch)

                    # Add the prediction
                    pred_loss = l1Loss(pred_slice[mask_slice], label_slice[mask_slice])
                    add_figure(pred_slice, writer, title='Predicted CT', label='Train/Pred CT', cmap='plasma',
                               epoch=epoch,
                               text=[f'Loss: {pred_loss.item():.4f}',
                                     f'Mean: {pred_slice[mask_slice].mean():.2f}',
                                     f'Min:  {pred_slice[mask_slice].min():.2f}',
                                     f'Max:  {pred_slice[mask_slice].max():.2f}'
                                     ], min_max=[0.0, 1.0], text_start=[5, 5], text_spacing=40)
                    # Add the stir
                    add_figure(label_slice, writer, title='Real CT', label='Train/Real CT', cmap='plasma', epoch=epoch,
                               text=[f'Mean: {label_slice[mask_slice].mean():.2f}',
                                     f'Min:  {label_slice[mask_slice].min():.2f}',
                                     f'Max:  {label_slice[mask_slice].max():.2f}'
                                     ], min_max=[0.0, 1.0], text_start=[5, 5], text_spacing=40)

            for param_group in optimizer.param_groups:
                clr = param_group['lr']
            writer.add_scalar('Batch/Learning Rate', clr, (iteration + (len(training_data_loader) * (epoch - 1))))
            writer.add_scalar('Batch/Avg. MSE Loss', b_loss, (iteration + (len(training_data_loader) * (epoch - 1))))
            print("=> Done with {} / {}  Batch Loss: {:.6f}".format(iteration, len(training_data_loader), b_loss))

        e_loss = (torch.tensor(n_samps) * torch.tensor(b_losses)).sum() / torch.tensor(n_samps).sum()
        writer.add_scalar('Epoch/Avg. MSE Loss', e_loss, epoch)
        # print(f"===> Avg. Loss: {e_loss:.6f}")
        print("===> Epoch {} Complete: Avg. Loss: {:.6f}".format(epoch, e_loss))
        scheduler.step(e_loss / len(training_data_loader))

    def infer(epoch):
        # crit = nn.MSELoss(reduction='none')
        n_samps = []
        b_losses = []
        crit = nn.MSELoss()

        print('===> Evaluating Model')
        with torch.no_grad():
            model.eval()
            for iteration, batch in enumerate(testing_data_loader, 1):
                inputs, mask, label = batch[0].to(device=device), batch[1].to(device=device), batch[2].to(device=device)

                n_samps.append(mask.sum().item())
                pred = model(inputs).squeeze()
                loss = crit(pred[mask], label[mask])
                b_loss = loss.item()
                b_losses.append(loss.item())

                if iteration == len(testing_data_loader) // 2 and epoch % 10 == 0:
                    im = len(inputs) // 4

                    l1Loss = nn.MSELoss()
                    mask_slice = mask[im, :, :]
                    label_slice = label[im, :, :] * mask_slice
                    pred_slice = pred[im, :, :] * mask_slice

                    if epoch == 10:
                        # Add the input images - they are not going to change
                        input1_slice = inputs[im, 0, :, :] * mask_slice
                        input2_slice = inputs[im, 1, :, :] * mask_slice
                        add_figure(input1_slice, writer, title='Input 1', label='Infer/Input1', cmap='viridis',
                                   epoch=epoch)
                        add_figure(input2_slice, writer, title='Input 2', label='Infer/Input2', cmap='viridis',
                                   epoch=epoch)
                        add_figure(label_slice, writer, title='Real CT', label='Infer/Real CT', cmap='plasma',
                                   epoch=epoch,
                                   text=[f'Mean: {label_slice[mask_slice].mean():.2f}',
                                         f'Min:  {label_slice[mask_slice].min():.2f}',
                                         f'Max:  {label_slice[mask_slice].max():.2f}'
                                         ], min_max=[0.0, 1.0])

                    # Add the prediction

                    pred_loss = l1Loss(pred_slice[mask_slice], label_slice[mask_slice])
                    add_figure(pred_slice, writer, title='Predicted CT', label='Infer/Pred T1', cmap='plasma',
                               epoch=epoch,
                               text=[f'Loss: {pred_loss.item():.4f}',
                                     f'Mean: {pred_slice[mask_slice].mean():.2f}',
                                     f'Min:  {pred_slice[mask_slice].min():.2f}',
                                     f'Max:  {pred_slice[mask_slice].max():.2f}'
                                     ], min_max=[0.0, 1.0])
                print(f"=> Done with {iteration} / {len(testing_data_loader)}  Batch Loss: {b_loss:.6f}")

            e_loss = (torch.tensor(n_samps) * torch.tensor(b_losses)).sum() / torch.tensor(n_samps).sum()
            writer.add_scalar('Infer/Avg. MSE Loss', e_loss, epoch)
            print(f"===> Avg. Loss: {e_loss:.6f}")

    # Add the git information to the opt
    _get_branch(opt)
    timestr = time.strftime("%Y-%m-%d-%H%M%S")
    opt.timestr = timestr
    writer = SummaryWriter(f'{opt.outDirectory}/runs/{timestr}')
    writer.add_text('Parameters', opt.__str__())

    # Seed anything random to be generated
    torch.manual_seed(opt.seed)

    try:
        os.stat(f'{opt.outDirectory}/saves/{timestr}/')
    except OSError:
        os.makedirs(f'{opt.outDirectory}/saves/{timestr}/')

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print('===> Generating Datasets ... ', end='')
    training_data_loader, testing_data_loader = get_loaders(opt)
    print(' done')

    model = unet_model.UNet(6, 1)
    model = model.to(device)
    model = nn.DataParallel(model)

    optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=1e-6)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=25, verbose=True, factor=0.5,
                                                     threshold=5e-3, cooldown=75, min_lr=1e-6)

    print("===> Beginning Training")

    epochs = range(1, opt.nEpochs + 1)

    for epoch in epochs:
        print("===> Learning Rate = {}".format(optimizer.param_groups[0]['lr']))
        train(epoch, scheduler)
        if epoch % 10 == 0:
            checkpoint({
                'epoch': epoch,
                'scheduler': opt.scheduler,
                'git_branch': opt.git_branch,
                'git_hash': opt.git_hash,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()}, opt, epoch)
        infer(epoch)