コード例 #1
0
 def test_splits_on_intrinsic_noarg_function(self):
     federated_broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     called_intrinsics = building_blocks.Struct([federated_broadcast])
     comp = building_blocks.Lambda(None, None, called_intrinsics)
     call = building_block_factory.create_null_federated_broadcast()
     self.assert_splits_on(comp, call)
コード例 #2
0
 def test_splits_on_selected_intrinsic_broadcast(self):
     federated_broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     called_intrinsics = building_blocks.Struct([federated_broadcast])
     comp = building_blocks.Lambda('a', tf.int32, called_intrinsics)
     call = building_block_factory.create_null_federated_broadcast()
     self.assert_splits_on(comp, call)
コード例 #3
0
 def test_splits_on_selected_intrinsic_nested_in_tuple_broadcast(self):
     first_broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     packed_broadcast = building_blocks.Struct([
         building_blocks.Data('a', computation_types.at_server(tf.int32)),
         first_broadcast
     ])
     sel = building_blocks.Selection(packed_broadcast, index=0)
     second_broadcast = building_block_factory.create_federated_broadcast(
         sel)
     result = transformations.to_call_dominant(second_broadcast)
     comp = building_blocks.Lambda('a', tf.int32, result)
     call = building_block_factory.create_null_federated_broadcast()
     self.assert_splits_on(comp, call)
コード例 #4
0
def _split_ast_on_broadcast(bb):
    """Splits an AST on the `broadcast` intrinsic.

  Args:
    bb: An AST of arbitrary shape, potentially containing a broadcast.

  Returns:
    Two ASTs, the first of which maps comp's input to the
    argument of broadcast, and the second of which maps comp's input and
    broadcast's output to comp's output.
  """
    before, after = transformations.force_align_and_split_by_intrinsics(
        bb, [building_block_factory.create_null_federated_broadcast()])
    return _untuple_broadcast_only_before_after(before, after)