Ejemplo n.º 1
0
def main(argv):
    # 0. Initialize Cytomine client and job
    with CytomineJob.from_cli(argv) as cj:
        cj.job.update(status=Job.RUNNING,
                      progress=0,
                      statusComment="Initialisation...")

        # 1. Create working directories on the machine:
        # - WORKING_PATH/in: input images
        # - WORKING_PATH/out: output images
        # - WORKING_PATH/ground_truth: ground truth images
        base_path = "{}".format(os.getenv("HOME"))
        gt_suffix = "_lbl"
        working_path = os.path.join(base_path, str(cj.job.id))
        in_path = os.path.join(working_path, "in")
        out_path = os.path.join(working_path, "out")
        gt_path = os.path.join(working_path, "ground_truth")

        if not os.path.exists(working_path):
            os.makedirs(working_path)
            os.makedirs(in_path)
            os.makedirs(out_path)
            os.makedirs(gt_path)

        # 2. Download the images (first input, then ground truth image)
        cj.job.update(
            progress=1,
            statusComment="Downloading images (to {})...".format(in_path))
        image_group = ImageGroupCollection().fetch_with_filter(
            "project", cj.parameters.cytomine_id_project)

        input_images = [i for i in image_group if gt_suffix not in i.name]
        gt_images = [i for i in image_group if gt_suffix in i.name]

        for input_image in input_images:
            input_image.download(os.path.join(in_path, "{id}.tif"))

        for gt_image in gt_images:
            related_name = gt_image.name.replace(gt_suffix, '')
            related_image = [i for i in input_images if related_name == i.name]
            if len(related_image) == 1:
                gt_image.download(
                    os.path.join(gt_path,
                                 "{}.tif".format(related_image[0].id)))

        # 3. Call the image analysis workflow using the run script
        cj.job.update(progress=25, statusComment="Launching workflow...")
        command = "/usr/bin/xvfb-run java -Xmx6000m -cp /fiji/jars/ij-1.52d.jar ij.ImageJ --headless --console " \
                  "-macro macro.ijm \"input={}, output={}\"".format(in_path, out_path)
        return_code = call(command, shell=True,
                           cwd="/fiji")  # waits for the subprocess to return

        if return_code != 0:
            err_desc = "Failed to execute the ImageJ macro (return code: {})".format(
                return_code)
            cj.job.update(progress=50, statusComment=err_desc)
            raise ValueError(err_desc)

        # 4. Upload the annotation and labels to Cytomine (annotations are extracted from the mask using
        # the AnnotationExporter module)


#        for image in cj.monitor(input_images, start=60, end=80, period=0.1, prefix="Extracting and uploading polygons from masks"):
#            file = "{}.tif".format(image.id)
#            path = os.path.join(out_path, file)
#            data = io.imread(path)

# extract objects
#            slices = mask_to_objects_2d(data)

#            print("Found {} polygons in this image {}.".format(len(slices), image.id))

# upload
#            collection = AnnotationCollection()
#            for obj_slice in slices:
#                collection.append(Annotation(
#                    location=affine_transform(obj_slice.polygon, [1, 0, 0, -1, 0, image.height]).wkt,
#                    id_image=image.id, id_project=cj.parameters.cytomine_id_project, property=[
#                        {"key": "index", "value": str(obj_slice.label)}
#                    ]
#                ))
#            collection.save()

# 5. Compute the metrics
        cj.job.update(progress=80, statusComment="Computing metrics...")
        for image in cj.monitor(input_images,
                                start=80,
                                end=98,
                                period=0.1,
                                prefix="computing metrics"):
            afile = "{}.tif".format(image.id)
            pathi = os.path.join(in_path, afile)
            patho = os.path.join(out_path, afile)
            #            data = io.imread(path)

            metrics, params = computemetrics(pathi, patho, "TreTrc", '/tmp')
            print('metrics for ' + pathi)
            print(metrics)

        # TODO: compute metrics:
        # in /out: output files {id}.tiff
        # in /ground_truth: label files {id}.tiff

        cj.job.update(progress=99, statusComment="Cleaning...")
        #        for image in input_images:
        #            os.remove(os.path.join(in_path, "{}.tif".format(image.id)))

        cj.job.update(status=Job.TERMINATED,
                      progress=100,
                      statusComment="Finished.")
def main(argv):
    with CytomineJob.from_cli(argv) as cj:
        # prepare paths
        working_path = str(Path.home())
        data_path = os.path.join(working_path, "pred_data")
        if not os.path.exists(data_path):
            os.makedirs(data_path)

        model_filename = "model.pkl"

        cj.job.update(progress=5, statusComment="Download model ...")
        model_job = Job().fetch(cj.parameters.cytomine_model_job_id)
        attached_files = AttachedFileCollection(model_job).fetch_with_filter(
            "project", cj.project.id)
        if not (0 < len(attached_files) < 2):
            raise ValueError(
                "More or less than 1 file attached to the Job (found {} file(s))."
                .format(len(attached_files)))
        attached_file = attached_files[0]
        if attached_file.filename != model_filename:
            raise ValueError(
                "Expected model file name is '{}' (found: '{}').".format(
                    model_filename, attached_file.filename))
        model_path = os.path.join(working_path, model_filename)
        attached_file.download(model_path)

        # load model
        with open(model_path, "rb") as file:
            data = pickle.load(file)
            model = data["model"]
            classifier = data["classifier"]
            network = data["network"]
            reduction = data["reduction"]

        # load and dump annotations
        cj.job.update(progress=10, statusComment="Download annotations.")
        annotations = get_annotations(
            project_id=cj.parameters.cytomine_project_id,
            images=parse_list_or_none(cj.parameters.cytomine_images_ids),
            users=parse_list_or_none(cj.parameters.cytomine_users_ids),
            showWKT=True)

        cj.job.update(statusComment="Fetch crops.", progress=15)
        n_samples = len(annotations)
        x = np.zeros([n_samples], dtype=np.object)
        for i, annotation in cj.monitor(enumerate(annotations),
                                        start=15,
                                        end=40,
                                        prefix="Fetch crops",
                                        period=0.1):
            file_format = os.path.join(data_path, "{id}.png")
            if not annotation.dump(dest_pattern=file_format):
                raise ValueError("Download error for annotation '{}'.".format(
                    annotation.id))
            x[i] = file_format.format(id=annotation.id)

        available_nets = {
            MODEL_RESNET50, MODEL_VGG19, MODEL_VGG16, MODEL_INCEPTION_V3,
            MODEL_INCEPTION_RESNET_V2, MODEL_MOBILE, MODEL_DENSE_NET_201,
            MODEL_NASNET_LARGE, MODEL_NASNET_MOBILE
        }

        if network not in available_nets:
            raise ValueError(
                "Invalid value (='{}'} for parameter 'network'.".format(
                    network))
        if reduction not in {"average_pooling"}:
            raise ValueError(
                "Invalid value (='{}') for parameter 'reduction'.".format(
                    reduction))
        if classifier not in {"svm"}:
            raise ValueError(
                "Invalid value (='{}') for parameter 'classifier'.".format(
                    classifier))

        # prepare network
        cj.job.update(statusComment="Load neural network '{}'".format(network),
                      progress=40)
        features = PretrainedModelFeatures(model=network,
                                           layer="last",
                                           reduction=reduction,
                                           weights="imagenet")
        height, width, _ = features._get_input_shape(network)
        loader = ImageLoader(load_size_range=(height, height),
                             crop_size=height,
                             random_crop=False)

        cj.job.update(statusComment="Transform features.", progress=50)
        x_feat = batch_transform(loader,
                                 features,
                                 x,
                                 logger=cj.logger(start=50, end=70,
                                                  period=0.1),
                                 batch_size=128)

        cj.job.update(statusComment="Prediction with '{}'.".format(classifier),
                      progress=70)
        if hasattr(model, "n_jobs"):
            model.n_jobs = cj.parameters.n_jobs

        probas = None
        if hasattr(model, "predict_proba"):
            probas = model.predict_proba(x_feat)
            y_pred = model.classes_.take(np.argmax(probas, axis=1), axis=0)
        else:
            y_pred = model.predict(x_feat)

        cj.job.update(statusComment="Upload annotations.", progress=80)
        annotation_collection = AnnotationCollection()
        for i, annotation in cj.monitor(enumerate(annotations),
                                        start=80,
                                        end=100,
                                        period=0.1,
                                        prefix="Upload annotations"):
            annotation_collection.append(
                Annotation(location=annotation.location,
                           id_image=annotation.image,
                           id_project=annotation.project,
                           term=[int(y_pred[i])],
                           rate=float(probas[i])
                           if probas is not None else 1.0).save())
        annotation_collection.save()

        cj.job.update(statusComment="Finished.", progress=100)