def apply_heuristics(self, node: Node, D: torch.Tensor) -> torch.Tensor:
        """
        Modify D with additional heuristics like searching for obvious wins in next move
        :param node: Node with selected action
        :param D: probability distribution over actions
        :return: potentially modified probability distribution over actions
        """
        if self.game_type == "hex":
            state, player = node.state, node.player

            # If first move in game, select center
            if sum(state) == 0:
                D.apply_(lambda x: 0)
                D[len(state) // 2] = 1.0
                return D

            # Check for obvious wins in next move
            for i, p in enumerate(D):
                # If probability is higher than 50%, check if it could be a winning state
                if p > 0.5:
                    test_state = state.copy()
                    test_state[i] = player
                    if self.verify_winning_state(test_state):
                        D.apply_(lambda x: 0)
                        D[i] = 1.0
                        return D
        return D
    def __init__(self,
                 data: Tensor,
                 target: Tensor,
                 cut: Optional[Tuple[float, float]] = None,
                 classes: Optional[List[int]] = None,
                 reset_targets: bool = False) -> None:
        if cut:
            if not 0 <= cut[0] < cut[1] <= 1:
                raise ValueError
            length = data.size(0)
            start = round(length * cut[0])
            end = round(length * cut[1])
            data, target = data[start:end], target[start:end]

        if classes:
            idxs = (target == classes[0])
            for cls in classes[1:]:
                idxs = idxs | (target == cls)
            data, target = data[idxs], target[idxs]

            if reset_targets:
                mapping = {c: i for (i, c) in enumerate(classes)}
                target.apply_(lambda i: mapping[i])

        self.data, self.target = data, target
Beispiel #3
0
 def forward(self, game_state: torch.Tensor) -> torch.Tensor:
     """
     Take the given state and forward it through the network. Return the output of the network
     :param game_state: input to the model
     :return: output from the model
     """
     game_state.apply_(lambda x: -1 if x == 2 else x)
     return self.model(game_state)
Beispiel #4
0
    def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
        levels_lst = list(self.quantization_levels)

        result = tensor2quantize.apply_(
            lambda x: quant_dequant_util(x, levels_lst))

        return result
Beispiel #5
0
    def quantize_APoT(self, tensor2quantize: Tensor):
        result = torch.tensor([])
        # map float_to_apot over tensor2quantize elements
        result = tensor2quantize.apply_(lambda x: float_to_apot(
            x, self.quantization_levels, self.level_indices))

        return result
Beispiel #6
0
    def dequantize(self, float2apot: Tensor):  # type: ignore[override]
        float2apot = float2apot.float()

        quantization_levels = self.quantization_levels
        level_indices = self.level_indices

        # map apot_to_float over tensor2quantize elements
        result = float2apot.apply_(lambda x: float(
            apot_to_float(x, quantization_levels, level_indices)))

        return result
Beispiel #7
0
    def quantize(self, tensor2quantize: Tensor):
        result = torch.tensor([])

        # map float_to_apot over tensor2quantize elements
        tensor2quantize = tensor2quantize.apply_(lambda x: float_to_apot(
            x, self.quantization_levels, self.level_indices, self.alpha))

        from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT

        result = TensorAPoT(self, tensor2quantize)

        return result