Example #1
0
    def publish(self):
        callback = RequestParams.get('callback', None)
        if callback is None:
            return '{0}({1})'.format(callback, {'msg': 'Param %s is required' % 'callback'})
        if 'metric' not in RequestParams:
            return '{0}({1})'.format(callback, {'msg': 'Param %s is required' % 'metric'})
        if 'value' not in RequestParams:
            return '{0}({1})'.format(callback, {'msg': 'Param %s is required' % 'value'})
        if 'timestamp' not in RequestParams:
            return '{0}({1})'.format(callback, {'msg': 'Param %s is required' % 'timestamp'})

        metric = RequestParams.get('metric')
        value = RequestParams.get('value')
        timestamp = RequestParams.get('timestamp')

        from lib import app
        app = app.CarbonHttpApplication()
        queue = app.queue

        try:
            m = Metric({
                'metric': metric,
                'value': value,
                'timestamp': timestamp
            })
            m.enqueue(queue)
        except ValueError:
            return '{0}({1})'.format(callback, {'msg': 'Invalid Metric.'})

        return '{0}({1})'.format(callback, {'msg': 'Publish success.'})
Example #2
0
    def publish(self):
        if 'metric' not in RequestParams:
            raise CParamsMissing('Param %s is required' % 'metric')
        if 'value' not in RequestParams:
            raise CParamsMissing('Param %s is required' % 'value')
        if 'timestamp' not in RequestParams:
            raise CParamsMissing('Param %s is required' % 'timestamp')

        metric = RequestParams.get('metric')
        value = RequestParams.get('value')
        timestamp = RequestParams.get('timestamp')

        from lib import app
        app = app.CarbonHttpApplication()
        queue = app.queue

        try:
            m = Metric({
                'metric': metric,
                'value': value,
                'timestamp': timestamp
            })
            m.enqueue(queue)
            return self.make_result()
        except ValueError:
            raise CParamsError('Invalid Metric.')
Example #3
0
def main(config, Model, data, base_path, _run):
    log("Start " + ex.get_experiment_info()["name"])

    base_path = base_path + str(_run._id) + "/"
    config["base_path"] = base_path

    # Model
    model = Model(config)
    model.to(config["device"])
    model.eval()

    if config["model_state"] == "":
        raise Exception("A model needs to be loaded")
    model.load_state_dict(torch.load(config["model_state"]))

    os.makedirs(config["base_path"] + "preds")

    test_data = data("test", loss=config["loss_fn"], dev=config["device"])

    stats = []
    with torch.no_grad():
        for i in range(len(test_data)):  # Studying in Depth (first) dimension
            if i % 10 == 0:
                print("{}/{}".format(i, len(test_data)))
            X, Y, id_, W = test_data[i]

            pred = model(X)
            pred = pred[0].cpu().numpy()
            Y = Y.cpu().numpy()

            # Stats
            m = Metric(pred, Y)
            dice_res = m.dice()[0]
            haus_res = m.hausdorff()[0]
            islands_res = m.islands()[0]
            stats.append("{}\t{}\t{}\t{}\n".format(id_, dice_res, haus_res,
                                                   islands_res))

            _out = np.argmax(np.moveaxis(np.reshape(pred, (2, 18, 256, 256)),
                                         1, -1),
                             axis=0)
            nib.save(nib.Nifti1Image(_out, np.eye(4)),
                     config["base_path"] + "preds/" + id_ + "_pred.nii.gz")
            _out = np.argmax(np.moveaxis(np.reshape(Y, (2, 18, 256, 256)), 1,
                                         -1),
                             axis=0)
            nib.save(nib.Nifti1Image(_out, np.eye(4)),
                     config["base_path"] + "preds/" + id_ + "_label.nii.gz")
            _out = np.moveaxis(np.reshape(X.cpu().numpy(), (18, 256, 256)), 0,
                               -1)
            nib.save(nib.Nifti1Image(_out, np.eye(4)),
                     config["base_path"] + "preds/" + id_ + ".nii.gz")

    with open(config["base_path"] + "results", "w") as f:
        for s in stats:
            f.write(s)
Example #4
0
    def publish(self):
        if 'metrics' not in RequestParams:
            raise CParamsMissing('Param %s is required' % 'metrics')

        metrics = RequestParams.get('metrics')

        if type(metrics) != list:
            raise CParamsError('Param %s must be list' % 'metrics')

        from lib import app
        app = app.CarbonHttpApplication()
        queue = app.queue

        for metric_dict in metrics:
            try:
                m = Metric(metric_dict)
                m.enqueue(queue)
            except ValueError:
                raise CParamsError('Invalid Metric.')

        return self.make_result()
Example #5
0
            if te_i % 10 == 0:
                print("Masks generated: {}/{}".format(te_i, len(test_data)))

            X, Y, id_ = test_data[te_i]

            output = model(X)
            pred = output[0].cpu().numpy()  # BCDHW

            # Optional Post-processing
            if removeSmallIslands_thr != -1:
                pred = removeSmallIslands(pred, thr=removeSmallIslands_thr)

            if type(Y) != type(None):
                Y = Y.cpu().numpy()
                with open(outputPath + "stats.csv", "a") as f:
                    measures = Metric(pred, Y)
                    f.write("{},{},{},{}\n".format(
                        id_,
                        measures.dice()[0, 1],
                        measures.compactness()[0],
                        measures.hausdorff_distance()[0]))

            pred = pred[0]  # CDHW
            pred = np.argmax(pred, axis=0)  # DHW
            pred = np.moveaxis(pred, 0, -1)  # HWD

            # The filename will be the name of the last 3 folders
            # Ideally Study_Timepoint_ScanID_predMask.nii.gz
            filename = "_".join(id_[:-1].split("/")[-3:]) + "_predMask.nii.gz"

            nib.save(nib.Nifti1Image(pred, np.eye(4)), outputPath + filename)
def main(config, Model, data, base_path, _run):
    log("Start " + ex.get_experiment_info()["name"])

    base_path = base_path + str(_run._id) + "/"
    config["base_path"] = base_path

    # Data
    tr_data = data("train",
                   loss=config["loss_fn"],
                   brainmask=config["brainmask"],
                   overlap=config["overlap"],
                   dev=config["device"])
    val_data = data("validation",
                    loss=config["loss_fn"],
                    brainmask=config["brainmask"],
                    overlap=config["overlap"],
                    dev=config["device"])
    #tr_data = data("train", loss=config["loss_fn"], dev=config["device"])
    #val_data = data("validation", loss=config["loss_fn"], dev=config["device"])

    # Calculating compactness in the training data
    """
    comp_vals = []
    for tr_i in range(len(tr_data)):
        X, Y, id_, W = tr_data[tr_i]
        comp_vals.append(str(Metric(Y.detach().cpu().numpy(), None).compactness()[0][-1]))

    with open("compactness_results", "a") as f:
        f.write(",".join(comp_vals) + "\n")
    import sys
    sys.exit(1)
    """

    # Model
    model = Model(config)
    #model.cuda()
    model.to(config["device"])

    # Weight initialization
    def weight_init(m):
        if isinstance(m, torch.nn.Conv3d):
            config["initW"](m.weight)
            config["initB"](m.bias)

    model.apply(weight_init)

    # Save graph
    X, _, _, _ = tr_data[0]
    tb_path = base_path[:-1].split("/")
    tb_path = "/".join(tb_path[:-2]) + "/tensorboard/" + "_".join(tb_path[-2:])
    writer = SummaryWriter(tb_path)
    writer.add_graph(model, X)
    writer.close()

    # Test how long each operation take
    with torch.autograd.profiler.profile(use_cuda=True) as prof:
        model(X)
    with open(base_path + "profile", "w") as f:
        f.write(str(prof))

    # Create folder for saving the model and validation results
    if len(config["save_validation"]) > 0:
        os.makedirs(config["base_path"] + "val_evol")
    os.makedirs(config["base_path"] + "model")

    # Config
    ep = config["epochs"]
    bs = config["batch"]
    loss_fn = config["loss_fn"]
    opt = config["opt"](model.parameters(),
                        lr=config["lr"],
                        weight_decay=config["weight_decay"])
    lr_scheduler = config["lr_scheduler"]
    if not lr_scheduler is None:
        lr_scheduler.setOptimizer(opt)

    # Save model and optimizer's state dict
    param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)

    log("Number of parameters: " + str(param_num))
    with open(base_path + "state_dict", "w") as f:
        f.write(">> Model's state dict:\n")  # Important for debugging
        for param_tensor in model.state_dict():
            f.write(param_tensor + "\t" +
                    str(model.state_dict()[param_tensor].size()) + "\n")

        f.write(
            "\n>> Optimizer's state dict:\n")  # Important for reproducibility
        for var_name in opt.state_dict():
            f.write(var_name + "\t" + str(opt.state_dict()[var_name]) + "\n")

    # Load weights if necessary
    if config["model_state"] != "":
        log("Loading previous model")
        model.load_state_dict(torch.load(config["model_state"]))

    # Counters and flags
    e = 0  # Epoch counter
    it = 0  # Iteration counter
    keep_training = True  # Flag to stop training when overfitting

    log("Training")
    while e < ep and keep_training:
        model.train()

        tr_loss = 0
        tr_islands = 0
        tr_i = 0
        while tr_i < len(tr_data) and keep_training:
            X, Y, id_, W = tr_data[tr_i]

            output = model(X)
            pred = output[0]

            if W is None:
                tr_loss_tmp = loss_fn(output, Y, config)
            else:
                tr_loss_tmp = loss_fn(output, Y, config, W)
            tr_loss += tr_loss_tmp
            tr_islands += np.sum(Metric(pred.detach().cpu(), None).islands())

            # Optimization
            opt.zero_grad()
            tr_loss_tmp.backward()

            #writer = SummaryWriter(tb_path)
            #record_grads = ["bottleneck4.seq.2.weight", "conv1.weight", "block3.seq.4.weight"]
            #for n, p in model.named_parameters():
            #    if n in record_grads:
            #        writer.add_histogram(n, p.grad, e)
            #writer.close()

            #plot_grad_flow(model.named_parameters())
            opt.step()

            it += 1
            tr_i += 1

        tr_loss /= len(tr_data)
        tr_islands /= len(tr_data)

        #for n, p in named_parameters:
        #    if(p.requires_grad) and ("bias" not in n):
        #        layers.append(n)
        #        ave_grads.append(p.grad.abs().mean())
        #        max_grads.append(p.grad.abs().max()
        # Tensorboard summaries
        writer = SummaryWriter(tb_path)
        writer.add_scalar("tr_loss", tr_loss, e)
        writer.add_scalar("tr_islands", tr_islands, e)
        writer.close()

        log("Validation")
        val_loss = 0
        val_islands = 0
        val_dice = 0
        val_i = 0
        model.eval()
        with torch.no_grad():
            while val_i < len(val_data) and keep_training:
                X, Y, id_, W = val_data[val_i]

                output = model(X)
                pred = output[0]
                if W is None:
                    val_loss_tmp = loss_fn(output, Y, config)
                else:
                    val_loss_tmp = loss_fn(output, Y, config, W)
                val_loss += val_loss_tmp
                m = Metric(pred.cpu().numpy(), Y.cpu().numpy())
                val_islands += np.sum(m.islands())
                val_dice += m.dice()[:, 1]  # Lesion Dice

                if id_ in config["save_validation"]:
                    name = id_ + "_" + str(e)
                    pred = np.moveaxis(
                        np.moveaxis(
                            np.reshape(pred.cpu().numpy(), (2, 18, 256, 256)),
                            1, -1), 0, -1)
                    if config["save_npy"]:
                        np.save(config["base_path"] + "val_evol/" + name, pred)
                    pred = np.argmax(pred, axis=-1)
                    nib.save(
                        nib.Nifti1Image(pred, np.eye(4)),
                        config["base_path"] + "val_evol/" + name + ".nii.gz")

                val_i += 1

        val_loss /= len(val_data)
        val_islands /= len(val_data)
        val_dice /= len(val_data)

        # Tensorboard summaries
        writer = SummaryWriter(tb_path)
        writer.add_scalar("val_loss", val_loss, e)
        writer.add_scalar("val_islands", val_islands, e)
        writer.add_scalar("val_dice", val_dice, e)
        writer.close()

        log("Epoch: {}. Loss: {}. Val Loss: {}".format(e, tr_loss, val_loss))

        # Reduce learning rate if needed, and stop if limit is reached.
        if lr_scheduler != None:
            lr_scheduler.step(val_loss)
            #keep_training = lr_scheduler.limit_cnt > -1 # -1 -> stop training
            if lr_scheduler.limit_cnt < 0:
                keep_training = False
                lr_scheduler.limit_cnt = lr_scheduler.limit  # Needed if we run ex. more than once!

        # Save model after every epoch
        torch.save(model.state_dict(),
                   config["base_path"] + "model/model-" + str(e))
        if e > 4 and os.path.exists(config["base_path"] + "model/model-" +
                                    str(e - 5)):
            os.remove(config["base_path"] + "model/model-" + str(e - 5))

        e += 1

    log("Testing")
    test_data = data("test",
                     loss=config["loss_fn"],
                     brainmask=config["brainmask"],
                     overlap=config["overlap"],
                     dev=config["device"])
    if config["save_prediction_mask"] or config[
            "save_prediction_softmaxprob"] or config["save_prediction_logits"]:
        os.makedirs(config["base_path"] + "preds")

    results = {}
    results_post = {}
    model.eval()
    if config[
            "pc_name"] != "sampo-tipagpu1":  # Sampo seems to not deal well with multiprocess
        pool = multiprocessing.Pool(processes=PROCESSES)
    with torch.no_grad():
        # Assuming that batch_size is 1
        for test_i in range(len(test_data)):
            X, Y, id_, _ = test_data[test_i]
            output = model(X)
            pred = output[0].cpu().numpy()
            Y = Y.cpu().numpy()  # NBWHC

            if config["removeSmallIslands_thr"] == -1 and config[
                    "save_prediction_mask"]:
                _out = np.argmax(np.moveaxis(
                    np.reshape(pred, (2, 18, 256, 256)), 1, -1),
                                 axis=0)
                nib.save(nib.Nifti1Image(_out, np.eye(4)),
                         config["base_path"] + "preds/" + id_ + "_mask.nii.gz")
            if config["save_prediction_logits"] and len(output) > 1:
                logits = output[1].cpu().numpy()
                _out = np.moveaxis(
                    np.moveaxis(np.reshape(logits, (2, 18, 256, 256)), 1, -1),
                    0, -1)
                nib.save(
                    nib.Nifti1Image(_out, np.eye(4)),
                    config["base_path"] + "preds/" + id_ + "_logits.nii.gz")

            if config["save_prediction_softmaxprob"]:
                _out = np.moveaxis(
                    np.moveaxis(np.reshape(pred, (2, 18, 256, 256)), 1, -1), 0,
                    -1)
                nib.save(
                    nib.Nifti1Image(_out, np.eye(4)), config["base_path"] +
                    "preds/" + id_ + "_softmaxprob.nii.gz")

            if config["pc_name"] != "sampo-tipagpu1":
                results[id_] = pool.apply_async(Metric(pred, Y).all)
            else:
                results[id_] = Metric(pred, Y).all()

            # Results after post-processing
            if config["removeSmallIslands_thr"] != -1:
                pred_post = pred.copy()
                pred_post = removeSmallIslands(
                    pred_post, thr=config["removeSmallIslands_thr"])
                if config["pc_name"] != "sampo-tipagpu1":
                    results_post[id_] = pool.apply_async(
                        Metric(pred_post, Y).all)
                else:
                    results_post[id_] = Metric(pred_post, Y).all()
                if config["save_prediction_mask"]:
                    _out = np.argmax(np.moveaxis(
                        np.reshape(pred_post, (2, 18, 256, 256)), 1, -1),
                                     axis=0)
                    nib.save(
                        nib.Nifti1Image(_out, np.eye(4)),
                        config["base_path"] + "preds/" + id_ + "_mask.nii.gz")

    if config["pc_name"] != "sampo-tipagpu1":
        for k in results:
            results[k] = results[k].get()
    with open(config["base_path"] + "results.json", "w") as f:
        f.write(json.dumps(results))

    if config["removeSmallIslands_thr"] != -1:
        if config["pc_name"] != "sampo-tipagpu1":
            for k in results_post:
                results_post[k] = results_post[k].get()
        # Results after post-processing
        with open(config["base_path"] + "results-post.json", "w") as f:
            f.write(json.dumps(results_post))

    if config["pc_name"] != "sampo-tipagpu1":
        pool.close()
        pool.join()
        pool.terminate()

    log("End")
Example #7
0
    def fit(self, tr_loader, val_loader, epochs, val_interval, loss,
            val_metrics, opt):
        """Trains the NN.

           Args:
            `tr_loader`: DataLoader with the training set.
            `val_loader`: DataLoader with the validaiton set.
            `epochs`: Number of epochs to train the model. If 0, no train.
            `val_interval`: After how many epochs to perform validation.
            `loss`: Name of the loss function.
            `val_metrics`: Which metrics to measure at validation time.
            `opt`: Optimizer.
        """
        t0 = time.time()
        e = 1
        # Expected classes of our dataset
        measure_classes = {0: "background", 1: "contra", 2: "R_hemisphere"}
        # Which classes will be reported during validation
        measure_classes_mean = np.array([1, 2])

        while e <= epochs:
            self.train()

            tr_loss = 0
            for (tr_i), (X, Y, info, W) in enumerate(tr_loader):
                X = [x.to(self.device) for x in X]
                Y = [y.to(self.device) for y in Y]
                W = [w.to(self.device) for w in W]

                output = self(X)
                pred = output[0]

                tr_loss_tmp = loss(output, Y, W)
                tr_loss += tr_loss_tmp

                # Optimization
                opt.zero_grad()
                tr_loss_tmp.backward()
                opt.step()

            tr_loss /= len(tr_loader)

            if len(val_loader) != 0 and e % val_interval == 0:
                log("Validation", self.out_path)
                self.eval()

                val_loss = 0
                # val_scores stores all needed metrics for assessing validation
                val_scores = np.zeros(
                    (len(val_metrics), len(val_loader), len(measure_classes)))
                Measure = Metric(val_metrics,
                                 onehot=softmax2onehot,
                                 classes=measure_classes,
                                 multiprocess=False)

                with torch.no_grad():
                    for (val_i), (X, Y, info, W) in enumerate(val_loader):
                        X = [x.to(self.device) for x in X]
                        Y = [y.to(self.device) for y in Y]
                        W = [w.to(self.device) for w in W]

                        output = self(X)
                        val_loss_tmp = loss(output, Y, W)
                        val_loss += val_loss_tmp

                        y_true_cpu = Y[0].cpu().numpy()
                        y_pred_cpu = output[0].cpu().numpy()

                        # Record all needed metrics
                        # If batch_size > 1, Measure.all() returns an avg.
                        tmp_res = Measure.all(y_pred_cpu, y_true_cpu)
                        for i, m in enumerate(val_metrics):
                            val_scores[i, val_i] = tmp_res[m]

                # Validation loss
                val_loss /= len(val_loader)
                val_str = " Val Loss: {}".format(val_loss)

                # val_metrics shape: num_metrics x num_batches x num_classes
                for i, m in enumerate(val_metrics):
                    # tmp shape: num_classes (averaged over num_batches when val != -1)
                    tmp = np.array(Measure._getMean(val_scores[i]))

                    # Mean validation value in metric m (all interesting classes)
                    tmp_val = tmp[measure_classes_mean]
                    # Note: if tmp_val is NaN, it means that the classes I am
                    # interested in (check lib/data/whatever, measure_classes_mean)
                    # were not found in the validation set.
                    tmp_val = np.mean(tmp_val[tmp_val != -1])
                    val_str += ". Val " + m + ": " + str(tmp_val)

            else:
                val_str = ""

            eta = " ETA: " + datetime.fromtimestamp(
                time.time() + (epochs - e) *
                (time.time() - t0) / e).strftime("%Y-%m-%d %H:%M:%S")
            log("Epoch: {}. Loss: {}.".format(e, tr_loss) + val_str + eta,
                self.out_path)

            # Save model after every epoch
            torch.save(
                self.state_dict(),
                self.out_path + "model/MedicDeepLabv3Plus-model-" + str(e))
            if e > 1 and os.path.exists(self.out_path +
                                        "model/MedicDeepLabv3Plus-model-" +
                                        str(e - 1)):
                os.remove(self.out_path + "model/MedicDeepLabv3Plus-model-" +
                          str(e - 1))

            e += 1
Example #8
0
    def evaluate(self, test_loader, metrics, remove_islands, save_output=True):
        """Tests/Evaluates the NN.

           Args:
            `test_loader`: DataLoader containing the test set. Batch_size = 1.
            `metrics`: Metrics to measure.
            `save_output`: (bool) whether to save the output segmentations.
            `remove_islands`: (bool) whether to apply post-processing.
        """

        # Expected classes of our dataset
        measure_classes = {0: "background", 1: "contra", 2: "R_hemisphere"}

        results = {}
        self.eval()
        Measure = Metric(metrics,
                         onehot=sigmoid2onehot,
                         classes=measure_classes,
                         multiprocess=True)

        # Pool to store pieces of output that will be put together
        # before evaluating the whole image.
        # This is useful when the entire image doesn't fit into mem.
        with torch.no_grad():
            for (test_i), (X, Y, info, W) in enumerate(test_loader):
                print("{}/{}".format(test_i + 1, len(test_loader)))
                X = [x.to(self.device) for x in X]
                Y = [y.to(self.device) for y in Y]
                W = [w.to(self.device) for w in W]
                id_ = info["id"][0]

                output = self(X)

                y_pred_cpu = output[0].cpu().numpy()
                y_true_cpu = Y[0].cpu().numpy()

                if remove_islands:
                    y_pred_cpu = removeSmallIslands(y_pred_cpu, thr=20)

                # Predictions (and GT) separate the two hemispheres
                # combineLabels will combine these such that it creates
                # brainmask and contra-hemisphere ROIs instead of
                # two different hemisphere ROIs.
                y_pred_cpu = combineLabels(y_pred_cpu)
                # If GT was provided it measures the performance
                if len(y_true_cpu.shape) > 1:
                    y_true_cpu = combineLabels(y_true_cpu)

                    results[id_] = Measure.all(y_pred_cpu, y_true_cpu)

                test_loader.dataset.save(y_pred_cpu[0], info,
                                         self.out_path + id_)

        # Gather results (multiprocessing)
        for k in results:
            results[k] = results[k].get()

        if len(results) > 0:
            with open(self.out_path + "stats.json", "w") as f:
                f.write(json.dumps(results))

        # If we are using multiprocessing we need to close the pool
        Measure.close()