예제 #1
0
def load_weights_from_hdf5(model_path, keras_model):
    f = h5py.File(model_path, 'r')['model_weights']
    layers = keras_model.inner_model.layers if hasattr(keras_model, "inner_model") \
        else keras_model.layers
    filtered_layers = []
    for layer in layers:
        weights = layer.weights
        if weights:
            filtered_layers.append(layer)
    layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
    filtered_layer_names = []
    for name in layer_names:
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
        if weight_names:
            filtered_layer_names.append(name)
    layer_names = filtered_layer_names
    if len(layer_names) != len(filtered_layers):
        raise ValueError('You are trying to load a weight file '
                         'containing ' + str(len(layer_names)) +
                         ' layers into a model with ' +
                         str(len(filtered_layers)) + ' layers.')
    weight_dict = {}
    for k, name in enumerate(layer_names):
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
        weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
        weight_dict.update({name: weight_values})
    return weight_dict
예제 #2
0
def _load_weights_from_hdf5_group(f,
                                  layers):
    """
    Implements topological (order-based) weight loading.

    Parameters
    ----------
    f : File
        A pointer to a HDF5 group.
    layers : list of np.array
        List of target layers.
    """
    filtered_layers = []
    for layer in layers:
        weights = layer.weights
        if weights:
            filtered_layers.append(layer)

    layer_names = load_attributes_from_hdf5_group(f, "layer_names")
    filtered_layer_names = []
    for name in layer_names:
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(g, "weight_names")
        if weight_names:
            filtered_layer_names.append(name)
    layer_names = filtered_layer_names
    if len(layer_names) != len(filtered_layers):
        raise ValueError("You are trying to load a weight file "
                         "containing " + str(len(layer_names)) +
                         " layers into a model with " +
                         str(len(filtered_layers)) + " layers.")

    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(g, "weight_names")
        weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
        layer = filtered_layers[k]
        symbolic_weights = layer.weights
        weight_values = _preprocess_weights_for_loading(
            layer=layer,
            weights=weight_values)
        if len(weight_values) != len(symbolic_weights):
            raise ValueError("Layer #" + str(k) +
                             " (named `" + layer.name +
                             "` in the current model) was found to "
                             "correspond to layer " + name +
                             " in the save file. "
                             "However the new layer " + layer.name +
                             " expects " + str(len(symbolic_weights)) +
                             " weights, but the saved weights have " +
                             str(len(weight_values)) +
                             " elements.")
        weight_value_tuples += zip(symbolic_weights, weight_values)
    K.batch_set_value(weight_value_tuples)
예제 #3
0
def _load_weights_from_hdf5_group_by_name(f, layers):
    """
    Implements name-based weight loading.

    Parameters
    ----------
    f : File
        A pointer to a HDF5 group.
    layers : list of np.array
        List of target layers.
    """
    # New file format.
    layer_names = load_attributes_from_hdf5_group(f, "layer_names")

    # Reverse index of layer name to list of layers with name.
    index = {}
    for layer in layers:
        if layer.name:
            index.setdefault(layer.name, []).append(layer)

    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(g, "weight_names")
        weight_values = [
            np.asarray(g[weight_name]) for weight_name in weight_names
        ]

        for layer in index.get(name, []):
            symbolic_weights = layer.weights
            weight_values = _preprocess_weights_for_loading(
                layer=layer, weights=weight_values)
            if len(weight_values) != len(symbolic_weights):
                warnings.warn(
                    "Skipping loading of weights for layer {} due to mismatch in number of weights ({} vs"
                    " {}).".format(layer, len(symbolic_weights),
                                   len(weight_values)))
                continue
            # Set values.
            for i in range(len(weight_values)):
                symbolic_shape = K.int_shape(symbolic_weights[i])
                if symbolic_shape != weight_values[i].shape:
                    warnings.warn(
                        "Skipping loading of weights for layer {} due to mismatch in shape ({} vs"
                        " {}).".format(layer, symbolic_weights[i].shape,
                                       weight_values[i].shape))
                    continue
                else:
                    weight_value_tuples.append(
                        (symbolic_weights[i], weight_values[i]))
    K.batch_set_value(weight_value_tuples)
예제 #4
0
파일: vgg16.py 프로젝트: egrassl/ts-cnn
def get_named_layer_weights_from_h5py(h5py_file):
    """decodes h5py for a given model downloaded by keras and gets layer weight name to value mapping"""
    with h5py.File(h5py_file) as h5py_stream:
        layer_names = load_attributes_from_hdf5_group(h5py_stream,
                                                      'layer_names')

        weights_values = []
        for name in layer_names:
            layer = h5py_stream[name]
            weight_names = load_attributes_from_hdf5_group(
                layer, 'weight_names')
            if weight_names:
                weight_values = [
                    np.asarray(layer[weight_name])
                    for weight_name in weight_names
                ]
                weights_values.append((name, weight_values))
    return weights_values
예제 #5
0
파일: weights_test.py 프로젝트: sai36/SRGAN
import h5py
import os
from keras.engine import saving
weights = h5py.File("/proj/SegSRGAN/snapshot/SegSRGAN_epoch_99", 'r')
print(weights)
G = weights[list(weights.keys())[1]]
weight_names = saving.load_attributes_from_hdf5_group(G, 'weight_names')
print(weight_names)
for i in weight_names:
    if 'gen_conv1' in i:
        weight_values = G[i]
    if 'gen_1conv' in i:
        last_conv = G[i]
    first_generator_kernel = weight_values.shape[4]
    nb_out_kernel = last_conv.shape[4]
    if nb_out_kernel > 2:
        fit_mask = True
        nb_classe_mask = nb_out_kernel - 2
        print("The initialize network will fit mask")
    else:
        fit_mask = False
        nb_classe_mask = 0
        print("The initialize network won't fit mask")
def segmentation(input_file_path,
                 step,
                 new_resolution,
                 path_output_cortex,
                 path_output_hr,
                 weights_path,
                 patch=None,
                 spline_order=3,
                 by_batch=False):
    """

    :param input_file_path: path of the image to be super resolved and segmented
    :param step: the shifting step for the patches
    :param new_resolution: the new z-resolution we want for the output image
    :param path_output_cortex: output path of the segmented cortex
    :param path_output_hr: output path of the super resolution output image
    :param weights_path: the path of the file which contains the pre-trained weights for the neural network
    :param patch: the size of the patches
    :param spline_order: for the interpolation
    :param by_batch: to enable the by-batch processing
    :return:
    """
    # TestFile = path de l'image en entree
    # high_resolution = tuple des resolutions (par axe)

    # Get the generator kernel from the weights we are going to use

    weights = h5py.File(weights_path, 'r')
    G = weights[list(weights.keys())[1]]
    weight_names = saving.load_attributes_from_hdf5_group(G, 'weight_names')
    for i in weight_names:
        if 'gen_conv1' in i:
            weight_values = G[i]
    first_generator_kernel = weight_values.shape[4]

    # Get the generator kernel from the weights we are going to use

    D = weights[list(weights.keys())[0]]
    weight_names = saving.load_attributes_from_hdf5_group(D, 'weight_names')
    for i in weight_names:
        if 'conv_dis_1/kernel' in i:
            weight_values = D[i]
    first_discriminator_kernel = weight_values.shape[4]

    # Selection of the kind of network

    if "_nn_residual" in list(weights.keys())[1]:

        residual_string = "_nn_residual"
        is_residual = False

    else:

        residual_string = ""
        is_residual = True

    if ('G_cond' + residual_string) == list(weights.keys())[1]:
        is_conditional = True
        u_net_gen = False
    elif ('G_unet' + residual_string) == list(weights.keys())[1]:
        is_conditional = False
        u_net_gen = True
    elif ('G_unet_cond' + residual_string) == list(weights.keys())[1]:
        is_conditional = True
        u_net_gen = True
    else:
        is_conditional = False
        u_net_gen = False

    # Check resolution
    if np.isscalar(new_resolution):
        new_resolution = (new_resolution, new_resolution, new_resolution)
    else:
        if len(new_resolution) != 3:
            raise AssertionError('Resolution not supported!')

    # Read low-resolution image
    if input_file_path.endswith('.nii.gz'):
        image_instance = NIFTIReader(input_file_path)
    elif os.path.isdir(input_file_path):
        image_instance = DICOMReader(input_file_path)

    test_image = image_instance.get_np_array()
    test_imageMinValue = float(np.min(test_image))
    test_imageMaxValue = float(np.max(test_image))
    test_imageNorm = test_image / test_imageMaxValue

    resolution = image_instance.get_resolution()
    itk_image = image_instance.itk_image

    # Check scale factor type
    up_scale = tuple(
        itema / itemb
        for itema, itemb in zip(itk_image.GetSpacing(), new_resolution))

    # spline interpolation
    interpolated_image = scipy.ndimage.zoom(test_imageNorm,
                                            zoom=up_scale,
                                            order=spline_order)

    if patch is not None:

        print("patch given")

        patch1 = patch2 = patch3 = int(patch)

        border = (int((interpolated_image.shape[0] - int(patch)) % step),
                  int((interpolated_image.shape[1] - int(patch)) % step),
                  int((interpolated_image.shape[2] - int(patch)) % step))

        border_to_add = (step - border[0], step - border[1], step - border[2])

        # padd border
        padded_interpolated_image = pad3D(
            interpolated_image, border_to_add)  # remove border of the image

    else:
        border = (int(interpolated_image.shape[0] % 4),
                  int(interpolated_image.shape[1] % 4),
                  int(interpolated_image.shape[2] % 4))
        border_to_add = (4 - border[0], 4 - border[1], 4 - border[2])

        padded_interpolated_image = pad3D(
            interpolated_image, border_to_add)  # remove border of the image

        height, width, depth = np.shape(padded_interpolated_image)
        patch1 = height
        patch2 = width
        patch3 = depth

    if ((step > patch1) | (step > patch2) |
        (step > patch3)) & (patch is not None):

        raise AssertionError('The step need to be smaller than the patch size')

    if (np.shape(padded_interpolated_image)[0] <
            patch1) | (np.shape(padded_interpolated_image)[1] < patch2) | (
                np.shape(padded_interpolated_image)[2] < patch3):

        raise AssertionError(
            'The patch size need to be smaller than the interpolated image size'
        )

    # Loading weights
    segsrgan_test_instance = SegSRGAN_test(weights_path, patch1, patch2,
                                           patch3, is_conditional, u_net_gen,
                                           is_residual, first_generator_kernel,
                                           first_discriminator_kernel,
                                           resolution)

    # GAN
    print("Testing : ")
    estimated_hr_image, estimated_cortex = segsrgan_test_instance.test_by_patch(
        padded_interpolated_image, step=step, by_batch=by_batch)
    # parcours de l'image avec le patch

    # Padding
    # on fait l'operation de padding a l'envers
    padded_estimated_hr_image = shave3D(estimated_hr_image, border_to_add)
    estimated_cortex = shave3D(estimated_cortex, border_to_add)

    # SR image
    estimated_hr_imageInverseNorm = padded_estimated_hr_image * test_imageMaxValue
    estimated_hr_imageInverseNorm[
        estimated_hr_imageInverseNorm <=
        test_imageMinValue] = test_imageMinValue  # Clear negative value
    output_image = sitk.GetImageFromArray(
        np.swapaxes(estimated_hr_imageInverseNorm, 0, 2))
    output_image.SetSpacing(new_resolution)
    output_image.SetOrigin(itk_image.GetOrigin())
    output_image.SetDirection(itk_image.GetDirection())

    sitk.WriteImage(output_image, path_output_hr)

    # Cortex segmentation
    output_cortex = sitk.GetImageFromArray(np.swapaxes(estimated_cortex, 0, 2))
    output_cortex.SetSpacing(new_resolution)
    output_cortex.SetOrigin(itk_image.GetOrigin())
    output_cortex.SetDirection(itk_image.GetDirection())

    sitk.WriteImage(output_cortex, path_output_cortex)

    return "Segmentation Done"
예제 #7
0
def load_weights_from_hdf5_group_by_name(f,
                                         layers,
                                         skip_mismatch=False,
                                         reshape=False):
    """Implements name-based weight loading.

    (instead of topological weight loading).

    Layers that have no matching name are skipped.

    # Arguments
        f: A pointer to a HDF5 group.
        layers: A list of target layers.
        skip_mismatch: Boolean, whether to skip loading of layers
            where there is a mismatch in the number of weights,
            or a mismatch in the shape of the weights.
        reshape: Reshape weights to fit the layer when the correct number
            of values are present but the shape does not match.

    # Raises
        ValueError: in case of mismatch between provided layers
            and weights file and skip_mismatch=False.
    """
    if 'keras_version' in f.attrs:
        original_keras_version = f.attrs['keras_version'].decode('utf8')
    else:
        original_keras_version = '1'
    if 'backend' in f.attrs:
        original_backend = f.attrs['backend'].decode('utf8')
    else:
        original_backend = None

    # New file format.
    layer_names = load_attributes_from_hdf5_group(f, 'layer_names')

    # Reverse index of layer name to list of layers with name.
    index = {}

    for layer in layers:
        if layer.name:
            index.setdefault(layer.name, []).append(layer)

    print(layer_names)
    print(index.keys())

    # We batch weight value assignments in a single backend call
    # which provides a speedup in TensorFlow.
    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        print(name)
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
        weight_values = [
            np.asarray(g[weight_name]) for weight_name in weight_names
        ]

        for layer in index.get(name, []):
            symbolic_weights = layer.weights

            symbolic_weights_names = [w.name for w in symbolic_weights]

            weight_values = preprocess_weights_for_loading(
                layer,
                weight_values,
                original_keras_version,
                original_backend,
                reshape=reshape)
            # if len(weight_values) != len(symbolic_weights):
            #     if skip_mismatch:
            #         warnings.warn('Skipping loading of weights for layer {}'.format(layer.name) +
            #                       ' due to mismatch in number of weights' +
            #                       ' ({} vs {}).'.format(len(symbolic_weights), len(weight_values)))
            #         continue
            #     else:
            #         raise ValueError('Layer #' + str(k) +
            #                          ' (named "' + layer.name +
            #                          '") expects ' +
            #                          str(len(symbolic_weights)) +
            #                          ' weight(s), but the saved weights' +
            #                          ' have ' + str(len(weight_values)) +
            #                          ' element(s).')
            # Set values.
            weight_names = [name.split('/')[-1]
                            for name in weight_names]  #delete prefix of name
            symbolic_weights_names = [
                name.split('/')[-1] for name in symbolic_weights_names
            ]
            print(weight_names)
            print(symbolic_weights_names)

            for i in range(len(weight_values)):

                if weight_names[i] in symbolic_weights_names:
                    ii = symbolic_weights_names.index(weight_names[i])

                    if K.int_shape(
                            symbolic_weights[ii]) != weight_values[i].shape:
                        if skip_mismatch:
                            warnings.warn(
                                'Skipping loading of weights for layer {}'.
                                format(layer.name) +
                                ' due to mismatch in shape' +
                                ' ({} vs {}).'.format(
                                    symbolic_weights[ii].shape,
                                    weight_values[i].shape))
                            continue
                        else:
                            raise ValueError(
                                'Layer #' + str(k) + ' (named "' + layer.name +
                                '"), weight ' + str(symbolic_weights[ii]) +
                                ' has shape {}'.format(
                                    K.int_shape(symbolic_weights[ii])) +
                                ', but the saved weight has shape ' +
                                str(weight_values[i].shape) + '.')
                    else:
                        weight_value_tuples.append(
                            (symbolic_weights[ii], weight_values[i]))

    K.batch_set_value(weight_value_tuples)
예제 #8
0
def load_weights_from_hdf5_group_by_name(f,
                                         layers,
                                         skip_mismatch=False,
                                         reshape=False,
                                         consider_weight_name_match=False):
    """Implements name-based weight loading.

    (instead of topological weight loading).

    Layers that have no matching name are skipped.

    # Arguments
        f: A pointer to a HDF5 group.
        layers: A list of target layers.
        skip_mismatch: Boolean, whether to skip loading of layers
            where there is a mismatch in the number of weights,
            or a mismatch in the shape of the weights.
        reshape: Reshape weights to fit the layer when the correct number
            of values are present but the shape does not match.
        consider_weight_name_match: Boolean, whether to consider loading of layers
            even when there is a mismatch in the number of weights,
            in this case loading any weights that have name and shape match,
            only applicable when `skip_mismatch` = False

    # Raises
        ValueError: in case of mismatch between provided layers
            and weights file and skip_mismatch=False.
    """
    if 'keras_version' in f.attrs:
        original_keras_version = f.attrs['keras_version'].decode('utf8')
    else:
        original_keras_version = '1'
    if 'backend' in f.attrs:
        original_backend = f.attrs['backend'].decode('utf8')
    else:
        original_backend = None

    # New file format.
    layer_names = load_attributes_from_hdf5_group(f, 'layer_names')

    # Reverse index of layer name to list of layers with name.
    index = {}
    for layer in layers:
        if layer.name:
            index.setdefault(layer.name, []).append(layer)

    # We batch weight value assignments in a single backend call
    # which provides a speedup in TensorFlow.
    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
        weight_values = [
            np.asarray(g[weight_name]) for weight_name in weight_names
        ]

        for layer in index.get(name, []):
            symbolic_weights = layer.weights
            weight_values = preprocess_weights_for_loading(
                layer,
                weight_values,
                original_keras_version,
                original_backend,
                reshape=reshape)
            if len(weight_values) != len(symbolic_weights):
                if skip_mismatch:
                    warnings.warn(
                        'Skipping loading of weights for '
                        'layer {}'.format(layer.name) + ' due to mismatch '
                        'in number of weights ({} vs {}).'.format(
                            len(symbolic_weights), len(weight_values)))
                    continue
                else:  #(thanhnt): Allows loading if variable name match (conditioned on variable shape match)
                    if not consider_weight_name_match:
                        raise ValueError(
                            'Layer #' + str(k) + ' (named "' + layer.name +
                            '") expects ' + str(len(symbolic_weights)) +
                            ' weight(s), but the saved weights' + ' have ' +
                            str(len(weight_values)) + ' element(s).' +
                            'Consider set `consider_weight_name_match`' +
                            ' to `True` to load weights by name match.')
                    else:
                        warnings.warn(
                            'Mismatch in '
                            'the number of weights ({} vs {}).'.format(
                                len(symbolic_weights), len(weight_values)) +
                            ' Loading still continues for whichever model variable whose name matches that of the stored variables '
                            '(conditioned on variable shape match).')
                        warning_weights = []
                        for i in range(len(symbolic_weights)):
                            symbolic_shape = K.int_shape(symbolic_weights[i])
                            symbolic_name = symbolic_weights[i].name.split(
                                '/')[-1].split(':')[0]
                            # Look up for any weight name match
                            _check = [  weight_value_tuples.append((symbolic_weights[i], weight_value))    \
                                for weight_name, weight_value in zip(weight_names, weight_values) \
                                    if weight_name.split('/')[-1].split(':')[0] == symbolic_name and \
                                        weight_value.shape == symbolic_shape ]
                            if len(_check) == 0:
                                warning_weights.append(
                                    symbolic_weights[i].name)
                        if len(warning_weights) > 0:
                            warnings.warn(
                                'Skipping loading of weights of some variables for '
                                'layer {}'.format(layer.name) +
                                ' due to mismatch '
                                'in variable names or variable shapes. '
                                'The variables are {}.'.format(warning_weights)
                                + 'The stored variables are {}.'.format(
                                    weight_names))
            else:
                # Set values.
                for i in range(len(weight_values)):
                    symbolic_shape = K.int_shape(symbolic_weights[i])
                    if symbolic_shape != weight_values[i].shape:
                        if skip_mismatch:
                            warnings.warn('Skipping loading of weights for '
                                          'layer {}'.format(layer.name) +
                                          ' due to '
                                          'mismatch in shape ({} vs {}).'.
                                          format(symbolic_weights[i].shape,
                                                 weight_values[i].shape))
                            continue
                        else:
                            raise ValueError(
                                'Layer #' + str(k) + ' (named "' + layer.name +
                                '"), weight ' + str(symbolic_weights[i]) +
                                ' has shape {}'.format(symbolic_shape) +
                                ', but the saved weight has shape ' +
                                str(weight_values[i].shape) + '.')
                    else:
                        weight_value_tuples.append(
                            (symbolic_weights[i], weight_values[i]))

    K.batch_set_value(weight_value_tuples)
예제 #9
0
def segmentation(input_file_path,
                 step,
                 new_resolution,
                 path_output_cortex,
                 path_output_hr,
                 path_output_mask,
                 weights_path,
                 interpolation_type,
                 patch=None,
                 spline_order=3,
                 by_batch=False,
                 interp='scipy'):
    """

    :param input_file_path: path of the image to be super resolved and segmented
    :param step: the shifting step for the patches
    :param new_resolution: the new z-resolution we want for the output image
    :param path_output_cortex: output path of the segmented cortex
    :param path_output_hr: output path of the super resolution output image
    :param path_output_mask: output path of the estimated mask. Only used if the weights have been obtain for also fit mask.
    :param weights_path: the path of the file which contains the pre-trained weights for the neural network
    :param patch: the size of the patches
    :param spline_order: for the interpolation
    :param by_batch: to enable the by-batch processing
    :return:
    """
    # TestFile = path de l'image en entree
    # high_resolution = tuple des resolutions (par axe)

    # Get the generator kernel from the weights we are going to use

    weights = h5py.File(weights_path, 'r')
    G = weights[list(weights.keys())[1]]
    weight_names = saving.load_attributes_from_hdf5_group(G, 'weight_names')
    for i in weight_names:
        if 'gen_conv1' in i:
            weight_values = G[i]
        if 'gen_1conv' in i:
            last_conv = G[i]

    first_generator_kernel = weight_values.shape[4]
    nb_out_kernel = last_conv.shape[4]

    if nb_out_kernel > 2:
        fit_mask = True
        nb_classe_mask = nb_out_kernel - 2
        print("The initialize network will fit mask")
    else:
        fit_mask = False
        nb_classe_mask = 0
        print("The initialize network won't fit mask")

    # Get the generator kernel from the weights we are going to use

    D = weights[list(weights.keys())[0]]
    weight_names = saving.load_attributes_from_hdf5_group(D, 'weight_names')
    for i in weight_names:
        if 'conv_dis_1/kernel' in i:
            weight_values = D[i]
    first_discriminator_kernel = weight_values.shape[4]

    # Selection of the kind of network

    if "_nn_residual" in list(weights.keys())[1]:

        residual_string = "_nn_residual"
        is_residual = False

    else:

        residual_string = ""
        is_residual = True

    if ('G_cond' + residual_string) == list(weights.keys())[1]:
        is_conditional = True
        u_net_gen = False
    elif ('G_unet' + residual_string) == list(weights.keys())[1]:
        is_conditional = False
        u_net_gen = True
    elif ('G_unet_cond' + residual_string) == list(weights.keys())[1]:
        is_conditional = True
        u_net_gen = True
    else:
        is_conditional = False
        u_net_gen = False

    # Check resolution
    if np.isscalar(new_resolution):
        new_resolution = (new_resolution, new_resolution, new_resolution)
    else:
        if len(new_resolution) != 3:
            raise AssertionError('Resolution not supported!')

    # Read low-resolution image
    if input_file_path.endswith('.nii.gz') or input_file_path.endswith('.hdr'):
        image_instance = NIFTIReader(input_file_path)
    elif os.path.isdir(input_file_path):
        image_instance = DICOMReader(input_file_path)

    test_image = image_instance.get_np_array()
    test_imageMinValue = float(np.min(test_image))

    norm_instance = Normalization(test_image)

    test_imageNorm = norm_instance.get_normalized_image()[
        0]  #zero indice means get only the normalized LR

    resolution = image_instance.get_resolution()
    itk_image = image_instance.itk_image

    # Check scale factor type
    up_scale = tuple(
        itema / itemb
        for itema, itemb in zip(itk_image.GetSpacing(), new_resolution))

    # spline interpolation
    interpolated_image, up_scale = inter.Interpolation(test_imageNorm, up_scale, spline_order, interp,
                                                       interpolation_type). \
        get_interpolated_image(image_instance)

    if patch is not None:

        print("patch given")

        patch1 = patch2 = patch3 = int(patch)

        border = (int((interpolated_image.shape[0] - int(patch)) % step),
                  int((interpolated_image.shape[1] - int(patch)) % step),
                  int((interpolated_image.shape[2] - int(patch)) % step))

        border_to_add = (step - border[0], step - border[1], step - border[2])

        # padd border
        padded_interpolated_image = pad3D(
            interpolated_image, border_to_add)  # remove border of the image

    else:
        border = (int(interpolated_image.shape[0] % 4),
                  int(interpolated_image.shape[1] % 4),
                  int(interpolated_image.shape[2] % 4))
        border_to_add = (4 - border[0], 4 - border[1], 4 - border[2])

        padded_interpolated_image = pad3D(
            interpolated_image, border_to_add)  # remove border of the image

        height, width, depth = np.shape(padded_interpolated_image)
        patch1 = height
        patch2 = width
        patch3 = depth

    if ((step > patch1) | (step > patch2) |
        (step > patch3)) & (patch is not None):

        raise AssertionError('The step need to be smaller than the patch size')

    if (np.shape(padded_interpolated_image)[0] <
            patch1) | (np.shape(padded_interpolated_image)[1] < patch2) | (
                np.shape(padded_interpolated_image)[2] < patch3):

        raise AssertionError(
            'The patch size need to be smaller than the interpolated image size'
        )

    # Loading weights
    segsrgan_test_instance = SegSRGAN_test(weights_path,
                                           patch1,
                                           patch2,
                                           patch3,
                                           is_conditional,
                                           u_net_gen,
                                           is_residual,
                                           first_generator_kernel,
                                           first_discriminator_kernel,
                                           resolution,
                                           fit_mask=fit_mask,
                                           nb_classe_mask=nb_classe_mask)

    # GAN
    print("Testing : ")
    estimated_hr_image, estimated_seg = segsrgan_test_instance.test_by_patch(
        padded_interpolated_image,
        step=step,
        by_batch=by_batch,
        nb_classe_mask=nb_classe_mask)

    estimated_cortex = estimated_seg[0]

    # parcours de l'image avec le patch

    # shaving :

    padded_estimated_hr_image = shave3D(estimated_hr_image, border_to_add)
    # SR image
    estimated_hr_imageInverseNorm = norm_instance.get_denormalized_result_image(
        padded_estimated_hr_image)
    estimated_hr_imageInverseNorm[
        estimated_hr_imageInverseNorm <=
        test_imageMinValue] = test_imageMinValue  # Clear negative value
    output_image = sitk.GetImageFromArray(
        np.swapaxes(estimated_hr_imageInverseNorm, 0, 2))
    output_image.SetSpacing(
        tuple(
            np.array(image_instance.itk_image.GetSpacing()) /
            np.array(up_scale)))
    output_image.SetOrigin(itk_image.GetOrigin())
    output_image.SetDirection(itk_image.GetDirection())
    sitk.WriteImage(output_image, path_output_hr)

    # Cortex segmentation
    estimated_cortex = shave3D(estimated_cortex, border_to_add)
    output_cortex = sitk.GetImageFromArray(np.swapaxes(estimated_cortex, 0, 2))
    output_cortex.SetSpacing(
        tuple(
            np.array(image_instance.itk_image.GetSpacing()) /
            np.array(up_scale)))
    output_cortex.SetOrigin(itk_image.GetOrigin())
    output_cortex.SetDirection(itk_image.GetDirection())
    sitk.WriteImage(output_cortex, path_output_cortex)

    # Mask :

    if fit_mask:
        estimated_mask = estimated_seg[1:]
        estimated_mask_discretized = np.full(estimated_mask.shape[1:], np.nan)
        estimated_mask_discretized = np.argmax(estimated_mask,
                                               axis=0).astype(np.float64)
        estimated_mask_discretized = shave3D(estimated_mask_discretized,
                                             border_to_add)

        # label_connected_region = measure.label(estimated_mask_binary, connectivity=2)
        # unique_elements, counts_elements = np.unique(label_connected_region[label_connected_region != 0], return_counts=
        #                                              True)
        # label_max_element = unique_elements[np.argmax(counts_elements)]
        # estimated_mask_binary[label_connected_region!=label_max_element] = 0
        output_mask = sitk.GetImageFromArray(
            np.swapaxes(estimated_mask_discretized, 0, 2))
        output_mask.SetSpacing(
            tuple(
                np.array(image_instance.itk_image.GetSpacing()) /
                np.array(up_scale)))
        output_mask.SetOrigin(itk_image.GetOrigin())
        output_mask.SetDirection(itk_image.GetDirection())
        sitk.WriteImage(output_mask, path_output_mask)

    return "Segmentation Done"
예제 #10
0
def load_weights_from_hdf5_group_new(f, layers, reshape=False):
    """Implements topological (order-based) weight loading.

    # Arguments
        f: A pointer to a HDF5 group.
        layers: a list of target layers.
        reshape: Reshape weights to fit the layer when the correct number
            of values are present but the shape does not match.

    # Raises
        ValueError: in case of mismatch between provided layers
            and weights file.
    """
    if 'keras_version' in f.attrs:
        original_keras_version = f.attrs['keras_version'].decode('utf8')
    else:
        original_keras_version = '1'
    if 'backend' in f.attrs:
        original_backend = f.attrs['backend'].decode('utf8')
    else:
        original_backend = None
    #Problemas en la m.layers[180].weights = [] (attentive)
    filtered_layers = []  #Recibe solo las layers que tienen pesos
    for layer in layers:
        weights = layer.weights
        if weights:
            filtered_layers.append(layer)

    layer_names = saving.load_attributes_from_hdf5_group(f, 'layer_names')
    filtered_layer_names = []
    for name in layer_names:
        g = f[name]
        weight_names = saving.load_attributes_from_hdf5_group(g, 'weight_names')
        if weight_names:
            filtered_layer_names.append(name)
    layer_names = filtered_layer_names
    #print("||||||||||||||||||||||||||||||||||||||")
    #for i in range(100,113):
    #    print(layer_names[i], "===", filtered_layers[i])
    #    print(" ")
    #print(layer_names[107], "===", filtered_layers[107]) #Problema en la 5/22
    #print("------------------------------------")
    #layer = filtered_layers[107]
    #symbolic_weights = layer.weights
    #for nro in range(0,22):
    #    print(nro)
    #    print(symbolic_weights[nro].shape)
    #print("||||||||||||||||||||||||||||||||||||||")
    if len(layer_names) != len(filtered_layers):
        raise ValueError('You are trying to load a weight file '
                         'containing ' + str(len(layer_names)) +
                         ' layers into a model with ' +
                         str(len(filtered_layers)) + ' layers.')

    # We batch weight value assignments in a single backend call
    # which provides a speedup in TensorFlow.
    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        g = f[name]
        weight_names = saving.load_attributes_from_hdf5_group(g, 'weight_names')
        weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
        layer = filtered_layers[k]
        symbolic_weights = layer.weights

        weight_values = saving.preprocess_weights_for_loading(layer,
                                                       weight_values,
                                                       original_keras_version,
                                                       original_backend,
                                                       reshape=reshape)
        #if(k == 107):
        #    for nro in range(0,21):
        #        print(weight_values[nro].shape)   
            
        if len(weight_values) != len(symbolic_weights):
            raise ValueError('Layer #' + str(k) +
                             ' (named "' + layer.name +
                             '" in the current model) was found to '
                             'correspond to layer ' + name +
                             ' in the save file. '
                             'However the new layer ' + layer.name +
                             ' expects ' + str(len(symbolic_weights)) +
                             ' weights, but the saved weights have ' +
                             str(len(weight_values)) +
                             ' elements.')
         
        #Zona Hack
        for i in range(len(symbolic_weights)):
            if (symbolic_weights[i].shape != weight_values[i].shape):
                weight_values[i] = np.moveaxis(weight_values[i], (0,1,2,3), (3,2,1,0))
        
        weight_value_tuples += zip(symbolic_weights, weight_values)

    #for i in range(len(weight_value_tuples)):
    #    print(i, weight_value_tuples[i][0].shape,"  =  ",weight_value_tuples[i][1].shape)

    K.batch_set_value(weight_value_tuples)
    print("Procedimiento weights_proc.py finalizado")