Пример #1
0
    def construct(self, x, rewards):
        """compute the perturbed gradients for parameters."""
        # Choose optimal action
        x_transpose = self.transpose(x, (1, 0))
        scores_a = self.squeeze(
            self.matmul(x, self.expand_dims(self._theta, 1)))
        scores_b = x_transpose * self.matmul(self._Vc_inv, x_transpose)
        scores_b = self.reduce_sum(scores_b, 0)
        scores = scores_a + self._beta * scores_b
        max_a = self.argmax(scores)
        xa = x[max_a]
        xaxat = self.matmul(self.expand_dims(xa, -1), self.expand_dims(xa, 0))
        y = rewards[max_a]
        y_max = self.reduce_max(rewards)
        y_diff = y_max - y
        self._current_regret = float(y_diff.asnumpy())
        self._regret += self._current_regret

        # Prepare noise
        B = np.random.normal(0, self._sigma, size=xaxat.shape)
        B = np.triu(B)
        B += B.transpose() - np.diag(B.diagonal())
        B = Tensor(B.astype(np.float32))
        Xi = np.random.normal(0, self._sigma, size=xa.shape)
        Xi = Tensor(Xi.astype(np.float32))

        # Add noise and update parameters
        return xaxat + B, xa * y + Xi, max_a
Пример #2
0
def to_tensor(obj, dtype=None):
    if dtype is None:
        res = Tensor(obj)
        if res.dtype == mnp.float64:
            res = res.astype(mnp.float32)
        if res.dtype == mnp.int64:
            res = res.astype(mnp.int32)
    else:
        res = Tensor(obj, dtype)
    return res
Пример #3
0
    class Net(nn.Cell):
        def __init__(self):
            super(Net, self).__init__()
            self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)

        def construct(self):
            return self.value.astype("float88")
Пример #4
0
    class Net(nn.Cell):
        def __init__(self):
            super(Net, self).__init__()
            self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.int64)

        def construct(self):
            return self.value.astype(mstype.bool_)
Пример #5
0
def create_network(name, thres_filename, *args, **kwargs):
    if name == 'resnet20':
        thres = np.load(thres_filename)
        thres = Tensor(thres.astype(np.float32))
        return resnet20(thres=thres)
    raise NotImplementedError(f"{name} is not implemented in the repo")