def test_register_derived_space(): spaces = [ ObservationSpace( name="ir", string_size_range=ScalarRange(min=ScalarLimit(value=0)), ), ] mock = MockGetObservation( ret=[Observation(string_value="Hello, world!")], ) observation = ObservationView(mock, spaces) observation.register_derived_space( base_name="ir", derived_name="ir_len", derived_space=Box(low=0, high=float("inf"), shape=(1,), dtype=int), cb=lambda base: [ len(base), ], ) value = observation["ir_len"] assert isinstance(value, list) assert value == [ len("Hello, world!"), ]
def shape2space(space: ObservationSpace) -> Space: """Convert an ObservationSpace description into a gym Space.""" def make_box(scalar_range_list, dtype, defaults): bounds = [scalar_range2tuple(r, defaults) for r in scalar_range_list] return Box( low=np.array([b[0] for b in bounds], dtype=dtype), high=np.array([b[1] for b in bounds], dtype=dtype), dtype=dtype, ) def make_seq(scalar_range, dtype, defaults): return Sequence( size_range=scalar_range2tuple(scalar_range, defaults), dtype=dtype, opaque_data_format=space.opaque_data_format, ) shape_type = space.WhichOneof("shape") if shape_type == "int64_range_list": return make_box( space.int64_range_list.range, np.int64, (np.iinfo(np.int64).min, np.iinfo(np.int64).max), ) elif shape_type == "double_range_list": return make_box(space.double_range_list.range, np.float64, (-np.inf, np.inf)) elif shape_type == "string_size_range": return make_seq(space.string_size_range, str, (0, None)) elif shape_type == "binary_size_range": return make_seq(space.binary_size_range, bytes, (0, None)) else: raise TypeError(f"Cannot determine shape of ObservationSpace: {space}")
def test_invalid_observation_index(): spaces = [ ObservationSpace( name="ir", string_size_range=ScalarRange(min=ScalarLimit(value=0)), ) ] observation = ObservationView(MockGetObservation(), spaces) with pytest.raises(KeyError): _ = observation[100]
def test_invalid_observation_name(): spaces = [ ObservationSpace( name="ir", string_size_range=ScalarRange(min=ScalarLimit(value=0)), ) ] observation = ObservationView(MockGetObservation(), spaces) with pytest.raises(KeyError) as ctx: _ = observation["invalid"] assert str(ctx.value) == "'invalid'"
def test_observed_value_types(): spaces = [ ObservationSpace( name="ir", space=Space(string_value=StringSpace(length_range=Int64Range( min=0))), ), ObservationSpace( name="features", space=Space(int64_box=Int64Box( low=Int64Tensor(shape=[2], value=[-100, -100]), high=Int64Tensor(shape=[2], value=[100, 100]), ), ), ), ObservationSpace( name="dfeat", space=Space(double_box=DoubleBox( low=DoubleTensor(shape=[1], value=[0.5]), high=DoubleTensor(shape=[1], value=[2.5]), ), ), ), ObservationSpace( name="binary", space=Space(int64_value=Int64Range(min=5, max=5)), ), ] mock = MockRawStep(ret=[ "Hello, IR", [1.0, 2.0], [-5, 15], b"Hello, bytes\0", "Hello, IR", [1.0, 2.0], [-5, 15], b"Hello, bytes\0", ]) observation = ObservationView(mock, spaces) value = observation["ir"] assert isinstance(value, str) assert value == "Hello, IR" value = observation["dfeat"] np.testing.assert_array_almost_equal(value, [1.0, 2.0]) value = observation["features"] np.testing.assert_array_equal(value, [-5, 15]) value = observation["binary"] assert value == b"Hello, bytes\0" # Check that the correct observation_space_list indices were used. assert mock.called_observation_spaces == [ "ir", "dfeat", "features", "binary" ] mock.called_observation_spaces = [] # Repeat the above tests using the generated bound methods. value = observation.ir() assert isinstance(value, str) assert value == "Hello, IR" value = observation.dfeat() np.testing.assert_array_almost_equal(value, [1.0, 2.0]) value = observation.features() np.testing.assert_array_equal(value, [-5, 15]) value = observation.binary() assert value == b"Hello, bytes\0" # Check that the correct observation_space_list indices were used. assert mock.called_observation_spaces == [ "ir", "dfeat", "features", "binary" ]
class UnrollingCompilationSession(CompilationSession): """Represents an instance of an interactive compilation session.""" compiler_version: str = "1.0.0" # The list of actions that are supported by this service. action_spaces = [ ActionSpace( name="unrolling", choice=[ ChoiceSpace( name="unroll_choice", named_discrete_space=NamedDiscreteSpace(value=[ "-loop-unroll -unroll-count=2", "-loop-unroll -unroll-count=4", "-loop-unroll -unroll-count=8", ], ), ) ], ) ] # A list of observation spaces supported by this service. Each of these # ObservationSpace protos describes an observation space. observation_spaces = [ ObservationSpace( name="ir", string_size_range=ScalarRange(min=ScalarLimit(value=0)), deterministic=True, platform_dependent=False, default_value=Observation(string_value=""), ), ObservationSpace( name="features", int64_range_list=ScalarRangeList(range=[ ScalarRange(min=ScalarLimit(value=0), max=ScalarLimit(value=1e5)), ScalarRange(min=ScalarLimit(value=0), max=ScalarLimit(value=1e5)), ScalarRange(min=ScalarLimit(value=0), max=ScalarLimit(value=1e5)), ]), ), ObservationSpace( name="runtime", scalar_double_range=ScalarRange(min=ScalarLimit(value=0)), deterministic=False, platform_dependent=True, default_value=Observation(scalar_double=0, ), ), ObservationSpace( name="size", scalar_double_range=ScalarRange(min=ScalarLimit(value=0)), deterministic=True, platform_dependent=True, default_value=Observation(scalar_double=0, ), ), ] def __init__( self, working_directory: Path, action_space: ActionSpace, benchmark: Benchmark, use_custom_opt: bool = True, ): super().__init__(working_directory, action_space, benchmark) logging.info("Started a compilation session for %s", benchmark.uri) self._benchmark = benchmark self._action_space = action_space # Resolve the paths to LLVM binaries once now. self._clang = str(llvm.clang_path()) self._llc = str(llvm.llc_path()) self._llvm_diff = str(llvm.llvm_diff_path()) self._opt = str(llvm.opt_path()) # LLVM's opt does not always enforce the unrolling options passed as cli arguments. Hence, we created our own exeutable with custom unrolling pass in examples/example_unrolling_service/loop_unroller that enforces the unrolling factors passed in its cli. # if self._use_custom_opt is true, use our custom exeutable, otherwise use LLVM's opt self._use_custom_opt = use_custom_opt # Dump the benchmark source to disk. self._src_path = str(self.working_dir / "benchmark.c") with open(self.working_dir / "benchmark.c", "wb") as f: f.write(benchmark.program.contents) self._llvm_path = str(self.working_dir / "benchmark.ll") self._llvm_before_path = str(self.working_dir / "benchmark.previous.ll") self._obj_path = str(self.working_dir / "benchmark.o") self._exe_path = str(self.working_dir / "benchmark.exe") run_command( [ self._clang, "-Xclang", "-disable-O0-optnone", "-emit-llvm", "-S", self._src_path, "-o", self._llvm_path, ], timeout=30, ) def apply_action( self, action: Action) -> Tuple[bool, Optional[ActionSpace], bool]: num_choices = len( self._action_space.choice[0].named_discrete_space.value) if len(action.choice) != 1: raise ValueError("Invalid choice count") # This is the index into the action space's values ("a", "b", "c") that # the user selected, e.g. 0 -> "a", 1 -> "b", 2 -> "c". choice_index = action.choice[0].named_discrete_value_index if choice_index < 0 or choice_index >= num_choices: raise ValueError("Out-of-range") args = self._action_space.choice[0].named_discrete_space.value[ choice_index] logging.info( "Applying action %d, equivalent command-line arguments: '%s'", choice_index, args, ) args = args.split() # make a copy of the LLVM file to compare its contents after applying the action shutil.copyfile(self._llvm_path, self._llvm_before_path) # apply action if self._use_custom_opt: # our custom unroller has an additional `f` at the beginning of each argument for i, arg in enumerate(args): # convert -<argument> to -f<argument> arg = arg[0] + "f" + arg[1:] args[i] = arg run_command( [ "../loop_unroller/loop_unroller", self._llvm_path, *args, "-S", "-o", self._llvm_path, ], timeout=30, ) else: run_command( [ self._opt, *args, self._llvm_path, "-S", "-o", self._llvm_path, ], timeout=30, ) # compare the IR files to check if the action had an effect try: subprocess.check_call( [self._llvm_diff, self._llvm_before_path, self._llvm_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=60, ) action_had_no_effect = True except subprocess.CalledProcessError: action_had_no_effect = False end_of_session = False # TODO: this needs investigation: for how long can we apply loop unrolling? e.g., detect if there are no more loops in the IR? new_action_space = None return (end_of_session, new_action_space, action_had_no_effect) @property def ir(self) -> str: with open(self._llvm_path) as f: return f.read() def get_observation(self, observation_space: ObservationSpace) -> Observation: logging.info("Computing observation from space %s", observation_space.name) if observation_space.name == "ir": return Observation(string_value=self.ir) elif observation_space.name == "features": stats = utils.extract_statistics_from_ir(self.ir) observation = Observation() observation.int64_list.value[:] = list(stats.values()) return observation elif observation_space.name == "runtime": # compile LLVM to object file run_command( [ self._llc, "-filetype=obj", self._llvm_path, "-o", self._obj_path, ], timeout=30, ) # build object file to binary run_command( [ "clang", self._obj_path, "-O3", "-o", self._exe_path, ], timeout=30, ) # TODO: add documentation that benchmarks need print out execution time # Running 5 times and taking the average of middle 3 exec_times = [] for _ in range(5): stdout = run_command( [self._exe_path], timeout=30, ) try: exec_times.append(int(stdout)) except ValueError: raise ValueError( f"Error in parsing execution time from output of command\n" f"Please ensure that the source code of the benchmark measures execution time and prints to stdout\n" f"Stdout of the program: {stdout}") exec_times = np.sort(exec_times) avg_exec_time = np.mean(exec_times[1:4]) return Observation(scalar_double=avg_exec_time) elif observation_space.name == "size": # compile LLVM to object file run_command( [ self._llc, "-filetype=obj", self._llvm_path, "-o", self._obj_path, ], timeout=30, ) # build object file to binary run_command( [ "clang", self._obj_path, "-Oz", "-o", self._exe_path, ], timeout=30, ) binary_size = os.path.getsize(self._exe_path) return Observation(scalar_double=binary_size) else: raise KeyError(observation_space.name)
class LoopToolCompilationSession(CompilationSession): """Represents an instance of an interactive loop_tool session.""" compiler_version: str = pkg_resources.get_distribution( "loop-tool-py").version # keep it simple for now: 1 variable, 1 nest action_spaces = [ ActionSpace( # shift around a single pre-split order, changing the size of splits name="simple", choice=[ ChoiceSpace( name="controls", named_discrete_space=NamedDiscreteSpace( value=["toggle_mode", "up", "down", "toggle_thread"], ), ) ], ), ActionSpace( # potentially define new splits name="split", choice=[ ChoiceSpace( name="controls", named_discrete_space=NamedDiscreteSpace(value=[ "toggle_mode", "up", "down", "toggle_thread", "split" ], ), ) ], ), ] observation_spaces = [ ObservationSpace( name="flops", scalar_double_range=ScalarRange(), deterministic=False, platform_dependent=True, default_value=Observation(scalar_double=0, ), ), ObservationSpace( name="loop_tree", string_size_range=ScalarRange(), deterministic=True, platform_dependent=False, default_value=Observation(string_value="", ), ), ObservationSpace( name="action_state", int64_range_list=ScalarRangeList(range=[ ScalarRange( min=ScalarLimit(value=0), max=ScalarLimit(value=2**36), ), ]), deterministic=True, platform_dependent=False, default_value=Observation(int64_list=Int64List(value=[0]), ), ), ] def __init__(self, working_directory: Path, action_space: ActionSpace, benchmark: Benchmark): super().__init__(working_directory, action_space, benchmark) self.action_space = action_space if "cuda" in benchmark.uri: self.backend = "cuda" lt.set_default_hardware("cuda") else: self.backend = "cpu" if self.backend not in lt.backends(): raise EnvironmentNotSupported( f"Failed to load {self.backend} dataset for loop_tool. Have you installed all required dependecies? See <https://facebookresearch.github.io/CompilerGym/envs/loop_tool.html#installation> for details. " ) self.ir = lt.IR() self.var = self.ir.create_var("a") r0 = self.ir.create_node("read", [], [self.var]) r1 = self.ir.create_node("read", [], [self.var]) add = self.ir.create_node("add", [r0, r1], [self.var]) w = self.ir.create_node("write", [add], [self.var]) self.ir.set_inputs([r0, r1]) self.ir.set_outputs([w]) self.size = int(benchmark.uri.split("/")[-1]) self.Ap = np.random.randn(self.size) self.Bp = np.random.randn(self.size) self.order = [(self.size, 0), (1, 0), (1, 0)] self.thread = [1, 0, 0] self.cursor = 0 self.mode = "size" logger.info("Started a compilation session for %s", benchmark.uri) def resize(self, increment): """ The idea is pull from or add to the parent loop. Three mutations possible to any size: A) x, y -> x + 1, 0 remove the tail, increase loop size, shrink parent B) x, y -> x, 0 only remove the tail, add to parent C) x, 0 -> x - 1, 0 if no tail, shrink the loop size, increase parent note: this means tails can never exist on innermost loops. this makes good sense :) A) [(a, b), (x, y), ...k] -> [(a', b'), (x + 1, 0), ...k] a * (x * k + y) + b = a' * (x + 1) * k + b' a' = (a * (x * k + y) + b) // ((x + 1) * k) b' = " " % " " B) [(a, b), (x, y), ...k] -> [(a', b'), (x, 0), ...k] a * (x * k + y) + b = a' * (x) * k + b' a' = (a * (x * k + y) + b) // ((x) * k) b' = " " % " " C) [(a, b), (x, y), ...k] -> [(a', b'), (x - 1, 0), ...k] a * (x * k + y) + b = a' * (x - 1) * k + b' a' = (a * (x * k + y) + b) // ((x - 1) * k) b' = " " % " " example interaction model: 1. cursor = 1 [1024, 1, 1] 2. up [512, 2, 1] 3. up [(341,1), 3, 1] 4. up [256, 4, 1] 5. cursor = 2, up [256, 2, 2] 6. up [256, (1, 1), 3] 7. cursor = 1, down [(341, 1), 1, 3] 8. cursor = 2, down [(341, 1), (1, 1), 2] 9. cursor = 1, down [512, 1, 2]""" if self.cursor == 0: return parent_size = self.order[self.cursor - 1] a = parent_size[0] b = parent_size[1] size = self.order[self.cursor] x = size[0] y = size[1] def lam(v, x): return v * x[0] + x[1] k = reduce(lam, self.order[self.cursor + 1:][::-1], 1) if increment == -1 and y: increment = 0 if (x + increment) < 1: return if (x + increment) > self.size: return n = a * x * k + b d = (x + increment) * k a_ = n // d b_ = n % d if a_ < 1: return if a_ > self.size: return self.order[self.cursor - 1] = (a_, b_) self.order[self.cursor] = (x + increment, 0) end_size = reduce(lam, self.order[::-1], 1) assert ( end_size == self.size ), f"{end_size} != {self.size} ({a}, {b}), ({x}, {y}) -> ({a_}, {b_}), ({x + increment}, 0)" def apply_action( self, action: Action) -> Tuple[bool, Optional[ActionSpace], bool]: if len(action.choice) != 1: raise ValueError("Invalid choice count") choice_index = action.choice[0].named_discrete_value_index if choice_index < 0 or choice_index >= len( self.action_space.choice[0].named_discrete_space.value): raise ValueError("Out-of-range") logger.info("Applied action %d", choice_index) act = self.action_space.choice[0].named_discrete_space.value[ choice_index] if self.mode not in ["size", "select"]: raise RuntimeError("Invalid mode set: {}".format(self.mode)) if act == "toggle_mode": if self.mode == "size": self.mode = "select" elif self.mode == "select": self.mode = "size" if act == "toggle_thread": self.thread[self.cursor] = not self.thread[self.cursor] if act == "down": # always loop around if self.mode == "size": self.resize(-1) elif self.mode == "select": next_cursor = (self.cursor - 1) % len(self.order) self.cursor = next_cursor if act == "up": # always loop around if self.mode == "size": self.resize(1) elif self.mode == "select": next_cursor = (self.cursor + 1) % len(self.order) self.cursor = next_cursor return False, None, False def lower(self): for n in self.ir.nodes: o = [(self.var, k) for k in self.order] self.ir.set_order(n, o) # always disable innermost self.ir.disable_reuse(n, len(o) - 1) loop_tree = lt.LoopTree(self.ir) parallel = set() t = loop_tree.roots[0] for b in self.thread: if b: parallel.add(t) if self.backend == "cpu": loop_tree.annotate(t, "cpu_parallel") t = loop_tree.children(t)[0] return loop_tree, parallel def flops(self): loop_tree, parallel = self.lower() if self.backend == "cuda": c = lt.cuda(loop_tree, parallel) else: c = lt.cpu(loop_tree) A = lt.Tensor(self.size) B = lt.Tensor(self.size) C = lt.Tensor(self.size) A.set(self.Ap) B.set(self.Bp) iters = 1000 warmup = 50 for i in range(warmup): c([A, B, C]) t = time.time() for i in range(iters - 1): c([A, B, C], False) c([A, B, C]) t_ = time.time() flops = self.size * iters / (t_ - t) / 1e9 return flops def get_observation(self, observation_space: ObservationSpace) -> Observation: if observation_space.name == "action_state": observation = Observation() # cursor, (size, tail) o = self.order[self.cursor] observation.int64_list.value[:] = [self.cursor, o[0], o[1]] return observation elif observation_space.name == "flops": return Observation(scalar_double=self.flops()) elif observation_space.name == "loop_tree": loop_tree, parallel = self.lower() return Observation(string_value=loop_tree.dump( lambda x: "[thread]" if x in parallel else "")) else: raise KeyError(observation_space.name)
class LoopsOptCompilationSession(CompilationSession): """Represents an instance of an interactive compilation session.""" compiler_version: str = "1.0.0" # The list of actions that are supported by this service. action_spaces = [ ActionSpace( name="loop-opt", space=Space(named_discrete=NamedDiscreteSpace(name=[ "--loop-unroll --unroll-count=2", "--loop-unroll --unroll-count=4", "--loop-unroll --unroll-count=8", "--loop-unroll --unroll-count=16", "--loop-unroll --unroll-count=32", "--loop-vectorize -force-vector-width=2", "--loop-vectorize -force-vector-width=4", "--loop-vectorize -force-vector-width=8", "--loop-vectorize -force-vector-width=16", "--loop-vectorize -force-vector-width=32", ]), ), ) ] # A list of observation spaces supported by this service. Each of these # ObservationSpace protos describes an observation space. observation_spaces = [ ObservationSpace( name="ir", space=Space(string_value=StringSpace(length_range=Int64Range( min=0))), deterministic=True, platform_dependent=False, default_observation=Event(string_value=""), ), ObservationSpace( name="Inst2vec", space=Space(int64_sequence=Int64SequenceSpace( length_range=Int64Range(min=0)), ), ), ObservationSpace( name="Autophase", space=Space( int64_sequence=Int64SequenceSpace(length_range=Int64Range( min=len(AUTOPHASE_FEATURE_NAMES), max=len(AUTOPHASE_FEATURE_NAMES), )), ), deterministic=True, platform_dependent=False, ), ObservationSpace( name="AutophaseDict", space=Space(space_dict=DictSpace( space={ name: Space(int64_value=Int64Range(min=0)) for name in AUTOPHASE_FEATURE_NAMES })), deterministic=True, platform_dependent=False, ), ObservationSpace( name="Programl", space=Space(string_value=StringSpace(length_range=Int64Range( min=0))), deterministic=True, platform_dependent=False, default_observation=Event(string_value=""), ), ObservationSpace( name="runtime", space=Space(double_value=DoubleRange(min=0)), deterministic=False, platform_dependent=True, default_observation=Event(double_value=0, ), ), ObservationSpace( name="size", space=Space(double_value=DoubleRange(min=0)), deterministic=True, platform_dependent=True, default_observation=Event(double_value=0, ), ), ] def __init__( self, working_directory: Path, action_space: ActionSpace, benchmark: Benchmark, use_custom_opt: bool = True, ): super().__init__(working_directory, action_space, benchmark) logging.info("Started a compilation session for %s", benchmark.uri) self._benchmark = benchmark self._action_space = action_space self.inst2vec = _INST2VEC_ENCODER # Resolve the paths to LLVM binaries once now. self._clang = str(llvm.clang_path()) self._llc = str(llvm.llc_path()) self._llvm_diff = str(llvm.llvm_diff_path()) self._opt = str(llvm.opt_path()) # LLVM's opt does not always enforce the loop optimization options passed as cli arguments. # Hence, we created our own exeutable with custom unrolling and vectorization pass in examples/loops_opt_service/opt_loops that enforces the unrolling and vectorization factors passed in its cli. # if self._use_custom_opt is true, use our custom exeutable, otherwise use LLVM's opt self._use_custom_opt = use_custom_opt # Dump the benchmark source to disk. self._src_path = str(self.working_dir / "benchmark.c") with open(self.working_dir / "benchmark.c", "wb") as f: f.write(benchmark.program.contents) self._llvm_path = str(self.working_dir / "benchmark.ll") self._llvm_before_path = str(self.working_dir / "benchmark.previous.ll") self._obj_path = str(self.working_dir / "benchmark.o") self._exe_path = str(self.working_dir / "benchmark.exe") run_command( [ self._clang, "-Xclang", "-disable-O0-optnone", "-emit-llvm", "-S", self._src_path, "-o", self._llvm_path, ], timeout=30, ) def apply_action( self, action: Event) -> Tuple[bool, Optional[ActionSpace], bool]: num_choices = len(self._action_space.space.named_discrete.name) # This is the index into the action space's values ("a", "b", "c") that # the user selected, e.g. 0 -> "a", 1 -> "b", 2 -> "c". choice_index = action.int64_value if choice_index < 0 or choice_index >= num_choices: raise ValueError("Out-of-range") args = self._action_space.space.named_discrete.name[choice_index] logging.info( "Applying action %d, equivalent command-line arguments: '%s'", choice_index, args, ) args = args.split() # make a copy of the LLVM file to compare its contents after applying the action shutil.copyfile(self._llvm_path, self._llvm_before_path) # apply action if self._use_custom_opt: # our custom opt-loops has an additional `f` at the beginning of each argument for i, arg in enumerate(args): # convert --<argument> to --f<argument> arg = arg[0:2] + "f" + arg[2:] args[i] = arg run_command( [ os.path.join(os.path.dirname(__file__), "../opt_loops/opt_loops"), self._llvm_path, *args, "-S", "-o", self._llvm_path, ], timeout=30, ) else: run_command( [ self._opt, *args, self._llvm_path, "-S", "-o", self._llvm_path, ], timeout=30, ) # compare the IR files to check if the action had an effect try: subprocess.check_call( [self._llvm_diff, self._llvm_before_path, self._llvm_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=60, ) action_had_no_effect = True except subprocess.CalledProcessError: action_had_no_effect = False end_of_session = False # TODO: this needs investigation: for how long can we apply loop optimizations? e.g., detect if there are no more loops in the IR? or look at the metadata? new_action_space = None return (end_of_session, new_action_space, action_had_no_effect) @property def ir(self) -> str: with open(self._llvm_path) as f: return f.read() def get_observation(self, observation_space: ObservationSpace) -> Event: logging.info("Computing observation from space %s", observation_space.name) if observation_space.name == "ir": return Event(string_value=self.ir) elif observation_space.name == "Inst2vec": Inst2vec_str = self.inst2vec.preprocess(self.ir) Inst2vec_ids = self.inst2vec.encode(Inst2vec_str) return Event(int64_tensor=Int64Tensor(shape=[len(Inst2vec_ids)], value=Inst2vec_ids)) elif observation_space.name == "Autophase": Autophase_str = run_command( [ os.path.join( os.path.dirname(__file__), "../../../compiler_gym/third_party/autophase/compute_autophase-prelinked", ), self._llvm_path, ], timeout=30, ) Autophase_list = list(map(int, list(Autophase_str.split(" ")))) return Event(int64_tensor=Int64Tensor(shape=[len(Autophase_list)], value=Autophase_list)) elif observation_space.name == "AutophaseDict": Autophase_str = run_command( [ os.path.join( os.path.dirname(__file__), "../../../compiler_gym/third_party/autophase/compute_autophase-prelinked", ), self._llvm_path, ], timeout=30, ) Autophase_list = list(map(int, list(Autophase_str.split(" ")))) Autophase_dict = { name: Event(int64_value=val) for name, val in zip(AUTOPHASE_FEATURE_NAMES, Autophase_list) } return Event(event_dict=DictEvent(event=Autophase_dict)) elif observation_space.name == "Programl": Programl_str = run_command( [ os.path.join( os.path.dirname(__file__), "../../../compiler_gym/third_party/programl/compute_programl", ), self._llvm_path, ], timeout=30, ) return Event(string_value=Programl_str) elif observation_space.name == "runtime": # compile LLVM to object file run_command( [ self._llc, "-filetype=obj", self._llvm_path, "-o", self._obj_path, ], timeout=30, ) # build object file to binary run_command( [ self._clang, self._obj_path, "-O3", "-o", self._exe_path, ], timeout=30, ) # TODO: add documentation that benchmarks need print out execution time # Running 5 times and taking the average of middle 3 exec_times = [] for _ in range(5): stdout = run_command( [self._exe_path], timeout=30, ) try: exec_times.append(int(stdout)) except ValueError: raise ValueError( f"Error in parsing execution time from output of command\n" f"Please ensure that the source code of the benchmark measures execution time and prints to stdout\n" f"Stdout of the program: {stdout}") exec_times = np.sort(exec_times) avg_exec_time = np.mean(exec_times[1:4]) return Event(double_value=avg_exec_time) elif observation_space.name == "size": # compile LLVM to object file run_command( [ self._llc, "-filetype=obj", self._llvm_path, "-o", self._obj_path, ], timeout=30, ) # build object file to binary run_command( [ self._clang, self._obj_path, "-Oz", "-o", self._exe_path, ], timeout=30, ) binary_size = os.path.getsize(self._exe_path) return Event(double_value=binary_size) else: raise KeyError(observation_space.name)
class ExampleCompilationSession(CompilationSession): """Represents an instance of an interactive compilation session.""" compiler_version: str = "1.0.0" # The action spaces supported by this service. Here we will implement a # single action space, called "default", that represents a command line with # three options: "a", "b", and "c". action_spaces = [ ActionSpace( name="default", space=Space(named_discrete=NamedDiscreteSpace(name=[ "a", "b", "c", ], ), ), ) ] # A list of observation spaces supported by this service. Each of these # ObservationSpace protos describes an observation space. observation_spaces = [ ObservationSpace( name="ir", space=Space( string_value=StringSpace(length_range=Int64Range(min=0)), ), deterministic=True, platform_dependent=False, default_observation=Event(string_value=""), ), ObservationSpace( name="features", space=Space(int64_box=Int64Box( low=Int64Tensor(shape=[3], value=[-100, -100, -100]), high=Int64Tensor(shape=[3], value=[100, 100, 100]), ), ), ), ObservationSpace( name="runtime", space=Space(double_value=DoubleRange(min=0), ), deterministic=False, platform_dependent=True, default_observation=Event(double_value=0, ), ), ] def __init__(self, working_directory: Path, action_space: ActionSpace, benchmark: Benchmark): super().__init__(working_directory, action_space, benchmark) logging.info("Started a compilation session for %s", benchmark.uri) def apply_action( self, action: Event) -> Tuple[bool, Optional[ActionSpace], bool]: num_choices = len(self.action_spaces[0].space.named_discrete.name) # This is the index into the action space's values ("a", "b", "c") that # the user selected, e.g. 0 -> "a", 1 -> "b", 2 -> "c". choice_index = action.int64_value logging.info("Applying action %d", choice_index) if choice_index < 0 or choice_index >= num_choices: raise ValueError("Out-of-range") # Here is where we would run the actual action to update the environment's # state. return False, None, False def get_observation(self, observation_space: ObservationSpace) -> Event: logging.info("Computing observation from space %s", observation_space) if observation_space.name == "ir": return Event(string_value="Hello, world!") elif observation_space.name == "features": observation = Event( int64_tensor=Int64Tensor(shape=[3], value=[0, 0, 0])) return observation elif observation_space.name == "runtime": return Event(double_value=0) else: raise KeyError(observation_space.name)
def from_proto(cls, index: int, proto: ObservationSpace): """Construct a space from an ObservationSpace message.""" shape_type = proto.WhichOneof("shape") def make_box(scalar_range_list, dtype, defaults): bounds = [ scalar_range2tuple(r, defaults) for r in scalar_range_list ] return Box( low=np.array([b[0] for b in bounds], dtype=dtype), high=np.array([b[1] for b in bounds], dtype=dtype), dtype=dtype, ) def make_seq(scalar_range, dtype, defaults): return Sequence( size_range=scalar_range2tuple(scalar_range, defaults), dtype=dtype, opaque_data_format=proto.opaque_data_format, ) # Translate from protocol buffer specification to python. There are # three variables to derive: 'space', the gym.Space instance describing # the space. 'cb' is a callback that translates from an Observation # message to a python type. and 'to_string' is a callback that # translates from a python type to a string for printing. if proto.opaque_data_format == "json://networkx/MultiDiGraph": # TODO(cummins): Add a Graph space. space = make_seq(proto.string_size_range, str, (0, None)) def cb(observation): return nx.readwrite.json_graph.node_link_graph(json.loads( observation.string_value), multigraph=True, directed=True) def to_string(observation): return json.dumps( nx.readwrite.json_graph.node_link_data(observation), indent=2) elif proto.opaque_data_format == "json://": space = make_seq(proto.string_size_range, str, (0, None)) def cb(observation): return json.loads(observation.string_value) def to_string(observation): return json.dumps(observation, indent=2) elif shape_type == "int64_range_list": space = make_box( proto.int64_range_list.range, np.int64, (np.iinfo(np.int64).min, np.iinfo(np.int64).max), ) def cb(observation): return np.array(observation.int64_list.value, dtype=np.int64) to_string = str elif shape_type == "double_range_list": space = make_box(proto.double_range_list.range, np.float64, (-np.inf, np.inf)) def cb(observation): return np.array(observation.double_list.value, dtype=np.float64) to_string = str elif shape_type == "string_size_range": space = make_seq(proto.string_size_range, str, (0, None)) def cb(observation): return observation.string_value to_string = str elif shape_type == "binary_size_range": space = make_seq(proto.binary_size_range, bytes, (0, None)) def cb(observation): return observation.binary_value to_string = str else: raise TypeError( f"Cannot determine shape of ObservationSpace: {proto}") return cls( id=proto.name, index=index, space=space, cb=cb, to_string=to_string, deterministic=proto.deterministic, platform_dependent=proto.platform_dependent, default_value=cb(proto.default_value), )
def test_observed_value_types(): spaces = [ ObservationSpace( name="ir", string_size_range=ScalarRange(min=ScalarLimit(value=0)), ), ObservationSpace( name="features", int64_range_list=ScalarRangeList(range=[ ScalarRange(min=ScalarLimit(value=-100), max=ScalarLimit(value=100)), ScalarRange(min=ScalarLimit(value=-100), max=ScalarLimit(value=100)), ]), ), ObservationSpace( name="dfeat", double_range_list=ScalarRangeList(range=[ ScalarRange(min=ScalarLimit(value=0.5), max=ScalarLimit(value=2.5)) ]), ), ObservationSpace( name="binary", binary_size_range=ScalarRange(min=ScalarLimit(value=5), max=ScalarLimit(value=5)), ), ] mock = MockRawStep(ret=[ "Hello, IR", [1.0, 2.0], [-5, 15], b"Hello, bytes\0", "Hello, IR", [1.0, 2.0], [-5, 15], b"Hello, bytes\0", ]) observation = ObservationView(mock, spaces) value = observation["ir"] assert isinstance(value, str) assert value == "Hello, IR" value = observation["dfeat"] np.testing.assert_array_almost_equal(value, [1.0, 2.0]) value = observation["features"] np.testing.assert_array_equal(value, [-5, 15]) value = observation["binary"] assert value == b"Hello, bytes\0" # Check that the correct observation_space_list indices were used. assert mock.called_observation_spaces == [ "ir", "dfeat", "features", "binary" ] mock.called_observation_spaces = [] # Repeat the above tests using the generated bound methods. value = observation.ir() assert isinstance(value, str) assert value == "Hello, IR" value = observation.dfeat() np.testing.assert_array_almost_equal(value, [1.0, 2.0]) value = observation.features() np.testing.assert_array_equal(value, [-5, 15]) value = observation.binary() assert value == b"Hello, bytes\0" # Check that the correct observation_space_list indices were used. assert mock.called_observation_spaces == [ "ir", "dfeat", "features", "binary" ]
def make_gcc_compilation_session(gcc_bin: str): """Create a class to represent a GCC compilation service. :param gcc_bin: Path to the gcc executable. This can a command name, like "gcc", or it can be path to the executable. Finally, if prefixed with "docker:" it can be the name of a docker image, e.g. "docker:gcc:11.2.0" """ gcc = Gcc(gcc_bin) # The available actions actions = [] # Actions that are small will have all their various choices made as # explicit actions. # Actions that are not small will have the abbility to increment the choice # by different amounts. for i, option in enumerate(gcc.spec.options): if len(option) < 10: for j in range(len(option)): actions.append(SimpleAction(option, i, j)) if len(option) >= 10: actions.append(IncrAction(option, i, 1)) actions.append(IncrAction(option, i, -1)) if len(option) >= 50: actions.append(IncrAction(option, i, 10)) actions.append(IncrAction(option, i, -10)) if len(option) >= 500: actions.append(IncrAction(option, i, 100)) actions.append(IncrAction(option, i, -100)) if len(option) >= 5000: actions.append(IncrAction(option, i, 1000)) actions.append(IncrAction(option, i, -1000)) action_spaces_ = [ ActionSpace( name="default", space=Space( named_discrete=NamedDiscreteSpace(name=[str(a) for a in actions]), ), ), ] observation_spaces_ = [ # A string of the source code ObservationSpace( name="source", space=Space(string_value=StringSpace(length_range=Int64Range(min=0))), deterministic=True, platform_dependent=False, default_observation=Event(string_value=""), ), # A string of the rtl code ObservationSpace( name="rtl", space=Space(string_value=StringSpace(length_range=Int64Range(min=0))), deterministic=True, platform_dependent=True, default_observation=Event(string_value=""), ), # A string of the assembled code ObservationSpace( name="asm", space=Space(string_value=StringSpace(length_range=Int64Range(min=0))), deterministic=True, platform_dependent=True, default_observation=Event(string_value=""), ), # The size of the assembled code ObservationSpace( name="asm_size", space=Space(int64_value=Int64Range(min=-1)), deterministic=True, platform_dependent=True, default_observation=Event( int64_value=-1, ), ), # The hash of the assembled code ObservationSpace( name="asm_hash", space=Space( string_value=StringSpace(length_range=Int64Range(min=0, max=200)), ), deterministic=True, platform_dependent=True, default_observation=Event(string_value=""), ), # Asm instruction counts - Counter as a JSON string ObservationSpace( name="instruction_counts", space=Space( string_value=StringSpace(length_range=Int64Range(min=0)), ), deterministic=True, platform_dependent=True, default_observation=Event(string_value=""), ), # A bytes of the object code ObservationSpace( name="obj", space=Space( byte_sequence=ByteSequenceSpace(length_range=Int64Range(min=0)), ), deterministic=True, platform_dependent=False, default_observation=Event(byte_tensor=ByteTensor(shape=[0], value=b"")), ), # The size of the object code ObservationSpace( name="obj_size", space=Space(int64_value=Int64Range(min=-1)), deterministic=True, platform_dependent=True, default_observation=Event( int64_value=-1, ), ), # The hash of the object code ObservationSpace( name="obj_hash", space=Space( string_value=StringSpace(length_range=Int64Range(min=0, max=200)), ), deterministic=True, platform_dependent=True, default_observation=Event(string_value=""), ), # A list of the choices. Each element corresponds to an option in the spec. # '-1' indicates that this is empty on the command line (e.g. if the choice # corresponding to the '-O' option is -1, then no -O flag will be emitted.) # If a nonnegative number if given then that particular choice is used # (e.g. for the -O flag, 5 means use '-Ofast' on the command line.) ObservationSpace( name="choices", space=Space( space_list=ListSpace( space=[ Space(int64_value=Int64Range(min=0, max=len(option) - 1)) for option in gcc.spec.options ] ), ), ), # The command line for compiling the object file as a string ObservationSpace( name="command_line", space=Space( string_value=StringSpace(length_range=Int64Range(min=0, max=200)), ), deterministic=True, platform_dependent=True, default_observation=Event(string_value=""), ), ] class GccCompilationSession(CompilationSession): """A GCC interactive compilation session.""" compiler_version: str = gcc.spec.version action_spaces = action_spaces_ observation_spaces = observation_spaces_ def __init__( self, working_directory: Path, action_space: ActionSpace, benchmark: Benchmark, ): super().__init__(working_directory, action_space, benchmark) # The benchmark being used self.benchmark = benchmark # Timeout value for compilation (in seconds) self._timeout = None # The source code self._source = None # The rtl code self._rtl = None # The assembled code self._asm = None # Size of the assembled code self._asm_size = None # Hash of the assembled code self._asm_hash = None # The object binary self._obj = None # size of the object binary self._obj_size = None # Hash of the object binary self._obj_hash = None # Set the path to the GCC executable self._gcc_bin = "gcc" # Initially the choices and the spec, etc are empty. They will be # initialised lazily self._choices = None @property def num_actions(self) -> int: return len(self.action_spaces[0].space.named_discrete.name) @property def choices(self) -> List[int]: if self._choices is None: self._choices = [-1] * len(gcc.spec.options) return self._choices @choices.setter def choices(self, value: List[int]): self._choices = value @property def source(self) -> str: """Get the benchmark source""" self.prepare_files() return self._source @property def rtl(self) -> bytes: """Get the RTL code""" self.dump_rtl() return self._rtl @property def asm(self) -> bytes: """Get the assembled code""" self.assemble() return self._asm @property def asm_size(self) -> int: """Get the assembled code size""" self.assemble() return self._asm_size @property def asm_hash(self) -> str: """Get the assembled code hash""" self.assemble() return self._asm_hash @property def instruction_counts(self) -> str: """Get the instuction counts as a JSON string""" self.assemble() insn_pat = re.compile("\t([a-zA-Z-0-9.-]+).*") insn_cnts = Counter() lines = self._asm.split("\n") for line in lines: m = insn_pat.fullmatch(line) if m: insn_cnts[m.group(1)] += 1 return json.dumps(insn_cnts) @property def obj(self) -> bytes: """Get the compiled code""" self.compile() return self._obj @property def obj_size(self) -> int: """Get the compiled code size""" self.compile() return self._obj_size @property def obj_hash(self) -> str: """Get the compiled code hash""" self.compile() return self._obj_hash @property def src_path(self) -> Path: """Get the path to the source file""" return self.working_dir / "src.c" @property def obj_path(self) -> Path: """Get the path to object file""" return self.working_dir / "obj.o" @property def asm_path(self) -> Path: """Get the path to the assembly""" return self.working_dir / "asm.s" @property def rtl_path(self) -> Path: """Get the path to the rtl""" return self.working_dir / "rtl.lsp" def obj_command_line( self, src_path: Path = None, obj_path: Path = None ) -> List[str]: """Get the command line to create the object file. The 'src_path' and 'obj_path' give the input and output paths. If not set, then they are taken from 'self.src_path' and 'self.obj_path'. This is useful for printing where the actual paths are not important.""" src_path = src_path or self.src_path obj_path = obj_path or self.obj_path # Gather the choices as strings opts = [ option[choice] for option, choice in zip(gcc.spec.options, self.choices) if choice >= 0 ] cmd_line = opts + ["-w", "-c", src_path, "-o", obj_path] return cmd_line def asm_command_line( self, src_path: Path = None, asm_path: Path = None ) -> List[str]: """Get the command line to create the assembly file. The 'src_path' and 'asm_path' give the input and output paths. If not set, then they are taken from 'self.src_path' and 'self.obj_path'. This is useful for printing where the actual paths are not important.""" src_path = src_path or self.src_path asm_path = asm_path or self.asm_path opts = [ option[choice] for option, choice in zip(gcc.spec.options, self.choices) if choice >= 0 ] cmd_line = opts + ["-w", "-S", src_path, "-o", asm_path] return cmd_line def rtl_command_line( self, src_path: Path = None, rtl_path: Path = None, asm_path: Path = None ) -> List[str]: """Get the command line to create the rtl file - might as well do the asm at the same time. The 'src_path', 'rtl_path', 'asm_path' give the input and output paths. If not set, then they are taken from 'self.src_path' and 'self.obj_path'. This is useful for printing where the actual paths are not important.""" src_path = src_path or self.src_path rtl_path = rtl_path or self.rtl_path asm_path = asm_path or self.asm_path opts = [ option[choice] for option, choice in zip(gcc.spec.options, self.choices) if choice >= 0 ] cmd_line = opts + [ "-w", "-S", src_path, f"-fdump-rtl-dfinish={rtl_path}", "-o", asm_path, ] return cmd_line def prepare_files(self): """Copy the source to the working directory.""" if not self._source: if self.benchmark.program.contents: self._source = self.benchmark.program.contents.decode() else: with urlopen(self.benchmark.program.uri) as r: self._source = r.read().decode() with open(self.src_path, "w") as f: print(self._source, file=f) def compile(self) -> Optional[str]: """Compile the benchmark""" if not self._obj: self.prepare_files() logger.debug( "Compiling: %s", " ".join(map(str, self.obj_command_line())) ) gcc( *self.obj_command_line(), cwd=self.working_dir, timeout=self._timeout, ) with open(self.obj_path, "rb") as f: # Set the internal variables self._obj = f.read() self._obj_size = os.path.getsize(self.obj_path) self._obj_hash = hashlib.md5(self._obj).hexdigest() def assemble(self) -> Optional[str]: """Assemble the benchmark""" if not self._asm: self.prepare_files() logger.debug( "Assembling: %s", " ".join(map(str, self.asm_command_line())) ) gcc( *self.asm_command_line(), cwd=self.working_dir, timeout=self._timeout, ) with open(self.asm_path, "rb") as f: # Set the internal variables asm_bytes = f.read() self._asm = asm_bytes.decode() self._asm_size = os.path.getsize(self.asm_path) self._asm_hash = hashlib.md5(asm_bytes).hexdigest() def dump_rtl(self) -> Optional[str]: """Dump the RTL (and assemble the benchmark)""" if not self._rtl: self.prepare_files() logger.debug( "Dumping RTL: %s", " ".join(map(str, self.rtl_command_line())) ) gcc( *self.rtl_command_line(), cwd=self.working_dir, timeout=self._timeout, ) with open(self.asm_path, "rb") as f: # Set the internal variables asm_bytes = f.read() self._asm = asm_bytes.decode() self._asm_size = os.path.getsize(self.asm_path) self._asm_hash = hashlib.md5(asm_bytes).hexdigest() with open(self.rtl_path, "rb") as f: # Set the internal variables rtl_bytes = f.read() self._rtl = rtl_bytes.decode() def reset_cached(self): """Reset the cached values""" self._obj = None self._obj_size = None self._obj_hash = None self._rtl = None self._asm = None self._asm_size = None self._asm_hash = None def apply_action( self, action_proto: Event ) -> Tuple[bool, Optional[ActionSpace], bool]: """Apply an action.""" if not action_proto.HasField("int64_value"): raise ValueError("Invalid action, int64_value expected.") choice_index = action_proto.int64_value if choice_index < 0 or choice_index >= self.num_actions: raise ValueError("Out-of-range") # Get the action action = actions[choice_index] # Apply the action to this session and check if we changed anything old_choices = self.choices.copy() action(self) logger.debug("Applied action %s", action) # Reset the internal variables if this action has caused a change in the # choices if old_choices != self.choices: self.reset_cached() # The action has not changed anything yet. That waits until an # observation is taken return False, None, False def get_observation(self, observation_space: ObservationSpace) -> Event: """Get one of the observations""" if observation_space.name == "source": return Event(string_value=self.source or "") elif observation_space.name == "rtl": return Event(string_value=self.rtl or "") elif observation_space.name == "asm": return Event(string_value=self.asm or "") elif observation_space.name == "asm_size": return Event(int64_value=self.asm_size or -1) elif observation_space.name == "asm_hash": return Event(string_value=self.asm_hash or "") elif observation_space.name == "instruction_counts": return Event(string_value=self.instruction_counts or "{}") elif observation_space.name == "obj": value = self.obj or b"" return Event(byte_tensor=ByteTensor(shape=[len(value)], value=value)) elif observation_space.name == "obj_size": return Event(int64_value=self.obj_size or -1) elif observation_space.name == "obj_hash": return Event(string_value=self.obj_hash or "") elif observation_space.name == "choices": observation = Event( int64_tensor=Int64Tensor( shape=[len(self.choices)], value=self.choices ) ) return observation elif observation_space.name == "command_line": return Event( string_value=gcc.bin + " " + " ".join(map(str, self.obj_command_line("src.c", "obj.o"))) ) else: raise KeyError(observation_space.name) def handle_session_parameter(self, key: str, value: str) -> Optional[str]: if key == "gcc_spec": return codecs.encode(pickle.dumps(gcc.spec), "base64").decode() elif key == "choices": choices = list(map(int, value.split(","))) assert len(choices) == len(gcc.spec.options) assert all( -1 <= p <= len(gcc.spec.options[i]) for i, p in enumerate(choices) ) if choices != self.choices: self.choices = choices self.reset_cached() return "" elif key == "timeout": self._timeout = None if value == "" else int(value) return "" return None return GccCompilationSession
def test_observed_value_types(): spaces = [ ObservationSpace( name="ir", string_size_range=ScalarRange(min=ScalarLimit(value=0)), ), ObservationSpace( name="features", int64_range_list=ScalarRangeList( range=[ ScalarRange( min=ScalarLimit(value=-100), max=ScalarLimit(value=100) ), ScalarRange( min=ScalarLimit(value=-100), max=ScalarLimit(value=100) ), ] ), ), ObservationSpace( name="dfeat", double_range_list=ScalarRangeList( range=[ ScalarRange(min=ScalarLimit(value=0.5), max=ScalarLimit(value=2.5)) ] ), ), ObservationSpace( name="binary", binary_size_range=ScalarRange( min=ScalarLimit(value=5), max=ScalarLimit(value=5) ), ), ] mock = MockGetObservation( ret=[ Observation(string_value="Hello, IR"), Observation(double_list=DoubleList(value=[1.0, 2.0])), Observation(int64_list=Int64List(value=[-5, 15])), Observation(binary_value=b"Hello, bytes\0"), ] ) observation = ObservationView(mock, spaces) value = observation["ir"] assert isinstance(value, str) assert value == "Hello, IR" value = observation["dfeat"] np.testing.assert_array_almost_equal(value, [1.0, 2.0]) assert value.dtype == np.float64 value = observation["features"] np.testing.assert_array_equal(value, [-5, 15]) assert value.dtype == np.int64 value = observation["binary"] assert value == b"Hello, bytes\0" # Check that the correct observation_space_list indices were used. assert mock.called_observation_spaces == [0, 2, 1, 3]
class ExampleCompilationSession(CompilationSession): """Represents an instance of an interactive compilation session.""" compiler_version: str = "1.0.0" # The action spaces supported by this service. Here we will implement a # single action space, called "default", that represents a command line with # three options: "a", "b", and "c". action_spaces = [ ActionSpace( name="default", choice=[ ChoiceSpace( name="optimization_choice", named_discrete_space=NamedDiscreteSpace(value=[ "a", "b", "c", ], ), ) ], ) ] # A list of observation spaces supported by this service. Each of these # ObservationSpace protos describes an observation space. observation_spaces = [ ObservationSpace( name="ir", string_size_range=ScalarRange(min=ScalarLimit(value=0)), deterministic=True, platform_dependent=False, default_value=Observation(string_value=""), ), ObservationSpace( name="features", int64_range_list=ScalarRangeList(range=[ ScalarRange(min=ScalarLimit(value=-100), max=ScalarLimit(value=100)), ScalarRange(min=ScalarLimit(value=-100), max=ScalarLimit(value=100)), ScalarRange(min=ScalarLimit(value=-100), max=ScalarLimit(value=100)), ]), ), ObservationSpace( name="runtime", scalar_double_range=ScalarRange(min=ScalarLimit(value=0)), deterministic=False, platform_dependent=True, default_value=Observation(scalar_double=0, ), ), ] def __init__(self, working_directory: Path, action_space: ActionSpace, benchmark: Benchmark): super().__init__(working_directory, action_space, benchmark) logging.info("Started a compilation session for %s", benchmark.uri) def apply_action( self, action: Action) -> Tuple[bool, Optional[ActionSpace], bool]: num_choices = len( self.action_spaces[0].choice[0].named_discrete_space.value) if len(action.choice) != 1: raise ValueError("Invalid choice count") # This is the index into the action space's values ("a", "b", "c") that # the user selected, e.g. 0 -> "a", 1 -> "b", 2 -> "c". choice_index = action.choice[0].named_discrete_value_index logging.info("Applying action %d", choice_index) if choice_index < 0 or choice_index >= num_choices: raise ValueError("Out-of-range") # Here is where we would run the actual action to update the environment's # state. return False, None, False def get_observation(self, observation_space: ObservationSpace) -> Observation: logging.info("Computing observation from space %s", observation_space) if observation_space.name == "ir": return Observation(string_value="Hello, world!") elif observation_space.name == "features": observation = Observation() observation.int64_list.value[:] = [0, 0, 0] return observation elif observation_space.name == "runtime": return Observation(scalar_double=0) else: raise KeyError(observation_space.name)