Пример #1
0
def save_ragged_vals_to_dataset(vals_list, output_path, concat_all=True):
    # print(vals_list)
    def data_gen():
        if concat_all:
            yield tf.ragged.constant(vals_list, dtype=tf.int32)
        else:
            for i, vals in enumerate(vals_list):
                if i % 10000 == 0:
                    print(i)
                yield tf.ragged.constant([vals], dtype=tf.int32)

    dataset = tf.data.Dataset.from_generator(
        data_gen, output_signature=tf.RaggedTensorSpec(ragged_rank=1, dtype=tf.int32))

    # for v in dataset:
    #     print(v)
    print('saving to', output_path)
    tf.data.experimental.save(dataset, output_path, shard_func=lambda x: np.int64(0))
    print('saved')
Пример #2
0
    def _filter_poses(self, boxes, poses3d, poses2d):
        poses3d_mean = tf.reduce_mean(poses3d, axis=-3)
        poses2d_mean = tf.reduce_mean(poses2d, axis=-3)

        plausible_mask_flat = tf.logical_and(
            tf.logical_and(
                is_pose_plausible(poses3d_mean.flat_values),
                are_augmentation_results_consistent(poses3d.flat_values)),
            is_pose_consistent_with_box(poses2d_mean.flat_values, boxes.flat_values))

        plausible_mask = tf.RaggedTensor.from_row_lengths(
            plausible_mask_flat, boxes.row_lengths())

        # Apply pose similarity-based non-maximum suppression to reduce duplicates
        nms_indices = tf.map_fn(
            fn=lambda args: pose_non_max_suppression(*args),
            elems=(poses3d_mean, boxes[..., 4], plausible_mask),
            fn_output_signature=tf.RaggedTensorSpec(shape=(None,), dtype=tf.int32))
        return nms_indices
Пример #3
0
def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
                          new_lengths: tf.Tensor) -> tf.RaggedTensor:
    """Truncates the rows of `ragged_tensor` to the given row lengths."""
    new_lengths = tf.broadcast_to(new_lengths,
                                  ragged_tensor.bounding_shape()[0:1])

    def fn(x):
        row, new_length = x
        return row[0:new_length]

    fn_dtype = tf.RaggedTensorSpec(dtype=ragged_tensor.dtype,
                                   ragged_rank=ragged_tensor.ragged_rank - 1)
    result = tf.map_fn(fn, (ragged_tensor, new_lengths), dtype=fn_dtype)
    # Work around broken shape propagation: without this, result has unknown rank.
    flat_values_shape = [None] * ragged_tensor.flat_values.shape.rank
    result = result.with_flat_values(
        tf.ensure_shape(result.flat_values, flat_values_shape))

    return result
Пример #4
0
 def get_input_info_dict(self, signature=None, tags=None):
   if signature == "ragged" and tags == set(["special"]):
     result = {
         "x":
             tensor_info.ParsedTensorInfo.from_type_spec(
                 type_spec=tf.RaggedTensorSpec(
                     shape=[None, None, None, 3], dtype=tf.float32,
                     ragged_rank=2)),
     }
   else:
     result = {
         "x":
             tensor_info.ParsedTensorInfo(
                 tf.float32,
                 tf.TensorShape([None]),
                 is_sparse=(signature == "sparse" and
                            tags == set(["special"]))),
     }
   if tags == set(["special"]) and signature == "extra":
     result["y"] = result["x"]
   return result
    def get_traing_set(self, batch_size, seed):
        training_set = tf.data.experimental.load(
            "./train",
            tf.RaggedTensorSpec(tf.TensorShape([3, None]), tf.int32, 1,
                                tf.int64))

        with open('./utils/tid_2_aid') as file:
            tid_2_aid = tf.constant(json.load(file))
            file.close()

            self.tid_2_aid = tf.lookup.StaticHashTable(
                tf.lookup.KeyValueTensorInitializer(tid_2_aid[:, 0],
                                                    tid_2_aid[:, 1]),
                default_value=-1)
            del tid_2_aid
        tf.random.set_seed(seed)
        np.random.seed(seed)
        return training_set.map(lambda x: self.corrupt(x)).shuffle(
            1000, seed, True).apply(
                tf.data.experimental.dense_to_ragged_batch(
                    batch_size, drop_remainder=True))
    def __init__(self, model_path, antialias_factor=4):
        super().__init__()
        self.antialias_factor = antialias_factor
        self.crop_model = tf.saved_model.load(model_path)
        self.crop_side = 256
        self.joint_names = self.crop_model.joint_names
        self.joint_edges = self.crop_model.joint_edges
        joint_names = [b.decode('utf8') for b in self.joint_names.numpy()]
        self.joint_info = data.datasets3d.JointInfo(joint_names, self.joint_edges.numpy())

        self.__call__.get_concrete_function(
            tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.uint8),
            tf.TensorSpec(shape=(None, 3, 3), dtype=tf.float32),
            tf.RaggedTensorSpec(shape=(None, None, 4), ragged_rank=1, dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32),
            tf.TensorSpec(shape=(), dtype=tf.int32))

        self.__call__.get_concrete_function(
            tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
            tf.TensorSpec(shape=(3, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32),
            tf.TensorSpec(shape=(), dtype=tf.int32))
Пример #7
0
    def hash_matrix(self):
        '''
        Creates hash_matrix with the weight indices
        '''
        indx = tf.convert_to_tensor([[
            xxhash.xxh32("{}_{}_{}".format(i, j, self.units),
                         self.hash_seed).intdigest() % self.n_weights
            for j in range(self.units)
        ] for i in range(self.input_dim)],
                                    dtype=tf.int32)

        spec = tf.RaggedTensorSpec(dtype=tf.dtypes.int64,
                                   ragged_rank=0,
                                   row_splits_dtype=tf.dtypes.int64)
        cond = tf.stack([indx == i for i in range(self.n_weights)])

        rag = tf.stack([
            tf.map_fn(fn=lambda t: tf.squeeze(tf.where(t), axis=1),
                      elems=c,
                      fn_output_signature=spec,
                      back_prop=False) for c in tf.transpose(cond, [2, 0, 1])
        ])
        self.indices = (rag + 1).to_tensor()
 def testInputSpecsToTensorRepresentations(self):
     tensor_representations = model_util.input_specs_to_tensor_representations(
         {
             'input_1':
             tf.TensorSpec(shape=(None, 2), dtype=tf.int64),
             'input_2':
             tf.SparseTensorSpec(shape=(None, 1), dtype=tf.float32),
             'input_3':
             tf.RaggedTensorSpec(shape=(None, None), dtype=tf.float32),
         })
     dense_tensor_representation = text_format.Parse(
         """
     dense_tensor {
       column_name: "input_1"
       shape { dim { size: 2 } }
     }
     """, schema_pb2.TensorRepresentation())
     sparse_tensor_representation = text_format.Parse(
         """
     varlen_sparse_tensor {
       column_name: "input_2"
     }
     """, schema_pb2.TensorRepresentation())
     ragged_tensor_representation = text_format.Parse(
         """
     ragged_tensor {
       feature_path {
         step: "input_3"
       }
     }
     """, schema_pb2.TensorRepresentation())
     self.assertEqual(
         {
             'input_1': dense_tensor_representation,
             'input_2': sparse_tensor_representation,
             'input_3': ragged_tensor_representation
         }, tensor_representations)
Пример #9
0
 def convert(self, tensor: TensorAlike) -> List[pa.Array]:
   """Converts the given TensorAlike to pa.Arrays after validating its spec."""
   if tf.__version__ < "2":
     if isinstance(tensor, np.ndarray):
       actual_spec = tf.TensorSpec(tensor.shape,
                                   tf.dtypes.as_dtype(tensor.dtype))
     elif isinstance(tensor, tf.compat.v1.SparseTensorValue):
       actual_spec = tf.SparseTensorSpec(tensor.dense_shape,
                                         tensor.values.dtype)
     elif isinstance(tensor, tf.compat.v1.ragged.RaggedTensorValue):
       actual_spec = tf.RaggedTensorSpec(
           tensor.shape,
           tensor.values.dtype,
           row_splits_dtype=tensor.row_splits.dtype)
     else:
       raise TypeError("Only ndarrays, SparseTensorValues and "
                       "RaggedTensorValues are supported with TF 1.x, "
                       "got {}".format(type(tensor)))
   else:
     actual_spec = tf.type_spec_from_value(tensor)
   if not self._type_spec.is_compatible_with(actual_spec):
     raise TypeError("Expected {} but got {}".format(self._type_spec,
                                                     actual_spec))
   return self._convert_internal(tensor)
Пример #10
0
 def testInferTensorSpecs(self):
     sparse_value = types.SparseTensorValue(
         values=np.array([0.5, -1., 0.5, -1.], dtype=np.float32),
         indices=np.array([[0, 3, 1], [0, 20, 0], [1, 3, 1], [1, 20, 0]]),
         dense_shape=np.array([2, 100, 3]))
     ragged_value = types.RaggedTensorValue(
         values=np.array([3, 1, 4, 1, 5, 9, 2, 7, 1, 8, 8, 2, 1],
                         dtype=np.float32),
         nested_row_splits=[
             np.array([0, 3, 6]),
             np.array([0, 2, 3, 4, 5, 5, 8]),
             np.array([0, 2, 3, 3, 6, 9, 10, 11, 13])
         ])
     tensor_values = {
         'features': {
             'feature_1': np.array([1, 2, 3], dtype=np.float32),
             'feature_2': sparse_value,
             'feature_3': ragged_value,
         },
         'labels': np.array([1], dtype=np.float32),
     }
     expected_specs = {
         'features': {
             'feature_1':
             tf.TensorSpec([None], dtype=tf.float32),
             'feature_2':
             tf.SparseTensorSpec(shape=[None, 100, 3], dtype=tf.float32),
             'feature_3':
             tf.RaggedTensorSpec(shape=[None, None, None, None],
                                 dtype=tf.float32)
         },
         'labels': tf.TensorSpec([None], dtype=tf.float32)
     }
     got_specs = util.infer_tensor_specs(
         util.to_tensorflow_tensors(tensor_values))
     self.assertDictEqual(expected_specs, got_specs)
Пример #11
0
  def test_ragged_roundtrip(self, exported_in_tf1):
    if not hasattr(meta_graph_pb2.TensorInfo, 'CompositeTensor'):
      self.skipTest('This version of TensorFlow does not support '
                    'CompositeTenors in TensorInfo.')
    input_specs = {
        'input':
            tf.RaggedTensorSpec(
                shape=[None, None],
                dtype=tf.float32,
                ragged_rank=1,
                row_splits_dtype=tf.int64)
    }

    def preprocessing_fn(inputs):
      return {'output': inputs['input'] / 2.0}

    export_path = _create_test_saved_model(
        exported_in_tf1,
        input_specs,
        preprocessing_fn,
        base_dir=self.get_temp_dir())

    splits = np.array([0, 2, 3], dtype=np.int64)
    values = np.array([1.0, 2.0, 4.0], dtype=np.float32)
    input_ragged = tf.RaggedTensor.from_row_splits(values, splits)

    # Using a computed input gives confidence that the graphs are fused
    inputs = {'input': input_ragged * 10}
    saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path)
    outputs = saved_model_loader.apply_transform_model(inputs)
    result = outputs['output']
    self.assertIsInstance(result, tf.RaggedTensor)

    # indices and shape unchanged; values multipled by 10 and divided by 2
    self.assertAllEqual(splits, result.row_splits)
    self.assertEqual([5.0, 10.0, 20.0], result.values.numpy().tolist())
Пример #12
0
class TensorToArrowTest(tf.test.TestCase, parameterized.TestCase):
    def _assert_tensor_alike_equal(self, left, right):
        self.assertIsInstance(left, type(right))
        if isinstance(left, tf.SparseTensor):
            self.assertAllEqual(left.values, right.values)
            self.assertAllEqual(left.indices, right.indices)
            self.assertAllEqual(left.dense_shape, right.dense_shape)
        else:
            self.assertAllEqual(left, right)

    @parameterized.named_parameters(*_CONVERT_TEST_CASES)
    def test_convert(self, type_specs, expected_schema,
                     expected_tensor_representations, tensor_input,
                     expected_record_batch):
        converter = tensor_to_arrow.TensorsToRecordBatchConverter(type_specs)

        expected_schema = pa.schema(
            [pa.field(n, t) for n, t in sorted(expected_schema.items())])

        self.assertTrue(converter.arrow_schema().equals(expected_schema),
                        "actual: {}".format(converter.arrow_schema()))

        canonical_expected_tensor_representations = {}
        for n, r in expected_tensor_representations.items():
            if not isinstance(r, schema_pb2.TensorRepresentation):
                r = text_format.Parse(r, schema_pb2.TensorRepresentation())
            canonical_expected_tensor_representations[n] = r

        self.assertEqual(canonical_expected_tensor_representations,
                         converter.tensor_representations())

        rb = converter.convert(tensor_input)
        self.assertTrue(
            rb.equals(
                pa.record_batch(
                    [arr for _, arr in sorted(expected_record_batch.items())],
                    schema=expected_schema)))

        # Test that TensorAdapter(TensorsToRecordBatchConverter()) is identity.
        adapter = tensor_adapter.TensorAdapter(
            tensor_adapter.TensorAdapterConfig(
                arrow_schema=converter.arrow_schema(),
                tensor_representations=converter.tensor_representations()))
        adapter_output = adapter.ToBatchTensors(rb, produce_eager_tensors=True)
        self.assertEqual(adapter_output.keys(), tensor_input.keys())
        for k in adapter_output.keys():
            self._assert_tensor_alike_equal(adapter_output[k], tensor_input[k])

    def test_unable_to_handle(self):
        with self.assertRaisesRegex(ValueError, "No handler found"):
            tensor_to_arrow.TensorsToRecordBatchConverter(
                {"sp": tf.SparseTensorSpec([None, None, None], tf.int32)})

        with self.assertRaisesRegex(ValueError, "No handler found"):
            tensor_to_arrow.TensorsToRecordBatchConverter(
                {"sp": tf.SparseTensorSpec([None, None], tf.bool)})

    def test_incompatible_type_spec(self):
        converter = tensor_to_arrow.TensorsToRecordBatchConverter(
            {"sp": tf.SparseTensorSpec([None, None], tf.int32)})
        with self.assertRaisesRegex(TypeError, "Expected SparseTensorSpec"):
            converter.convert({
                "sp":
                tf.SparseTensor(indices=[[0, 1]],
                                values=tf.constant([0], dtype=tf.int64),
                                dense_shape=[4, 1])
            })

    @parameterized.named_parameters(*[
        dict(testcase_name="bool_value_type",
             spec=tf.RaggedTensorSpec(shape=[2, None, None],
                                      dtype=tf.bool,
                                      ragged_rank=2,
                                      row_splits_dtype=tf.int64)),
        dict(testcase_name="2d_leaf_value",
             spec=tf.RaggedTensorSpec(shape=[2, None, None],
                                      dtype=tf.int32,
                                      ragged_rank=1,
                                      row_splits_dtype=tf.int64)),
        dict(testcase_name="ragged_rank_less_than_one",
             spec=tf.RaggedTensorSpec(shape=[2],
                                      dtype=tf.int32,
                                      ragged_rank=0,
                                      row_splits_dtype=tf.int64)),
    ])
    def test_unable_to_handle_ragged(self, spec):
        with self.assertRaisesRegex(ValueError, "No handler found"):
            tensor_to_arrow.TensorsToRecordBatchConverter({"rt": spec})
Пример #13
0
          tf.SparseTensor(values=[b"aa", b"bb"],
                          indices=[[2, 0], [2, 1]],
                          dense_shape=[4, 2])
      },
      expected_record_batch={
          "sp1":
          pa.array([[1], [], [2], []], type=pa.large_list(pa.int32())),
          "sp2":
          pa.array([[], [], [b"aa", b"bb"], []],
                   type=pa.large_list(pa.large_binary()))
      }),
 dict(testcase_name="ragged_tensors",
      type_specs={
          "sp1":
          tf.RaggedTensorSpec(tf.TensorShape([3, None]),
                              tf.int64,
                              ragged_rank=1,
                              row_splits_dtype=tf.int64),
          "sp2":
          tf.RaggedTensorSpec(tf.TensorShape([3, None]),
                              tf.string,
                              ragged_rank=1,
                              row_splits_dtype=tf.int64),
      },
      expected_schema={
          "sp1": pa.large_list(pa.int64()),
          "sp2": pa.large_list(pa.large_binary()),
      },
      expected_tensor_representations={
          "sp1":
          """ragged_tensor {
                     feature_path {
Пример #14
0
class Pose3dEstimator(tf.Module):
    def __init__(self):
        super().__init__()

        # Note that only the Trackable resource attributes such as Variables and Models will be
        # retained when saving to SavedModel
        self.crop_model = tf.saved_model.load(FLAGS.input_model_path)
        self.crop_side = FLAGS.crop_side
        self.joint_names = self.crop_model.joint_names
        self.joint_edges = self.crop_model.joint_edges
        joint_names = [b.decode('utf8') for b in self.joint_names.numpy()]
        self.joint_info = data.datasets3d.JointInfo(joint_names, self.joint_edges.numpy())
        self.detector = tf.saved_model.load(FLAGS.detector_path) if FLAGS.detector_path else None

        if len(joint_names) == 122:
            skeleton_infos = util.load_pickle('./saved_model_export/skeleton_types.pkl')
            self.per_skeleton_indices = {
                k: tf.Variable(v['indices'], dtype=tf.int32, trainable=False)
                for k, v in skeleton_infos.items()}

            self.per_skeleton_joint_names = {
                k: tf.Variable(v['names'], dtype=tf.string, trainable=False)
                for k, v in skeleton_infos.items()}
            self.per_skeleton_joint_edges = {
                k: tf.Variable(v['edges'], dtype=tf.int32, trainable=False)
                for k, v in skeleton_infos.items()}
            self.per_skeleton_indices[''] = tf.range(122, dtype=tf.int32)
            self.per_skeleton_joint_names[''] = self.joint_names
            self.per_skeleton_joint_edges[''] = self.joint_edges

    @tf.function(input_signature=[
        tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.uint8),  # images
        tf.TensorSpec(shape=(None, 3, 3), dtype=tf.float32),  # intrinsic_matrix
        tf.TensorSpec(shape=(None, 5), dtype=tf.float32),  # distortion_coeffs
        tf.TensorSpec(shape=(None, 4, 4), dtype=tf.float32),  # extrinsic_matrix
        tf.TensorSpec(shape=(3,), dtype=tf.float32),  # world_up_vector
        tf.TensorSpec(shape=(), dtype=tf.float32),  # default_fov_degrees
        tf.TensorSpec(shape=(), dtype=tf.int32),  # internal_batch_size
        tf.TensorSpec(shape=(), dtype=tf.int32),  # antialias_factor
        tf.TensorSpec(shape=(), dtype=tf.int32),  # num_aug
        tf.TensorSpec(shape=(), dtype=tf.bool),  # average_aug
        tf.TensorSpec(shape=(), dtype=tf.string),  # skeleton
        tf.TensorSpec(shape=(), dtype=tf.float32),  # detector_threshold
        tf.TensorSpec(shape=(), dtype=tf.float32),  # detector_nms_iou_threshold
        tf.TensorSpec(shape=(), dtype=tf.int32),  # max_detections
        tf.TensorSpec(shape=(), dtype=tf.bool),  # detector_flip_aug
        tf.TensorSpec(shape=(), dtype=tf.bool)])  # suppress_implausible_poses
    def detect_poses_batched(
            self, images, intrinsic_matrix=(UNKNOWN_INTRINSIC_MATRIX,),
            distortion_coeffs=(DEFAULT_DISTORTION,), extrinsic_matrix=(DEFAULT_EXTRINSIC_MATRIX,),
            world_up_vector=DEFAULT_WORLD_UP, default_fov_degrees=55, internal_batch_size=64,
            antialias_factor=1, num_aug=5, average_aug=True, skeleton='', detector_threshold=0.3,
            detector_nms_iou_threshold=0.7, max_detections=-1, detector_flip_aug=False,
            suppress_implausible_poses=True):
        boxes = self._get_boxes(
            images, detector_flip_aug, detector_nms_iou_threshold, detector_threshold,
            max_detections)
        return self._estimate_poses_batched(
            images, boxes, intrinsic_matrix, distortion_coeffs, extrinsic_matrix, world_up_vector,
            default_fov_degrees, internal_batch_size, antialias_factor, num_aug, average_aug,
            skeleton, suppress_implausible_poses)

    @tf.function(input_signature=[
        tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.uint8),  # images
        tf.RaggedTensorSpec(shape=(None, None, 4), ragged_rank=1, dtype=tf.float32),  # boxes
        tf.TensorSpec(shape=(None, 3, 3), dtype=tf.float32),  # intrinsic_matrix
        tf.TensorSpec(shape=(None, 5), dtype=tf.float32),  # distortion_coeffs
        tf.TensorSpec(shape=(None, 4, 4), dtype=tf.float32),  # extrinsic_matrix
        tf.TensorSpec(shape=(3,), dtype=tf.float32),  # world_up_vector
        tf.TensorSpec(shape=(), dtype=tf.float32),  # default_fov_degrees
        tf.TensorSpec(shape=(), dtype=tf.int32),  # internal_batch_size
        tf.TensorSpec(shape=(), dtype=tf.int32),  # antialias_factor
        tf.TensorSpec(shape=(), dtype=tf.int32),  # num_aug
        tf.TensorSpec(shape=(), dtype=tf.bool),  # average_aug
        tf.TensorSpec(shape=(), dtype=tf.string),  # skeleton
    ])
    def estimate_poses_batched(
            self, images, boxes, intrinsic_matrix=(UNKNOWN_INTRINSIC_MATRIX,),
            distortion_coeffs=(DEFAULT_DISTORTION,),
            extrinsic_matrix=(DEFAULT_EXTRINSIC_MATRIX,), world_up_vector=DEFAULT_WORLD_UP,
            default_fov_degrees=55, internal_batch_size=64, antialias_factor=1, num_aug=5,
            average_aug=True, skeleton=''):
        boxes = tf.concat([boxes, tf.ones_like(boxes[..., :1])], axis=-1)
        pred = self._estimate_poses_batched(
            images, boxes, intrinsic_matrix, distortion_coeffs, extrinsic_matrix, world_up_vector,
            default_fov_degrees, internal_batch_size, antialias_factor, num_aug, average_aug,
            skeleton, suppress_implausible_poses=False)
        del pred['boxes']
        return pred

    def _estimate_poses_batched(
            self, images, boxes, intrinsic_matrix, distortion_coeffs, extrinsic_matrix,
            world_up_vector, default_fov_degrees, internal_batch_size, antialias_factor, num_aug,
            average_aug, skeleton, suppress_implausible_poses):
        # Special case when zero boxes are provided or found
        # (i.e., all images images without person detections)
        # This must be explicitly handled, else the shapes don't work out automatically
        # for the TensorArray in _predict_in_batches.
        if tf.size(boxes) == 0:
            return self._predict_empty(images, num_aug, average_aug)

        n_images = tf.shape(images)[0]
        # If one intrinsic matrix is given, repeat it for all images
        if tf.shape(intrinsic_matrix)[0] == 1:
            # If intrinsic_matrix is not given, fill it in based on field of view
            if tf.reduce_all(intrinsic_matrix == -1):
                intrinsic_matrix = intrinsic_matrix_from_field_of_view(
                    default_fov_degrees, tf.shape(images)[1:3])
            intrinsic_matrix = tf.repeat(intrinsic_matrix, n_images, axis=0)

        # If one distortion coeff/extrinsic matrix is given, repeat it for all images
        if tf.shape(distortion_coeffs)[0] == 1:
            distortion_coeffs = tf.repeat(distortion_coeffs, n_images, axis=0)
        if tf.shape(extrinsic_matrix)[0] == 1:
            extrinsic_matrix = tf.repeat(extrinsic_matrix, n_images, axis=0)

        # Now repeat these camera params for each box
        n_box_per_image = boxes.row_lengths()
        intrinsic_matrix = tf.repeat(intrinsic_matrix, n_box_per_image, axis=0)
        distortion_coeffs = tf.repeat(distortion_coeffs, n_box_per_image, axis=0)

        # Up-vector in camera-space
        camspace_up = tf.einsum('c,bCc->bC', world_up_vector, extrinsic_matrix[..., :3, :3])
        camspace_up = tf.repeat(camspace_up, n_box_per_image, axis=0)

        # Set up the test-time augmentation parameters
        aug_gammas = tf.cast(tf.linspace(0.6, 1.0, num_aug), tf.float32)
        aug_angle_range = np.float32(np.deg2rad(FLAGS.rot_aug))
        if FLAGS.rot_aug_linspace_noend:
            aug_angles = linspace_noend(-aug_angle_range, aug_angle_range, num_aug)
        else:
            aug_angles = tf.linspace(-aug_angle_range, aug_angle_range, num_aug)
        aug_scales = tf.concat([
            linspace_noend(0.8, 1.0, num_aug // 2),
            tf.linspace(1.0, 1.1, num_aug - num_aug // 2)], axis=0)
        aug_should_flip = (tf.range(num_aug) - num_aug // 2) % 2 != 0
        aug_flipmat = tf.constant([[-1, 0, 0], [0, 1, 0], [0, 0, 1]], np.float32)
        aug_maybe_flipmat = tf.where(
            aug_should_flip[:, np.newaxis, np.newaxis], aug_flipmat, tf.eye(3))
        aug_rotmat = rotation_mat_zaxis(-aug_angles)
        aug_rotflipmat = aug_maybe_flipmat @ aug_rotmat

        # crops_flat, poses3dcam_flat = self._predict_in_batches(
        poses3d_flat = self._predict_in_batches(
            images, intrinsic_matrix, distortion_coeffs, camspace_up, boxes, internal_batch_size,
            aug_should_flip, aug_rotflipmat, aug_gammas, aug_scales, antialias_factor)

        # Project the 3D poses to get the 2D poses
        poses2d_flat_normalized = to_homogeneous(
            distort_points(project(poses3d_flat), distortion_coeffs))
        poses2d_flat = tf.einsum('bank,bjk->banj', poses2d_flat_normalized,
                                 intrinsic_matrix[..., :2, :])
        poses2d_flat = tf.ensure_shape(poses2d_flat, [None, None, self.joint_info.n_joints, 2])

        # Arrange the results back into ragged tensors
        poses3d = tf.RaggedTensor.from_row_lengths(poses3d_flat, n_box_per_image)
        poses2d = tf.RaggedTensor.from_row_lengths(poses2d_flat, n_box_per_image)
        # crops = tf.RaggedTensor.from_row_lengths(crops_flat, n_box_per_image)

        if suppress_implausible_poses:
            # Filter the resulting poses for individual plausibility to reduce false positives
            selected_indices = self._filter_poses(boxes, poses3d, poses2d)
            boxes, poses3d, poses2d = [
                tf.gather(x, selected_indices, batch_dims=1)
                for x in [boxes, poses3d, poses2d]]
            # crops = tf.gather(crops, selected_indices, batch_dims=1)

        # Convert to world coordinates
        extrinsic_matrix = tf.repeat(tf.linalg.inv(extrinsic_matrix), poses3d.row_lengths(), axis=0)
        poses3d = tf.RaggedTensor.from_row_lengths(
            tf.einsum(
                'bank,bjk->banj', to_homogeneous(poses3d.flat_values),
                extrinsic_matrix[..., :3, :]),
            poses3d.row_lengths())

        if skeleton != '':
            poses3d = self._get_skeleton(poses3d, skeleton)
            poses2d = self._get_skeleton(poses2d, skeleton)

        if average_aug:
            poses3d = tf.reduce_mean(poses3d, axis=-3)
            poses2d = tf.reduce_mean(poses2d, axis=-3)

        result = dict(boxes=boxes, poses3d=poses3d, poses2d=poses2d)
        # result['crops'] = crops
        return result

    def _get_boxes(
            self, images, detector_flip_aug, detector_nms_iou_threshold, detector_threshold,
            max_detections):
        if self.detector is None:
            n_images = tf.shape(images)[0]
            boxes = tf.RaggedTensor.from_row_lengths(
                tf.zeros(shape=(0, 5)), tf.zeros(n_images, tf.int64))
        else:
            boxes = self.detector.predict_multi_image(
                images, detector_threshold, detector_nms_iou_threshold, detector_flip_aug,
                detector_flip_aug and FLAGS.detector_flip_vertical_too)
            if max_detections > -1 and not tf.size(boxes) == 0:
                topk_indices = topk_indices_ragged(boxes[..., 4], max_detections)
                boxes = tf.gather(boxes, topk_indices, axis=1, batch_dims=1)
        return boxes

    def _predict_in_batches(
            self, images, intrinsic_matrix, distortion_coeffs, camspace_up, boxes,
            internal_batch_size, aug_should_flip, aug_rotflipmat, aug_gammas, aug_scales,
            antialias_factor):

        num_aug = tf.shape(aug_gammas)[0]
        boxes_per_batch = internal_batch_size // num_aug
        boxes_flat = boxes.flat_values
        image_id_per_box = boxes.value_rowids()

        # Gamma decoding for correct image rescaling later on
        images = (tf.cast(images, tf.float32) / np.float32(255)) ** 2.2

        if boxes_per_batch == 0:
            # Run all as a single batch
            return self._predict_single_batch(
                images, intrinsic_matrix, distortion_coeffs, camspace_up, boxes_flat,
                image_id_per_box, aug_rotflipmat, aug_should_flip, aug_scales, aug_gammas,
                antialias_factor)
        else:
            # Chunk the image crops into batches and predict them one by one
            n_total_boxes = tf.shape(boxes_flat)[0]
            n_batches = tf.cast(tf.math.ceil(n_total_boxes / boxes_per_batch), tf.int32)
            poses3d_batches = tf.TensorArray(
                tf.float32, size=n_batches, element_shape=(None, None, self.joint_info.n_joints, 3),
                infer_shape=False)
            # crop_batches = tf.TensorArray(
            #     tf.float32, size=n_batches,
            #     element_shape=(None, None, self.crop_side, self.crop_side, 3),
            #     infer_shape=False)

            for i in tf.range(n_batches):
                batch_slice = slice(i * boxes_per_batch, (i + 1) * boxes_per_batch)
                # crops, poses3d = self._predict_single_batch(
                poses3d = self._predict_single_batch(
                    images, intrinsic_matrix[batch_slice], distortion_coeffs[batch_slice],
                    camspace_up[batch_slice], boxes_flat[batch_slice],
                    image_id_per_box[batch_slice], aug_rotflipmat, aug_should_flip, aug_scales,
                    aug_gammas, antialias_factor)
                poses3d_batches = poses3d_batches.write(i, poses3d)
                # crop_batches = crop_batches.write(i, crops)

            # return crop_batches.concat(), poses3d_batches.concat()
            return poses3d_batches.concat()

    def _predict_single_batch(
            self, images, intrinsic_matrix, distortion_coeffs, camspace_up, boxes, image_ids,
            aug_rotflipmat, aug_should_flip, aug_scales, aug_gammas, antialias_factor):
        # Get crops and info about the transformation used to create them
        # Each has shape [num_aug, n_boxes, ...]
        crops, new_intrinsic_matrix, R = self._get_crops(
            images, intrinsic_matrix, distortion_coeffs, camspace_up, boxes, image_ids,
            aug_rotflipmat, aug_scales, aug_gammas, antialias_factor)

        # Flatten each and predict the pose with the crop model
        new_intrinsic_matrix_flat = tf.reshape(new_intrinsic_matrix, (-1, 3, 3))
        crops_flat = tf.reshape(crops, (-1, self.crop_side, self.crop_side, 3))
        poses_flat = self.crop_model.predict_multi(
            tf.cast(crops_flat, tf.float16),
            new_intrinsic_matrix_flat)

        # Unflatten the result
        num_aug = tf.shape(aug_should_flip)[0]
        poses = tf.reshape(poses_flat, [num_aug, -1, self.joint_info.n_joints, 3])

        # Reorder the joints for mirrored predictions (e.g., swap left and right wrist)
        left_right_swapped_poses = tf.gather(poses, self.joint_info.mirror_mapping, axis=-2)
        poses = tf.where(
            tf.reshape(aug_should_flip, [-1, 1, 1, 1]), left_right_swapped_poses, poses)

        # Transform the predictions back into the original camera space
        # We need to multiply by the inverse of R, but since we are using row vectors in `poses`
        # the inverse and transpose cancel out, leaving just R.
        poses_orig_camspace = poses @ R

        # Transpose to [n_boxes, num_aug, ...]
        # return (tf.transpose(crops, [1, 0, 2, 3, 4]),
        # tf.transpose(poses_orig_camspace, [1, 0, 2, 3]))
        return tf.transpose(poses_orig_camspace, [1, 0, 2, 3])

    def _get_crops(
            self, images, intrinsic_matrix, distortion_coeffs, camspace_up, boxes, image_ids,
            aug_rotflipmat, aug_scales, aug_gammas, antialias_factor):
        R_noaug, box_aug_scales = self._get_new_rotation_and_scale(
            intrinsic_matrix, distortion_coeffs, camspace_up, boxes)

        # How much we need to scale overall, taking scale augmentation into account
        # From here on, we introduce the dimension of augmentations
        crop_scales = aug_scales[:, tf.newaxis] * box_aug_scales[tf.newaxis, :]
        # Build the new intrinsic matrix
        n_box = tf.shape(boxes)[0]
        num_aug = tf.shape(aug_gammas)[0]
        new_intrinsic_matrix = tf.concat([
            tf.concat([
                # Top-left of original intrinsic matrix gets scaled
                intrinsic_matrix[tf.newaxis, :, :2, :2] * crop_scales[:, :, tf.newaxis, tf.newaxis],
                # Principal point is the middle of the new image size
                tf.fill((num_aug, n_box, 2, 1), self.crop_side / 2)], axis=3),
            tf.concat([
                # [0, 0, 1] as the last row of the intrinsic matrix:
                tf.zeros((num_aug, n_box, 1, 2), tf.float32),
                tf.ones((num_aug, n_box, 1, 1), tf.float32)], axis=3)], axis=2)
        R = aug_rotflipmat[:, tf.newaxis] @ R_noaug
        new_invprojmat = tf.linalg.inv(new_intrinsic_matrix @ R)

        # If we perform antialiasing through output scaling, we render a larger image first and then
        # shrink it. So we scale the homography first.
        if antialias_factor > 1:
            scaling_mat = corner_aligned_scale_mat(1 / tf.cast(antialias_factor, tf.float32))
            new_invprojmat = new_invprojmat @ scaling_mat

        crops = warp_images(
            images,
            intrinsic_matrix=tf.tile(intrinsic_matrix, [num_aug, 1, 1]),
            new_invprojmat=tf.reshape(new_invprojmat, [-1, 3, 3]),
            distortion_coeffs=tf.tile(distortion_coeffs, [num_aug, 1]),
            crop_scales=tf.reshape(crop_scales, [-1]) * tf.cast(antialias_factor, tf.float32),
            output_shape=(self.crop_side * antialias_factor, self.crop_side * antialias_factor),
            image_ids=tf.tile(image_ids, [num_aug]))

        # Downscale the result if we do antialiasing through output scaling
        if antialias_factor == 2:
            crops = tf.nn.avg_pool2d(crops, 2, 2, padding='VALID')
        elif antialias_factor == 4:
            crops = tf.nn.avg_pool2d(crops, 4, 4, padding='VALID')
        elif antialias_factor > 4:
            crops = tf.image.resize(
                crops, (self.crop_side, self.crop_side), tf.image.ResizeMethod.AREA)
        crops = tf.reshape(crops, [num_aug, n_box, self.crop_side, self.crop_side, 3])
        # The division by 2.2 cancels the original gamma decoding from earlier
        crops **= tf.reshape(aug_gammas / 2.2, [-1, 1, 1, 1, 1])
        return crops, new_intrinsic_matrix, R

    def _get_new_rotation_and_scale(self, intrinsic_matrix, distortion_coeffs, camspace_up, boxes):
        # Transform five points on each box: the center and the four side midpoints
        x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
        boxpoints_homog = to_homogeneous(tf.stack([
            tf.stack([x + w / 2, y + h / 2], axis=1),
            tf.stack([x + w / 2, y], axis=1),
            tf.stack([x + w, y + h / 2], axis=1),
            tf.stack([x + w / 2, y + h], axis=1),
            tf.stack([x, y + h / 2], axis=1)], axis=1))
        boxpoints_camspace = tf.einsum(
            'bpc,bCc->bpC', boxpoints_homog, tf.linalg.inv(intrinsic_matrix))
        boxpoints_camspace = to_homogeneous(
            undistort_points(boxpoints_camspace[..., :2], distortion_coeffs))

        # Create a rotation matrix that will put the box center to the principal point
        # and apply the augmentation rotation and flip, to get the new coordinate frame
        box_center_camspace = boxpoints_camspace[:, 0]
        R_noaug = get_new_rotation_matrix(forward_vector=box_center_camspace, up_vector=camspace_up)

        # Transform the side midpoints of the box to the new coordinate frame
        sidepoints_camspace = boxpoints_camspace[:, 1:5]
        sidepoints_new = project(tf.einsum(
            'bpc,bCc->bpC', sidepoints_camspace, intrinsic_matrix @ R_noaug))

        # Measure the size of the reprojected boxes
        vertical_size = tf.linalg.norm(sidepoints_new[:, 0] - sidepoints_new[:, 2], axis=-1)
        horiz_size = tf.linalg.norm(sidepoints_new[:, 1] - sidepoints_new[:, 3], axis=-1)
        box_size_new = tf.maximum(vertical_size, horiz_size)

        # How much we need to scale (zoom) to have the boxes fill out the final crop
        box_aug_scales = self.crop_side / box_size_new
        return R_noaug, box_aug_scales

    def _predict_empty(self, image, num_aug, average_aug):
        if average_aug:
            poses3d = tf.zeros(shape=(0, self.joint_info.n_joints, 3))
            poses2d = tf.zeros(shape=(0, self.joint_info.n_joints, 2))
        else:
            poses3d = tf.zeros(shape=(0, num_aug, self.joint_info.n_joints, 3))
            poses2d = tf.zeros(shape=(0, num_aug, self.joint_info.n_joints, 2))

        n_images = tf.shape(image)[0]
        poses3d = tf.RaggedTensor.from_row_lengths(poses3d, tf.zeros(n_images, tf.int64))
        poses2d = tf.RaggedTensor.from_row_lengths(poses2d, tf.zeros(n_images, tf.int64))
        boxes = tf.zeros(shape=(0, 5))
        boxes = tf.RaggedTensor.from_row_lengths(boxes, tf.zeros(n_images, tf.int64))

        result = dict(boxes=boxes, poses3d=poses3d, poses2d=poses2d)
        #     crops = tf.zeros(shape=(0, num_aug, self.crop_side, self.crop_side, 3))
        #     crops = tf.RaggedTensor.from_row_lengths(crops, tf.zeros(n_images, tf.int64))
        #     result['crops'] = crops
        return result

    def _filter_poses(self, boxes, poses3d, poses2d):
        poses3d_mean = tf.reduce_mean(poses3d, axis=-3)
        poses2d_mean = tf.reduce_mean(poses2d, axis=-3)

        plausible_mask_flat = tf.logical_and(
            tf.logical_and(
                is_pose_plausible(poses3d_mean.flat_values),
                are_augmentation_results_consistent(poses3d.flat_values)),
            is_pose_consistent_with_box(poses2d_mean.flat_values, boxes.flat_values))

        plausible_mask = tf.RaggedTensor.from_row_lengths(
            plausible_mask_flat, boxes.row_lengths())

        # Apply pose similarity-based non-maximum suppression to reduce duplicates
        nms_indices = tf.map_fn(
            fn=lambda args: pose_non_max_suppression(*args),
            elems=(poses3d_mean, boxes[..., 4], plausible_mask),
            fn_output_signature=tf.RaggedTensorSpec(shape=(None,), dtype=tf.int32))
        return nms_indices

    def _get_skeleton(self, poses, skeleton):
        # We must list all possibilities since we can't address the Python dictionary
        # `self.per_skeleton_indices` with the tf.Tensor `skeleton`.
        if skeleton == b'smpl_24':
            indices = self.per_skeleton_indices['smpl_24']
        elif skeleton == b'coco_19':
            indices = self.per_skeleton_indices['coco_19']
        elif skeleton == b'h36m_17':
            indices = self.per_skeleton_indices['h36m_17']
        elif skeleton == b'h36m_25':
            indices = self.per_skeleton_indices['h36m_25']
        elif skeleton == b'mpi_inf_3dhp_17':
            indices = self.per_skeleton_indices['mpi_inf_3dhp_17']
        elif skeleton == b'mpi_inf_3dhp_28':
            indices = self.per_skeleton_indices['mpi_inf_3dhp_28']
        elif skeleton == b'smpl+head_30':
            indices = self.per_skeleton_indices['smpl+head_30']
        else:
            indices = tf.range(122, dtype=tf.int32)

        return tf.gather(poses, indices, axis=-2)

    @tf.function(input_signature=[
        tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),  # image
        tf.TensorSpec(shape=(3, 3), dtype=tf.float32),  # intrinsic_matrix
        tf.TensorSpec(shape=(5,), dtype=tf.float32),  # distortion_coeffs
        tf.TensorSpec(shape=(4, 4), dtype=tf.float32),  # extrinsic_matrix
        tf.TensorSpec(shape=(3,), dtype=tf.float32),  # world_up_vector
        tf.TensorSpec(shape=(), dtype=tf.float32),  # default_fov_degrees
        tf.TensorSpec(shape=(), dtype=tf.int32),  # internal_batch_size
        tf.TensorSpec(shape=(), dtype=tf.int32),  # antialias_factor
        tf.TensorSpec(shape=(), dtype=tf.int32),  # num_aug
        tf.TensorSpec(shape=(), dtype=tf.bool),  # average_aug
        tf.TensorSpec(shape=(), dtype=tf.string),  # skeleton
        tf.TensorSpec(shape=(), dtype=tf.float32),  # detector_threshold
        tf.TensorSpec(shape=(), dtype=tf.float32),  # detector_nms_iou_threshold
        tf.TensorSpec(shape=(), dtype=tf.int32),  # max_detections
        tf.TensorSpec(shape=(), dtype=tf.bool),  # detector_flip_aug
        tf.TensorSpec(shape=(), dtype=tf.bool)  # suppress_implausible_poses
    ])
    def detect_poses(
            self, image, intrinsic_matrix=UNKNOWN_INTRINSIC_MATRIX,
            distortion_coeffs=DEFAULT_DISTORTION, extrinsic_matrix=DEFAULT_EXTRINSIC_MATRIX,
            world_up_vector=DEFAULT_WORLD_UP, default_fov_degrees=55, internal_batch_size=64,
            antialias_factor=1, num_aug=5, average_aug=True, skeleton='', detector_threshold=0.3,
            detector_nms_iou_threshold=0.7, max_detections=-1, detector_flip_aug=False,
            suppress_implausible_poses=True):
        images = image[tf.newaxis]
        intrinsic_matrix = intrinsic_matrix[tf.newaxis]
        distortion_coeffs = distortion_coeffs[tf.newaxis]
        extrinsic_matrix = extrinsic_matrix[tf.newaxis]
        result = self.detect_poses_batched(
            images, intrinsic_matrix, distortion_coeffs, extrinsic_matrix, world_up_vector,
            default_fov_degrees, internal_batch_size, antialias_factor, num_aug, average_aug,
            skeleton, detector_threshold, detector_nms_iou_threshold, max_detections,
            detector_flip_aug, suppress_implausible_poses)
        return tf.nest.map_structure(lambda x: tf.squeeze(x, 0), result)

    @tf.function(input_signature=[
        tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),  # image
        tf.TensorSpec(shape=(None, 4), dtype=tf.float32),  # boxes
        tf.TensorSpec(shape=(3, 3), dtype=tf.float32),  # intrinsic_matrix
        tf.TensorSpec(shape=(5,), dtype=tf.float32),  # distortion_coeffs
        tf.TensorSpec(shape=(4, 4), dtype=tf.float32),  # extrinsic_matrix
        tf.TensorSpec(shape=(3,), dtype=tf.float32),  # world_up_vector
        tf.TensorSpec(shape=(), dtype=tf.float32),  # default_fov_degrees
        tf.TensorSpec(shape=(), dtype=tf.int32),  # internal_batch_size
        tf.TensorSpec(shape=(), dtype=tf.int32),  # antialias_factor
        tf.TensorSpec(shape=(), dtype=tf.int32),  # num_aug
        tf.TensorSpec(shape=(), dtype=tf.bool),  # average_aug
        tf.TensorSpec(shape=(), dtype=tf.string),  # skeleton
    ])
    def estimate_poses(
            self, image, boxes, intrinsic_matrix=UNKNOWN_INTRINSIC_MATRIX,
            distortion_coeffs=DEFAULT_DISTORTION, extrinsic_matrix=DEFAULT_EXTRINSIC_MATRIX,
            world_up_vector=DEFAULT_WORLD_UP, default_fov_degrees=55, internal_batch_size=64,
            antialias_factor=1, num_aug=5, average_aug=True, skeleton=''):
        images = image[tf.newaxis]
        boxes = tf.RaggedTensor.from_tensor(boxes[tf.newaxis])
        intrinsic_matrix = intrinsic_matrix[tf.newaxis]
        distortion_coeffs = distortion_coeffs[tf.newaxis]
        extrinsic_matrix = extrinsic_matrix[tf.newaxis]
        result = self.estimate_poses_batched(
            images, boxes, intrinsic_matrix, distortion_coeffs, extrinsic_matrix, world_up_vector,
            default_fov_degrees, internal_batch_size, antialias_factor, num_aug, average_aug,
            skeleton)
        return tf.nest.map_structure(lambda x: tf.squeeze(x, 0), result)
Пример #15
0
        return encoding_layer

###################################################
# TF ENCODING HELPER FUNCTIONS
###################################################

@tf.function
def raged_lists_batch_to_multihot(ragged_lists_batch: tf.RaggedTensor, multihot_dim: int) -> tf.Tensor:
    """ Maps a batch of label indices to a batch of multi-hot ones """
    # TODO: Seems tf.one_hot supports ragged tensors, so try to remove to_tensor call
    t = ragged_lists_batch.to_tensor(-1) # Default value = -1 -> one_hot will not assign any one
    t = tf.one_hot( t , multihot_dim )
    t = tf.reduce_max( t , axis=1 )
    return t

@tf.function(input_signature=[tf.RaggedTensorSpec(shape=[None,None], dtype=tf.int64), tf.TensorSpec(shape=(), dtype=tf.bool)])
def pad_sequence_right( sequences_batch: tf.RaggedTensor, mask: bool) -> tf.Tensor:
    """ Pad sequences with zeros on right side """

    # Avoid sequences larger than sequence_length: Get last sequence_length of each sequence
    sequences_batch = sequences_batch[:,-settings.settings.sequence_length:]

    if mask:
        # Add one to indices, to reserve 0 index for padding
        sequences_batch += 1

    # Convert to dense, padding zeros to the right
    sequences_batch = sequences_batch.to_tensor(0, shape=[None, settings.settings.sequence_length])
    return sequences_batch

Пример #16
0
    def input_fn(self, data_file, n_repeat=1):
        batch_size = self.batch_size
        texts = list()
        mstr_sep_texts = list()
        all_label_strs, all_labels = list(), list()
        with open(data_file, encoding='utf-8') as f:
            for i, line in enumerate(f):
                x = json.loads(line)
                mstr = x['mention_span']
                text = '{} [MASK] such as {} {}'.format(
                    ' '.join(x['left_context_token']), mstr,
                    ' '.join(x['right_context_token']))
                # text = '{} {} {}'.format(
                #     ' '.join(x['left_context_token']), mstr, ' '.join(x['right_context_token']))
                mstr_sep_text = '{} [SEP] {} {} {}'.format(
                    mstr, ' '.join(x['left_context_token']), mstr,
                    ' '.join(x['right_context_token']))
                # print(text)
                texts.append(text)
                mstr_sep_texts.append(mstr_sep_text)
                labels = x['y_str']
                tids = [self.type_id_dict.get(t, -1) for t in labels]
                tids = [tid for tid in tids if tid > -1]
                # if i > 5:
                all_label_strs.append(labels)
                all_labels.append(tids)
                # if len(texts) >= 12:
                #     break
        print(len(texts), 'texts')

        def tok_id_seq_gen():
            text_ids, tok_id_seqs, tok_id_seqs_repeat = list(), list(), list()
            label_strs, y_vecs = list(), list()
            for _ in range(n_repeat):
                batch_id = 0
                for text_idx, text in enumerate(texts):
                    tokens = self.tokenizer.tokenize(text)
                    tokens_full = ['[CLS]'] + tokens + ['[SEP]']
                    tok_id_seq = self.tokenizer.convert_tokens_to_ids(
                        tokens_full)
                    tok_id_seqs.append(tok_id_seq)

                    fet_tokens = ['[CLS]'] + self.tokenizer.tokenize(
                        mstr_sep_texts[text_idx]) + ['[SEP]']
                    fet_tok_id_seq = self.tokenizer.convert_tokens_to_ids(
                        fet_tokens)
                    # tok_id_seq = np.array([len(text)], np.float32)
                    for _ in range(self.retriever_beam_size):
                        tok_id_seqs_repeat.append(fet_tok_id_seq)
                    text_ids.append(text_idx)
                    y_vecs.append(
                        to_one_hot(all_labels[text_idx], self.n_types))
                    label_strs.append(all_label_strs[text_idx])

                    if len(tok_id_seqs) >= batch_size:
                        # tok_id_seq_batch, input_mask = get_padded_bert_input(tok_id_seqs)
                        tok_id_seq_batch = tf.ragged.constant(tok_id_seqs)
                        tok_id_seqs_repeat_ragged = tf.ragged.constant(
                            tok_id_seqs_repeat)
                        # y_vecs_tensor = tf.concat(y_vecs)
                        # yield {'tok_id_seq_batch': tok_id_seq_batch, 'input_mask': input_mask}, y_vecs
                        yield {
                            'batch_id': batch_id,
                            'tok_id_seq_batch': tok_id_seq_batch,
                            'tok_id_seqs_repeat': tok_id_seqs_repeat_ragged,
                            'text_ids': text_ids,
                            'labels': tf.constant(label_strs),
                            'tmp': tf.constant([[1, 2], [3, 4]], tf.int32),
                        }, y_vecs
                        text_ids, tok_id_seqs, tok_id_seqs_repeat, y_vecs = list(
                        ), list(), list(), list()
                        label_strs = list()
                        batch_id += 1
                        # y_vecs = list()
                if len(tok_id_seqs) > 0:
                    # tok_id_seq_batch, input_mask = get_padded_bert_input(tok_id_seqs)
                    tok_id_seq_batch = tf.ragged.constant(tok_id_seqs)
                    tok_id_seqs_repeat_ragged = tf.ragged.constant(
                        tok_id_seqs_repeat)
                    # y_vecs_tensor = tf.concat(y_vecs)
                    # yield {'tok_id_seq_batch': tok_id_seq_batch, 'input_mask': input_mask}, y_vecs
                    yield {
                        'batch_id': batch_id,
                        'tok_id_seq_batch': tok_id_seq_batch,
                        'tok_id_seqs_repeat': tok_id_seqs_repeat_ragged,
                        'text_ids': text_ids,
                        'labels': tf.constant(label_strs),
                        'tmp': tf.constant([[1, 2], [3, 4]], tf.int32),
                    }, y_vecs

        # for v in iter(tok_id_seq_gen()):
        #     print(v)
        dataset = tf.data.Dataset.from_generator(
            tok_id_seq_gen,
            output_signature=(
                {
                    'batch_id':
                    tf.TensorSpec(shape=None, dtype=tf.int32),
                    'tok_id_seq_batch':
                    tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=1),
                    'tok_id_seqs_repeat':
                    tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=1),
                    'text_ids':
                    tf.TensorSpec(shape=None, dtype=tf.int32),
                    'labels':
                    tf.TensorSpec(shape=None, dtype=tf.string),
                    'tmp':
                    tf.TensorSpec(shape=None, dtype=tf.int32),
                    # 'block_emb': tf.TensorSpec(shape=block_emb_shape, dtype=tf.float32)},
                },
                tf.TensorSpec(shape=None, dtype=tf.float32)))

        return dataset
Пример #17
0
from tensorflow_transform import test_case

_TEST_BATCH_SIZES = [1, 10]
_TEST_DTYPES = [
    tf.int16,
    tf.int32,
    tf.int64,
    tf.float32,
    tf.float64,
    tf.string,
]

_TEST_TENSORS_TYPES = [
    (lambda dtype: tf.TensorSpec([None], dtype=dtype), tf.Tensor, []),
    (lambda dtype: tf.TensorSpec([None, 2], dtype=dtype), tf.Tensor, [2]),
    (lambda dtype: tf.RaggedTensorSpec([None, None], dtype=dtype),
     tf.RaggedTensor, [None]),
    (
        lambda dtype: tf.RaggedTensorSpec(  # pylint: disable=g-long-lambda
            [None, None, 2],
            dtype=dtype,
            ragged_rank=1),
        tf.RaggedTensor,
        [None, 2]),
]


class TF2UtilsTest(test_case.TransformTestCase):
    def test_strip_and_get_tensors_and_control_dependencies(self):
        @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int64)])
        def func(x):
Пример #18
0
_RAGGED_TENSOR_TEST_CASES = (_MakeRaggedTensorDTypesTestCases() + [
    dict(testcase_name="Simple",
         tensor_representation_textpb="""
        ragged_tensor {
          feature_path {
            step: "ragged_feature"
          }
          row_partition_dtype: INT32
        }
        """,
         record_batch=pa.RecordBatch.from_arrays([
             pa.array([[1], None, [2], [3, 4, 5], []],
                      type=pa.list_(pa.int64()))
         ], ["ragged_feature"]),
         expected_type_spec=tf.RaggedTensorSpec(tf.TensorShape([None, None]),
                                                tf.int64,
                                                ragged_rank=1,
                                                row_splits_dtype=tf.int32),
         expected_ragged_tensor=tf.compat.v1.ragged.RaggedTensorValue(
             values=np.asarray([1, 2, 3, 4, 5]),
             row_splits=np.asarray([0, 1, 1, 2, 5, 5]))),
    dict(testcase_name="3D",
         tensor_representation_textpb="""
        ragged_tensor {
          feature_path {
            step: "ragged_feature"
          }
          row_partition_dtype: INT32
        }
        """,
         record_batch=pa.RecordBatch.from_arrays([
             pa.array([[[1]], None, [[2]], [[3, 4], [5]], []],
Пример #19
0
def type_to_tf_structure(type_spec: computation_types.Type):
    """Returns nested `tf.data.experimental.Structure` for a given TFF type.

  Args:
    type_spec: A `computation_types.Type`, the type specification must be
      composed of only named tuples and tensors. In all named tuples that appear
      in the type spec, all the elements must be named.

  Returns:
    An instance of `tf.data.experimental.Structure`, possibly nested, that
    corresponds to `type_spec`.

  Raises:
    ValueError: if the `type_spec` is composed of something other than named
      tuples and tensors, or if any of the elements in named tuples are unnamed.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return tf.TensorSpec(type_spec.shape, type_spec.dtype)
    elif type_spec.is_struct():
        elements = structure.to_elements(type_spec)
        if not elements:
            return ()
        element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements]
        named = element_outputs[0][0] is not None
        if not all((e[0] is not None) == named for e in element_outputs):
            raise ValueError('Tuple elements inconsistently named.')
        if type_spec.python_container is None:
            if named:
                return collections.OrderedDict(element_outputs)
            else:
                return tuple(v for _, v in element_outputs)
        else:
            container_type = type_spec.python_container
            if (py_typecheck.is_named_tuple(container_type)
                    or py_typecheck.is_attrs(container_type)):
                return container_type(**dict(element_outputs))
            elif container_type is tf.RaggedTensor:
                flat_values = type_spec.flat_values
                nested_row_splits = type_spec.nested_row_splits
                ragged_rank = len(nested_row_splits)
                return tf.RaggedTensorSpec(
                    shape=tf.TensorShape([None] * (ragged_rank + 1)),
                    dtype=flat_values.dtype,
                    ragged_rank=ragged_rank,
                    row_splits_dtype=nested_row_splits[0].dtype,
                    flat_values_spec=None)
            elif container_type is tf.SparseTensor:
                # We can't generally infer the shape from the type of the tensors, but
                # we *can* infer the rank based on the shapes of `indices` or
                # `dense_shape`.
                if (type_spec.indices.shape is not None
                        and type_spec.indices.shape.dims[1] is not None):
                    rank = type_spec.indices.shape.dims[1]
                    shape = tf.TensorShape([None] * rank)
                elif (type_spec.dense_shape.shape is not None
                      and type_spec.dense_shape.shape.dims[0] is not None):
                    rank = type_spec.dense_shape.shape.dims[0]
                    shape = tf.TensorShape([None] * rank)
                else:
                    shape = None
                return tf.SparseTensorSpec(shape=shape,
                                           dtype=type_spec.values.dtype)
            elif named:
                return container_type(element_outputs)
            else:
                return container_type(e if e[0] is not None else e[1]
                                      for e in element_outputs)
    else:
        raise ValueError('Unsupported type {}.'.format(
            py_typecheck.type_string(type(type_spec))))
@author: landon
"""

import tensorflow as tf
import numpy as np
import json
import argparse
import os
import time

HAS_TITLE = [1,1,0,1,1,1,1,1,1]
N_TRACKS = [0,5,5,10,25,25,100,100,1]
IS_RANDOM = [0,0,0,0,0,1,0,1,0]

TENSOR_SPEC = tf.RaggedTensorSpec(tf.TensorShape([4, None]), tf.int32, 1, tf.int64)



def delete_tensor_by_indices(tensor,indices,n_tracks):
    idxs = tf.reshape(indices,(-1,1))
    mask = ~tf.scatter_nd(indices=idxs,updates=tf.ones_like(indices,dtype=tf.bool),shape=[n_tracks])
    return tf.boolean_mask(tensor,mask)

@tf.autograph.experimental.do_not_convert
def map_func(x):
    return {
        'track_ids':x[0],
        'title_ids':x[1],
        'n_tracks':x[2][0],
        }
Пример #21
0
class UtilsTest(tf.test.TestCase, parameterized.TestCase):
    @parameterized.parameters(
        ((None, 3), tf.int32),
        ((2, ), tf.float64),
        ((2, None), tf.float32),
    )
    def test_tensor_placeholder(self, shape, dtype):
        spec = tf.TensorSpec(shape=shape, dtype=dtype)
        placeholder = utils.placeholder(spec)
        assert tuple(placeholder.shape) == shape
        assert placeholder.dtype == dtype

    @parameterized.parameters(
        ((2, None), tf.int32, 1),
        ((None, None), tf.int64, 1),
        ((None, None, 2), tf.float64, 1),
        ((None, None, None), tf.float64, 2),
    )
    def test_ragged_placeholder(self, shape, dtype, ragged_rank):
        expected = tf.RaggedTensorSpec(shape,
                                       ragged_rank=ragged_rank,
                                       dtype=dtype)
        placeholder = utils.placeholder(expected)
        assert_specs_equal(utils.type_spec(placeholder), expected)
        assert tuple(placeholder.shape) == shape
        assert placeholder.dtype == dtype

    @parameterized.parameters(
        ((None, None), tf.int32),
        ((None, 3), tf.int32),
        ((2, ), tf.float64),
        ((2, None), tf.float32),
    )
    def test_sparse_placeholder(self, shape, dtype):
        spec = tf.SparseTensorSpec(shape, dtype)
        placeholder = utils.placeholder(spec)
        assert tuple(placeholder.shape) == shape
        assert placeholder.dtype == dtype

    # @parameterized.parameters(
    #     (tf.TensorSpec((2,), tf.float64), 3, None, tf.TensorSpec, (3, 2)),
    #     (tf.TensorSpec((None,), tf.float64), 3, False, tf.TensorSpec, (3, None)),
    #     (tf.TensorSpec((None,), tf.float64), None, False, tf.TensorSpec, (None, None)),
    #     (
    #         tf.TensorSpec((None,), tf.float64),
    #         None,
    #         True,
    #         tf.RaggedTensorSpec,
    #         (None, None),
    #     ),
    #     (
    #         tf.RaggedTensorSpec((None, None), tf.float32),
    #         None,
    #         None,
    #         tf.RaggedTensorSpec,
    #         (None, None, None),
    #     ),
    #     (
    #         tf.SparseTensorSpec((2, 3), tf.float64),
    #         None,
    #         None,
    #         tf.SparseTensorSpec,
    #         (None, 2, 3),
    #     ),
    # )
    # def test_batched_spec(self, spec, batch_size, ragged, expected_cls, expected_shape):
    #     actual = utils.batched_spec(spec, batch_size=batch_size, ragged=ragged)
    #     assert tuple(actual._shape) == expected_shape
    #     assert isinstance(actual, expected_cls)
    #     assert actual._dtype == spec._dtype

    @parameterized.parameters(
        (
            tf.keras.Input((3, ), batch_size=2, dtype=tf.float64),
            tf.TensorSpec((2, 3), tf.float64),
        ),
        (
            tf.keras.Input(
                shape=(None, ), batch_size=3, ragged=True, dtype=tf.float64),
            tf.RaggedTensorSpec((3, None), tf.float64),
        ),
        (
            tf.keras.Input(
                shape=(4, ), batch_size=3, sparse=True, dtype=tf.float64),
            tf.SparseTensorSpec((3, 4), tf.float64),
        ),
    )
    def test_type_spec(self, x, expected):
        assert_specs_equal(utils.type_spec(x), expected)
Пример #22
0
    def call(self, ground_truth_boxes, anchors):
        # ground_truth_box [n_gt_boxes, box_dim] or [batch_size, n_gt_boxes, box_dim]
        # anchor [n_anchors, box_dim]
        def iou(ground_truth_box, anchor):
            # [n_anchors, 1]
            y_min_anchors, x_min_anchors, y_max_anchors, x_max_anchors = tf.split(
                anchor, num_or_size_splits=4, axis=-1)
            # [n_gt_boxes, 1] or [batch_size, n_gt_boxes, 1]
            y_min_gt, x_min_gt, y_max_gt, x_max_gt = tf.split(
                ground_truth_box, num_or_size_splits=4, axis=-1)
            # [n_anchors]
            anchor_areas = tf.squeeze((y_max_anchors - y_min_anchors) *
                                      (x_max_anchors - x_min_anchors), [1])
            # [n_gt_boxes, 1] or [batch_size, n_gt_boxes, 1]
            gt_areas = (y_max_gt - y_min_gt) * (x_max_gt - x_min_gt)

            # [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors]
            max_y_min = tf.maximum(y_min_gt, tf.transpose(y_min_anchors))
            min_y_max = tf.minimum(y_max_gt, tf.transpose(y_max_anchors))
            intersect_heights = tf.maximum(
                tf.constant(0, dtype=ground_truth_box.dtype),
                (min_y_max - max_y_min))

            # [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors]
            max_x_min = tf.maximum(x_min_gt, tf.transpose(x_min_anchors))
            min_x_max = tf.minimum(x_max_gt, tf.transpose(x_max_anchors))
            intersect_widths = tf.maximum(
                tf.constant(0, dtype=ground_truth_box.dtype),
                (min_x_max - max_x_min))

            # [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors]
            intersections = intersect_heights * intersect_widths

            # [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors]
            unions = gt_areas + anchor_areas - intersections

            return tf.cast(tf.truediv(intersections, unions), tf.float32)

        if isinstance(ground_truth_boxes, tf.RaggedTensor):
            if anchors.shape.ndims == 2:
                return tf.map_fn(
                    lambda x: iou(x, anchors),
                    elems=ground_truth_boxes,
                    parallel_iterations=32,
                    back_prop=False,
                    fn_output_signature=tf.RaggedTensorSpec(dtype=tf.float32,
                                                            ragged_rank=0),
                )
            else:
                return tf.map_fn(
                    lambda x: iou(x[0], x[1]),
                    elems=[ground_truth_boxes, anchors],
                    parallel_iterations=32,
                    back_prop=False,
                    fn_output_signature=tf.RaggedTensorSpec(dtype=tf.float32,
                                                            ragged_rank=0),
                )
        if anchors.shape.ndims == 2:
            return iou(ground_truth_boxes, anchors)
        elif anchors.shape.ndims == 3:
            return tf.map_fn(
                lambda x: iou(x[0], x[1]),
                elems=[ground_truth_boxes, anchors],
                dtype=tf.float32,
                parallel_iterations=32,
                back_prop=False,
            )
Пример #23
0
class TensorToArrowTest(tf.test.TestCase, parameterized.TestCase):
    def _assert_tensor_alike_equal(self, left, right):
        self.assertIsInstance(left, type(right))
        if isinstance(left, (tf.SparseTensor, tf.compat.v1.SparseTensorValue)):
            self.assertAllEqual(left.values, right.values)
            self.assertAllEqual(left.indices, right.indices)
            self.assertAllEqual(left.dense_shape, right.dense_shape)
        else:
            self.assertAllEqual(left, right)

    @parameterized.named_parameters(*_CONVERT_TEST_CASES)
    def test_convert(
        self,
        type_specs,
        expected_schema,
        expected_tensor_representations,
        tensor_input,
        expected_record_batch,
        options=tensor_to_arrow.TensorsToRecordBatchConverter.Options()):
        def convert_and_check(tensors, test_values_conversion):
            converter = tensor_to_arrow.TensorsToRecordBatchConverter(
                type_specs, options)

            self.assertEqual(
                {f.name: f.type
                 for f in converter.arrow_schema()}, expected_schema,
                "actual: {}".format(converter.arrow_schema()))

            canonical_expected_tensor_representations = {}
            for n, r in expected_tensor_representations.items():
                if not isinstance(r, schema_pb2.TensorRepresentation):
                    r = text_format.Parse(r, schema_pb2.TensorRepresentation())
                canonical_expected_tensor_representations[n] = r

            self.assertEqual(canonical_expected_tensor_representations,
                             converter.tensor_representations())

            rb = converter.convert(tensors)
            self.assertLen(expected_record_batch, rb.num_columns)
            for i, column in enumerate(rb):
                expected = expected_record_batch[rb.schema[i].name]
                self.assertTrue(
                    column.equals(expected),
                    "{}: actual: {}, expected: {}".format(
                        rb.schema[i].name, column, expected))
            # Test that TensorAdapter(TensorsToRecordBatchConverter()) is identity.
            adapter = tensor_adapter.TensorAdapter(
                tensor_adapter.TensorAdapterConfig(
                    arrow_schema=converter.arrow_schema(),
                    tensor_representations=converter.tensor_representations()))
            adapter_output = adapter.ToBatchTensors(
                rb, produce_eager_tensors=not test_values_conversion)
            self.assertEqual(adapter_output.keys(), tensors.keys())
            for k in adapter_output.keys():
                if "value" not in k:
                    self._assert_tensor_alike_equal(adapter_output[k],
                                                    tensors[k])

        def ragged_tensor_to_value(tensor):
            if isinstance(tensor, tf.RaggedTensor):
                values = tensor.values
                return tf.compat.v1.ragged.RaggedTensorValue(
                    values=ragged_tensor_to_value(values),
                    row_splits=tensor.row_splits.numpy())
            else:
                return tensor.numpy()

        def convert_eager_to_value(tensor):
            if isinstance(tensor, tf.SparseTensor):
                return tf.compat.v1.SparseTensorValue(tensor.indices,
                                                      tensor.values,
                                                      tensor.dense_shape)
            elif isinstance(tensor, tf.Tensor):
                return tensor.numpy()
            elif isinstance(tensor, tf.RaggedTensor):
                return ragged_tensor_to_value(tensor)
            else:
                raise NotImplementedError(
                    "Only support converting SparseTensors, Tensors and RaggedTensors. "
                    "Got: {}".format(type(tensor)))

        if tf.__version__ >= "2":
            convert_and_check(tensor_input, test_values_conversion=False)

        if tf.executing_eagerly():
            values_input = {
                k: convert_eager_to_value(v)
                for k, v in tensor_input.items()
            }
        else:
            some_tensor = next(iter(tensor_input.values()))
            graph = (some_tensor.row_splits[0].graph if isinstance(
                some_tensor, tf.RaggedTensor) else some_tensor.graph)
            with tf.compat.v1.Session(graph=graph) as s:
                values_input = s.run(tensor_input)
        convert_and_check(values_input, test_values_conversion=True)

    def test_relaxed_varlen_sparse_tensor(self):
        # Demonstrates that TensorAdapter(TensorsToRecordBatchConverter()) is not
        # an identity if the second dense dimension of SparseTensor is not tight.
        type_specs = {"sp": tf.SparseTensorSpec([None, None], tf.int32)}
        sp = tf.compat.v1.SparseTensorValue(values=np.array([1, 2], np.int32),
                                            indices=[[0, 0], [2, 0]],
                                            dense_shape=[4, 2])
        if tf.__version__ >= "2":
            sp = tf.SparseTensor.from_value(sp)
        converter = tensor_to_arrow.TensorsToRecordBatchConverter(type_specs)
        rb = converter.convert({"sp": sp})
        adapter = tensor_adapter.TensorAdapter(
            tensor_adapter.TensorAdapterConfig(
                arrow_schema=converter.arrow_schema(),
                tensor_representations=converter.tensor_representations()))
        adapter_output = adapter.ToBatchTensors(
            rb, produce_eager_tensors=tf.__version__ >= "2")
        self.assertAllEqual(sp.values, adapter_output["sp"].values)
        self.assertAllEqual(sp.indices, adapter_output["sp"].indices)
        self.assertAllEqual(adapter_output["sp"].dense_shape, [4, 1])

    def test_unable_to_handle(self):
        with self.assertRaisesRegex(ValueError, "No handler found"):
            tensor_to_arrow.TensorsToRecordBatchConverter(
                {"sp": tf.SparseTensorSpec([None, None, None], tf.int32)})

        with self.assertRaisesRegex(ValueError, "No handler found"):
            tensor_to_arrow.TensorsToRecordBatchConverter(
                {"sp": tf.SparseTensorSpec([None, None], tf.bool)})

    def test_incompatible_type_spec(self):
        converter = tensor_to_arrow.TensorsToRecordBatchConverter(
            {"sp": tf.SparseTensorSpec([None, None], tf.int32)})
        sp_cls = (tf.SparseTensor
                  if tf.__version__ >= "2" else tf.compat.v1.SparseTensorValue)
        with self.assertRaisesRegex(TypeError, "Expected SparseTensorSpec"):
            converter.convert({
                "sp":
                sp_cls(indices=[[0, 1]],
                       values=tf.constant([0], dtype=tf.int64),
                       dense_shape=[4, 1])
            })

    @parameterized.named_parameters(*[
        dict(testcase_name="bool_value_type",
             spec=tf.RaggedTensorSpec(shape=[2, None, None],
                                      dtype=tf.bool,
                                      ragged_rank=2,
                                      row_splits_dtype=tf.int64)),
        dict(testcase_name="2d_leaf_value",
             spec=tf.RaggedTensorSpec(shape=[2, None, None],
                                      dtype=tf.int32,
                                      ragged_rank=1,
                                      row_splits_dtype=tf.int64)),
        dict(testcase_name="ragged_rank_less_than_one",
             spec=tf.RaggedTensorSpec(shape=[2],
                                      dtype=tf.int32,
                                      ragged_rank=0,
                                      row_splits_dtype=tf.int64)),
    ])
    def test_unable_to_handle_ragged(self, spec):
        with self.assertRaisesRegex(ValueError, "No handler found"):
            tensor_to_arrow.TensorsToRecordBatchConverter({"rt": spec})
Пример #24
0
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
  """A function constructs a complete tree given all the leaf nodes.

  The function takes a 1-D array representing the leaf nodes of a tree and the
  tree's arity, and constructs a complete tree by recursively summing the
  adjacent children to get the parent until reaching the root node. Because we
  assume a complete tree, if the number of leaf nodes does not divide arity, the
  leaf nodes will be padded with zeros.

  Args:
    leaf_nodes: A 1-D array storing the leaf nodes of the tree.
    arity: A `int` for the branching factor of the tree, i.e. the number of
      children for each internal node.

  Returns:
    `tf.RaggedTensor` representing the tree. For example, if
    `leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value
    should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way,
    `tree[layer][index]` can be used to access the node indexed by (layer,
    index) in the tree,
  """

  def pad_zero(leaf_nodes, size):
    paddings = tf.zeros(
        shape=(size - leaf_nodes.shape[0],), dtype=leaf_nodes.dtype)
    return tf.concat((leaf_nodes, paddings), axis=0)

  leaf_nodes_size = tf.constant(leaf_nodes.shape[0], dtype=tf.float32)
  num_layers = tf.math.ceil(
      tf.math.log(leaf_nodes_size) /
      tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1
  leaf_nodes = pad_zero(
      leaf_nodes, tf.math.pow(tf.cast(arity, dtype=tf.float32), num_layers - 1))

  def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor:
    return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1)

  # The following `tf.while_loop` constructs the tree from bottom up by
  # iteratively applying `_shrink_layer` to each layer of the tree. The reason
  # for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not
  # support auto-translation from python loop to tf loop when loop variables
  # contain a `RaggedTensor` whose shape changes across iterations.

  idx = tf.identity(num_layers)
  loop_cond = lambda i, h: tf.less_equal(2.0, i)

  def _loop_body(i, h):
    return [
        tf.add(i, -1.0),
        tf.concat(([_shrink_layer(h[0], arity)], h), axis=0)
    ]

  _, tree = tf.while_loop(
      loop_cond,
      _loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])],
      shape_invariants=[
          idx.get_shape(),
          tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1)
      ])

  return tree
Пример #25
0
# # all_vals = get_ragged_vals(5)
# for length in lens:
#     vals = list()
#     for _ in range(length):
#         vals.append(random.uniform(-1, 1))
#     blocks_list.append(vals)
blocks_list = [[3, 4], [5, 3, 8], [8]]


def data_gen():
    for vals in blocks_list:
        yield tf.ragged.constant([vals], dtype=tf.float32)

dataset = tf.data.Dataset.from_generator(
    data_gen,
    output_signature=tf.RaggedTensorSpec(ragged_rank=1, dtype=tf.float32))

for v in dataset:
    print(v)
# dataset = dataset.apply(
#     tf.data.experimental.dense_to_ragged_batch(2))
# for v in dataset.batch(4):
#     print(v)
#     # print(v[1].to_tensor())
#     vt = v[1].to_tensor()
#     print(vt)
#     print(tf.squeeze(vt))
# for v in dataset:
#     print(v)

# tf.data.experimental.save(dataset, '/data/hldai/data/tmp/tmp.tfdata', shard_func=lambda x: np.int64(0))
Пример #26
0
def input_fn():
    import json
    from locbert import tokenization

    batch_size = 4
    data_file = '/data/hldai/data/ultrafine/uf_data/crowd/test.json'
    type_vocab_file = '/data/hldai/data/ultrafine/uf_data/ontology/types.txt'
    reader_module_path = '/data/hldai/data/realm_data/cc_news_pretrained/bert'
    vocab_file = os.path.join(reader_module_path, 'assets/vocab.txt')
    tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case=True)

    types, type_id_dict = datautils.load_vocab_file(type_vocab_file)
    n_types = len(types)

    # texts = ['He is a teacher.',
    #          'He teaches his students.',
    #          'He is a lawyer.']
    texts = list()

    all_labels = list()
    with open(data_file, encoding='utf-8') as f:
        for i, line in enumerate(f):
            x = json.loads(line)
            text = '{} {} {}'.format(
                ' '.join(x['left_context_token']), x['mention_span'], ' '.join(x['right_context_token']))
            # print(text)
            texts.append(text)
            labels = x['y_str']
            tids = [type_id_dict.get(t, -1) for t in labels]
            tids = [tid for tid in tids if tid > -1]
            # if i > 5:
            all_labels.append(tids)
            if len(texts) >= 8:
                break
    print(len(texts), 'texts')

    def tok_id_seq_gen():
        tok_id_seqs = list()
        y_vecs = list()
        for i, text in enumerate(texts):
            tokens = tokenizer.tokenize(text)
            # print(tokens)
            tokens_full = ['[CLS]'] + tokens + ['[SEP]']
            tok_id_seq = tokenizer.convert_tokens_to_ids(tokens_full)
            # tok_id_seq = np.array([len(text)], np.float32)
            tok_id_seqs.append(tok_id_seq)
            y_vecs.append(to_one_hot(all_labels[i], n_types))
            if len(tok_id_seqs) >= batch_size:
                # tok_id_seq_batch, input_mask = get_padded_bert_input(tok_id_seqs)
                tok_id_seq_batch = tf.ragged.constant(tok_id_seqs)
                # y_vecs_tensor = tf.concat(y_vecs)
                yield {'tok_id_seq_batch': tok_id_seq_batch,
                       # 'input_mask': input_mask,
                       'vals': np.random.uniform(-1, 1, (3, 5))}, y_vecs
                tok_id_seqs = list()
                y_vecs = list()
        if len(tok_id_seqs) > 0:
            # tok_id_seq_batch, input_mask = get_padded_bert_input(tok_id_seqs)
            # y_vecs_tensor = tf.concat(y_vecs)
            tok_id_seq_batch = tf.ragged.constant(tok_id_seqs)
            yield {'tok_id_seq_batch': tok_id_seq_batch,
                   # 'input_mask': input_mask,
                   'vals': np.random.uniform(-1, 1, (3, 5))}, y_vecs

    # for v in iter(tok_id_seq_gen()):
    #     print(v)
    dataset = tf.data.Dataset.from_generator(
        tok_id_seq_gen,
        output_signature=(
            {
                'tok_id_seq_batch': tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=1),
                # 'tok_id_seq_batch': tf.TensorSpec(shape=None, dtype=tf.int32),
                # 'input_mask': tf.TensorSpec(shape=None, dtype=tf.int32),
                'vals': tf.TensorSpec(shape=None, dtype=tf.float32)
            },
            tf.TensorSpec(shape=None, dtype=tf.float32)))

    return dataset
class Pose3dEstimator(tf.Module):
    def __init__(self, model_path, antialias_factor=4):
        super().__init__()
        self.antialias_factor = antialias_factor
        self.crop_model = tf.saved_model.load(model_path)
        self.crop_side = 256
        self.joint_names = self.crop_model.joint_names
        self.joint_edges = self.crop_model.joint_edges
        joint_names = [b.decode('utf8') for b in self.joint_names.numpy()]
        self.joint_info = data.datasets3d.JointInfo(joint_names, self.joint_edges.numpy())

        self.__call__.get_concrete_function(
            tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.uint8),
            tf.TensorSpec(shape=(None, 3, 3), dtype=tf.float32),
            tf.RaggedTensorSpec(shape=(None, None, 4), ragged_rank=1, dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32),
            tf.TensorSpec(shape=(), dtype=tf.int32))

        self.__call__.get_concrete_function(
            tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
            tf.TensorSpec(shape=(3, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32),
            tf.TensorSpec(shape=(), dtype=tf.int32))

    @tf.function
    def __call__(self, image, intrinsic_matrix, boxes, internal_batch_size=64, n_aug=5):
        if image.shape.rank == 3:
            return self.predict_single_image(
                image, intrinsic_matrix, boxes, internal_batch_size, n_aug)
        else:
            return self.predict_multi_image(
                image, intrinsic_matrix, boxes, internal_batch_size, n_aug)

    @tf.function(input_signature=[
        tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
        tf.TensorSpec(shape=(3, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32),
        tf.TensorSpec(shape=(), dtype=tf.int32)])
    def predict_single_image(self, image, intrinsic_matrix, boxes, internal_batch_size=64, n_aug=5):
        if tf.size(boxes) == 0:
            return tf.zeros(shape=(0, self.joint_info.n_joints, 3))
        ragged_boxes = tf.RaggedTensor.from_tensor(boxes[np.newaxis])
        return self.predict_multi_image(
            image[np.newaxis], intrinsic_matrix[np.newaxis], ragged_boxes, internal_batch_size,
            n_aug)[0]

    @tf.function(input_signature=[
        tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.uint8),
        tf.TensorSpec(shape=(None, 3, 3), dtype=tf.float32),
        tf.RaggedTensorSpec(shape=(None, None, 4), ragged_rank=1, dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32),
        tf.TensorSpec(shape=(), dtype=tf.int32)])
    def predict_multi_image(self, image, intrinsic_matrix, boxes, internal_batch_size=64, n_aug=5):
        """Estimate 3D human poses in camera space for multiple bounding boxes specified
        for an image.
        """
        n_images = tf.shape(image)[0]
        if tf.size(boxes) == 0:
            # Special case for zero boxes provided
            result_flat = tf.zeros(shape=(0, self.joint_info.n_joints, 3))
            return tf.RaggedTensor.from_row_lengths(result_flat, tf.zeros(n_images, tf.int64))

        boxes_flat = boxes.flat_values
        n_box_per_image = boxes.row_lengths()
        image_id_per_box = boxes.value_rowids()
        n_total_boxes = tf.shape(boxes_flat)[0]
        # image = (tf.cast(image, tf.float32) / np.float32(255))# ** 2.2

        if tf.shape(intrinsic_matrix)[0] == 1:
            intrinsic_matrix = tf.repeat(intrinsic_matrix, n_images, axis=0)

        Ks = tf.repeat(intrinsic_matrix, n_box_per_image, axis=0)
        gammas = tf.cast(tf.linspace(0.6, 1.0, n_aug), tf.float16)
        angle_range = np.float32(np.deg2rad(25))
        angles = tf.linspace(-angle_range, angle_range, n_aug)
        scales = tf.concat([
            linspace_noend(0.8, 1.0, n_aug // 2),
            tf.linspace(1.0, 1.1, n_aug - n_aug // 2)], axis=0)
        flips = (tf.range(n_aug) - n_aug // 2) % 2 != 0
        flipmat = tf.constant([[-1, 0, 0], [0, 1, 0], [0, 0, 1]], np.float32)
        maybe_flip = tf.where(flips[:, np.newaxis, np.newaxis], flipmat, tf.eye(3))
        rotmat = rotation_mat_zaxis(-angles)

        crops_per_batch = internal_batch_size // n_aug

        if crops_per_batch == 0:
            # No batching
            results_flat = self.predict_single_batch(
                image, Ks, boxes_flat, image_id_per_box, n_aug, rotmat, maybe_flip, flips, scales,
                gammas)
            return tf.RaggedTensor.from_row_lengths(results_flat, n_box_per_image)

        n_batches = tf.cast(tf.math.ceil(n_total_boxes / crops_per_batch), tf.int32)
        result_batches = tf.TensorArray(
            tf.float32, size=n_batches, element_shape=(None, self.joint_info.n_joints, 3),
            infer_shape=False)

        for i in tf.range(n_batches):
            box_batch = boxes_flat[i * crops_per_batch:(i + 1) * crops_per_batch]
            image_ids = image_id_per_box[i * crops_per_batch:(i + 1) * crops_per_batch]
            K_batch = Ks[i * crops_per_batch:(i + 1) * crops_per_batch]
            poses = self.predict_single_batch(
                image, K_batch, box_batch, image_ids, n_aug, rotmat, maybe_flip, flips, scales,
                gammas)
            result_batches = result_batches.write(i, poses)

        results_flat = result_batches.concat()
        return tf.RaggedTensor.from_row_lengths(results_flat, n_box_per_image)

    def predict_single_batch(
            self, images, K, boxes, image_ids, n_aug, rotmat, maybe_flip, flips, scales, gammas):
        n_box = tf.shape(boxes)[0]
        center_points = boxes[:, :2] + boxes[:, 2:4] / 2
        box_center_camspace = transf(center_points - K[:, :2, 2], tf.linalg.inv(K[:, :2, :2]))
        box_center_camspace = tf.concat(
            [box_center_camspace, tf.ones_like(box_center_camspace[:, :1])], axis=1)

        new_z = box_center_camspace / tf.linalg.norm(box_center_camspace, axis=-1, keepdims=True)
        new_x = tf.stack([new_z[:, 2], tf.zeros_like(new_z[:, 2]), -new_z[:, 0]], axis=1)
        new_y = tf.linalg.cross(new_z, new_x)
        nonaug_R = tf.stack([new_x, new_y, new_z], axis=1)
        new_R = maybe_flip[:, np.newaxis] @ rotmat[:, np.newaxis] @ nonaug_R
        box_scales = self.crop_side / tf.reduce_max(boxes[:, 2:4], axis=-1)
        new_K_mid = (tf.reshape(scales, [-1, 1, 1, 1]) *
                     tf.reshape(box_scales, [1, -1, 1, 1]) *
                     tf.reshape(K[:, :2, :2], [1, -1, 2, 2]))
        intrinsic_matrix = tf.concat([
            tf.concat([new_K_mid, tf.fill((n_aug, n_box, 2, 1), self.crop_side / 2)], axis=3),
            tf.concat([tf.zeros((n_aug, n_box, 1, 2), tf.float32),
                       tf.ones((n_aug, n_box, 1, 1), tf.float32)], axis=3)], axis=2)
        new_proj_matrix = intrinsic_matrix @ new_R
        homography = K @ tf.linalg.inv(new_proj_matrix)
        intrinsic_matrix_flat = tf.reshape(intrinsic_matrix, [n_aug * n_box, 3, 3])
        homography = tf.reshape(homography, [n_aug, n_box, 3, 3])
        homography = tf.reshape(homography, [-1, 9])
        homography = homography[:, :8] / homography[:, 8:]

        if self.antialias_factor > 1:
            H = homography
            a = self.antialias_factor
            homography = tf.stack([H[:, 0] / a, H[:, 1] / a, H[:, 2] - (a - 1) / 2,
                                   H[:, 3] / a, H[:, 4] / a, H[:, 5] - (a - 1) / 2,
                                   H[:, 6] / a, H[:, 7] / a], axis=1)

        temp_side = self.crop_side * self.antialias_factor
        image_ids = tf.tile(image_ids, [n_aug])
        crops = perspective_transform(
            images, homography, (temp_side, temp_side), 'BILINEAR', image_ids)

        crops = tf.cast(crops, tf.float16) / 255

        if self.antialias_factor > 1:
            crops = tf.image.resize(
                crops, (self.crop_side, self.crop_side), method=tf.image.ResizeMethod.AREA,
                antialias=True)
        crops = tf.reshape(crops, [n_aug, n_box * self.crop_side, self.crop_side, 3])
        crops **= tf.reshape(gammas, [-1, 1, 1, 1])
        crops = tf.reshape(crops, [-1, self.crop_side, self.crop_side, 3])

        poses = self.crop_model(crops, intrinsic_matrix_flat)
        poses = tf.reshape(poses, [n_aug, -1, tf.shape(poses)[1], 3])
        left_right_swapped = tf.gather(poses, self.joint_info.mirror_mapping, axis=2)
        poses = tf.where(tf.reshape(flips, [-1, 1, 1, 1]), left_right_swapped, poses)
        poses_origspace = tf.einsum('...nk,...jk->...nj', poses, tf.linalg.inv(new_R))
        return tf.reduce_mean(poses_origspace, axis=0)
Пример #28
0
class Prediction(tf.Module):

    NOT_FOUND_INDEX = -1

    def __init__(self, model=None):

        self.item_labels_path = Labels.item_labels_path()
        self.customer_labels_path = Labels.customer_labels_path()

        if model:
            self.model: tf.keras.Model = model
        else:
            self.model: tf.keras.Model = tf.keras.models.load_model(
                settings.get_model_path('exported_model'))
            #print(">>>", self.model)
            self.model.summary()

        # Lookup table: Customer label -> Customer index
        self.customer_labels_lookup = tf.lookup.StaticHashTable(
            tf.lookup.TextFileInitializer(self.customer_labels_path,
                                          tf.string,
                                          tf.lookup.TextFileIndex.WHOLE_LINE,
                                          tf.int64,
                                          tf.lookup.TextFileIndex.LINE_NUMBER,
                                          delimiter=" "),
            Prediction.NOT_FOUND_INDEX)

        # Lookup table: Item label -> Item index
        self.item_labels_lookup = tf.lookup.StaticHashTable(
            tf.lookup.TextFileInitializer(self.item_labels_path,
                                          tf.string,
                                          tf.lookup.TextFileIndex.WHOLE_LINE,
                                          tf.int64,
                                          tf.lookup.TextFileIndex.LINE_NUMBER,
                                          delimiter=" "),
            Prediction.NOT_FOUND_INDEX)

        # Reverse lookup tables (item index -> item string label)
        self.item_indices_lookup = tf.lookup.StaticHashTable(
            tf.lookup.TextFileInitializer(self.item_labels_path,
                                          tf.int64,
                                          tf.lookup.TextFileIndex.LINE_NUMBER,
                                          tf.string,
                                          tf.lookup.TextFileIndex.WHOLE_LINE,
                                          delimiter=" "), "")

    @tf.function(
        input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.int32)])
    def post_process_items(self, batch_items_indices: tf.Tensor) -> tf.Tensor:

        # Do lookup
        return self.item_indices_lookup.lookup(
            tf.cast(batch_items_indices, tf.int64))

    @tf.function(input_signature=[
        tf.TensorSpec(shape=[None, None], dtype=tf.float32),
        tf.TensorSpec(shape=[], dtype=tf.int64)
    ])
    def _top_predictions_tensor(self, results, n_results):

        # Get most probable item indices
        sorted_indices = tf.argsort(results, direction='DESCENDING')
        top_indices = sorted_indices[:, 0:n_results]
        top_probabilities = tf.gather(results, top_indices, batch_dims=1)

        # Convert item indices to item labels
        #print("top_indices", top_indices)
        top_item_labels = self.post_process_items(top_indices)

        return top_item_labels, top_probabilities

    @staticmethod
    @tf.function(input_signature=[
        tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64),
        tf.TensorSpec(shape=[None, None], dtype=tf.float32)
    ])
    def _remove_input_items_from_prediction(batch_item_indices, result):
        # batch_item_indices is a ragged with input indices for each batch row. Ex [ [0] , [1, 2] ]
        batch_item_indices = batch_item_indices.to_tensor(
            -1)  # -> [ [0,-1] , [1, 2] ]
        #print(batch_item_indices, batch_item_indices.shape[0])

        # Convert batch_item_indices row indices to (row,column) indices
        row_indices = tf.range(0,
                               tf.shape(batch_item_indices)[0],
                               dtype=tf.int64)  # -> [ 0, 1 ]
        row_indices = tf.repeat(
            row_indices,
            [tf.shape(batch_item_indices)[1]])  # -> [ 0, 0, 1, 1 ]
        #print(">>>", batch_item_indices)
        batch_item_indices = tf.reshape(batch_item_indices,
                                        shape=[-1])  # -> [ 0, -1, 1, 2 ]
        batch_item_indices = tf.stack(
            [row_indices, batch_item_indices],
            axis=1)  # -> [ [0,0] , [0,-1], [1,1], [1,2] ]

        # batch_item_indices.to_tensor(-1) added -1's to pad the matrix. Remove these indices
        # Needed according to tf.tensor_scatter_nd_update doc. (it will fail in CPU execution, if there are out of bound indices)
        # Get indices without -1's:
        gather_idxs = tf.where(
            batch_item_indices[:, 1] != -1)  # -> [[0], [2], [3]]
        batch_item_indices = tf.gather_nd(
            batch_item_indices, gather_idxs)  # -> [ [0,0] , [1,1], [1,2] ]

        # To remove input indices, we will set a probability -1 in their indices
        updates = tf.repeat(
            -1.0,
            tf.shape(batch_item_indices)[0])  # -> [ -1, -1, -1 ]

        # Assign -1's to the input indices:
        return tf.tensor_scatter_nd_update(result, batch_item_indices, updates)

    @staticmethod
    @tf.function(input_signature=[
        tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64),
        tf.TensorSpec(shape=(), dtype=tf.int64)
    ])
    def count_non_equal(batch: tf.RaggedTensor, value: tf.Tensor) -> tf.Tensor:
        """ Returns the count of elements in 'batch'' distincts to 'value' on each batch row """
        elements_equal_to_value = tf.not_equal(batch, value)
        as_ints = tf.cast(elements_equal_to_value, tf.int64)
        count = tf.reduce_sum(as_ints, axis=1)
        return count

    @staticmethod
    @tf.function(input_signature=[
        tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64),
        tf.TensorSpec(shape=(), dtype=tf.int64)
    ])
    def remove_not_found_index(batch_item_indices: tf.RaggedTensor,
                               not_found_index: tf.Tensor) -> tf.RaggedTensor:
        # Count non -1's on each row
        found_counts_per_row = Prediction.count_non_equal(
            batch_item_indices, not_found_index)
        #print("found_counts_per_row", found_counts_per_row)

        # Get non -1 values batch_item_indices from flat values
        flat_values = batch_item_indices.flat_values
        mask = tf.not_equal(flat_values, not_found_index)
        #print("mask", mask )
        flat_found_indices = tf.boolean_mask(flat_values, mask)
        #print("flat_found_indices", flat_found_indices )
        return tf.RaggedTensor.from_row_lengths(flat_found_indices,
                                                found_counts_per_row)

    @tf.function(input_signature=[
        tf.RaggedTensorSpec(shape=[None, None], dtype=tf.string)
    ])
    def preprocess_items(self, batch_item_labels: tf.RaggedTensor):
        #print("batch_item_labels ->", batch_item_labels)

        # Do lookups item label -> index, -1 if not found
        batch_item_indices = tf.ragged.map_flat_values(
            self.item_labels_lookup.lookup, batch_item_labels)
        #print( "batch_item_indices", batch_item_indices )

        # Remove -1's:
        batch_item_indices = Prediction.remove_not_found_index(
            batch_item_indices, Prediction.NOT_FOUND_INDEX)
        #print( "batch_item_indices >>>", batch_item_indices )

        # Remove duplicated items
        # TODO: UNIMPLEMENTED. tf.unique works only with 1D dimensions...
        # batch_item_indices = tf.map_fn(lambda x: tf.unique(x), batch_item_indices.to_tensor(-1) )
        # print("unique", tf.unique(batch_item_indices))

        return batch_item_indices

    @tf.function
    def prepreprocess_customers(self, batch_customer_labels: tf.Tensor):

        # Get customer label
        batch_customer_indices = self.customer_labels_lookup.lookup(
            batch_customer_labels)

        # Get the "UNKNOWN" customer index
        unknown_customer_index = self.customer_labels_lookup.lookup(
            tf.constant(Labels.UNKNOWN_LABEL, dtype=tf.string))

        # Replace -1 by "UNKNOWN" index
        update_indices = tf.where(
            tf.math.equal(batch_customer_indices, Prediction.NOT_FOUND_INDEX))
        batch_customer_indices = tf.tensor_scatter_nd_update(
            batch_customer_indices, update_indices,
            tf.repeat(unknown_customer_index, tf.size(update_indices)))
        return batch_customer_indices

    @tf.function
    def _run_model_and_postprocess(self, batch_item_indices,
                                   batch_customer_indices, n_results):

        # Run the model
        batch = (batch_item_indices, batch_customer_indices)
        result = self.model(batch, training=False)
        # Convert logits to probabilities
        result = tf.nn.softmax(result)

        if settings.model_type == ModelType.GPT:

            # GPT return probabilities for each sequence timestep.
            # We need the probabilities for the LAST input timested.
            # Batch is a ragged tensors (ex. [[1], [2,3]]), so, get the probabilities for last element position in each
            # batch sequence:
            indices = batch_item_indices.row_lengths()
            indices -= 1
            #print("indices", indices)
            result = tf.gather(result, indices, batch_dims=1)
            #print("probs", probs)

        # Set result[ batch_item_indices ] = -1.0:
        #print("batch_item_indices >>>***", batch_item_indices)
        result = Prediction._remove_input_items_from_prediction(
            batch_item_indices, result)

        # Get most probable n results
        return self._top_predictions_tensor(result, n_results)

    @tf.function
    def _run_model_filter_empty_sequences(self,
                                          batch_item_indices: tf.RaggedTensor,
                                          batch_customer_indices, n_results):

        # Check if there are empty sequences
        sequences_lenghts = batch_item_indices.row_lengths()
        non_empty_seq_count = tf.math.count_nonzero(sequences_lenghts)
        n_sequences = tf.shape(sequences_lenghts, tf.int64)[0]

        #print(">>>", non_empty_seq_count, n_results)
        if non_empty_seq_count == 0:
            # All sequences are empty
            label_predictions = tf.zeros([n_sequences, n_results],
                                         dtype=tf.string)
            probs_predictions = tf.zeros([n_sequences, n_results],
                                         dtype=tf.float32)
            return (label_predictions, probs_predictions)

        elif non_empty_seq_count >= n_sequences:
            # There are no empty sequences. Run the model with the full batch
            return self._run_model_and_postprocess(batch_item_indices,
                                                   batch_customer_indices,
                                                   n_results)
        else:
            # There are some empty sequences
            # Model will fail if a sequence is empty, and it seems it's the expected behaviour: Do not feed empty sequences
            # Get non empty sequences mask
            non_empty_mask = tf.math.greater(sequences_lenghts, 0)

            # Get non empty sequences
            non_empty_sequences: tf.RaggedTensor = tf.ragged.boolean_mask(
                batch_item_indices, non_empty_mask)
            non_empty_customers = tf.boolean_mask(batch_customer_indices,
                                                  non_empty_mask)

            # Run model
            label_predictions, probs_predictions = self._run_model_and_postprocess(
                non_empty_sequences, non_empty_customers, n_results)

            # Merge real predictions with empty predictions for empty sequences:
            indices = tf.where(non_empty_mask)
            final_shape = [n_sequences, n_results]
            label_predictions = tf.scatter_nd(indices, label_predictions,
                                              final_shape)
            #print(label_predictions)
            probs_predictions = tf.scatter_nd(indices, probs_predictions,
                                              final_shape)
            #print(probs_predictions)
            return (label_predictions, probs_predictions)

    @tf.function(input_signature=[
        tf.RaggedTensorSpec(shape=[None, None], dtype=tf.string),
        tf.TensorSpec(shape=[None], dtype=tf.string),
        tf.TensorSpec(shape=[], dtype=tf.int64)
    ])
    def run_model_prediction(self, batch_item_labels, batch_customer_labels,
                             n_results):

        # Convert labels to indices
        #print(">>> batch_item_labels", batch_item_labels)
        batch_item_indices = self.preprocess_items(batch_item_labels)
        #print(">>> batch_item_indices", batch_item_indices)
        #print(">>> batch_customer_labels", batch_customer_labels)
        batch_customer_indices = self.prepreprocess_customers(
            batch_customer_labels)
        #print(">>> batch_customer_indices", batch_customer_indices)

        # Run the model
        return self._run_model_filter_empty_sequences(batch_item_indices,
                                                      batch_customer_indices,
                                                      n_results)

    @tf.function(input_signature=[
        tf.TensorSpec(shape=[None], dtype=tf.string),
        tf.TensorSpec(shape=[], dtype=tf.string),
        tf.TensorSpec(shape=[], dtype=tf.int64)
    ])
    def run_model_single(self, item_labels, customer_label, n_results):
        # Convert single example to batch
        batch_item_labels = tf.expand_dims(item_labels, axis=0)
        ragged_batch_item_labels = tf.RaggedTensor.from_tensor(
            batch_item_labels)
        batch_customer_labels = tf.expand_dims(customer_label, axis=0)

        predicted_item_labels, predicted_item_probs = self.run_model_prediction(
            ragged_batch_item_labels, batch_customer_labels, n_results)

        # Remove batch dimension
        predicted_item_labels = tf.squeeze(predicted_item_labels, 0)
        predicted_item_probs = tf.squeeze(predicted_item_probs, 0)
        return (predicted_item_labels, predicted_item_probs)

    def predict_batch(self, transactions: List[Transaction],
                      n_items_result: int) -> List:

        # Setup batch
        batch = Transaction.to_net_inputs_batch(transactions)

        #print("*** batch", batch)
        batch = (tf.ragged.constant(batch[0],
                                    dtype=tf.string), tf.constant(batch[1]))
        #print("*** batch", batch)

        results = self.run_model_prediction(batch[0], batch[1], n_items_result)
        #print("raw", results)

        results = (results[0].numpy(), results[1].numpy())
        return results

    def predict_single(self, transaction: Transaction,
                       n_items_result: int) -> List[Tuple[str, float]]:
        results = self.predict_batch([transaction], n_items_result)
        return (results[0][0], results[1][0])
Пример #29
0
class Prediction(tf.Module):
    """ Run and process candidates generation model predictions """

    # Directory in "models/[MODEL]/" dir where to export the model
    CHECKPOINTS_DIR = 'checkpoints'

    # Directory in "models/[MODEL]/" dir where to export the model
    EXPORTED_MODEL_DIR = 'exported_model'

    def __init__(self, model: tf.keras.Model = None):

        if model:
            self.model: tf.keras.Model = model
        else:
            self.model: tf.keras.Model = tf.keras.models.load_model(
                settings.settings.get_model_path(
                    False, Prediction.EXPORTED_MODEL_DIR))
            #self.model.summary()

    @tf.function(input_signature=[
        tf.TensorSpec(shape=[None, None], dtype=tf.float32),
        tf.TensorSpec(shape=[], dtype=tf.int64)
    ])
    def _top_predictions_tensor(self, results,
                                n_results) -> Tuple[tf.Tensor, tf.Tensor]:
        """ Returns a tuple (items indices, items probabilities) with most probable items, up to "n_results" """
        # Get most probable item indices
        sorted_indices = tf.argsort(results, direction='DESCENDING')
        top_indices = sorted_indices[:, 0:n_results]
        top_probabilities = tf.gather(results, top_indices, batch_dims=1)
        return top_indices, top_probabilities

    @staticmethod
    @tf.function(input_signature=[
        tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64),
        tf.TensorSpec(shape=[None, None], dtype=tf.float32)
    ])
    def _remove_input_items_from_prediction(batch_item_indices, result):
        """ Remove input items from predictions, as we don't want to predict them. It is done setting their
            probabilities to -1

            Args:
                batch_item_indices: Batch input item indices 
                result: Batch predicted probabilities
            Returns: 
                Batch predicted probabilities with input items probs. set to -1
        """
        # batch_item_indices is a ragged with input indices for each batch row. Ex [ [0] , [1, 2] ]
        batch_item_indices = batch_item_indices.to_tensor(
            -1)  # -> [ [0,-1] , [1, 2] ]
        #print(batch_item_indices, batch_item_indices.shape[0])

        # Convert batch_item_indices row indices to (row,column) indices
        row_indices = tf.range(0,
                               tf.shape(batch_item_indices)[0],
                               dtype=tf.int64)  # -> [ 0, 1 ]
        row_indices = tf.repeat(
            row_indices,
            [tf.shape(batch_item_indices)[1]])  # -> [ 0, 0, 1, 1 ]
        #print(">>>", batch_item_indices)
        batch_item_indices = tf.reshape(batch_item_indices,
                                        shape=[-1])  # -> [ 0, -1, 1, 2 ]
        batch_item_indices = tf.stack(
            [row_indices, batch_item_indices],
            axis=1)  # -> [ [0,0] , [0,-1], [1,1], [1,2] ]

        # batch_item_indices.to_tensor(-1) added -1's to pad the matrix. Remove these indices
        # Needed according to tf.tensor_scatter_nd_update doc. (it will fail in CPU execution, if there are out of bound indices)
        # Get indices without -1's:
        gather_idxs = tf.where(
            batch_item_indices[:, 1] != -1)  # -> [[0], [2], [3]]
        batch_item_indices = tf.gather_nd(
            batch_item_indices, gather_idxs)  # -> [ [0,0] , [1,1], [1,2] ]

        # To remove input indices, we will set a probability -1 in their indices
        updates = tf.repeat(
            -1.0,
            tf.shape(batch_item_indices)[0])  # -> [ -1, -1, -1 ]

        # Assign -1's to the input indices:
        return tf.tensor_scatter_nd_update(result, batch_item_indices, updates)

    @tf.function
    def _run_model_and_postprocess(self, inputs_batch,
                                   item_indices_feature_name, n_results):

        # Run the model
        result = self.model(inputs_batch, training=False)
        # Convert logits to probabilities
        result = tf.nn.softmax(result)

        # Label indices feature values
        batch_item_indices = inputs_batch[item_indices_feature_name]

        if settings.settings.model_type == settings.ModelType.GPT:

            # GPT return probabilities for each sequence timestep.
            # We need the probabilities for the LAST input timested.
            # Batch is a ragged tensors (ex. [[1], [2,3]]), so, get the probabilities for last element position in each
            # batch sequence:
            indices = batch_item_indices.row_lengths()
            indices -= 1
            #print("indices", indices)
            result = tf.gather(result, indices, batch_dims=1)
            #print("probs", probs)

        # Set result[ batch_item_indices ] = -1.0:
        result = Prediction._remove_input_items_from_prediction(
            batch_item_indices, result)

        # Get most probable n results
        return self._top_predictions_tensor(result, n_results)

    def _preprocess_transactions(
        self, transactions: List[Transaction]
    ) -> Tuple[List[Transaction], List[int]]:
        """ Remove unknown items and empty transactions.

            Returns: Non empty transactions, with labels replaced by indices, and indices of empty transactions in
                     the original transactions list
        """
        result = [], empty_sequences_idxs = []
        for idx, trn in enumerate(transactions):
            trn = trn.replace_labels_by_indices()
            # Now, trn item sequence contains -1 for unknown items. Remove sequence feature for these items
            trn = trn.remove_unknown_item_indices()
            if trn.sequence_length() == 0:
                # Empty sequence. If we feed it to the model, it can fail (rnn layers)
                empty_sequences_idxs.append(idx)
            else:
                result.append(trn)

        return result, empty_sequences_idxs

    def _transactions_to_model_inputs(
            self, transactions: List[Transaction]) -> Dict[str, tf.Tensor]:

        # Prepare dictionary with features names
        inputs_dict = {
            feature.name: []
            for feature in settings.settings.features
        }

        # Concatenate transaction values for each feature
        for trn in transactions:
            # Use labels indices instead raw values
            for feature in settings.settings.features:
                inputs_dict[feature.name].append(trn[feature.name])

        # To tensor values
        for feature in settings.settings.features:
            if feature.sequence:
                inputs_dict[feature.name] = tf.ragged.constant(
                    inputs_dict[feature.name], dtype=tf.int64)
            else:
                inputs_dict[feature.name] = tf.constant(
                    inputs_dict[feature.name], dtype=tf.int64)

        # Keras inputs are mapped by position, so return result as a list
        #return [ inputs_dict[feature.name] for feature in settings.settings.features ]
        return inputs_dict

    def predict_raw_batch(
            self, transactions: List[Transaction], n_items_result: int
    ) -> Tuple[np.ndarray, np.ndarray, List[tf.Tensor]]:
        # Transactions to keras inputs
        batch = self._transactions_to_model_inputs(transactions)

        # Run prediction
        if len(batch) > 0:
            top_item_indices, top_probabilities = self._run_model_and_postprocess(
                batch, settings.settings.features.item_label_feature,
                n_items_result)
            top_item_indices = top_item_indices.numpy()
            top_probabilities = top_probabilities.numpy()
        else:
            top_item_indices = np.array([], dtype=int)
            top_probabilities = np.array([], dtype=float)

        return top_item_indices, top_probabilities, batch

    def predict_batch(self,
                      transactions: List[Transaction],
                      n_items_result: int,
                      preprocess=True) -> Tuple[np.ndarray, np.ndarray]:

        if preprocess:
            # Convert labels to indices. Remove unknown items and empty sequences
            transactions, empty_sequences_idxs = self._preprocess_transactions(
                transactions)
        else:
            empty_sequences_idxs = []

        top_item_indices, top_probabilities = self.predict_raw_batch(
            transactions, n_items_result)

        # Convert item indices to labels
        top_item_labels = settings.settings.features.items_sequence_feature(
        ).labels.indices_to_labels(top_item_indices)

        # Insert fake results for empty sequences
        if len(empty_sequences_idxs) > 0:
            empty_labels = np.zeros([n_items_result], dtype=str)
            empty_probs = np.zeros([n_items_result], dtype=float)
            # Inserts cannot be done all at same time, np.insert expects indices to the array before insert, and we don't have them
            for idx in empty_sequences_idxs:
                top_item_labels = np.insert(top_item_labels,
                                            idx,
                                            empty_labels,
                                            axis=0)
                top_probabilities = np.insert(top_probabilities,
                                              idx,
                                              empty_probs,
                                              axis=0)

        return top_item_labels, top_probabilities

    def predict_single(self,
                       transaction: Transaction,
                       n_items_result: int,
                       preprocess=True):
        top_item_labels, top_probabilities = self.predict_batch([transaction],
                                                                n_items_result,
                                                                preprocess)
        return top_item_labels[0], top_probabilities[0]
Пример #30
0
class EAMPotential(tf.keras.Model):
    """Base class for all EAM based potentials"""
    def __init__(self,
                 atom_types,
                 build_forces=False,
                 preprocessed_input=False,
                 cutoff=10.,
                 method='partition_stitch',
                 force_method='old'):
        """
        atom_types: list of str used to construct all the model layers. Note
                                the order is important as it determines the
                                corresponding integer representation that is
                                fed to the model.
        build_forces: boolean determines if the graph for evaluating forces
                              should be constructed.
        preprocessed_input: boolean switches between input of atomic positions
                                    and interatomic distances. Preprocessed
                                    is currently faster as the calculation of
                                    interatomic distances in the model can not
                                    be parallelized.
        cutoff: float if preprocessed_input=False this determines the global
                      cutoff until which interatomic distances are calculated.
        method: str switch for different implementations of the main body
                    options are:
                    - 'partition_stitch' default and probably savest choice
                    - 'where' uses branchless programming can be significantly
                              faster as it is easier to parallelize
                    - 'gather_scatter' old implementation, no longer maintained
        force_method: str switch for debugging, options are:
                    - 'old' default
                    - 'new'
        """
        super().__init__()

        self.atom_types = atom_types
        self.type_dict = {}
        for i, t in enumerate(atom_types):
            self.type_dict[t] = i
        self.atom_pair_types = []
        self.pair_type_dict = {}
        for i, (t1, t2) in enumerate(
                combinations_with_replacement(self.atom_types, 2)):
            t = ''.join([t1, t2])
            self.atom_pair_types.append(t)
            self.pair_type_dict[t] = i
        self.build_forces = build_forces
        self.preprocessed_input = preprocessed_input
        self.method = method
        self.force_method = force_method

        inputs = {
            'types': tf.keras.Input(shape=(None, 1),
                                    ragged=True,
                                    dtype=tf.int32)
        }
        if self.preprocessed_input:
            inputs['pair_types'] = tf.keras.Input(shape=(None, None, 1),
                                                  ragged=True,
                                                  dtype=inputs['types'].dtype)
            inputs['distances'] = tf.keras.Input(shape=(None, None, 1),
                                                 ragged=True)
            if self.build_forces:
                # [batchsize] x N x (N-1) x N x 3
                inputs['dr_dx'] = tf.keras.Input(shape=(None, None, None, 3),
                                                 ragged=True)
        else:
            self.cutoff = cutoff
            inputs['positions'] = tf.keras.Input(shape=(None, 3), ragged=True)

        self._set_inputs(inputs)

        (self.pair_potentials, self.pair_rho, self.embedding_functions,
         self.offsets) = self.build_functions()

    def build_functions(self):
        pass

    def call(self, inputs):
        if self.preprocessed_input:
            return self.shallow_call(inputs)
        return self.deep_call(inputs)

    def deep_call(self, inputs):
        types = inputs['types']
        positions = inputs['positions']
        distances, pair_types, dr_dx = tf.map_fn(
            lambda x: distances_and_pair_types(x[0], x[1], len(
                self.atom_types), self.cutoff), (positions, types),
            fn_output_signature=(tf.RaggedTensorSpec(shape=[None, None, 1],
                                                     ragged_rank=1,
                                                     dtype=positions.dtype),
                                 tf.RaggedTensorSpec(shape=[None, None, 1],
                                                     ragged_rank=1,
                                                     dtype=types.dtype),
                                 tf.RaggedTensorSpec(
                                     shape=[None, None, None, 3],
                                     ragged_rank=2,
                                     dtype=positions.dtype)))

        if self.build_forces:
            return self.main_body_with_forces(types, distances, pair_types,
                                              dr_dx)
        return self.main_body_no_forces(types, distances, pair_types)

    @tf.function
    def shallow_call(self, inputs):
        types = inputs['types']
        distances = inputs['distances']
        pair_types = inputs['pair_types']

        if self.build_forces:
            dr_dx = inputs['dr_dx']
            return self.main_body_with_forces(types, distances, pair_types,
                                              dr_dx)
        return self.main_body_no_forces(types, distances, pair_types)

    @tf.function(input_signature=(tf.RaggedTensorSpec(
        tf.TensorShape([None, None, 1]), tf.int32, 1, tf.int64),
                                  tf.RaggedTensorSpec(
                                      tf.TensorShape([None, None, None, 1]),
                                      tf.keras.backend.floatx(), 2, tf.int64),
                                  tf.RaggedTensorSpec(
                                      tf.TensorShape([None, None, None, 1]),
                                      tf.int32, 2, tf.int64)))
    def main_body_no_forces(self, types, distances, pair_types):
        """Calculates the energy per atom by calling the main body
        """
        if self.method == 'partition_stitch':
            energy = self.body_partition_stitch(types, distances, pair_types)
        elif self.method == 'gather_scatter':
            energy = self.body_gather_scatter(types, distances, pair_types)
        elif self.method == 'where':
            energy = self.body_where(types, distances, pair_types)
        else:
            raise NotImplementedError('Unknown method %s' % self.method)
        number_of_atoms = tf.cast(types.row_lengths(),
                                  energy.dtype,
                                  name='number_of_atoms')
        energy_per_atom = tf.divide(energy,
                                    tf.expand_dims(number_of_atoms, axis=-1),
                                    name='energy_per_atom')

        return {'energy_per_atom': energy_per_atom}

    @tf.function(
        input_signature=(tf.RaggedTensorSpec(tf.TensorShape([None, None, 1]),
                                             tf.int32, 1, tf.int64),
                         tf.RaggedTensorSpec(
                             tf.TensorShape([None, None, None, 1]),
                             tf.keras.backend.floatx(), 2, tf.int64),
                         tf.RaggedTensorSpec(
                             tf.TensorShape([None, None, None, 1]), tf.int32,
                             2, tf.int64),
                         tf.RaggedTensorSpec(
                             tf.TensorShape([None, None, None, None, 3]),
                             tf.keras.backend.floatx(), 3, tf.int64)))
    def main_body_with_forces(self, types, distances, pair_types, dr_dx):
        """Calculates the energy per atom and the derivative of the total
           energy with respect to the distances
        """
        with tf.GradientTape() as tape:
            tape.watch(distances.flat_values)
            if self.method == 'partition_stitch':
                energy = self.body_partition_stitch(types, distances,
                                                    pair_types)
            elif self.method == 'gather_scatter':
                energy = self.body_gather_scatter(types, distances, pair_types)
            elif self.method == 'where':
                energy = self.body_where(types, distances, pair_types)
            else:
                raise NotImplementedError('Unknown method %s' % self.method)
        number_of_atoms = tf.cast(types.row_lengths(),
                                  energy.dtype,
                                  name='number_of_atoms')
        energy_per_atom = tf.divide(energy,
                                    tf.expand_dims(number_of_atoms, axis=-1),
                                    name='energy_per_atom')

        if self.force_method == 'old':
            # Probably not should not be reshaped to RaggedTensor only to sum
            # over these dimensions in the next step.
            dE_dr = tf.RaggedTensor.from_nested_row_splits(
                tape.gradient(energy, distances.flat_values),
                distances.nested_row_splits,
                name='dE_dr')
            # dr_dx.shape = (batch_size, None, None, None, 3)
            # dE_dr.shape = (batch_size, None, None, 1)
            # Sum over atom indices i and j. Force is the negative gradient.
            forces = -tf.reduce_sum(dr_dx * tf.expand_dims(dE_dr, -1),
                                    axis=(-3, -4),
                                    name='dE_dr_times_dr_dx')
        elif self.force_method == 'new':
            dr_dx = dr_dx.merge_dims(1, 2)
            print(dr_dx.nested_row_splits)
            dE_dr = tf.RaggedTensor.from_row_splits(tape.gradient(
                energy, distances.flat_values),
                                                    dr_dx.row_splits,
                                                    name='dE_dr')
            # dr_dx.shape = (batch_size, None, None, 3)
            # dE_dr.shape = (batch_size, None, 1)
            forces = -tf.reduce_sum(dr_dx * tf.expand_dims(dE_dr, -1),
                                    axis=-3,
                                    name='dE_dr_times_dr_dx')
        else:
            raise NotImplementedError('Unknown force method %s' %
                                      self.force_method)
        return {'energy_per_atom': energy_per_atom, 'forces': forces}

    @tf.function
    def body_where(self, types, distances, pair_types):
        """"""
        rho = distances
        phi = distances
        for t in self.atom_pair_types:
            cond = tf.equal(pair_types, self.pair_type_dict[t])
            rho = ragged_where(
                cond, tf.ragged.map_flat_values(self.pair_rho[t], rho), rho)
            phi = ragged_where(
                cond, tf.ragged.map_flat_values(self.pair_potentials[t], phi),
                phi)

        # Sum over atoms j
        sum_rho = tf.reduce_sum(rho**2, axis=-2, name='sum_rho')
        sum_phi = tf.reduce_sum(phi, axis=-2, name='sum_phi')

        # Make sure that sum_rho is never exactly zero since this leads to
        # problems in the gradient of the square root embedding function
        embedding_energies = tf.math.maximum(sum_rho, 1e-30)
        for t in self.atom_types:
            cond = tf.equal(types, self.type_dict[t])
            # tf.abs(x) necessary here since most embedding functions do not
            # support negative inputs but embedding energies are typically
            # negative and tf.where applies the function to every vector entry
            embedding_energies = ragged_where(
                cond,
                tf.ragged.map_flat_values(
                    lambda x: (self.embedding_functions[t]
                               (tf.abs(x)) + self.offsets[t](x)),
                    embedding_energies), embedding_energies)
        atomic_energies = sum_phi + embedding_energies

        # Sum over atoms i
        return tf.reduce_sum(atomic_energies, axis=-2, name='energy')

    @tf.function
    def body_partition_stitch(self, types, distances, pair_types):
        """main body using dynamic_partition and dynamic_stitch methods"""
        pair_type_indices = tf.dynamic_partition(tf.expand_dims(
            tf.range(tf.size(distances)), -1),
                                                 pair_types.flat_values,
                                                 len(self.atom_pair_types),
                                                 name='pair_type_indices')
        # Partition distances according to the pair_type
        partitioned_r = tf.dynamic_partition(distances.flat_values,
                                             pair_types.flat_values,
                                             len(self.atom_pair_types),
                                             name='partitioned_r')
        rho = [
            self.pair_rho[t](tf.expand_dims(part, -1))
            for t, part in zip(self.atom_pair_types, partitioned_r)
        ]
        phi = [
            self.pair_potentials[t](tf.expand_dims(part, -1))
            for t, part in zip(self.atom_pair_types, partitioned_r)
        ]

        rho = tf.dynamic_stitch(pair_type_indices, rho)
        phi = tf.dynamic_stitch(pair_type_indices, phi)

        # Reshape to ragged tensors
        phi = tf.RaggedTensor.from_nested_row_splits(
            phi, distances.nested_row_splits)
        rho = tf.RaggedTensor.from_nested_row_splits(
            rho, distances.nested_row_splits)

        # Sum over atoms j
        sum_rho = tf.reduce_sum(rho**2, axis=-2, name='sum_rho')
        sum_phi = tf.reduce_sum(phi, axis=-2, name='sum_phi')

        # Make sure that sum_rho is never exactly zero since this leads to
        # problems in the gradient of the square root embedding function
        sum_rho = tf.math.maximum(sum_rho, 1e-30)

        # Embedding energy
        partitioned_sum_rho = tf.dynamic_partition(sum_rho,
                                                   types,
                                                   len(self.atom_types),
                                                   name='partitioned_sum_rho')
        type_indices = tf.dynamic_partition(tf.expand_dims(
            tf.range(tf.size(sum_rho)), -1),
                                            types.flat_values,
                                            len(self.atom_types),
                                            name='type_indices')
        # energy offset is added here
        embedding_energies = [
            self.embedding_functions[t](tf.expand_dims(rho_t, -1)) +
            self.offsets[t](tf.expand_dims(rho_t, -1))
            for t, rho_t in zip(self.atom_types, partitioned_sum_rho)
        ]
        embedding_energies = tf.dynamic_stitch(type_indices,
                                               embedding_energies,
                                               name='embedding_energies')

        atomic_energies = sum_phi.flat_values + embedding_energies
        # Reshape to ragged
        atomic_energies = tf.RaggedTensor.from_row_splits(
            atomic_energies, types.row_splits, name='atomic_energies')
        # Sum over atoms i
        return tf.reduce_sum(atomic_energies, axis=-2, name='energy')

    @tf.function
    def body_gather_scatter(self, types, distances, pair_types):
        """main body using gather and scatter methods"""
        phi = tf.zeros_like(distances.flat_values)
        rho = tf.zeros_like(distances.flat_values)

        for ij, type_ij in enumerate(self.atom_pair_types):
            # Flat values necessary until where properly supports
            # ragged tensors
            indices = tf.where(tf.equal(pair_types, ij).flat_values)[:, 0]
            masked_distances = tf.gather(distances.flat_values, indices)
            phi = tf.tensor_scatter_nd_update(
                phi, tf.expand_dims(indices, -1),
                self.pair_potentials[type_ij](masked_distances))
            rho = tf.tensor_scatter_nd_update(
                rho, tf.expand_dims(indices, -1),
                self.pair_rho[type_ij](masked_distances))
        # Reshape back to ragged tensors
        phi = tf.RaggedTensor.from_nested_row_splits(
            phi, distances.nested_row_splits)
        rho = tf.RaggedTensor.from_nested_row_splits(
            rho, distances.nested_row_splits)
        # Sum over atoms j and flatten again
        atomic_energies = tf.reduce_sum(phi, axis=-2).flat_values
        sum_rho = tf.reduce_sum(rho**2, axis=-2).flat_values
        # Make sure that sum_rho is never exactly zero since this leads to
        # problems in the gradient of the square root embedding function
        sum_rho = tf.math.maximum(sum_rho, 1e-30)
        for i, t in enumerate(self.atom_types):
            indices = tf.where(tf.equal(types, i).flat_values)[:, 0]
            atomic_energies = tf.tensor_scatter_nd_add(
                atomic_energies, tf.expand_dims(indices, -1),
                self.embedding_functions[t](tf.gather(sum_rho, indices)))
        # Reshape to ragged
        atomic_energies = tf.RaggedTensor.from_row_splits(
            atomic_energies, types.row_splits)
        # Sum over atoms i
        return tf.reduce_sum(atomic_energies, axis=-2, name='energy')

    def tabulate(self,
                 filename,
                 atomic_numbers,
                 atomic_masses,
                 lattice_constants=dict(),
                 lattice_types=dict(),
                 cutoff_rho=120.0,
                 nrho=10000,
                 cutoff=6.2,
                 nr=10000):
        """
        Uses the atsim module to tabulate the potential for use in LAMMPS or
        similar programs.

        filename: str path of the output file. Note that the extension
                     .eam.fs is added
        atomic_numbers: dict containing the atomic_numbers of all
                            self.atom_types
        atomic_masses: dict containing the atomic mass of all self.atom_types

        """
        from atsim.potentials import EAMPotential, Potential
        from atsim.potentials.eam_tabulation import SetFL_FS_EAMTabulation

        warnings.warn('Long range behavior of the tabulated potentials differs'
                      ' from the tensorflow implementation!')

        def pair_wrapper(fun):
            def wrapped_fun(x):
                return 2 * fun(tf.reshape(x, (1, 1)))

            return wrapped_fun

        def rho_wrapper(fun):
            def wrapped_fun(x):
                return tf.math.square(fun(tf.reshape(x, (1, 1))))

            return wrapped_fun

        def F_wrapper(fun):
            def wrapped_fun(x):
                return fun(tf.reshape(x, (1, 1)))

            return wrapped_fun

        pair_potentials = []
        pair_densities = {t: {} for t in self.atom_types}
        for (t1, t2) in combinations_with_replacement(self.atom_types, 2):
            pair_type = ''.join([t1, t2])

            def pair_pot(x):
                return self.pair_potentials[pair_type](tf.reshape(x, (1, 1)))

            pair_potentials.append(
                Potential(
                    t1, t2,
                    pair_wrapper(self.pair_potentials[pair_type])
                    #lambda x: self.pair_potentials[pair_type](
                    #    x*tf.ones((1, 1)))
                ))
            pair_densities[t1][t2] = rho_wrapper(self.pair_rho[pair_type])
            pair_densities[t2][t1] = rho_wrapper(self.pair_rho[pair_type])

        eam_potentials = []
        for t in self.atom_types:
            eam_potentials.append(
                EAMPotential(t,
                             atomic_numbers[t],
                             atomic_masses[t],
                             F_wrapper(self.embedding_functions[t]),
                             pair_densities[t],
                             latticeConstant=lattice_constants.get(t, 0.0),
                             latticeType=lattice_types.get(t, 'fcc')))

        tabulation = SetFL_FS_EAMTabulation(pair_potentials, eam_potentials,
                                            cutoff, nr, cutoff_rho, nrho)

        with open(''.join([filename, '.eam.fs']), 'w') as outfile:
            tabulation.write(outfile)