Пример #1
0
def train_xz_init_model_chorda(num_classes, epochs):
    model_name = "XZ-Unet-Init"
    # create a new model folder based on name and date
    now = datetime.datetime.now()
    now_str = now.strftime('%Y-%m-%d_%H-%M-%S')

    rootDir = "C:/users/jfauser/IPCAI2019/ModelData/" + model_name + "/"

    print("Started at {}".format(now_str))

    input_shape = (128, 512)
    model = unet(input_shape + (1, ), num_classes)
    batch_size = 8

    print("loading data set 1")
    file_dataset1 = rootDir + "dataset1.h5"
    f1 = h5py.File(file_dataset1.strip(), "r")
    images_1 = f1["images"][()]
    all_labels = f1["labels"][()]  # already in to_categorical
    slices = []
    for set in all_labels:
        slices.append(set[:, :, 5])  #chorda label
    slices = np.stack(slices)
    labels_1 = utils.to_categorical(slices, num_classes)

    print("loading data set 2")
    file_dataset2 = rootDir + "dataset2.h5"
    f2 = h5py.File(file_dataset2.strip(), "r")
    images_2 = f2["images"][()]
    all_labels = f2["labels"][()]
    slices = []
    for set in all_labels:
        slices.append(set[:, :, 5])  #chorda label
    slices = np.stack(slices)
    labels_2 = utils.to_categorical(slices, num_classes)

    out_dir = rootDir + now_str + "/"
    filename_weigths1 = out_dir + "weights1.hdf5"
    filename_weigths2 = out_dir + "weights2.hdf5"

    if not os.path.isdir(out_dir):
        os.mkdir(out_dir)

    print("train model 1")
    train_model.train_model([images_1, labels_1], filename_weigths1, model,
                            batch_size,
                            epochs)  # commences training on the set
    del model

    print("train model 2")
    model = unet(input_shape + (1, ), num_classes)
    train_model.train_model([images_2, labels_2], filename_weigths2, model,
                            batch_size,
                            epochs)  # commences training on the set
    del model
Пример #2
0
def train_xy_init_model(num_classes, epochs):
    # first get the time
    now = datetime.datetime.now()
    now_str = now.strftime('%Y-%m-%d_%H-%M-%S')
    print("Started at {}".format(now_str))
    # then create a new model folder based on name and date
    model_name = "XY-Unet-Init"
    rootDir = "C:/users/jfauser/IPCAI2019/ModelData/" + model_name + "/"
    out_dir = rootDir + now_str + "/"
    if not os.path.isdir(out_dir):
        os.mkdir(out_dir)

    # compile the net architecture
    batch_size = 16
    input_shape = (512, 512)
    model = unet(input_shape + (1, ), num_classes)

    # load the two training sets
    print("loading data set 1")
    if num_classes == 2:
        images_1, labels_1 = tools.get_data_set_for_chorda(rootDir +
                                                           "dataset1.h5")
    else:
        images_1, labels_1 = tools.get_data_set(rootDir + "dataset1.h5")

    print("loading data set 2")
    if num_classes == 2:
        images_2, labels_2 = tools.get_data_set_for_chorda(rootDir +
                                                           "dataset2.h5")
    else:
        images_2, labels_2 = tools.get_data_set(rootDir + "dataset2.h5")

    # and train
    filename_weigths1 = out_dir + "weights1.hdf5"
    filename_weigths2 = out_dir + "weights2.hdf5"

    print("train model 1")
    train_model.train_model([images_1, labels_1], filename_weigths1, model,
                            batch_size,
                            epochs)  # commences training on the set
    del model

    print("train model 2")
    model = unet(input_shape + (1, ), num_classes)
    train_model.train_model([images_2, labels_2], filename_weigths2, model,
                            batch_size,
                            epochs)  # commences training on the set
    del model

    return now_str
Пример #3
0
def predict(images, img_infos, filenames, model_file, out_dir, num_classes):
    print("predicting")
    for f, info in zip(filenames, img_infos):
        print("  image " + f)

    print("")
    print("with model " + model_file)
    print("writing to " + out_dir)
    input_shape = (512, 512)
    model = unet(input_shape + (1, ), num_classes)
    model.load_weights(model_file)
    predictions = model.predict(images, batch_size=16, verbose=1)
    print("finished")

    # devide into images
    idx = 0
    for info, file in zip(img_infos, filenames):
        dim = info[0][2]
        img_predictions = predictions[idx:idx +
                                      dim]  # predictions for whole image
        print("size of predictions: {}".format(img_predictions.shape))
        labeled_slices = []
        for slice_predictions in img_predictions:  # num_classes predictions for one slice
            labeled_slice = vote_on_predicted_slices(
                slice_predictions
            )  # make a labeled slice out of num_classes slices
            labeled_slices.append(labeled_slice)
        labeled_img = np.stack(labeled_slices)
        print("size of voting result: {}".format(labeled_img.shape))
        label_image = sitk_tools.create_itk_image(labeled_img, info)
        sitk_tools.write_img(label_image, out_dir + file)
        idx += dim
Пример #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", help="Directory that contains the data", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training")
    parser.add_argument("--gpu_id", type=str, default="1", help="Select a gpu")
    parser.add_argument("--use_vat", type=int, default=0, help="Enables VAT")
    parser.add_argument("--use_pseudo_labels", type=int, default=0, help="Enables pseudo-label usage")
    parser.add_argument("--use_mean_teacher", type=int, default=0, help="Enables mean teacher")
    parser.add_argument("--dataset", type=str, default="ENDOVIS", help="Choose RMIT or Endovis to train on.")
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    dataset = args.dataset.upper()

    # init network
    ch = 3
    h = 256 if dataset == "ENDOVIS" else 288
    w = 320 if dataset == "ENDOVIS" else 384
    x = tf.placeholder(tf.float32, shape=[args.batch_size, h, w, ch])
    is_training = tf.placeholder(tf.bool)
    alpha = tf.placeholder_with_default(1 / 5.5, [], name="alpha_lrelu")
    num_parts = 5 if dataset == "ENDOVIS" else 4
    num_connections = 4 if dataset == "ENDOVIS" else 0
    keep_prob = .9 if dataset == "RMIT" else .7
    output_map, _ = unet(x, keep_prob, ch,
                         num_parts + num_connections,
                         is_training=is_training,
                         features_root=64,
                         alpha=alpha)

    train(x, output_map, alpha, 50000, args.root, args.batch_size,
          is_training, args.gpu_id, args.use_vat, 
          args.use_pseudo_labels, args.use_mean_teacher, args.dataset)
Пример #5
0
def train_xz_init_model(num_classes, epochs):
    #This needs to be set according to your directory
    #Needs to contain /rawdata and /groundtruth with equally named mhd-files
    model_name = "XZ-Unet-Init"
    # create a new model folder based on name and date
    now = datetime.datetime.now()
    now_str = now.strftime('%Y-%m-%d_%H-%M-%S')

    rootDir = "C:/users/jfauser/IPCAI2019/ModelData/" + model_name + "/"

    print("Started at {}".format(now_str))

    input_shape = (128, 512)
    model = unet(input_shape + (1, ), num_classes)
    batch_size = 8

    print("loading data set 1")
    images_1, labels_1 = get_data_set(rootDir + "dataset1.h5")
    print("loading data set 2")
    images_2, labels_2 = get_data_set(rootDir + "dataset2.h5")

    out_dir = rootDir + now_str + "/"
    filename_weigths1 = out_dir + "weights1.hdf5"
    filename_weigths2 = out_dir + "weights2.hdf5"

    if not os.path.isdir(out_dir):
        os.mkdir(out_dir)

    print("train model 1")
    train_model.train_model([images_1, labels_1], filename_weigths1, model,
                            batch_size,
                            epochs)  # commences training on the set
    del model

    print("train model 2")
    model = unet(input_shape + (1, ), num_classes)
    train_model.train_model([images_2, labels_2], filename_weigths2, model,
                            batch_size,
                            epochs)  # commences training on the set
    del model
Пример #6
0
 def build_net(self):
     batch_norm = self.params.batch_norm
     # tf.reset_default_graph()
     roi_images = self.tf_placeholders["images"]
     net = self.params.net
     if net == 'unet':
         mask_logits = unet(roi_images,
                            num_classes=2,
                            training=self.is_train,
                            init_channels=8,
                            n_layers=6,
                            batch_norm=batch_norm)
     else:
         mask_logits = fcn_gcn_net(roi_images,
                                   num_classes=2,
                                   k_gcn=11,
                                   training=self.is_train,
                                   init_channels=8,
                                   n_layers=7,
                                   batch_norm=True)
     return mask_logits
Пример #7
0
            if rawimg.GetDepth() < 128:
                sample = resample(rawimg, (512, 512, 128), False)
            else:
                sample = rawimg
            segment_cubed((sample, i))


ct_dir = "Data/rawdata/"  #Needs to be in current working directory
gt_dir = "Data/groundtruth/"
weight_dir = "model/weights2018-10-23-13/"
data_dir = "Dataset_1/"
input_shape = (128, 128)
num_classes = 11
cwd = os.getcwd() + "/"

model = unet(input_shape + (1, ), num_classes)
weights = os.listdir(cwd + weight_dir + data_dir)[-1]
print("Using model {} for segmentation".format(weights))
model.load_weights(cwd + weight_dir + data_dir + weights)
model.predict(np.zeros((1, 128, 128, 1)))  #warmup

pr_dir = cwd + weight_dir + "prediction1/"
if not os.path.isdir(pr_dir):
    os.mkdir(pr_dir)
print(pr_dir)
images = [
    i for i in sorted(os.listdir(cwd + ct_dir))
    if i.endswith('.mhd') and not (i in os.listdir(pr_dir))
]
print(images)
segment_list(images)
Пример #8
0
def main():
    #Set GPU device if available
    cuda = torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")
    # Initialise Model
    model = u_net.unet().to(device)
    def initweights(layer):
  
    if type(layer)==nn.Conv2d:
        kernsize=layer.kernel_size
        cin=layer.in_channels
        N=cin*kernsize[0]*kernsize[0]
        standard=math.sqrt(2/N)
        torch.nn.init.normal_(layer.weight,std=standard)
    model.apply(initweights)
    
    Start_From_Checkpoint = False
    #Initialise Dataset
    input_folder = '/mnt/lustre/projects/ds19/eng121/Image_crops/'
    target_folder = '/mnt/lustre/projects/ds19/eng121/Map_crops/'
    batch_size = 5
    n_workers = multiprocessing.cpu_count()
    trainset = dataset(input_folder, target_folder, model,device,False)
    valset = dataset(input_folder,target_folder,model,device,True)
    #Initialise Dataloader
    dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=n_workers)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                            shuffle=True, num_workers=n_workers)
    optimizer = optim.Adam(model.parameters(), lr = 1e-3)
    loss_fn = nn.CrossEntropyLoss()
    save_dir = './'
    model_name = 'Unet'
    #Create Save Path from save_dir and model_name, we will save and load our checkpoint here
    # Save_Path = os.path.join(save_dir, model_name + ".pt")
    Load_Path = os.path.join(save_dir, model_name + ".pt")

    #Setup defaults:
    start_epoch = 0
    best_valid_acc = 0
    train_loss = []
    train_acc = []
    valid_loss = []
    valid_acc = []
    #Create the save directory if it does note exist
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    #Load Checkpoint
    if Start_From_Checkpoint:
        #Check if checkpoint exists
        if os.path.isfile(Load_Path):
            #load Checkpoint
            check_point = torch.load(Load_Path)
            #Checkpoint is saved as a python dictionary
            model.load_state_dict(check_point['model_state_dict'])
            optimizer.load_state_dict(check_point['optimizer_state_dict'])
            start_epoch = check_point['epoch']
            best_valid_acc = check_point['best_acc']
            train_loss = check_point['train_loss']
            train_acc = check_point['train_acc']
            valid_loss = check_point['valid_loss']
            valid_acc = check_point['valid_acc']
            
            print("Checkpoint loaded, starting from epoch:", start_epoch)
        else:
            #Raise Error if it does not exist
            raise ValueError("Checkpoint Does not exist")
    else:
        #If checkpoint does exist and Start_From_Checkpoint = False
        #Raise an error to prevent accidental overwriting
        if os.path.isfile(Load_Path):
            raise ValueError("Warning Checkpoint exists")
        else:
            print("Starting from scratch")
    train(save_dir,model_name,model,optimizer,device,loss_fn,dataloader,valloader,best_valid_acc,start_epoch,n_epochs=20,train_loss = train_loss,train_acc=train_acc, val_loss = valid_loss, val_acc = valid_acc)

if __name__ == '__main__':
    main()
Пример #9
0
def evaluate_ENDOVIS(root, model):

    T = 20  # threshold

    # GPU Config
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=.9)

    # Start session and initialize weights
    # tf.reset_default_graph()
    # imported_meta = tf.train.import_meta_graph(model + "/model.ckpt.meta")
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    # imported_meta.restore(sess, tf.train.latest_checkpoint(model + "/"))
    # for n in tf.get_default_graph().as_graph_def().node:
    #     if "output_map" in n.name:
    #         print(n.name)

    net_in = tf.placeholder(tf.float32, shape=[1, 256, 320, 3])
    y = tf.placeholder(tf.float32, shape=[1, 256, 320, 9])
    is_training = tf.placeholder(tf.bool)
    keep_prob = .9
    num_parts = 5
    num_connections = 4

    net_out, attention = unet(net_in,
                              keep_prob,
                              3,
                              num_parts + num_connections,
                              is_training=is_training,
                              features_root=64,
                              alpha=1 / 5.5)
    tv = tf.image.total_variation(net_out)
    loss = tf.losses.mean_squared_error(labels=y, predictions=net_out)
    # net_out = post_processing(net_out)
    # print(attention[0].get_shape())
    # print(attention[1].get_shape())

    if model is not None:
        restore_op, restore_dict = tf.contrib.framework.assign_from_checkpoint(
            model + "/model.ckpt",
            tf.contrib.slim.get_variables_to_restore(),
            ignore_missing_vars=True)
        sess.run(restore_op, feed_dict=restore_dict)
        print("Restored session and reset global step")

    # net_in = tf.get_default_graph().get_tensor_by_name("Placeholder:0")
    # is_training = tf.get_default_graph().get_tensor_by_name("Placeholder_1:0")
    # try:
    #     net_out = tf.get_default_graph().get_tensor_by_name("output_map/Relu:0")
    # except:
    # net_out = tf.get_default_graph().get_tensor_by_name("output_map/conv2d/dropout/cond/Merge:0")

    mode = ("training", "test")
    testing = (False, True)

    for m in range(2):

        print("Results for", mode[m])

        b = Batch(root,
                  1,
                  testing=testing[m],
                  augment=False,
                  dataset="ENDOVIS",
                  include_unlabelled=False,
                  train_postprocessing=True)  # False for MSE
        false_pos = np.zeros((5, ), dtype=np.float32)
        false_neg = np.zeros((5, ), dtype=np.float32)
        true_pos = np.zeros((5, ), dtype=np.float32)
        rmse = np.zeros_like(true_pos)
        mae = np.zeros_like(rmse)
        counter = np.zeros_like(rmse)

        precision = lambda fp, tp: tp / (tp + fp)
        recall = lambda fn, tp: tp / (tp + fn)
        # exclude = ("test5", "test6")
        # exclude = ("test1", "test2", "test3", "test4")
        exclude = ()
        w_multiplier = 720. / 320.
        h_multiplier = 576. / 256.
        avr_loss = 0.

        for _ in range(len(b.img_list)):
            img, label, _, _ = b.get_batch()
            # if b.batch_instrument_count[0] == 1:
            #     continue
            skip = False
            for e in exclude:
                if e in b.name_list[0]:
                    skip = True
                    break
            if skip:
                continue
            # t_loss = sess.run(loss, feed_dict={net_in: img, is_training: False, y: label})
            # avr_loss += t_loss

            output, a1, a0, total_var = sess.run(
                [net_out, attention[1], attention[0], tv],
                feed_dict={
                    net_in: img,
                    is_training: False
                })
            # print(total_var)
            blur = output[0].copy()
            blur[:, :, :5] = cv2.GaussianBlur(blur[:, :, :5], (T + 1, T + 1),
                                              0)
            _, blur[:, :, :5] = nms(blur[:, :, :5])

            # if blur[:, :, 5:].std() < .01:
            if 1000 > total_var > 700:
                mask = cv2.addWeighted(
                    blur[:, :, 5:], 1,
                    cv2.GaussianBlur(blur[:, :, 5:], (T + 1, T + 1), 0), -1, 0)
                blur[:, :, 5:] += mask

            loc_pred = [[], [], [], [], []]
            loc_true = [[], [], [], [], []]

            k = 5
            for i in range(5):
                heatmap = blur[:, :, i].copy()
                for j in range(k):
                    _, _, _, max_loc = cv2.minMaxLoc(heatmap)
                    if max_loc[0] == max_loc[1] == 0:
                        break
                    loc_pred[i].append(max_loc)
                    y, x = max_loc
                    heatmap[x - 5:x + 5, y - 5:y + 5] = 0.

            for ch in range(10):
                _, _, _, max_loc = cv2.minMaxLoc(label[0][:, :, ch])
                if max_loc[0] != 0 and max_loc[1] != 0:
                    loc_true[ch % 5].append(max_loc)

            # print(loc_pred[0])
            # print(loc_true[0])
            # return

            candidates = [[], [], [], []]

            for idx, tmp in enumerate([(0, 2, 5), (1, 2, 6), (2, 3, 7),
                                       (3, 4, 8)]):

                joint_idx1, joint_idx2, connection_idx = tmp[0], tmp[1], tmp[2]

                matching_scores = np.zeros(
                    (len(loc_pred[joint_idx1]), len(loc_pred[joint_idx2])),
                    dtype=np.float32)

                for y, pt1 in enumerate(loc_pred[joint_idx1]):
                    for x, pt2 in enumerate(loc_pred[joint_idx2]):

                        matching_scores[y, x] = compute_integral(
                            pt1, pt2, blur[:, :, connection_idx])
                        # print("left2head", matching_scores)

                # print(matching_scores)
                row_idx, col_idx = linear_sum_assignment(-matching_scores)

                for a, c in zip(row_idx, col_idx):
                    candidates[idx].append(
                        (loc_pred[joint_idx1][a], loc_pred[joint_idx2][c]))

            # print(candidates)

            parsed = []
            for pairs in candidates[-1]:
                shaft, end = pairs
                for next_pairs in candidates[-2]:
                    head, shaft_next = next_pairs
                    if shaft[0] == shaft_next[0] and shaft[1] == shaft_next[1]:
                        parsed.append([head, shaft, end])

            # print("parsed top:", parsed2)

            for i, partial_pose in enumerate(parsed):
                head, _, _ = partial_pose
                for next_pairs in candidates[-3]:
                    right, head_next = next_pairs
                    if head[0] == head_next[0] and head[1] == head_next[1]:
                        parsed[i].insert(0, right)
                for next_pairs in candidates[-4]:
                    left, head_next = next_pairs
                    if head[0] == head_next[0] and head[1] == head_next[1]:
                        parsed[i].insert(0, left)
            # print(parsed)

            for i, pose in enumerate(parsed):
                if len(pose) < 5:
                    for _ in range(5 - len(pose)):
                        parsed[i].insert(0, ())

            parse_failed = False
            final_prediction = [[], [], [], [], []]
            if len(parsed) == 2:
                inst1, inst2 = parsed
                final_prediction = list(zip(inst1, inst2))
            elif len(parsed) == 1:
                final_prediction = parsed
            else:
                parse_failed = True
                for i, pair in enumerate(candidates):
                    # print(pair[0][0])
                    if len(pair) >= 2:
                        final_prediction[i] = [pair[0][0], pair[0][1]]
                    elif len(pair) == 1:
                        final_prediction[i] = [pair[0][0]]
                    else:
                        final_prediction[i] = []
                # print(final_prediction, "\n")

            # print(final_prediction)
            # print(loc_true)

            # return

            for k in range(5):
                try:
                    cost_matrix = np.zeros(
                        (len(loc_true[k]), len(final_prediction[k])),
                        dtype=np.float32)
                    pred = final_prediction[k]
                except IndexError:
                    # print(final_prediction[0])
                    cost_matrix = np.zeros(
                        (len(loc_true[k]), len(final_prediction[0][k])),
                        dtype=np.float32)
                    pred = [final_prediction[0][k]]

                # print(b.batch_instrument_count[0] - len(final_prediction[k]))
                for y, e_true in enumerate(loc_true[k]):
                    for x, e_pred in enumerate(pred):  # (final_prediction[k]):
                        # print(e_true, e_pred)
                        try:
                            cost_matrix[y, x] = ((e_pred[0] - e_true[0]) * h_multiplier) ** 2 + \
                                                ((e_pred[1] - e_true[1]) * w_multiplier) ** 2
                        except IndexError:
                            if len(final_prediction[0][k]) != 0:
                                cost_matrix[y, x] = 10000
                            else:
                                continue

                row_idx, col_idx = linear_sum_assignment(cost_matrix)

                if len(pred) < b.batch_instrument_count[0]:
                    false_neg[k] += abs(
                        len(pred) - b.batch_instrument_count[0])

                for r, c in zip(row_idx, col_idx):
                    # check true positive and additional false positive
                    mae[k] += np.sqrt(cost_matrix[r, c])
                    counter[k] += 1
                    if np.sqrt(cost_matrix[r, c]) < T:
                        rmse[k] += cost_matrix[r, c]
                        true_pos[k] += 1
                        # print(cost_matrix[r, c])
                    else:
                        false_pos[k] += 1

        f1 = lambda p, r: (2 * p * r) / (p + r)
        p = precision(false_pos, true_pos)
        r = recall(false_neg, true_pos)
        print(true_pos, false_neg, false_pos)
        print("RMSE", np.sqrt(rmse / true_pos))
        print("Precision", p)
        print("Recall", r)
        print("F1", f1(p, r))
        print("MEA", mae / counter)
        print("\n")
Пример #10
0
def main():
    # get_kaggle_credentials()
    # Download the data
    # get_data(constants.competition_name)

    img_dir = os.path.join(constants.competition_name, "train")
    label_dir = os.path.join(constants.competition_name, "train_masks")
    # Read label data
    df_train = pd.read_csv(
        os.path.join(constants.competition_name, "train_masks.csv"))
    ids_train = df_train["img"].map(lambda s: s.split(".")[0])

    x_train_filenames = []
    y_train_filenames = []
    for img_id in ids_train:
        x_train_filenames.append(os.path.join(img_dir,
                                              "{}.jpg".format(img_id)))
        y_train_filenames.append(
            os.path.join(label_dir, "{}_mask.gif".format(img_id)))

    x_train_filenames, x_val_filenames, y_train_filenames, y_val_filenames = train_test_split(
        x_train_filenames, y_train_filenames, test_size=0.2, random_state=42)

    num_train_examples = len(x_train_filenames)
    num_val_examples = len(x_val_filenames)

    print("Number of training examples: {}".format(num_train_examples))
    print("Number of validation examples: {}".format(num_val_examples))
    print("x_train_filenames: {}".format(x_train_filenames[:5]))
    print("y_train_filenames: {}".format(y_train_filenames[:5]))

    #visualize_samples(5, x_train_filenames, y_train_filenames)

    # Apply data augmentation to the training dataset - but NOT the validation one!
    tr_cfg = {
        "resize": [constants.img_shape[0], constants.img_shape[1]],
        "scale": 1 / 255.0,
        "hue_delta": 0.1,
        "horizontal_flip": True,
        "width_shift_range": 0.1,
        "height_shift_range": 0.1
    }
    tr_preprocessing_fn = functools.partial(dataloader._augment, **tr_cfg)

    val_cfg = {
        "resize": [constants.img_shape[0], constants.img_shape[1]],
        "scale": 1 / 255.0
    }
    val_preprocessing_fn = functools.partial(dataloader._augment, **val_cfg)

    train_ds = dataloader.get_baseline_dataset(x_train_filenames,
                                               y_train_filenames,
                                               preproc_fn=tr_preprocessing_fn,
                                               batch_size=constants.batch_size)
    val_ds = dataloader.get_baseline_dataset(x_val_filenames,
                                             y_val_filenames,
                                             preproc_fn=val_preprocessing_fn,
                                             batch_size=constants.batch_size)

    # Set up model!

    model = u_net.unet(constants.img_shape)
    model.compile(optimizer="adam",
                  loss=losses.bce_dice_loss,
                  metrics=[losses.dice_loss])
    model.summary()

    # Train!
    save_model_path = "/tmp/weights.hdf5"
    cp = tf.keras.callbacks.ModelCheckpoint(filepath=save_model_path,
                                            monitor="val_dice_loss",
                                            save_best_only=True,
                                            verbose=1)

    history = model.fit(
        train_ds,
        steps_per_epoch=int(
            np.ceil(num_train_examples / float(constants.batch_size))),
        epochs=constants.epochs,
        validation_data=val_ds,
        validation_steps=int(
            np.ceil(num_val_examples / float(constants.batch_size))),
        callbacks=[cp])

    dice = history.history['dice_loss']

    val_dice = history.history['val_dice_loss']

    loss = history.history['loss']
    val_loss = history.history['val_loss']

    epochs_range = range(constants.epochs)

    plt.figure(figsize=(16, 8))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, dice, label='Training Dice Loss')
    plt.plot(epochs_range, val_dice, label='Validation Dice Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Dice Loss')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')

    plt.show()
Пример #11
0
import torch
import os
import cv2
from torch.utils.data import Dataset
import torchvision.transforms as Transforms
import matplotlib.pyplot as plt
import u_net

# Check out how much memory is being used by a single image forward prop through network.
input_folder = '../Image_crops/'
image = cv2.imread(input_folder + 'slide001_core004_crop0024.png')
plt.imshow(image)
model = u_net.unet().float()
transform = Transforms.ToTensor()
image = transform(image / 255).float()

model.train()
model.forward(image.unsqueeze(0))

plt.show()
Пример #12
0
import data_load
import torch
import u_net

input_folder = '../Image_crops/'
target_folder = '../Map_crops/'

model = u_net.unet()
dataset = data_load.dataset(input_folder, target_folder, model)
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=10,
                                         shuffle=True,
                                         num_workers=1)
# print(next(enumerate(dataloader))[1][1].dtype)
print(next(enumerate(dataloader))[1][1])
# print(model.encoder0.layer.end=)
Пример #13
0
def train(input_t, output_map, alpha, max_it, root, batch_size, is_training, id, use_vat, use_pseudo_labels,
          use_mean_teacher, dataset):
    """
    :param input_t: input tensor
    :param output_map: output layer of the network
    :param alpha: placeholder for leaky relu
    :param max_it: maximum training iterations
    :param root: base directory that contains the images
    :param batch_size: batch size
    :param is_training: toggle training
    :param id: GPU id
    :param use_vat: Enable VAT
    :param use_pseudo_labels: Use pseudo labels
    :param use_mean_teacher: Use mean teacher
    :param dataset: Choose dataset
    :return:
    """

    h = 256 if dataset == "ENDOVIS" else 288
    w = 320 if dataset == "ENDOVIS" else 384
    num_parts = 5 if dataset == "ENDOVIS" else 4
    num_connections = 4 if dataset == "ENDOVIS" else 0

    # GPU Config
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=.95)

    # Set up placeholders
    y = tf.placeholder(tf.float32, shape=[None, h, w, num_parts + num_connections])
    lr = tf.placeholder(tf.float32)
    loss_mask = tf.placeholder(tf.float32, shape=[batch_size])

    # Loss
    if not use_mean_teacher:
        avr_loss = tf.losses.mean_squared_error(y, output_map,
                                                weights=tf.reshape(loss_mask,
                                                                   [batch_size, 1, 1, 1]))
    if use_mean_teacher:
        ema = tf.train.ExponentialMovingAverage(decay=.95)

        def ema_getter(getter, name, *args, **kwargs):
            var = getter(name, *args, **kwargs)
            ema_var = ema.average(var)
            return ema_var if ema_var else var

        tf.get_variable_scope().set_custom_getter(ema_getter)
        model_vars = tf.trainable_variables()
        output_student = output_map
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema.apply(model_vars))
        output_teacher, _ = unet(input_t, .9 if dataset == "RMIT" else .7, 3,
                                 num_parts + num_connections,
                                 is_training=is_training,
                                 features_root=64,
                                 alpha=alpha)
        output_teacher = tf.stop_gradient(output_teacher)
        avr_loss = batch_size / tf.reduce_sum(loss_mask) * \
                   tf.losses.mean_squared_error(y, output_student,
                                                weights=tf.reshape(loss_mask,
                                                                   [batch_size, 1, 1, 1]))
        m = tf.placeholder(tf.float32, shape=[])
        avr_loss = avr_loss + m * .1 * tf.losses.mean_squared_error(output_teacher, output_student)

    if use_vat:
        avr_loss = batch_size / tf.reduce_sum(loss_mask) * avr_loss + \
                   virtual_adversarial_loss(input_t, y, is_training=is_training, alpha=alpha)

    # Adam solver
    with tf.variable_scope("Adam", reuse=tf.AUTO_REUSE):
        opt = tf.train.AdamOptimizer(lr).minimize(avr_loss)

    # Start session and initialize weights
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                            allow_soft_placement=True,
                                            log_device_placement=True))
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(max_to_keep=10000)

    b_train = Batch(root, batch_size, dataset="ENDOVIS",
                    include_unlabelled=use_vat or use_mean_teacher or use_tvm,
                    pseudo_label=use_pseudo_labels)
    b_test = Batch(root, batch_size, dataset="ENDOVIS", include_unlabelled=False, testing=True, augment=False,
                   train_postprocessing=False)

    current_lr = 1e-3 
    print("Chosen lr:", current_lr)

    # if model_dir is not None:
    #     restore_op, restore_dict = tf.contrib.framework.assign_from_checkpoint(
    #         model_dir + "/model.ckpt",
    #         tf.contrib.slim.get_variables_to_restore(),
    #         ignore_missing_vars=True
    #     )
    #     sess.run(restore_op, feed_dict=restore_dict)
    #     print("Restored session")

    # save graph
    writer = tf.summary.FileWriter(logdir='logdir', graph=sess.graph)
    writer.flush()

    if use_vat:
        test_interval = 250
    else:
        test_interval = 200

    def sigmoid_schedule(global_step, warm_up_steps=20000):
        if global_step > warm_up_steps:
            return 1.

        return np.exp(-5. * (1. - (global_step / warm_up_steps)) ** 2)

    for i in range(max_it):

        imgs, targets, _, mask = b_train.get_batch()

        current_loss, net_out, _ = sess.run(
            [avr_loss, output_map, opt],
            feed_dict={input_t: imgs,
                       y: targets,
                       lr: current_lr,
                       is_training: True,
                       alpha: 1 / np.random.uniform(low=3, high=8),
                       loss_mask: mask,
                       m: sigmoid_schedule(i)
                       }
        )

        if i % 100 == 0:
            print("Current regression loss:", current_loss.sum())
            loc_pred = []
            loc_true = []
            for ch in range(num_parts):
                if b_train.batch_instrument_count[0] == 1:
                    _, _, _, m_loc1 = cv2.minMaxLoc(net_out[0, :, :, ch])
                    loc_pred.append(m_loc1)
                    _, _, _, m_loc2 = cv2.minMaxLoc(targets[0][:, :, ch])
                    loc_true.append(m_loc2)
                else:
                    pass

            print("For the first sample-> Predicted: {}    Ground Truth: {}\n".format(loc_pred, loc_true))

        # save model for evaluation
        if i % test_interval == 0 and i != 0:

            print("Testing at iteration", i, "...")
            dir2save = os.path.join("tmp" + str(i), "model.ckpt")
            save_path = saver.save(sess, dir2save)
            print("Saved model to", save_path)

    sess.close()