def create_field_from_array_like(field_name, maybe_array, annotations=None): if annotations and isinstance(annotations, dict): index_dimensions = annotations.get('index_dimensions', 0) field_type = annotations.get('field_type', FieldType.GENERIC) elif isinstance(maybe_array, ArrayWrapper): index_dimensions = maybe_array.index_dimensions field_type = maybe_array.field_type maybe_array = maybe_array.array else: index_dimensions = 0 field_type = FieldType.GENERIC if 'tensorflow' in str(type(maybe_array)) and 'Tensor' in str( type(maybe_array)): try: # This fails on eager execution return Field.create_fixed_size( maybe_array.name or field_name, maybe_array.shape, index_dimensions=index_dimensions, dtype=maybe_array.dtype.as_numpy_dtype()) except Exception: return Field.create_fixed_size( field_name, maybe_array.shape, index_dimensions=index_dimensions, dtype=maybe_array.dtype.as_numpy_dtype()) elif 'torch.Tensor' in str(type(maybe_array)): maybe_array = _torch_tensor_to_numpy_shim(maybe_array) field = Field.create_from_numpy_array(field_name, maybe_array, index_dimensions) field.field_type = field_type if hasattr(maybe_array, 'coordinate_transform'): field.coordinate_transform = maybe_array.coordinate_transform if hasattr(maybe_array, 'coordinate_origin'): field.coordinate_origin = maybe_array.coordinate_origin return field
def __init__(self, input: Field, output: Field, block_stencil, matching_stencil, compilation_target, max_block_matches, blockmatching_threshold, hard_threshold, matching_function=pystencils_reco.functions.squared_difference, wiener_sigma=None, **compilation_kwargs): matching_stencil = sorted(matching_stencil, key=lambda o: sum(abs(o) for o in o)) input_field = pystencils_reco._crazy_decorator.coerce_to_field('input_field', input) output_field = pystencils_reco._crazy_decorator.coerce_to_field('output_field', output) accumulated_weights = input_field.new_field_with_different_name('accumulated_weights') block_scores_shape = output_field.shape + (len(matching_stencil),) block_scores = Field.create_fixed_size('block_scores', block_scores_shape, index_dimensions=1, dtype=input_field.dtype.numpy_dtype) self.block_scores = block_scores block_matched_shape = input_field.shape + (max_block_matches, len(block_stencil)) block_matched_field = Field.create_fixed_size('block_matched', block_matched_shape, index_dimensions=2, dtype=input_field.dtype.numpy_dtype) self.block_matched_field = block_matched_field self.block_matching = block_matching_integer_offsets(input, input, block_scores, block_stencil, matching_stencil, compilation_target, matching_function, **compilation_kwargs) self.collect_patches = collect_patches(block_scores, input, block_matched_field, block_stencil, matching_stencil, blockmatching_threshold, max_block_matches, compilation_target, **compilation_kwargs) complex_field = Field.create_fixed_size('complex_field', block_matched_shape + (2,), index_dimensions=3, dtype=input_field.dtype.numpy_dtype) group_weights = Field.create_fixed_size('group_weights', block_scores_shape, index_dimensions=1, dtype=input_field.dtype.numpy_dtype) self.complex_field = complex_field self.group_weights = group_weights self.hard_thresholding = hard_thresholding( complex_field, group_weights, hard_threshold).compile(compilation_target) if not wiener_sigma: wiener_sigma = pystencils_reco.typed_symbols('wiener_sigma', input_field.dtype.numpy_dtype) wiener_coefficients = Field.create_fixed_size('wiener_coefficients', block_matched_shape, index_dimensions=2, dtype=input_field.dtype.numpy_dtype) self.get_wieners = calc_wiener_coefficients(complex_field, wiener_coefficients, wiener_sigma).compile(compilation_target) self.apply_wieners = apply_wieners(complex_field, wiener_coefficients, group_weights).compile(compilation_target) self.aggregate = aggregate(block_scores, output, block_matched_field, block_stencil, matching_stencil, blockmatching_threshold, max_block_matches, compilation_target, group_weights, accumulated_weights, **compilation_kwargs)