示例#1
0
def create_field_from_array_like(field_name, maybe_array, annotations=None):
    if annotations and isinstance(annotations, dict):
        index_dimensions = annotations.get('index_dimensions', 0)
        field_type = annotations.get('field_type', FieldType.GENERIC)
    elif isinstance(maybe_array, ArrayWrapper):
        index_dimensions = maybe_array.index_dimensions
        field_type = maybe_array.field_type
        maybe_array = maybe_array.array
    else:
        index_dimensions = 0
        field_type = FieldType.GENERIC

    if 'tensorflow' in str(type(maybe_array)) and 'Tensor' in str(
            type(maybe_array)):
        try:
            # This fails on eager execution
            return Field.create_fixed_size(
                maybe_array.name or field_name,
                maybe_array.shape,
                index_dimensions=index_dimensions,
                dtype=maybe_array.dtype.as_numpy_dtype())
        except Exception:
            return Field.create_fixed_size(
                field_name,
                maybe_array.shape,
                index_dimensions=index_dimensions,
                dtype=maybe_array.dtype.as_numpy_dtype())
    elif 'torch.Tensor' in str(type(maybe_array)):
        maybe_array = _torch_tensor_to_numpy_shim(maybe_array)

    field = Field.create_from_numpy_array(field_name, maybe_array,
                                          index_dimensions)
    field.field_type = field_type
    if hasattr(maybe_array, 'coordinate_transform'):
        field.coordinate_transform = maybe_array.coordinate_transform
    if hasattr(maybe_array, 'coordinate_origin'):
        field.coordinate_origin = maybe_array.coordinate_origin
    return field
示例#2
0
    def __init__(self, input: Field,
                 output: Field,
                 block_stencil,
                 matching_stencil,
                 compilation_target,
                 max_block_matches,
                 blockmatching_threshold,
                 hard_threshold,
                 matching_function=pystencils_reco.functions.squared_difference,
                 wiener_sigma=None,
                 **compilation_kwargs):
        matching_stencil = sorted(matching_stencil, key=lambda o: sum(abs(o) for o in o))

        input_field = pystencils_reco._crazy_decorator.coerce_to_field('input_field', input)
        output_field = pystencils_reco._crazy_decorator.coerce_to_field('output_field', output)
        accumulated_weights = input_field.new_field_with_different_name('accumulated_weights')

        block_scores_shape = output_field.shape + (len(matching_stencil),)
        block_scores = Field.create_fixed_size('block_scores',
                                               block_scores_shape,
                                               index_dimensions=1,
                                               dtype=input_field.dtype.numpy_dtype)
        self.block_scores = block_scores

        block_matched_shape = input_field.shape + (max_block_matches, len(block_stencil))
        block_matched_field = Field.create_fixed_size('block_matched',
                                                      block_matched_shape,
                                                      index_dimensions=2,
                                                      dtype=input_field.dtype.numpy_dtype)
        self.block_matched_field = block_matched_field

        self.block_matching = block_matching_integer_offsets(input,
                                                             input,
                                                             block_scores,
                                                             block_stencil,
                                                             matching_stencil,
                                                             compilation_target,
                                                             matching_function,
                                                             **compilation_kwargs)
        self.collect_patches = collect_patches(block_scores,
                                               input,
                                               block_matched_field,
                                               block_stencil,
                                               matching_stencil,
                                               blockmatching_threshold,
                                               max_block_matches,
                                               compilation_target,
                                               **compilation_kwargs)
        complex_field = Field.create_fixed_size('complex_field',
                                                block_matched_shape + (2,),
                                                index_dimensions=3,
                                                dtype=input_field.dtype.numpy_dtype)
        group_weights = Field.create_fixed_size('group_weights',
                                                block_scores_shape,
                                                index_dimensions=1,
                                                dtype=input_field.dtype.numpy_dtype)
        self.complex_field = complex_field
        self.group_weights = group_weights
        self.hard_thresholding = hard_thresholding(
            complex_field, group_weights, hard_threshold).compile(compilation_target)
        if not wiener_sigma:
            wiener_sigma = pystencils_reco.typed_symbols('wiener_sigma', input_field.dtype.numpy_dtype)
        wiener_coefficients = Field.create_fixed_size('wiener_coefficients',
                                                      block_matched_shape,
                                                      index_dimensions=2,
                                                      dtype=input_field.dtype.numpy_dtype)
        self.get_wieners = calc_wiener_coefficients(complex_field, wiener_coefficients,
                                                    wiener_sigma).compile(compilation_target)
        self.apply_wieners = apply_wieners(complex_field, wiener_coefficients,
                                           group_weights).compile(compilation_target)

        self.aggregate = aggregate(block_scores,
                                   output,
                                   block_matched_field,
                                   block_stencil,
                                   matching_stencil,
                                   blockmatching_threshold,
                                   max_block_matches,
                                   compilation_target,
                                   group_weights,
                                   accumulated_weights,
                                   **compilation_kwargs)