Beispiel #1
0
 def _graph_fn_unsquash(self, values):
     if get_backend() == "tf":
         return tf.atanh((values - self.low) /
                         (self.high - self.low) * 2.0 - 1.0)
     elif get_backend() == "tf":
         return torch.atanh((values - self.low) /
                            (self.high - self.low) * 2.0 - 1.0)
Beispiel #2
0
    def get_log_prob(self, state: rlt.FeatureData,
                     squashed_action: torch.Tensor):
        """
        Action is expected to be squashed with tanh
        """
        if self.use_l2_normalization:
            # TODO: calculate log_prob for l2 normalization
            # https://math.stackexchange.com/questions/3120506/on-the-distribution-of-a-normalized-gaussian-vector
            # http://proceedings.mlr.press/v100/mazoure20a/mazoure20a.pdf
            pass

        loc, scale_log = self._get_loc_and_scale_log(state)
        raw_action = torch.atanh(squashed_action)
        r = (raw_action - loc) / scale_log.exp()
        log_prob = self._normal_log_prob(r, scale_log)
        squash_correction = self._squash_correction(squashed_action)
        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram("actor/get_log_prob/loc",
                                               loc.detach().cpu())
            SummaryWriterContext.add_histogram("actor/get_log_prob/scale_log",
                                               scale_log.detach().cpu())
            SummaryWriterContext.add_histogram("actor/get_log_prob/log_prob",
                                               log_prob.detach().cpu())
            SummaryWriterContext.add_histogram(
                "actor/get_log_prob/squash_correction",
                squash_correction.detach().cpu())
        return torch.sum(log_prob - squash_correction, dim=1).reshape(-1, 1)
Beispiel #3
0
def sumlogC(x, eps=1e-5):
    """
    Numerically stable implementation of sum of logarithm of Continuous Bernoulli constant C
    Returns log normalising constant for x in (0, x-eps) and (x+eps, 1)
    Uses Taylor 3rd degree approximation in [x-eps, x+eps].

    Parameter
    ----------
    x : `torch.Tensor`
        [batch_size x C x H x W]. x takes values in (0,1).
    """
    # clip x such that x in (0, 1)
    x = torch.clamp(x, eps, 1. - eps)
    # get mask if x is not in [0.5-eps, 0.5+eps]
    mask = torch.abs(x - .5).ge(eps)
    # points that are (0, 0.5-eps) and (0.5+eps, 1)
    far = x[mask]
    # points that are [0.5-eps, 0.5+eps]
    close = x[~mask]
    # Given by log(|2tanh^-1(1-2x)|) - log(|1-2x|)
    far_values = torch.log(torch.abs(
        2. * torch.atanh(1 - 2. * far))) - torch.log(torch.abs(1 - 2. * far))
    # using Taylor expansion to 3rd degree
    close_values = torch.log(2. + (1 - 2 * close).pow(2) / 3 +
                             (1 - 2 * close).pow(4) / 5)
    return far_values.sum() + close_values.sum()
Beispiel #4
0
 def log_prob(self, value: torch.Tensor) -> torch.Tensor:
     value = torch.clamp(value, -1 + 1e-6, 1 - 1e-6)
     # Log likelihood for Gaussian with change of variable
     log_prob = super().log_prob(torch.atanh(value))
     # Squash correction (from original SAC implementation)
     # this comes from the fact that tanh is bijective and differentiable
     log_prob -= torch.sum(torch.log(1 - value**2), dim=1)
     return log_prob
Beispiel #5
0
 def logp(self, x, ac):
     if self.use_squashing:
         ac = torch.atanh(ac)
         m = self.__act_distribution(x)
         logp_pi = m.log_prob(ac).sum(axis=1)
         logp_pi -= (2 * (np.log(2) - ac - F.softplus(-2 * ac))).sum(axis=1)
         return logp_pi
     else:
         m = self.__act_distribution(x)
         return m.log_prob(ac).sum(axis=1)
 def evaluation(self, state, action):
     mean, log_std = self.forward(state)
     std = log_std.exp()
     normal = Normal(mean, std)
     y_t = (action - self.action_bias) / self.action_scale
     x_t = torch.atanh(y_t)
     log_prob = normal.log_prob(x_t)
     log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
     log_prob = log_prob.sum(1, keepdim=True)
     return log_prob
Beispiel #7
0
def cosine_loss(a, v, y):
#     d = nn.functional.cosine_similarity(a, v)
#     print(y)
    a = a / a.norm(p=2, dim=1, keepdim=True)
    v = v / v.norm(p=2, dim=1, keepdim=True)
    cosim = a.matmul(v.transpose(0, 1))
    cosim = cosim[torch.eye(len(cosim)).bool()].unsqueeze(1)
    cosim = torch.atanh(cosim)
    loss = logloss(F.sigmoid(cosim), y)
    return loss
Beispiel #8
0
    def _compute_actor_loss(self, observation, action, weight):
        dist = self.policy.dist(observation)

        # unnormalize action via inverse tanh function
        unnormalized_action = torch.atanh(action.clamp(-0.999999, 0.999999))

        # compute log probability
        _, log_probs = squash_action(dist, unnormalized_action)

        return -(weight * log_probs).mean()
Beispiel #9
0
 def forward(self, z, inverse_mode=False, context=None):
     if inverse_mode:
         squashed = 2 * (z - self.low) / self.range - 1
         return torch.atanh(
             torch.clip(z, -1 + SMALL_NUMBER,
                        1 - SMALL_NUMBER)), -self._log_abs_det_jac(squashed)
     else:
         squashed = torch.tanh(z)
         squashed01 = (squashed + 1) / 2
         return squashed01 * self.range + self.low, self._log_abs_det_jac(
             squashed)
Beispiel #10
0
    def get_logprob(self, state, action):
        mean, log_std = self.forward(state)

        std = log_std.exp()
        normal = Normal(mean, std)

        x_t = torch.atanh(action)
        y_t = torch.tanh(x_t)
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(1 - y_t.pow(2) + 1e-6)

        return log_prob
Beispiel #11
0
    def reduce_tanh(self, nodes):
        w = torch.exp(self.w)
        msg = w * nodes.mailbox['m'] + self.b
        """ print("MSG max: ", torch.max(msg))
        print("MSG min: ", torch.min(msg)) """

        fsum = torch.clamp(torch.sum(torch.tanh(msg), dim=1), -0.99, 0.99)
        """ print("max: ", torch.max(fsum))
        print("min: ", torch.min(fsum)) """

        out_h = (torch.atanh(fsum) - self.b) / w
        return {'neigh': out_h}
Beispiel #12
0
    def _compute_actor_loss(self, obs_t, act_t):
        dist = self.policy.dist(obs_t)

        # unnormalize action via inverse tanh function
        unnormalized_act_t = torch.atanh(act_t.clamp(-0.999999, 0.999999))

        # compute log probability
        _, log_probs = squash_action(dist, unnormalized_act_t)

        # compute exponential weight
        weights = self._compute_weights(obs_t, act_t)

        return -(log_probs * weights).sum()
Beispiel #13
0
    def reduce_tanh(self, nodes):
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        alpha = F.dropout(alpha, self.dropout, training=self.training)
        msg = (alpha * nodes.mailbox['z'])

        w = torch.exp(self.w)
        msg = w * msg + self.b

        fsum = torch.clamp(torch.sum(torch.tanh(msg), dim=1), -0.9999999,
                           0.9999999)
        """ print("max: ", torch.max(fsum))
        print("min: ", torch.min(fsum)) """

        out_h = (torch.atanh(fsum) - self.b) / w
        return {'h': out_h}
Beispiel #14
0
def to_tanh_space(x, box):
    # type: (Union[Variable, torch.FloatTensor], Tuple[float, float]) -> Union[Variable, torch.FloatTensor]
    """
    Convert a batch of tensors to tanh-space. This method complements the
    implementation of the change-of-variable trick in terms of tanh.

    :param x: the batch of tensors, of dimension [B x n_features]
    :param box: a tuple of lower bound and upper bound of the box constraint
    :return: the batch of tensors in tanh-space, of the same dimension;
             the returned tensor is on the same device as ``x``
    """

    _box_mul = (box[1] - box[0]) * 0.5
    _box_plus = (box[1] + box[0]) * 0.5
    return torch.atanh((x - _box_plus) / _box_mul) * 1e4
Beispiel #15
0
    def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
        assert self._policy is not None

        dist = self._policy.dist(batch.observations)

        # unnormalize action via inverse tanh function
        clipped_actions = batch.actions.clamp(-0.999999, 0.999999)
        unnormalized_act_t = torch.atanh(clipped_actions)

        # compute log probability
        _, log_probs = squash_action(dist, unnormalized_act_t)

        weight = self._compute_weight(batch.observations, batch.actions)

        return -(log_probs * weight).mean()
    def _graph_fn_unsquash(self, values):
        """
        Reverse operation as _graph_fn_squash (using argus tanh).

        Args:
            values (DataOp): The values to unsquash.

        Returns:
            The unsquashed values.
        """
        if get_backend() == "tf":
            return tf.atanh((values - self.low) /
                            (self.high - self.low) * 2.0 - 1.0)
        elif get_backend() == "tf":
            return torch.atanh((values - self.low) /
                               (self.high - self.low) * 2.0 - 1.0)
Beispiel #17
0
 def get_repara_action(self, obs, action):
     # Generate the latent feature
     phi = self.feature_net(obs)
     # Reparameterize action with epsilon
     action_mean, action_std, _ = self.actor_net.distribution(phi)
     if isinstance(self.actor_net, MLPSquashedGaussianActor):
         action = torch.clamp(action / self.actor_net.action_lim, -0.999,
                              0.999)
         u = torch.atanh(action)
         eps = (u - action_mean) / action_std
         repara_action = self.actor_net.action_lim * torch.tanh(
             action_mean + action_std * eps.detach())
     else:
         eps = (action - action_mean) / action_std
         repara_action = action_mean + action_std * eps.detach()
     return repara_action
Beispiel #18
0
 def forward(self, phi, action=None, deterministic=False):
     # Compute action distribution and the log_pi of given actions
     action_mean, action_std, action_distribution = self.distribution(phi)
     if action is None:
         if deterministic:
             u = action_mean
         else:
             u = action_distribution.rsample(
             ) if self.rsample else action_distribution.sample()
         action = self.action_lim * torch.tanh(u)
     else:
         u = torch.clamp(action / self.action_lim, -0.999, 0.999)
         u = torch.atanh(u)
     # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
     log_pi = self.log_pi_from_distribution(action_distribution, u)
     return action, action_mean, action_std, log_pi
Beispiel #19
0
    def _compute_actor_loss(  # type: ignore
        self, obs_t: torch.Tensor, act_t: torch.Tensor
    ) -> torch.Tensor:
        assert self._policy is not None

        dist = self._policy.dist(obs_t)

        # unnormalize action via inverse tanh function
        unnormalized_act_t = torch.atanh(act_t.clamp(-0.999999, 0.999999))

        # compute log probability
        _, log_probs = squash_action(dist, unnormalized_act_t)

        weight = self._compute_weight(obs_t, act_t)

        return -(log_probs * weight).mean()
Beispiel #20
0
    def test_mod_mapping(self):
        x, y, z = sympy.symbols("x y z")
        expression = x ** 2 + sympy.atanh(sympy.Mod(y + 1, 2) - 1) * 3.2 * z

        module = sympy2torch(expression, [x, y, z])

        X = torch.rand(100, 3).float() * 10

        true_out = (
            X[:, 0] ** 2 + torch.atanh(torch.fmod(X[:, 1] + 1, 2) - 1) * 3.2 * X[:, 2]
        )
        torch_out = module(X)

        np.testing.assert_array_almost_equal(
            true_out.detach(), torch_out.detach(), decimal=4
        )
Beispiel #21
0
    def rnd2rgb(self, rnd, clip=False):
        rnd = torch.atanh(rnd)
        rnd_arr = torch.zeros(rnd.shape, device=rnd.get_device())

        if not self.linear:
            vals = self.distortion + 1e-4
            eta = vals[4]

            L = rnd_arr[:, 0:1, ]
            a = rnd_arr[:, 1:2, ]
            b = rnd_arr[:, 2:3, ]
            y = (L + vals[1]) / vals[0]
            x = (a / vals[2]) + y
            z = y - (b / vals[3])

            rnd_arr[:, 0:1, ] = x
            rnd_arr[:, 1:2, ] = y
            rnd_arr[:, 2:3, ] = z

            mask = rnd_arr > eta
            rnd_arr[mask] = rnd_arr[mask].pow(3.)
            rnd_arr[~mask] = (rnd_arr[~mask] - (vals[1] / vals[0])) * 3 * (
                    eta ** 2)
        else:
            for i in range(3):
                rnd_arr[:, i:i + 1, ] = rnd[:, i:i + 1, ]

        # rescale to the reference white (illuminant)
        ref_white = self.ref_white + 1e-4
        white_arr = torch.zeros(rnd.shape, device=rnd.get_device())
        for i in range(3):
            white_arr[:, i:i + 1, ] = rnd_arr[:, i:i + 1, ] * ref_white[i]

        rgb = torch.zeros(rnd.shape, device=rnd.get_device())
        rgb_transform = torch.inverse(self.trans_mat)
        for i in range(3):
            x_r = white_arr[:, 0:1, ] * rgb_transform[i, 0]
            y_g = white_arr[:, 1:2, ] * rgb_transform[i, 1]
            z_b = white_arr[:, 2:3, ] * rgb_transform[i, 2]
            rgb[:, i:i + 1, ] = x_r + y_g + z_b

        if clip:
            rgb[torch.isnan(rgb)] = 0
            rgb = (rgb * 0.5) + 0.5
            rgb[rgb < 0] = 0
            rgb[rgb > 1] = 1
        return rgb
Beispiel #22
0
    def _compute_actor_loss(
        self,
        observation: torch.Tensor,
        action: torch.Tensor,
        weight: torch.Tensor,
    ) -> torch.Tensor:
        assert self._policy is not None

        dist = self._policy.dist(observation)

        # unnormalize action via inverse tanh function
        unnormalized_action = torch.atanh(action.clamp(-0.999999, 0.999999))

        # compute log probability
        _, log_probs = squash_action(dist, unnormalized_action)

        return -(weight * log_probs).mean()
Beispiel #23
0
    def rnd2rgb(self, y):
        for i in range(self.layers - 1, -1, -1):
            trans_mat = getattr(self, 't%.3d' % i)[0].weight.detach().squeeze()
            trans_mat = torch.inverse(trans_mat)

            y = torch.atanh(y)
            if self.bias:
                bias_vec = getattr(self, 't%.3d' % i)[0].bias.detach().squeeze()
                for i in range(3):
                    y[:, i, ] -= bias_vec[i]
            device = y.get_device()
            device = 'cpu' if device == -1 else device
            x = torch.zeros(y.shape, device=device)
            for i in range(3):
                x_r = y[:, 0:1, ] * trans_mat[i, 0]
                y_g = y[:, 1:2, ] * trans_mat[i, 1]
                z_b = y[:, 2:3, ] * trans_mat[i, 2]
                x[:, i:i + 1, ] = x_r + y_g + z_b
            y = x.clone()
        return x
Beispiel #24
0
 def forward(self, sample):
     return torch.atanh(sample / self.a) / self.b
Beispiel #25
0
def atanh(x):
    return _torch.atanh(x)
Beispiel #26
0
 def pointwise_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
     f = torch.zeros(3)
     g = torch.tensor([-1, 0, 1])
     w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
     return (
         torch.abs(torch.tensor([-1, -2, 3])),
         torch.absolute(torch.tensor([-1, -2, 3])),
         torch.acos(a),
         torch.arccos(a),
         torch.acosh(a.uniform_(1.0, 2.0)),
         torch.add(a, 20),
         torch.add(a, torch.randn(4, 1), alpha=10),
         torch.addcdiv(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.addcmul(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.angle(a),
         torch.asin(a),
         torch.arcsin(a),
         torch.asinh(a),
         torch.arcsinh(a),
         torch.atan(a),
         torch.arctan(a),
         torch.atanh(a.uniform_(-1.0, 1.0)),
         torch.arctanh(a.uniform_(-1.0, 1.0)),
         torch.atan2(a, a),
         torch.bitwise_not(t),
         torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.ceil(a),
         torch.clamp(a, min=-0.5, max=0.5),
         torch.clamp(a, min=0.5),
         torch.clamp(a, max=0.5),
         torch.clip(a, min=-0.5, max=0.5),
         torch.conj(a),
         torch.copysign(a, 1),
         torch.copysign(a, b),
         torch.cos(a),
         torch.cosh(a),
         torch.deg2rad(
             torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0,
                                                              -90.0]])),
         torch.div(a, b),
         torch.divide(a, b, rounding_mode="trunc"),
         torch.divide(a, b, rounding_mode="floor"),
         torch.digamma(torch.tensor([1.0, 0.5])),
         torch.erf(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
         torch.exp(torch.tensor([0.0, math.log(2.0)])),
         torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
         torch.expm1(torch.tensor([0.0, math.log(2.0)])),
         torch.fake_quantize_per_channel_affine(
             torch.randn(2, 2, 2),
             (torch.randn(2) + 1) * 0.05,
             torch.zeros(2),
             1,
             0,
             255,
         ),
         torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
         torch.float_power(torch.randint(10, (4, )), 2),
         torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4,
                                                             -5])),
         torch.floor(a),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
         torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
         torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.frac(torch.tensor([1.0, 2.5, -3.2])),
         torch.randn(4, dtype=torch.cfloat).imag,
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
         torch.lerp(torch.arange(1.0, 5.0),
                    torch.empty(4).fill_(10), 0.5),
         torch.lerp(
             torch.arange(1.0, 5.0),
             torch.empty(4).fill_(10),
             torch.full_like(torch.arange(1.0, 5.0), 0.5),
         ),
         torch.lgamma(torch.arange(0.5, 2, 0.5)),
         torch.log(torch.arange(5) + 10),
         torch.log10(torch.rand(5)),
         torch.log1p(torch.randn(5)),
         torch.log2(torch.rand(5)),
         torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logical_and(r, s),
         torch.logical_and(r.double(), s.double()),
         torch.logical_and(r.double(), s),
         torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
         torch.logical_not(
             torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
         torch.logical_not(
             torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
             out=torch.empty(3, dtype=torch.int16),
         ),
         torch.logical_or(r, s),
         torch.logical_or(r.double(), s.double()),
         torch.logical_or(r.double(), s),
         torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_xor(r, s),
         torch.logical_xor(r.double(), s.double()),
         torch.logical_xor(r.double(), s),
         torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logit(torch.rand(5), eps=1e-6),
         torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
         torch.i0(torch.arange(5, dtype=torch.float32)),
         torch.igamma(a, b),
         torch.igammac(a, b),
         torch.mul(torch.randn(3), 100),
         torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
         torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
         torch.tensor([float("nan"),
                       float("inf"), -float("inf"), 3.14]),
         torch.nan_to_num(w),
         torch.nan_to_num(w, nan=2.0),
         torch.nan_to_num(w, nan=2.0, posinf=1.0),
         torch.neg(torch.randn(5)),
         # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
         torch.polygamma(1, torch.tensor([1.0, 0.5])),
         torch.polygamma(2, torch.tensor([1.0, 0.5])),
         torch.polygamma(3, torch.tensor([1.0, 0.5])),
         torch.polygamma(4, torch.tensor([1.0, 0.5])),
         torch.pow(a, 2),
         torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
         torch.rad2deg(
             torch.tensor([[3.142, -3.142], [6.283, -6.283],
                           [1.570, -1.570]])),
         torch.randn(4, dtype=torch.cfloat).real,
         torch.reciprocal(a),
         torch.remainder(torch.tensor([-3.0, -2.0]), 2),
         torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.round(a),
         torch.rsqrt(a),
         torch.sigmoid(a),
         torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sgn(a),
         torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sin(a),
         torch.sinc(a),
         torch.sinh(a),
         torch.sqrt(a),
         torch.square(a),
         torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
         torch.tan(a),
         torch.tanh(a),
         torch.trunc(a),
         torch.xlogy(f, g),
         torch.xlogy(f, g),
         torch.xlogy(f, 4),
         torch.xlogy(2, g),
     )
 def forward(x):
     return TanhInv.alpha*torch.atanh(x)
Beispiel #28
0
def l2_attack(input,
              target,
              model,
              targeted,
              use_log,
              use_tanh,
              solver,
              reset_adam_after_found=True,
              abort_early=True,
              batch_size=128,
              max_iter=1000,
              const=0.01,
              confidence=0.0,
              early_stop_iters=100,
              binary_search_steps=9,
              step_size=0.01,
              adam_beta1=0.9,
              adam_beta2=0.999):

    early_stop_iters = early_stop_iters if early_stop_iters != 0 else max_iter // 10

    input = torch.from_numpy(input).cuda()
    target = torch.from_numpy(target).cuda()

    var_len = input.view(-1).size()[0]
    modifier_up = np.zeros(var_len, dtype=np.float32)
    modifier_down = np.zeros(var_len, dtype=np.float32)
    real_modifier = torch.zeros(input.size(), dtype=torch.float32).cuda()
    mt = np.zeros(var_len, dtype=np.float32)
    vt = np.zeros(var_len, dtype=np.float32)
    adam_epoch = np.ones(var_len, dtype=np.int32)
    grad = np.zeros(batch_size, dtype=np.float32)
    hess = np.zeros(batch_size, dtype=np.float32)

    upper_bound = 1e10
    lower_bound = 0.0
    out_best_attack = input.clone().detach().cpu().numpy()
    out_best_const = const
    out_bestl2 = 1e10
    out_bestscore = -1

    if use_tanh:
        input = torch.atanh(input * 1.99999)

    if not use_tanh:
        modifier_up = 0.5 - input.clone().detach().view(-1).cpu().numpy()
        modifier_down = -0.5 - input.clone().detach().view(-1).cpu().numpy()

    def compare(x, y):
        if not isinstance(x, (float, int, np.int64)):
            if targeted:
                x[y] -= confidence
            else:
                x[y] += confidence
            x = np.argmax(x)
        if targeted:
            return x == y
        else:
            return x != y

    for step in range(binary_search_steps):
        bestl2 = 1e10
        prev = 1e6
        bestscore = -1
        last_loss2 = 1.0
        # reset ADAM status
        mt.fill(0)
        vt.fill(0)
        adam_epoch.fill(1)
        stage = 0

        for iter in range(max_iter):
            if (iter + 1) % 100 == 0:
                loss, l2, loss2, _, __ = loss_run(input, target, model,
                                                  real_modifier, use_tanh,
                                                  use_log, targeted,
                                                  confidence, const)
                print(
                    "[STATS][L2] iter = {}, loss = {:.5f}, loss1 = {:.5f}, loss2 = {:.5f}"
                    .format(iter + 1, loss[0], l2[0], loss2[0]))
                sys.stdout.flush()

            var_list = np.array(range(0, var_len), dtype=np.int32)
            indice = var_list[np.random.choice(var_list.size,
                                               batch_size,
                                               replace=False)]
            var = np.repeat(real_modifier.detach().cpu().numpy(),
                            batch_size * 2 + 1,
                            axis=0)
            for i in range(batch_size):
                var[i * 2 + 1].reshape(-1)[indice[i]] += 0.0001
                var[i * 2 + 2].reshape(-1)[indice[i]] -= 0.0001
            var = torch.from_numpy(var)
            var = var.view((-1, ) + input.size()[1:]).cuda()
            losses, l2s, losses2, scores, pert_images = loss_run(
                input, target, model, var, use_tanh, use_log, targeted,
                confidence, const)
            real_modifier_numpy = real_modifier.clone().detach().cpu().numpy()
            if solver == "adam":
                coordinate_ADAM(losses,
                                indice,
                                grad,
                                hess,
                                batch_size,
                                mt,
                                vt,
                                real_modifier_numpy,
                                adam_epoch,
                                modifier_up,
                                modifier_down,
                                step_size,
                                adam_beta1,
                                adam_beta2,
                                proj=not use_tanh)
            if solver == "newton":
                coordinate_Newton(losses,
                                  indice,
                                  grad,
                                  hess,
                                  batch_size,
                                  mt,
                                  vt,
                                  real_modifier_numpy,
                                  adam_epoch,
                                  modifier_up,
                                  modifier_down,
                                  step_size,
                                  adam_beta1,
                                  adam_beta2,
                                  proj=not use_tanh)
            real_modifier = torch.from_numpy(real_modifier_numpy).cuda()

            if losses2[0] == 0.0 and last_loss2 != 0.0 and stage == 0:
                if reset_adam_after_found:
                    mt.fill(0)
                    vt.fill(0)
                    adam_epoch.fill(1)
                stage = 1
            last_loss2 = losses2[0]

            if abort_early and (iter + 1) % early_stop_iters == 0:
                if losses[0] > prev * .9999:
                    print("Early stopping because there is no improvement")
                    break
                prev = losses[0]

            if l2s[0] < bestl2 and compare(
                    scores[0], np.argmax(target.cpu().numpy(), -1)):
                bestl2 = l2s[0]
                bestscore = np.argmax(scores[0])

            if l2s[0] < out_bestl2 and compare(
                    scores[0], np.argmax(target.cpu().numpy(), -1)):
                if out_bestl2 == 1e10:
                    print(
                        "[STATS][L3](First valid attack found!) iter = {}, loss = {:.5f}, loss1 = {:.5f}, loss2 = {:.5f}"
                        .format(iter + 1, losses[0], l2s[0], losses2[0]))
                    sys.stdout.flush()
                out_bestl2 = l2s[0]
                out_bestscore = np.argmax(scores[0])
                out_best_attack = pert_images[0]
                out_best_const = const

        if compare(bestscore, np.argmax(target.cpu().numpy(),
                                        -1)) and bestscore != -1:
            print('old constant: ', const)
            upper_bound = min(upper_bound, const)
            if upper_bound < 1e9:
                const = (lower_bound + upper_bound) / 2
            print('new constant: ', const)
        else:
            print('old constant: ', const)
            lower_bound = max(lower_bound, const)
            if upper_bound < 1e9:
                const = (lower_bound + upper_bound) / 2
            else:
                const *= 10
            print('new constant: ', const)

    return out_best_attack, out_bestscore
Beispiel #29
0
 def _inverse(self, y):
     # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
     # one should use `cache_size=1` instead
     return torch.atanh(y)
Beispiel #30
0
 def backward(self, y):
     self._low = torch.min(y)
     self._high = torch.max(y)
     return torch.atanh(y)