def test_exception_raises(self): def migrate_params(params): raise RuntimeError("huh") with patch.object(module, "migrate_params", migrate_params): with self.assertRaisesRegex(RuntimeError, "huh"): module.migrate_params_thrift({})
def test_default_returns_params(self): self.assertEqual( module.migrate_params_thrift( arrow_raw_params_to_thrift(RawParams({"A": [1], "B": "x"})) ), arrow_raw_params_to_thrift(RawParams({"A": [1], "B": "x"})), )
def run_in_sandbox(compiled_module: CompiledModule, function: str, args: List[Any]) -> None: """Run `function` with `args`, and write the (Thrift) result to `sys.stdout`.""" # TODO sandbox -- will need an OS `clone()` with namespace, cgroups, .... # Run the user's code in a new (programmatic) module. # # This gives the user code a blank namespace -- exactly what we want. module_name = f"rawmodule.{compiled_module.module_slug}" user_code_module = types.ModuleType(module_name) sys.modules[module_name] = user_code_module # simulate "import" exec(compiled_module.code_object, user_code_module.__dict__) # And now ... now we're unsafe! Because `code_object` may be malicious, any # line of code from here on out gives undefined behavior. Luckily, a parent # is catching all possibile outcomes.... # Now override the pieces of the _default_ module with the user-supplied # ones. That way, when the default `render_pandas()` calls `render()`, that # `render()` is the user-code `render()` (if supplied). # # Good thing we've forked! This totally messes with global variables. module = cjwkernel.pandas.module for fn in ( "fetch", "fetch_arrow", "fetch_pandas", "fetch_thrift", "migrate_params", "migrate_params_thrift", "render", "render_arrow", "render_arrow_v1", "render_pandas", "render_thrift", ): if fn in user_code_module.__dict__: module.__dict__[fn] = user_code_module.__dict__[fn] # Set ModuleSpec global parameter -- module frameworks use it for params module.__dict__["ModuleSpec"] = load_spec(compiled_module.module_spec_dict) if function == "render_thrift": result = module.render_thrift(*args) elif function == "migrate_params_thrift": result = module.migrate_params_thrift(*args) elif function == "validate_thrift": result = module.validate_thrift(*args) elif function == "fetch_thrift": result = module.fetch_thrift(*args) else: raise NotImplementedError transport = thrift.transport.TTransport.TFileObjectTransport( sys.__stdout__.buffer) protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(transport) if result is not None: result.write(protocol) transport.flush()
def test_default_returns_params(self): thrift_result = module.migrate_params_thrift( RawParams({ "A": [1], "B": "x" }).to_thrift()) result = RawParams.from_thrift(thrift_result).params self.assertEqual(result, {"A": [1], "B": "x"})
def test_allow_override(self): def migrate_params(params): self.assertEqual(params, {"x": "y"}) return {"y": "z"} with patch.object(module, "migrate_params", migrate_params): self.assertEqual( module.migrate_params_thrift( {"x": ttypes.Json(string_value="y")}), ttypes.MigrateParamsResult( {"y": ttypes.Json(string_value="z")}), )
def _test(self, fn, params={}): with patch.object(module, "migrate_params", fn): thrift_result = module.migrate_params_thrift( arrow_raw_params_to_thrift(RawParams(params))) return thrift_raw_params_to_arrow(thrift_result).params
def _test(self, fn, params={}): with patch.object(module, "migrate_params", fn): thrift_result = module.migrate_params_thrift(RawParams(params).to_thrift()) return RawParams.from_thrift(thrift_result).params
def test_default_returns_params(self): self.assertEqual( module.migrate_params_thrift({"x": ttypes.Json(string_value="y")}), ttypes.MigrateParamsResult({"x": ttypes.Json(string_value="y")}), )