Esempio n. 1
0
    def draw_inference_from_image(self):

        log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
        # Plugin initialization for specified device and load extensions library if specified
        log.info("Creating Inference Engine")
        ie = IECore()
        if self.extention_lib_path and 'CPU' in self.device:
            ie.add_extension(self.extention_lib_path, "CPU")
        # Read IR
        log.info("Loading network files:\n\t{}\n\t{}".format(self.model_xml, self.model_path))
        net = IENetwork(model=self.model_xml, weights=self.model_path)

        if "CPU" in self.device:
            supported_layers = ie.query_network(net, "CPU")
            not_supported_layers = [l for l in net.layers.keys() if l not in supported_layers]
            if len(not_supported_layers) != 0:
                log.error("Following layers are not supported by the plugin for specified device {}:\n {}".
                          format(self.device, ', '.join(not_supported_layers)))
                log.error("Please try to specify cpu extensions library path in config")
                sys.exit(1)

        assert len(net.inputs.keys()) == 1, "Sample supports only single input topologies"
        assert len(net.outputs) == 1, "Sample supports only single output topologies"

        log.info("Preparing input blobs")
        input_blob = next(iter(net.inputs))
        out_blob = next(iter(net.outputs))
        net.batch_size = 1

        # Read and pre-process input images
        n, c, h, w = net.inputs[input_blob].shape
        
        
        image = cv2.imread(self.input_stream)
        initial_h, initial_w = image.shape[:2]

        if image.shape[:-1] != (h, w):
            log.warning("Image {} is resized from {} to {}".format(self.input_stream, image.shape[:-1], (h, w)))
            input_image = cv2.resize(image, (w, h))
        input_image = input_image.transpose((2, 0, 1))  # Change data layout from HWC to CHW

        # Loading model to the plugin
        log.info("Loading model to the plugin")
        exec_net = ie.load_network(network=net, device_name=self.device)

        if self.labels:
            with open(self.labels, 'r') as f:
                labels_map = [x.strip() for x in f]
        else:
            labels_map = None

        # Start sync inference
        log.info("Starting inference in synchronous mode")
        res = exec_net.infer(inputs={input_blob: input_image})

        # Processing output blob
        log.info("Processing output blob")
        res = res[out_blob]
        
        detections = list()
        for obj in res[0][0]:
            if obj[2] > self.prob_thresh:
                detection_data = dict()
                xmin = int(obj[3] * initial_w)
                ymin = int(obj[4] * initial_h)
                xmax = int(obj[5] * initial_w)
                ymax = int(obj[6] * initial_h)
                class_id = int(obj[1])
                detection_data['class'] = class_id
                detection_data['bbox'] = [(xmin, ymin), (xmax, ymax)]
                detections.append(detection_data)
                # cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)
                # Draw box and label\class_id
                color = (min(class_id * 12.5, 255),
                            min(class_id * 7, 255),
                            min(class_id * 5, 255))
                cv2.rectangle(image, (xmin, ymin), (xmax, ymax),
                                color, 2)
                det_label = labels_map[class_id] if labels_map else str(class_id)
                cv2.putText(image, det_label + ' ' + str(round(obj[2] * 100, 1)) + ' %', (xmin, ymin - 7),
                            cv2.FONT_HERSHEY_COMPLEX, 0.6, color, 1)
        # comment next two lines to stop rendering detection result 
        cv2.imshow("Detection Result(s)", image)
        cv2.waitKey(0)
        return detections
Esempio n. 2
0
def main():
    log.basicConfig(format="[ %(levelname)s ] %(message)s",
                    level=log.INFO,
                    stream=sys.stdout)
    args = build_argparser().parse_args()
    model_xml = args.model
    model_bin = os.path.splitext(model_xml)[0] + ".bin"

    preprocess_times = collections.deque()
    infer_times = collections.deque()
    postprocess_times = collections.deque()

    ROIfile = open("ROIs.txt", "w")
    # output stored here, view with ROIviewer

    # Plugin initialization for specified device and load extensions library if specified
    log.info("Initializing plugin for {} device...".format(args.device))
    plugin = IEPlugin(device=args.device, plugin_dirs=args.plugin_dir)
    if args.cpu_extension and 'CPU' in args.device:
        plugin.add_cpu_extension(args.cpu_extension)

    # Read IR
    log.info("Reading IR...")
    net = IENetwork(model=model_xml, weights=model_bin)

    if plugin.device == "CPU":
        supported_layers = plugin.get_supported_layers(net)
        not_supported_layers = [
            l for l in net.layers.keys() if l not in supported_layers
        ]
        if len(not_supported_layers) != 0:
            log.error(
                "Following layers are not supported by the plugin for specified device {}:\n {}"
                .format(plugin.device, ', '.join(not_supported_layers)))
            log.error(
                "Please try to specify cpu extensions library path in demo's command line parameters using -l "
                "or --cpu_extension command line argument")
            sys.exit(1)

    #Set Batch Size
    net.batch_size = args.b
    batchSize = net.batch_size
    frameLimit = args.fr
    assert len(
        net.inputs.keys()) == 1, "Demo supports only single input topologies"
    assert len(net.outputs) == 1, "Demo supports only single output topologies"
    input_blob = next(iter(net.inputs))
    out_blob = next(iter(net.outputs))
    log.info("Loading IR to the plugin...")
    exec_net = plugin.load(network=net, num_requests=2)

    # Read and pre-process input image
    n, c, h, w = net.inputs[input_blob].shape
    output_dims = net.outputs[out_blob].shape
    infer_width = w
    infer_height = h
    num_channels = c
    channel_size = infer_width * infer_height
    full_image_size = channel_size * num_channels

    print("inputdims=", w, h, c, n)
    print("outputdims=", output_dims[3], output_dims[2], output_dims[1],
          output_dims[0])
    if int(output_dims[3]) > 1:
        print("SSD Mode")
        output_mode = output_mode_type.SSD_MODE
    else:
        print("Single Classification Mode")
        output_mode = CLASSIFICATION_MODE
        output_data_size = int(output_dims[2]) * int(output_dims[1]) * int(
            output_dims[0])
    del net
    if args.input == 'cam':
        input_stream = 0
    else:
        input_stream = args.input
        assert os.path.isfile(args.input), "Specified input file doesn't exist"
    if args.labels:
        with open(args.labels, 'r') as f:
            labels_map = [x.strip() for x in f]
    else:
        labels_map = None

    cap = cv2.VideoCapture(input_stream)

    cur_request_id = 0
    next_request_id = 1

    is_async_mode = True
    if (is_async_mode == True):
        log.info("Starting inference in async mode...")
    else:
        log.info("Starting inference in sync mode...")

    render_time = 0

    framenum = 0
    process_more_frames = True
    frames_in_output = batchSize

    while process_more_frames:
        time1 = time.time()
        for mb in range(0, batchSize):
            ret, frame = cap.read()
            if not ret or (framenum >= frameLimit):
                process_more_frames = False
                frames_in_output = mb

            if (not process_more_frames):
                break

            # convert image to blob
            # Fill input tensor with planes. First b channel, then g and r channels
            in_frame = cv2.resize(frame, (w, h))
            in_frame = in_frame.transpose(
                (2, 0, 1))  # Change data layout from HWC to CHW

        time2 = time.time()
        diffPreProcess = time2 - time1
        if process_more_frames:
            preprocess_times.append(diffPreProcess * 1000)

            # Main sync point:
            # in the truly Async mode we start the NEXT infer request, while waiting for the CURRENT to complete
            # in the regular mode we start the CURRENT request and immediately wait for it's completion
            inf_start = time.time()
            if is_async_mode:
                exec_net.start_async(request_id=next_request_id,
                                     inputs={input_blob: in_frame})
            else:
                exec_net.start_async(request_id=cur_request_id,
                                     inputs={input_blob: in_frame})
            if exec_net.requests[cur_request_id].wait(-1) == 0:
                inf_end = time.time()
                det_time = inf_end - inf_start
                infer_times.append(det_time * 1000)
                time1 = time.time()

                for mb in range(0, batchSize):
                    if (framenum >= frameLimit):
                        process_more_frames = False
                        break

            # Parse detection results of the current request
                    res = exec_net.requests[cur_request_id].outputs[out_blob]
                    for obj in res[0][0]:
                        # Write into ROIs.txt only objects when probability more than specified threshold
                        if obj[2] > args.prob_threshold:
                            confidence = obj[2]
                            locallabel = obj[1] - 1
                            print(str(0),
                                  str(framenum),
                                  str(locallabel),
                                  str(confidence),
                                  str(obj[3]),
                                  str(obj[4]),
                                  str(obj[5]),
                                  str(obj[6]),
                                  file=ROIfile)

                    sys.stdout.write("\rframenum:" + str(framenum + 1))
                    sys.stdout.flush()
                    render_start = time.time()
                    framenum = framenum + 1

                time2 = time.time()
                diffPostProcess = time2 - time1
                postprocess_times.append(diffPostProcess * 1000)

            if is_async_mode:
                cur_request_id, next_request_id = next_request_id, cur_request_id

    print("\n")
    preprocesstime = 0
    inferencetime = 0
    postprocesstime = 0

    for obj in preprocess_times:
        preprocesstime += obj
    for obj in infer_times:
        inferencetime += obj
    for obj in postprocess_times:
        postprocesstime += obj

    print("Preprocess: {:.2f} ms/frame".format(
        preprocesstime / (len(preprocess_times) * batchSize)))
    print("Inference: {:.2f} ms/frame ".format(inferencetime /
                                               (len(infer_times) * batchSize)))
    print("Postprocess: {:.2f} ms/frame".format(
        postprocesstime / (len(postprocess_times) * batchSize)))

    del exec_net
    del plugin
def main():
    log.basicConfig(format='[ %(levelname)s ] %(message)s',
                    level=log.INFO,
                    stream=sys.stdout)
    args = parse_args()

    # ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
    log.info('Creating Inference Engine')
    ie = IECore()

    # ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation------------------------------
    log.info(
        f'Loading the network using ngraph function with weights from {args.model}'
    )
    ngraph_function = create_ngraph_function(args)
    net = IENetwork(ngraph.impl.Function.to_capsule(ngraph_function))

    # ---------------------------Step 3. Configure input & output----------------------------------------------------------
    log.info('Configuring input and output blobs')
    # Get names of input and output blobs
    input_blob = next(iter(net.input_info))
    out_blob = next(iter(net.outputs))

    # Set input and output precision manually
    net.input_info[input_blob].precision = 'U8'
    net.outputs[out_blob].precision = 'FP32'

    # Set a batch size to a equal number of input images
    net.batch_size = len(args.input)

    # ---------------------------Step 4. Loading model to the device-------------------------------------------------------
    log.info('Loading the model to the plugin')
    exec_net = ie.load_network(network=net, device_name=args.device)

    # ---------------------------Step 5. Create infer request--------------------------------------------------------------
    # load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
    # instance which stores infer requests. So you already created Infer requests in the previous step.

    # ---------------------------Step 6. Prepare input---------------------------------------------------------------------
    n, c, h, w = net.input_info[input_blob].input_data.shape
    input_data = np.ndarray(shape=(n, c, h, w))

    for i in range(n):
        image = read_image(args.input[i])

        light_pixel_count = np.count_nonzero(image > 127)
        darK_pixel_count = np.count_nonzero(image < 127)
        is_light_image = (light_pixel_count - darK_pixel_count) > 0

        if is_light_image:
            log.warning(
                f'Image {args.input[i]} is inverted to white over black')
            image = cv2.bitwise_not(image)

        if image.shape != (h, w):
            log.warning(
                f'Image {args.input[i]} is resized from {image.shape} to {(h, w)}'
            )
            image = cv2.resize(image, (w, h))

        input_data[i] = image

# ---------------------------Step 7. Do inference----------------------------------------------------------------------
    log.info('Starting inference in synchronous mode')
    res = exec_net.infer(inputs={input_blob: input_data})

    # ---------------------------Step 8. Process output--------------------------------------------------------------------
    # Generate a label list
    if args.labels:
        with open(args.labels, 'r') as f:
            labels = [line.split(',')[0].strip() for line in f]

    res = res[out_blob]

    for i in range(n):
        probs = res[i]
        # Get an array of args.number_top class IDs in descending order of probability
        top_n_idexes = np.argsort(probs)[-args.number_top:][::-1]

        header = 'classid probability'
        header = header + ' label' if args.labels else header

        log.info(f'Image path: {args.input[i]}')
        log.info(f'Top {args.number_top} results: ')
        log.info(header)
        log.info('-' * len(header))

        for class_id in top_n_idexes:
            probability_indent = ' ' * (len('classid') - len(str(class_id)) +
                                        1)
            label_indent = ' ' * (len('probability') -
                                  8) if args.labels else ''
            label = labels[class_id] if args.labels else ''
            log.info(
                f'{class_id}{probability_indent}{probs[class_id]:.7f}{label_indent}{label}'
            )
        log.info('')


# ----------------------------------------------------------------------------------------------------------------------
    log.info(
        'This sample is an API example, '
        'for any performance measurements please use the dedicated benchmark_app tool\n'
    )
    return 0
def main():
    log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
    args = build_argparser().parse_args()
    model_xml = args.model
    model_bin = os.path.splitext(model_xml)[0] + ".bin"

    # Plugin initialization for specified device and load extensions library if specified
    plugin = IEPlugin(device=args.device, plugin_dirs=args.plugin_dir)

    # Configure plugin to support dynamic batch size
    plugin.set_config({"DYN_BATCH_ENABLED": "YES"})

    # Load cpu_extensions library if specified
    if args.cpu_extension and 'CPU' in args.device:
        plugin.add_cpu_extension(args.cpu_extension)

    # Read IR
    log.info("Loading network files:\n\t{}\n\t{}".format(model_xml, model_bin))
    net = IENetwork(model=model_xml, weights=model_bin)

    # Check for unsupported layers if the device is 'CPU'
    if plugin.device == "CPU":
        unsupported_layers = [layer for layer in net.layers if layer not in plugin.get_supported_layers(net)]
        if len(unsupported_layers) != 0:
            log.error("Following layers are not supported by the plugin for specified device {}:\n {}".
                      format(plugin.device, ', '.join(unsupported_layers)))
            log.error("Please try to specify cpu extensions library path in sample's command line parameters using -l "
                      "or --cpu_extension command line argument")
            sys.exit(1)

    assert len(net.inputs.keys()) == 1, "Sample supports only single input topologies"

    log.info("Preparing input blobs")
    input_blob = next(iter(net.inputs))

    # Set max batch size
    inputs_count = len(args.input)
    if args.max_batch < inputs_count:
        log.warning("Defined max_batch size {} less than input images count {}."
                    "\n\t\t\tInput images count will be used as max batch size".format(args.max_batch, inputs_count))
    net.batch_size = max(args.max_batch, inputs_count)

    # Create numpy array for the max_batch size images
    n, c, h, w = net.inputs[input_blob].shape
    images = np.zeros(shape=(n, c, h, w))

    # Read and pre-process input images
    for i in range(inputs_count):
        image = cv2.imread(args.input[i])
        if image.shape[:-1] != (h, w):
            log.warning("Image {} is resized from {} to {}".format(args.input[i], image.shape[:-1], (h, w)))
            image = cv2.resize(image, (w, h))
        image = image.transpose((2, 0, 1))  # Change data layout from HWC to CHW
        images[i] = image
    log.info("Batch size is {}".format(n))

    # Loading model to the plugin
    log.info("Loading model to the plugin")
    exec_net = plugin.load(network=net)
    del net

    def infer():
        for i in range(args.number_iter):
            t0 = time()
            exec_net.infer(inputs={input_blob: images})
            infer_time.append((time() - t0) * 1000)
        log.info("Average running time of one iteration: {} ms".format(np.average(np.asarray(infer_time))))
        if args.perf_counts:
            perf_counts = exec_net.requests[0].get_perf_counts()
            log.info("Performance counters:")
            print("{:<70} {:<15} {:<15} {:<15} {:<10}".format('name', 'layer_type', 'exet_type', 'status',
                                                              'real_time, us'))
            for layer, stats in perf_counts.items():
                print("{:<70} {:<15} {:<15} {:<15} {:<10}".format(layer, stats['layer_type'], stats['exec_type'],
                                                                  stats['status'], stats['real_time']))

    # Start sync inference with full batch size
    log.info(
        "Starting inference with full batch {} ({} iterations)".format(n, args.number_iter)
    )
    infer_time = []
    infer()

    # Set batch size dynamically for the infer request and start sync inference
    infer_time = []
    exec_net.requests[0].set_batch(inputs_count)
    log.info("Starting inference with dynamically defined batch {} for the 2nd infer request ({} iterations)".format(
        inputs_count, args.number_iter))
    infer()

    del exec_net
    del plugin
    args = build_argparser().parse_args()

    model_xml = args.model
    model_bin = os.path.splitext(model_xml)[0] + ".bin"

    # Plugin initialization for Movidius stick
    plugin = IEPlugin(device="MYRIAD")

    # Read IR
    log.info("Loading network files:\n\t{}\n\t{}".format(model_xml, model_bin))
    net = IENetwork(model=model_xml, weights=model_bin)

    log.info("Preparing input blobs")
    input_blob = next(iter(net.inputs))
    out_blob = next(iter(net.outputs))
    net.batch_size = len(args.image)  # Should be 1

    # Read and pre-process input images
    image = cv2.imread(args.image).astype(
        np.float16)  # Will have to load in range [0-1] for resizing prob maps?
    image = image.transpose((2, 0, 1))  # Change data layout from HWC to CHW

    # Reshape input layer for image
    net.reshape(
        {input_blob: (1, image.shape[0], image.shape[1], image.shape[2])})

    # Loading model to the plugin
    log.info("Loading model to the plugin")
    exec_net = plugin.load(network=net)
    del net
Esempio n. 6
0
    s, e = 0, 0
    for i in range(num_batches):
        s, e = i * batch_size, (i + 1) * batch_size
        batch_data_dict = {k: v[s:e] for k, v in data_dict.items()}
        out[s:e] = f(batch_data_dict)
    if e < len(out):
        batch_data_dict = {k: v[e:] for k, v in data_dict.items()}
        out[e:] = f(batch_data_dict)


modelname = 'resources/networks/mars-small128'

# OV configuration
ov_net = IENetwork(model=modelname + '.xml', weights=modelname + '.bin')
ov_net.batch_size = all_batch_size
ov_plugin = IEPlugin(device='CPU')

# TF configuration
tf_session = tf.Session()
with tf.gfile.GFile(modelname + '.pb', 'rb') as gfile:
    tf_graph = tf.GraphDef()
    tf_graph.ParseFromString(gfile.read())
tf.import_graph_def(tf_graph, name='net')
tf_input_node = tf.get_default_graph().get_tensor_by_name('net/images:0')
tf_output_node = tf.get_default_graph().get_tensor_by_name('net/features:0')

# ?x128x64x3
testinput = np.random.random_sample((all_batch_size, 128, 64, 3))
testinput2 = testinput[:, :, :, ::-1]
print(testinput - testinput2)
Esempio n. 7
0
def main():
    args = parse_arguments()

    # --------------------------------- 1. Load Plugin for inference engine ---------------------------------
    logger.info("Creating Inference Engine")
    ie = IECore()

    if 'CPU' in args.target_device:
        if args.path_to_extension:
            ie.add_extension(args.path_to_extension, "CPU")
        if args.number_threads is not None:
            ie.set_config({'CPU_THREADS_NUM': str(args.number_threads)}, "CPU")
    elif 'GPU' in args.target_device:
        if args.path_to_cldnn_config:
            ie.set_config({'CONFIG_FILE': args.path_to_cldnn_config}, "GPU")
            logger.info("GPU extensions is loaded {}".format(
                args.path_to_cldnn_config))
    else:
        raise AttributeError(
            "Device {} do not support of 3D convolution. "
            "Please use CPU, GPU or HETERO:*CPU*, HETERO:*GPU*")

    logger.info("Device is {}".format(args.target_device))
    version = ie.get_versions(args.target_device)[args.target_device]
    version_str = "{}.{}.{}".format(version.major, version.minor,
                                    version.build_number)
    logger.info("Plugin version is {}".format(version_str))

    # --------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ---------------------

    xml_filename = os.path.abspath(args.path_to_model)
    bin_filename = os.path.abspath(os.path.splitext(xml_filename)[0] + '.bin')

    ie_network = IENetwork(xml_filename, bin_filename)

    input_info = ie_network.inputs
    if len(input_info) == 0:
        raise AttributeError('No inputs info is provided')
    elif len(input_info) != 1:
        raise AttributeError("only one input layer network is supported")

    input_name = next(iter(input_info))
    out_name = next(iter(ie_network.outputs))

    if args.shape:
        logger.info("Reshape of network from {} to {}".format(
            input_info[input_name].shape, args.shape))
        ie_network.reshape({input_name: args.shape})
        input_info = ie_network.inputs

    # ---------------------------------------- 4. Preparing input data ----------------------------------------
    logger.info("Preparing inputs")

    if len(input_info[input_name].shape) != 5:
        raise AttributeError(
            "Incorrect shape {} for 3d convolution network".format(args.shape))

    n, c, d, h, w = input_info[input_name].shape
    ie_network.batch_size = n

    if not os.path.exists(args.path_to_input_data):
        raise AttributeError("Path to input data: '{}' does not exist".format(
            args.path_to_input_data))

    input_type = get_input_type(args.path_to_input_data)
    is_nifti_data = (input_type == NIFTI_FILE or input_type == NIFTI_FOLDER)

    if input_type == NIFTI_FOLDER:
        series_name = find_series_name(args.path_to_input_data)
        original_data, data_crop, affine, original_size, bbox = \
            read_image(args.path_to_input_data, data_name=series_name, sizes=(d, h, w),
                       mri_sequence_order=args.mri_sequence, full_intensities_range=args.full_intensities_range)

    elif input_type == NIFTI_FILE:
        original_data, data_crop, affine, original_size, bbox = \
            read_image(args.path_to_input_data, data_name=args.path_to_input_data, sizes=(d, h, w), is_series=False,
                       mri_sequence_order=args.mri_sequence, full_intensities_range=args.full_intensities_range)
    else:
        data_crop = np.zeros(shape=(n, c, d, h, w), dtype=np.float)
        im_seq = ImageSequence.Iterator(Image.open(args.path_to_input_data))
        for i, page in enumerate(im_seq):
            im = np.array(page).reshape(h, w, c)
            for channel in range(c):
                data_crop[:, channel, i, :, :] = im[:, :, channel]
        original_data = data_crop
        original_size = original_data.shape[-3:]

    test_im = {input_name: data_crop}

    # ------------------------------------- 4. Loading model to the plugin -------------------------------------
    logger.info("Loading model to the plugin")
    executable_network = ie.load_network(network=ie_network,
                                         device_name=args.target_device)
    del ie_network

    # ---------------------------------------------- 5. Do inference --------------------------------------------
    logger.info("Start inference")
    start_time = datetime.now()
    res = executable_network.infer(test_im)
    infer_time = datetime.now() - start_time
    logger.info("Finish inference")
    logger.info("Inference time is {}".format(infer_time))

    # ---------------------------- 6. Processing of the received inference results ------------------------------
    result = res[out_name]
    batch, channels, out_d, out_h, out_w = result.shape

    list_img = list()
    list_seg_result = list()

    logger.info("Processing of the received inference results is started")
    start_time = datetime.now()
    for batch, data in enumerate(result):
        seg_result = np.zeros(shape=original_size, dtype=np.uint8)
        if data.shape[1:] != original_size:
            x = bbox[1] - bbox[0]
            y = bbox[3] - bbox[2]
            z = bbox[5] - bbox[4]
            out_result = np.zeros(shape=((channels, ) + original_size),
                                  dtype=float)
            out_result[:,bbox[0]:bbox[1], bbox[2]:bbox[3], bbox[4]:bbox[5]] = \
                resample_np(data, (channels, x, y, z), 1)
        else:
            out_result = data

        if channels == 1:
            reshaped_data = out_result.reshape(original_size[0],
                                               original_size[1],
                                               original_size[2])
            mask = reshaped_data[:, :, :] > 0.5
            reshaped_data[mask] = 1
            seg_result = reshaped_data.astype(int)
        elif channels == 4:
            seg_result = np.argmax(out_result, axis=0).astype(int)
        elif channels == 3:
            res = np.zeros(shape=out_result.shape, dtype=bool)
            res = out_result > 0.5
            wt = res[0]
            tc = res[1]
            et = res[2]

            seg_result[wt] = 2
            seg_result[tc] = 1
            seg_result[et] = 3

        im = np.stack([
            original_data[batch, 0, :, :, :], original_data[batch, 0, :, :, :],
            original_data[batch, 0, :, :, :]
        ],
                      axis=3)

        im = 255 * (im - im.min()) / (im.max() - im.min())
        color_seg_frame = np.zeros(im.shape, dtype=np.uint8)
        for idx, c in enumerate(CLASSES_COLOR_MAP):
            color_seg_frame[seg_result[:, :, :] == idx, :] = np.array(
                c, dtype=np.uint8)
        mask = seg_result[:, :, :] > 0
        im[mask] = color_seg_frame[mask]

        for k in range(im.shape[2]):
            if is_nifti_data:
                list_img.append(
                    Image.fromarray(im[:, :, k, :].astype('uint8'), 'RGB'))
            else:
                list_img.append(
                    Image.fromarray(im[k, :, :, :].astype('uint8'), 'RGB'))

        if args.output_nifti and is_nifti_data:
            list_seg_result.append(seg_result)

    result_processing_time = datetime.now() - start_time
    logger.info("Processing of the received inference results is finished")
    logger.info("Processing time is {}".format(result_processing_time))

    # --------------------------------------------- 7. Save output -----------------------------------------------
    tiff_output_name = os.path.join(args.path_to_output, 'output.tiff')
    Image.new('RGB', (original_data.shape[3], original_data.shape[2])).save(tiff_output_name, \
        append_images=list_img, save_all=True)
    logger.info("Result tiff file was saved to {}".format(tiff_output_name))

    if args.output_nifti and is_nifti_data:
        for seg_res in list_seg_result:
            nii_filename = os.path.join(
                args.path_to_output,
                'output_{}.nii.gz'.format(list_seg_result.index(seg_res)))
            nib.save(nib.Nifti1Image(seg_res, affine=affine), nii_filename)
            logger.info(
                "Result nifti file was saved to {}".format(nii_filename))