def shift_img(img, shft_int = 1):
    """
    Shifts the pixels of an object within a grayscale image either 'shft_int' 
    pixels to the left or right depending on the sign of 'shft_int'.
    
    Input: 'img' tuple with a torch tensor and an integer,
           'shft_int' no. of pixels to shift along x-axis
    Output: tuple with a shifted torch tensor and an unmodified integer
        
    """
    no_cols = img[0].shape[1]
    lst_col =  no_cols - 1
    col_sty = no_cols - abs(shft_int)

    # shift object to the left
    if shft_int < 0:
        shft_int = abs(shft_int)
        col_idx = torch.cat([torch.ones(shft_int, dtype = torch.bool),
                             torch.zeros(col_sty, dtype = torch.bool)])
        cols = torch.reshape(img[0][0,:,col_idx], (no_cols,shft_int))
        cols_sum = torch.sum(cols)
        inval_shft = torch.is_nonzero(cols_sum)
        if inval_shft:
            raise ValueError('Consider shifting to the right for this image.')
        mod_img = torch.cat([img[0][0,:,~col_idx],cols], dim = 1)
        mod_img = torch.reshape(mod_img, (1,mod_img.shape[0], mod_img.shape[1]))
        mod_img = (mod_img,img[1])
        return mod_img
    
    # shift object to right
    col_idx = torch.cat([torch.zeros(col_sty, dtype = torch.bool),
                         torch.ones(shft_int, dtype = torch.bool)])
    cols = torch.reshape(img[0][0,:,col_idx], (no_cols,shft_int))
    cols_sum = torch.sum(cols)
    inval_shft = torch.is_nonzero(cols_sum)
    if inval_shft:
        raise ValueError('Consider shifting to the left for this image.')
    
    mod_img = torch.cat([cols,img[0][0,:,~col_idx]], dim = 1)
    mod_img = torch.reshape(mod_img, (1,mod_img.shape[0], mod_img.shape[1]))
    mod_img = (mod_img,img[1])
    
    return mod_img

### CHECK
#
#img = train_data[0]
#plot_it(img)
#plot_it(shift_img(img, shft_int = 5))
#plot_it(shift_img(img, shft_int = -3))
#plot_it(shift_img(img, shft_int = 3))
Exemple #2
0
 def compute_stats(pred, target):
     num_classes = pred.shape[1]
     if pred.device.type == 'hpu':
         scores = torch.zeros(num_classes - 1,
                              device=torch.device("cpu"),
                              dtype=torch.float32)
     else:
         scores = torch.zeros(num_classes - 1,
                              device=pred.device,
                              dtype=torch.float32)
     for i in range(1, num_classes):
         if (target != i).all():
             # no foreground class
             _, _pred = torch.max(pred, 1)
             scores[i - 1] += 1 if (_pred != i).all() else 0
             continue
         _tp, _fp, _tn, _fn, _ = stat_scores(pred=pred,
                                             target=target,
                                             class_index=i)
         denom = (2 * _tp + _fp + _fn).to(torch.float)
         score_cls = (2 * _tp).to(torch.float) / denom if torch.is_nonzero(
             denom) else 0.0
         if pred.device.type == 'hpu':
             scores[i - 1] = scores[i - 1] + score_cls.item()
         else:
             scores[i - 1] += score_cls
     if pred.device.type == 'hpu':
         scores = scores.to(torch.device("hpu"))
     return scores
Exemple #3
0
def dice_score(
    pred: torch.Tensor,
    target: torch.Tensor,
    bg: bool = False,
    nan_score: float = 0.0,
    no_fg_score: float = 0.0,
    reduction: str = 'elementwise_mean',
) -> torch.Tensor:
    """
    .. deprecated::
        Use :func:`torchmetrics.functional.dice_score`. Will be removed in v1.4.0.
    """
    num_classes = pred.shape[1]
    bg = (1 - int(bool(bg)))
    scores = torch.zeros(num_classes - bg,
                         device=pred.device,
                         dtype=torch.float32)
    for i in range(bg, num_classes):
        if not (target == i).any():
            # no foreground class
            scores[i - bg] += no_fg_score
            continue

        tp, fp, tn, fn, sup = stat_scores(pred=pred,
                                          target=target,
                                          class_index=i)
        denom = (2 * tp + fp + fn).to(torch.float)
        # nan result
        score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(
            denom) else nan_score

        scores[i - bg] += score_cls
    return reduce(scores, reduction=reduction)
def assert_gradients_match_and_reset_gradient(ort_model,
                                              pt_model,
                                              none_pt_params=[],
                                              reset_gradient=True,
                                              rtol=1e-05,
                                              atol=1e-06):
    ort_named_params = list(ort_model.named_parameters())
    pt_named_params = list(pt_model.named_parameters())
    assert len(ort_named_params) == len(pt_named_params)

    for ort_named_param, pt_named_param in zip(ort_named_params,
                                               pt_named_params):
        ort_name, ort_param = ort_named_param
        pt_name, pt_param = pt_named_param

        assert pt_name in ort_name
        if pt_name in none_pt_params:
            assert pt_param.grad is None
            assert ort_param.grad is None or not torch.is_nonzero(
                torch.count_nonzero(ort_param.grad))
        else:
            assert_values_are_close(ort_param.grad,
                                    pt_param.grad,
                                    rtol=rtol,
                                    atol=atol)

        if reset_gradient:
            ort_param.grad = None
            pt_param.grad = None
Exemple #5
0
    def assert_gradients_match_and_reset_gradient(self,
                                                  ort_model,
                                                  pt_model,
                                                  none_pt_params=None,
                                                  reset_gradient=True,
                                                  rtol=1e-05,
                                                  atol=1e-06):
        if none_pt_params is None:
            none_pt_params = []
        ort_named_params = list(ort_model.named_parameters())
        pt_named_params = list(pt_model.named_parameters())
        self.assertEqual(len(ort_named_params), len(pt_named_params))

        for ort_named_param, pt_named_param in zip(ort_named_params,
                                                   pt_named_params):
            ort_name, ort_param = ort_named_param
            pt_name, pt_param = pt_named_param

            self.assertIn(pt_name, ort_name)
            if pt_name in none_pt_params:
                self.assertNotEmpty(pt_param.grad)
                if ort_param is not None:
                    self.assertFalse(
                        torch.is_nonzero(torch.count_nonzero(ort_param.grad)))
            else:
                self.assert_values_are_close(ort_param.grad,
                                             pt_param.grad,
                                             rtol=rtol,
                                             atol=atol)

            if reset_gradient:
                ort_param.grad = None
                pt_param.grad = None
Exemple #6
0
 def compute_stats(pred, target):
     num_classes = pred.shape[1]
     _bg = 1
     scores = torch.zeros(num_classes - _bg,
                          device=pred.device,
                          dtype=torch.float32)
     precision = torch.zeros(num_classes - _bg,
                             device=pred.device,
                             dtype=torch.float32)
     recall = torch.zeros(num_classes - _bg,
                          device=pred.device,
                          dtype=torch.float32)
     for i in range(_bg, num_classes):
         if not (target == i).any():
             # no foreground class
             _, _pred = torch.max(pred, 1)
             scores[i - _bg] += 1 if not (_pred == i).any() else 0
             recall[i - _bg] += 1 if not (_pred == i).any() else 0
             precision[i - _bg] += 1 if not (_pred == i).any() else 0
             continue
         _tp, _fp, _tn, _fn, _ = stat_scores(pred=pred,
                                             target=target,
                                             class_index=i)
         denom = (2 * _tp + _fp + _fn).to(torch.float)
         score_cls = (2 * _tp).to(torch.float) / denom if torch.is_nonzero(
             denom) else 0.0
         scores[i - _bg] += score_cls
     return scores
def dice_score(
    pred: torch.Tensor,
    target: torch.Tensor,
    bg: bool = False,
    nan_score: float = 0.0,
    no_fg_score: float = 0.0,
    reduction: str = 'elementwise_mean',
) -> torch.Tensor:
    """
    Compute dice score from prediction scores

    Args:
        pred: estimated probabilities
        target: ground-truth labels
        bg: whether to also compute dice for the background
        nan_score: score to return, if a NaN occurs during computation
        no_fg_score: score to return, if no foreground pixel was found in target
        reduction: a method to reduce metric score over labels.

            - ``'elementwise_mean'``: takes the mean (default)
            - ``'sum'``: takes the sum
            - ``'none'``: no reduction will be applied

    Return:
        Tensor containing dice score

    Example:

        >>> from pytorch_lightning.metrics.functional import dice_score
        >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
        ...                      [0.05, 0.85, 0.05, 0.05],
        ...                      [0.05, 0.05, 0.85, 0.05],
        ...                      [0.05, 0.05, 0.05, 0.85]])
        >>> target = torch.tensor([0, 1, 3, 2])
        >>> dice_score(pred, target)
        tensor(0.3333)

    """
    num_classes = pred.shape[1]
    bg = (1 - int(bool(bg)))
    scores = torch.zeros(num_classes - bg,
                         device=pred.device,
                         dtype=torch.float32)
    for i in range(bg, num_classes):
        if not (target == i).any():
            # no foreground class
            scores[i - bg] += no_fg_score
            continue

        tp, fp, tn, fn, sup = stat_scores(pred=pred,
                                          target=target,
                                          class_index=i)
        denom = (2 * tp + fp + fn).to(torch.float)
        # nan result
        score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(
            denom) else nan_score

        scores[i - bg] += score_cls
    return reduce(scores, reduction=reduction)
Exemple #8
0
    def compute(self):
        nom = self.logit_signs_sum
        denom = self.num
        self.reset()

        if torch.is_nonzero(denom):
            return (nom / denom).item()
        return float('nan')
 def get_val_batches(self, dataset: Dataset) -> list:
     val_batches = []
     for plot in self.plots:
         batch = [
             get_positive_example(dataset)
             for _ in range(plot['batch_size'])
         ]
         for _, label, _ in batch:
             assert torch.is_nonzero(label)
         val_batches.append(batch)
     return val_batches
Exemple #10
0
def dice_score(
    preds: Tensor,
    target: Tensor,
    bg: bool = False,
    nan_score: float = 0.0,
    no_fg_score: float = 0.0,
    reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
    """Compute dice score from prediction scores.

    Args:
        preds: estimated probabilities
        target: ground-truth labels
        bg: whether to also compute dice for the background
        nan_score: score to return, if a NaN occurs during computation
        no_fg_score: score to return, if no foreground pixel was found in target
        reduction: a method to reduce metric score over labels.

            - ``'elementwise_mean'``: takes the mean (default)
            - ``'sum'``: takes the sum
            - ``'none'`` or ``None``: no reduction will be applied

    Return:
        Tensor containing dice score

    Example:
        >>> from torchmetrics.functional import dice_score
        >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
        ...                      [0.05, 0.85, 0.05, 0.05],
        ...                      [0.05, 0.05, 0.85, 0.05],
        ...                      [0.05, 0.05, 0.05, 0.85]])
        >>> target = torch.tensor([0, 1, 3, 2])
        >>> dice_score(pred, target)
        tensor(0.3333)
    """
    num_classes = preds.shape[1]
    bg_inv = 1 - int(bg)
    scores = torch.zeros(num_classes - bg_inv, device=preds.device, dtype=torch.float32)
    for i in range(bg_inv, num_classes):
        if not (target == i).any():
            # no foreground class
            scores[i - bg_inv] += no_fg_score
            continue

        # TODO: rewrite to use general `stat_scores`
        tp, fp, _, fn, _ = _stat_scores(preds=preds, target=target, class_index=i)
        denom = (2 * tp + fp + fn).to(torch.float)
        # nan result
        score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else nan_score

        scores[i - bg_inv] += score_cls
    return reduce(scores, reduction=reduction)
def shift_image(img, shft_int = 1):
    """
    Shifts the pixels of a grayscale image to the right by shft_int if the shft_int 
    last columns of the tensor have zero entries only (i.e. black). If they cointain 
    any non-zeros, the pixels are shifted to the left by shft_int.
    Abortion if both cases occur.
    
    Input: 'img' tuple with a torch tensor and an integer,
           'shft_int' no. of cols to shift along x-axis
    Output: tuple with a shifted torch tensor and an unmodified integer
        
    """
    no_cols = img[0].shape[1]
    lst_col =  no_cols - 1
    col_sty = no_cols - shft_int 
    col_idx = torch.cat([torch.zeros(col_sty, dtype = torch.bool),
                         torch.ones(shft_int, dtype = torch.bool)])
    cols = torch.reshape(img[0][0,:,col_idx], (no_cols,shft_int))
    cols_sum = torch.sum(cols)
    inval_shft = torch.is_nonzero(cols_sum)

    if inval_shft:
        col_idx = torch.cat([torch.ones(shft_int, dtype = torch.bool),
                             torch.zeros(col_sty, dtype = torch.bool)])
        cols = torch.reshape(img[0][0,:,col_idx], (no_cols,shft_int))
        cols_sum = torch.sum(cols)
        inval_shft = torch.is_nonzero(cols_sum)
        if inval_shft:
            raise ValueError('Consider shifting along another axis.')
        mod_img = torch.cat([img[0][0,:,~col_idx],cols], dim = 1)
        mod_img = torch.reshape(mod_img, (1,mod_img.shape[0], mod_img.shape[1]))
        mod_img = (mod_img,img[1])
        return mod_img
    
    mod_img = torch.cat([cols,img[0][0,:,~col_idx]], dim = 1)
    mod_img = torch.reshape(mod_img, (1,mod_img.shape[0], mod_img.shape[1]))
    mod_img = (mod_img,img[1])
    return mod_img
Exemple #12
0
 def compute_stats(self, preds, target):
     scores = torch.zeros(self.n_class,
                          device=preds.device,
                          dtype=torch.float32)
     preds = torch.argmax(preds, dim=1)
     for i in range(1, self.n_class + 1):
         if (target != i).all():
             # no foreground class
             scores[i - 1] += 1 if (preds != i).all() else 0
             continue
         tp, fn, fp = self.get_stats(preds, target, i)
         denom = (2 * tp + fp + fn).to(torch.float)
         score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(
             denom) else 0.0
         scores[i - 1] += score_cls
     return scores
Exemple #13
0
    def compute_stats_brats(self, p, y):
        scores = torch.zeros(self.n_class,
                             device=p.device,
                             dtype=torch.float32)
        p = (torch.sigmoid(p) > 0.5).int()
        y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3
        y = torch.stack([y_wt, y_tc, y_et], dim=1)

        for i in range(self.n_class):
            p_i, y_i = p[:, i], y[:, i]
            if (y_i != 1).all():
                # no foreground class
                scores[i - 1] += 1 if (p_i != 1).all() else 0
                continue
            tp, fn, fp = self.get_stats(p_i, y_i, 1)
            denom = (2 * tp + fp + fn).to(torch.float)
            score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(
                denom) else 0.0
            scores[i - 1] += score_cls
        return scores
Exemple #14
0
    def optimize_tabular(self, agent, trajectory_buffer, update_target=False):
        with torch.no_grad():
            N = len(trajectory_buffer.buffer)
            inner_states, outer_states, actions, action_distributions, rewards, dones, next_inner_states, \
                next_outer_states = trajectory_buffer.sample(None, random_sample=False)

            PS_s = agent.concept_architecture(inner_states, outer_states)[0]
            concepts = PS_s.argmax(1).detach().cpu().numpy()

            next_PS_s = agent.concept_architecture(
                next_inner_states[-1, :].view(1, -1),
                next_outer_states[-1, :, :, :].unsqueeze(0))[0]
            next_concept = next_PS_s.argmax(1).detach().cpu().numpy()
            next_concepts = np.concatenate([concepts[1:], next_concept])

            PA_S, log_PA_S = agent.PA_S()
            HA_gS = -(PA_S * log_PA_S).sum(1)
            HA_S = (self.PS.view(-1) * HA_gS).sum()
            Alpha = agent.log_Alpha.exp().item()

            assert torch.isfinite(HA_S).all(), 'Alahuakbar'

            # PA_s = agent.second_level_architecture.actor(inner_states, outer_states)[0]
            ratios = PA_S[concepts,
                          actions] / action_distributions[np.arange(0, N),
                                                          actions]
            if self.clip_ratios:
                ratios = ratios.clip(0.0, 1.0)

            assert torch.isfinite(ratios).all(), 'Alahuakbar 1'

            Q = (1. - self.forgetting_factor) * agent.Q_table.detach().clone()
            C = (1. - self.forgetting_factor) * agent.C_table.detach().clone()

            assert torch.isfinite(Q).all(), 'Alahuakbar 2'
            assert torch.isfinite(C).all(), 'Alahuakbar 3'

            if N > 0:
                G = 0
                WIS_trajectory = 1
                for i in range(N - 1, -1, -1):
                    S, A, R, WIS_step, nS = concepts[i], actions[i], rewards[
                        i], ratios[i], next_concepts[i]
                    G = self.discount_factor * G + R
                    if self.MC_entropy:
                        dH = HA_gS[nS] - HA_S
                        G += self.discount_factor * Alpha * dH
                    C[S, A] = C[S, A] + WIS_trajectory
                    if torch.is_nonzero(C[S, A]):
                        assert torch.isfinite(C[S, A]), 'Infinity and beyond!'
                        Q[S, A] = Q[S, A] + (WIS_trajectory /
                                             C[S, A]) * (G - Q[S, A])
                        WIS_trajectory = WIS_trajectory * WIS_step
                        if self.clip_ratios:
                            WIS_trajectory = WIS_trajectory.clip(0.0, 10.0)
                    if not torch.is_nonzero(WIS_trajectory):
                        break

            dQ = (Q - agent.Q_table).pow(2).mean()

            agent.update_Q(Q, C)
            if update_target:
                agent.update_target(self.MC_update_rate)

            Pi = agent.Pi_table.detach().clone()
            log_Pi = torch.log(Pi)
            HA_gS = -(Pi * log_Pi).sum(1)
            HA_S = (self.PS.view(-1) * HA_gS).sum()

            assert torch.isfinite(HA_S).all(), 'Alahuakbar'

            # Optimize Alpha
            agent.update_Alpha(HA_S)

            # Optimize policy
            Alpha = agent.log_Alpha.exp().item()
            duals = (1e-3) * torch.ones_like(self.PS.view(-1, 1))
            found_policy = False
            iters_left = 8
            while not found_policy and iters_left > 0:
                Q_adjusted = (Q + Alpha * duals * log_Pi) / (1. + duals)
                Pi_new = torch.exp(Q_adjusted / (Alpha + 1e-10))
                Pi_new = Pi_new / Pi_new.sum(1, keepdim=True)
                log_Pi_new = torch.log(Pi_new + 1e-10)
                KL_div = (Pi_new * (log_Pi_new - log_Pi)).sum(1, keepdim=True)
                valid_policies = KL_div <= self.policy_divergence_limit
                if torch.all(valid_policies):
                    found_policy = True
                else:
                    iters_left -= 1
                    duals = 10**(1. - valid_policies.float()) * duals

            if found_policy:
                agent.update_policy(Pi_new)

            metrics = {
                'Q_change': dQ.item(),
                'entropy': HA_S.item(),
                'Alpha': Alpha,
                'found_policy': float(found_policy),
                'max_dual': duals.max().item(),
            }

            return metrics
Exemple #15
0
        def optimize_tabular(self, agent, trajectory_buffer): 
        with torch.no_grad():
            N = len(trajectory_buffer.buffer)
            inner_states, outer_states, actions, rewards, dones, next_inner_states, \
                next_outer_states = trajectory_buffer.sample(None, random_sample=False)
            
            PS_s = agent.concept_architecture(inner_states, outer_states)[0]
            concepts = PS_s.argmax(1).detach().cpu().numpy()
            
            next_PS_s = agent.concept_architecture(next_inner_states[-1,:].view(1,-1), next_outer_states[-1,:,:,:].unsqueeze(0))[0]
            next_concept = next_PS_s.argmax(1).detach().cpu().numpy()
            next_concepts = np.concatenate([concepts[1:], next_concept])
            
            PA_S, log_PA_S = agent.PA_S()
            HA_gS = -(PA_S * log_PA_S).sum(1)
            HA_S = (self.PS.view(-1) * HA_gS).sum()
            Alpha = agent.log_Alpha.exp().item()
            
            assert torch.isfinite(HA_S).all(), 'Alahuakbar'

            PA_s = agent.second_level_architecture.actor(inner_states, outer_states)[0]
            ratios = PA_S[concepts, actions] / PA_s[np.arange(0,N),actions]

            assert torch.isfinite(ratios).all(), 'Alahuakbar 1'

            Q = agent.Q_table.detach().clone()
            C = agent.C_table.detach().clone()
            Q0 = Q.clone()

            assert torch.isfinite(Q).all(), 'Alahuakbar 2'
            assert torch.isfinite(C).all(), 'Alahuakbar 3'

            if N > 0:
                G = 0
                WIS_trajectory = 1
                for i in range(N-1, -1, -1):
                    S, A, R, nS = concepts[i], actions[i], rewards[i], next_concepts[i]
                    dH = HA_gS[nS] - HA_S
                    G = self.discount_factor * G + R + self.discount_factor * Alpha * dH 
                    C[S,A] = C[S,A] + WIS_trajectory
                    if torch.is_nonzero(C[S,A]):
                        assert torch.isfinite(C[S,A]), 'Infinity and beyond!'
                        Q[S,A] = Q[S,A] + (WIS_trajectory/C[S,A]) * (G - Q[S,A])
                        agent.update_Q(Q, C)

                        PA_S = agent.PA_S()[0]
                        HA_gS = -(PA_S * log_PA_S).sum(1)
                        HA_S = (self.PS.view(-1) * HA_gS).sum()

                        WIS_step = PA_S[S,A] / PA_s[i,A]
                        WIS_trajectory = WIS_trajectory * WIS_step
                    if not torch.is_nonzero(WIS_trajectory):
                        break 

            agent.update_Q(Q, 0.9*C)
            dQ = (Q - Q0).pow(2).mean()            

            PA_S, log_PA_S = agent.PA_S()
            HA_gS = -(PA_S * log_PA_S).sum(1)
            HA_S = (self.PS.view(-1) * HA_gS).sum()

            assert torch.isfinite(HA_S).all(), 'Alahuakbar'

            # Optimize Alpha
            agent.update_Alpha(HA_S)

       
            metrics = {
                'Q_change': dQ.item(),
                'entropy': HA_S.item(),
                'Alpha': Alpha,
            }

            return metrics
Exemple #16
0
 def tensor_general_ops(self):
     a = torch.randn(4)
     b = torch.tensor([1.5])
     x = torch.ones((2, ))
     c = torch.randn(4, dtype=torch.cfloat)
     w = torch.rand(4, 4, 4, 4)
     v = torch.rand(4, 4, 4, 4)
     return len(
         # torch.is_tensor(a),
         # torch.is_storage(a),
         torch.is_complex(a),
         torch.is_conj(a),
         torch.is_floating_point(a),
         torch.is_nonzero(b),
         # torch.set_default_dtype(torch.float32),
         # torch.get_default_dtype(),
         # torch.set_default_tensor_type(torch.DoubleTensor),
         torch.numel(a),
         # torch.set_printoptions(),
         # torch.set_flush_denormal(False),
         # https://pytorch.org/docs/stable/tensors.html#tensor-class-reference
         # x.new_tensor([[0, 1], [2, 3]]),
         x.new_full((3, 4), 3.141592),
         x.new_empty((2, 3)),
         x.new_ones((2, 3)),
         x.new_zeros((2, 3)),
         x.is_cuda,
         x.is_quantized,
         x.is_meta,
         x.device,
         x.dim(),
         c.real,
         c.imag,
         # x.backward(),
         x.clone(),
         w.contiguous(),
         w.contiguous(memory_format=torch.channels_last),
         w.copy_(v),
         w.copy_(1),
         w.copy_(0.5),
         x.cpu(),
         # x.cuda(),
         # x.data_ptr(),
         x.dense_dim(),
         w.fill_diagonal_(0),
         w.element_size(),
         w.exponential_(),
         w.fill_(0),
         w.geometric_(0.5),
         a.index_fill(0, torch.tensor([0, 2]), 1),
         a.index_put_([torch.argmax(a)], torch.tensor(1.0)),
         a.index_put([torch.argmax(a)], torch.tensor(1.0)),
         w.is_contiguous(),
         c.is_complex(),
         w.is_conj(),
         w.is_floating_point(),
         w.is_leaf,
         w.is_pinned(),
         w.is_set_to(w),
         # w.is_shared,
         w.is_coalesced(),
         w.coalesce(),
         w.is_signed(),
         w.is_sparse,
         torch.tensor([1]).item(),
         x.log_normal_(),
         # x.masked_scatter_(),
         # x.masked_scatter(),
         # w.normal(),
         w.numel(),
         # w.pin_memory(),
         # w.put_(0, torch.tensor([0, 1], w)),
         x.repeat(4, 2),
         a.clamp_(0),
         a.clamp(0),
         a.clamp_min(0),
         a.hardsigmoid_(),
         a.hardsigmoid(),
         a.hardswish_(),
         a.hardswish(),
         a.hardtanh_(),
         a.hardtanh(),
         a.leaky_relu_(),
         a.leaky_relu(),
         a.relu_(),
         a.relu(),
         a.resize_as_(a),
         a.type_as(a),
         a._shape_as_tensor(),
         a.requires_grad_(False),
     )
Exemple #17
0
    def get_m(self, observations, comm_actions, prev_actions=None):

        #comm_rewards = K.zeros((observations.shape[0], observations.shape[1], 1),
        #                    dtype=observations.dtype, device=observations.device)

        if self.medium_type is 'obs_only':
            medium = K.zeros(
                (1, observations.shape[1], observations.shape[2] + 1),
                dtype=observations.dtype,
                device=observations.device)
        else:
            medium = K.zeros(
                (1, observations.shape[1],
                 observations.shape[2] + prev_actions.shape[2] + 1),
                dtype=observations.dtype,
                device=observations.device)

        granted_agent = comm_actions.argmax(dim=0)[:, 0]
        for i in range(self.num_agents):
            #if competitive_comm:
            #    comm_rewards[i, granted_agent == i, :] = 1
            if self.medium_type is 'obs_only':
                medium[:, granted_agent == i, :] = K.cat([
                    observations[[i], ][:, granted_agent == i, :],
                    (i + 1) * K.ones((1, (granted_agent == i).sum().item(), 1),
                                     dtype=observations.dtype,
                                     device=observations.device)
                ],
                                                         dim=-1)
            else:
                medium[:, granted_agent == i, :] = K.cat([
                    observations[[i], ][:, granted_agent == i, :],
                    prev_actions[[i], ][:, granted_agent == i, :],
                    (i + 1) * K.ones((1, (granted_agent == i).sum().item(), 1),
                                     dtype=observations.dtype,
                                     device=observations.device)
                ],
                                                         dim=-1)

        if K.is_nonzero(
            ((comm_actions < 0.00001).sum(dim=0) == self.num_agents)[:,
                                                                     0].sum()):

            #comm_rewards[:,((comm_actions>0.5).sum(dim=0) == 0)[:,0],:] = -1
            if self.medium_type is 'obs_only':
                medium[:, ((comm_actions < 0.00001).sum(
                    dim=0) == self.num_agents)[:, 0], :] = K.cat([
                        K.zeros((1, 1, observations.shape[2]),
                                dtype=observations.dtype,
                                device=observations.device),
                        K.zeros((1, 1, 1),
                                dtype=observations.dtype,
                                device=observations.device)
                    ],
                                                                 dim=-1)
            else:
                medium[:, ((comm_actions < 0.00001).sum(
                    dim=0) == self.num_agents)[:, 0], :] = K.cat([
                        K.zeros((1, 1, observations.shape[2]),
                                dtype=observations.dtype,
                                device=observations.device),
                        K.zeros((1, 1, prev_actions.shape[2]),
                                dtype=observations.dtype,
                                device=observations.device),
                        K.zeros((1, 1, 1),
                                dtype=observations.dtype,
                                device=observations.device)
                    ],
                                                                 dim=-1)

        if K.is_nonzero(
            ((comm_actions > 0.99999).sum(dim=0) == self.num_agents)[:,
                                                                     0].sum()):
            #comm_rewards[:,((comm_actions>0.5).sum(dim=0) > 1)[:,0],:] = -1
            if self.medium_type is 'obs_only':
                medium[:, ((comm_actions > 0.99999).sum(
                    dim=0) == self.num_agents)[:, 0], :] = K.cat([
                        K.zeros((1, 1, observations.shape[2]),
                                dtype=observations.dtype,
                                device=observations.device),
                        K.ones(
                            (1, 1, 1),
                            dtype=observations.dtype,
                            device=observations.device) * (self.num_agents + 1)
                    ],
                                                                 dim=-1)
            else:
                medium[:, ((comm_actions > 0.99999).sum(
                    dim=0) == self.num_agents)[:, 0], :] = K.cat([
                        K.zeros((1, 1, observations.shape[2]),
                                dtype=observations.dtype,
                                device=observations.device),
                        K.zeros((1, 1, prev_actions.shape[2]),
                                dtype=observations.dtype,
                                device=observations.device),
                        K.ones(
                            (1, 1, 1),
                            dtype=observations.dtype,
                            device=observations.device) * (self.num_agents + 1)
                    ],
                                                                 dim=-1)

        return K.tensor(medium, requires_grad=True)