def test_federated_collect_with_client_int(self): x = _mock_data_of_type(computation_types.at_clients(tf.int32)) val = intrinsics.federated_collect(x) self.assert_value(val, 'int32*@SERVER')
def test_federated_collect_with_server_int_fails(self): x = _mock_data_of_type(computation_types.at_server(tf.int32)) with self.assertRaises(TypeError): intrinsics.federated_collect(x)
def after_merge_with_collect(original_arg, merged_arg): del merged_arg # Unused # Second element in original arg is the clients-placed value. return intrinsics.federated_collect(original_arg[1])