def shift_img(img, shft_int = 1): """ Shifts the pixels of an object within a grayscale image either 'shft_int' pixels to the left or right depending on the sign of 'shft_int'. Input: 'img' tuple with a torch tensor and an integer, 'shft_int' no. of pixels to shift along x-axis Output: tuple with a shifted torch tensor and an unmodified integer """ no_cols = img[0].shape[1] lst_col = no_cols - 1 col_sty = no_cols - abs(shft_int) # shift object to the left if shft_int < 0: shft_int = abs(shft_int) col_idx = torch.cat([torch.ones(shft_int, dtype = torch.bool), torch.zeros(col_sty, dtype = torch.bool)]) cols = torch.reshape(img[0][0,:,col_idx], (no_cols,shft_int)) cols_sum = torch.sum(cols) inval_shft = torch.is_nonzero(cols_sum) if inval_shft: raise ValueError('Consider shifting to the right for this image.') mod_img = torch.cat([img[0][0,:,~col_idx],cols], dim = 1) mod_img = torch.reshape(mod_img, (1,mod_img.shape[0], mod_img.shape[1])) mod_img = (mod_img,img[1]) return mod_img # shift object to right col_idx = torch.cat([torch.zeros(col_sty, dtype = torch.bool), torch.ones(shft_int, dtype = torch.bool)]) cols = torch.reshape(img[0][0,:,col_idx], (no_cols,shft_int)) cols_sum = torch.sum(cols) inval_shft = torch.is_nonzero(cols_sum) if inval_shft: raise ValueError('Consider shifting to the left for this image.') mod_img = torch.cat([cols,img[0][0,:,~col_idx]], dim = 1) mod_img = torch.reshape(mod_img, (1,mod_img.shape[0], mod_img.shape[1])) mod_img = (mod_img,img[1]) return mod_img ### CHECK # #img = train_data[0] #plot_it(img) #plot_it(shift_img(img, shft_int = 5)) #plot_it(shift_img(img, shft_int = -3)) #plot_it(shift_img(img, shft_int = 3))
def compute_stats(pred, target): num_classes = pred.shape[1] if pred.device.type == 'hpu': scores = torch.zeros(num_classes - 1, device=torch.device("cpu"), dtype=torch.float32) else: scores = torch.zeros(num_classes - 1, device=pred.device, dtype=torch.float32) for i in range(1, num_classes): if (target != i).all(): # no foreground class _, _pred = torch.max(pred, 1) scores[i - 1] += 1 if (_pred != i).all() else 0 continue _tp, _fp, _tn, _fn, _ = stat_scores(pred=pred, target=target, class_index=i) denom = (2 * _tp + _fp + _fn).to(torch.float) score_cls = (2 * _tp).to(torch.float) / denom if torch.is_nonzero( denom) else 0.0 if pred.device.type == 'hpu': scores[i - 1] = scores[i - 1] + score_cls.item() else: scores[i - 1] += score_cls if pred.device.type == 'hpu': scores = scores.to(torch.device("hpu")) return scores
def dice_score( pred: torch.Tensor, target: torch.Tensor, bg: bool = False, nan_score: float = 0.0, no_fg_score: float = 0.0, reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ .. deprecated:: Use :func:`torchmetrics.functional.dice_score`. Will be removed in v1.4.0. """ num_classes = pred.shape[1] bg = (1 - int(bool(bg))) scores = torch.zeros(num_classes - bg, device=pred.device, dtype=torch.float32) for i in range(bg, num_classes): if not (target == i).any(): # no foreground class scores[i - bg] += no_fg_score continue tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i) denom = (2 * tp + fp + fn).to(torch.float) # nan result score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero( denom) else nan_score scores[i - bg] += score_cls return reduce(scores, reduction=reduction)
def assert_gradients_match_and_reset_gradient(ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-05, atol=1e-06): ort_named_params = list(ort_model.named_parameters()) pt_named_params = list(pt_model.named_parameters()) assert len(ort_named_params) == len(pt_named_params) for ort_named_param, pt_named_param in zip(ort_named_params, pt_named_params): ort_name, ort_param = ort_named_param pt_name, pt_param = pt_named_param assert pt_name in ort_name if pt_name in none_pt_params: assert pt_param.grad is None assert ort_param.grad is None or not torch.is_nonzero( torch.count_nonzero(ort_param.grad)) else: assert_values_are_close(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol) if reset_gradient: ort_param.grad = None pt_param.grad = None
def assert_gradients_match_and_reset_gradient(self, ort_model, pt_model, none_pt_params=None, reset_gradient=True, rtol=1e-05, atol=1e-06): if none_pt_params is None: none_pt_params = [] ort_named_params = list(ort_model.named_parameters()) pt_named_params = list(pt_model.named_parameters()) self.assertEqual(len(ort_named_params), len(pt_named_params)) for ort_named_param, pt_named_param in zip(ort_named_params, pt_named_params): ort_name, ort_param = ort_named_param pt_name, pt_param = pt_named_param self.assertIn(pt_name, ort_name) if pt_name in none_pt_params: self.assertNotEmpty(pt_param.grad) if ort_param is not None: self.assertFalse( torch.is_nonzero(torch.count_nonzero(ort_param.grad))) else: self.assert_values_are_close(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol) if reset_gradient: ort_param.grad = None pt_param.grad = None
def compute_stats(pred, target): num_classes = pred.shape[1] _bg = 1 scores = torch.zeros(num_classes - _bg, device=pred.device, dtype=torch.float32) precision = torch.zeros(num_classes - _bg, device=pred.device, dtype=torch.float32) recall = torch.zeros(num_classes - _bg, device=pred.device, dtype=torch.float32) for i in range(_bg, num_classes): if not (target == i).any(): # no foreground class _, _pred = torch.max(pred, 1) scores[i - _bg] += 1 if not (_pred == i).any() else 0 recall[i - _bg] += 1 if not (_pred == i).any() else 0 precision[i - _bg] += 1 if not (_pred == i).any() else 0 continue _tp, _fp, _tn, _fn, _ = stat_scores(pred=pred, target=target, class_index=i) denom = (2 * _tp + _fp + _fn).to(torch.float) score_cls = (2 * _tp).to(torch.float) / denom if torch.is_nonzero( denom) else 0.0 scores[i - _bg] += score_cls return scores
def dice_score( pred: torch.Tensor, target: torch.Tensor, bg: bool = False, nan_score: float = 0.0, no_fg_score: float = 0.0, reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ Compute dice score from prediction scores Args: pred: estimated probabilities target: ground-truth labels bg: whether to also compute dice for the background nan_score: score to return, if a NaN occurs during computation no_fg_score: score to return, if no foreground pixel was found in target reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied Return: Tensor containing dice score Example: >>> from pytorch_lightning.metrics.functional import dice_score >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], ... [0.05, 0.85, 0.05, 0.05], ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> dice_score(pred, target) tensor(0.3333) """ num_classes = pred.shape[1] bg = (1 - int(bool(bg))) scores = torch.zeros(num_classes - bg, device=pred.device, dtype=torch.float32) for i in range(bg, num_classes): if not (target == i).any(): # no foreground class scores[i - bg] += no_fg_score continue tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i) denom = (2 * tp + fp + fn).to(torch.float) # nan result score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero( denom) else nan_score scores[i - bg] += score_cls return reduce(scores, reduction=reduction)
def compute(self): nom = self.logit_signs_sum denom = self.num self.reset() if torch.is_nonzero(denom): return (nom / denom).item() return float('nan')
def get_val_batches(self, dataset: Dataset) -> list: val_batches = [] for plot in self.plots: batch = [ get_positive_example(dataset) for _ in range(plot['batch_size']) ] for _, label, _ in batch: assert torch.is_nonzero(label) val_batches.append(batch) return val_batches
def dice_score( preds: Tensor, target: Tensor, bg: bool = False, nan_score: float = 0.0, no_fg_score: float = 0.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", ) -> Tensor: """Compute dice score from prediction scores. Args: preds: estimated probabilities target: ground-truth labels bg: whether to also compute dice for the background nan_score: score to return, if a NaN occurs during computation no_fg_score: score to return, if no foreground pixel was found in target reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'`` or ``None``: no reduction will be applied Return: Tensor containing dice score Example: >>> from torchmetrics.functional import dice_score >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], ... [0.05, 0.85, 0.05, 0.05], ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> dice_score(pred, target) tensor(0.3333) """ num_classes = preds.shape[1] bg_inv = 1 - int(bg) scores = torch.zeros(num_classes - bg_inv, device=preds.device, dtype=torch.float32) for i in range(bg_inv, num_classes): if not (target == i).any(): # no foreground class scores[i - bg_inv] += no_fg_score continue # TODO: rewrite to use general `stat_scores` tp, fp, _, fn, _ = _stat_scores(preds=preds, target=target, class_index=i) denom = (2 * tp + fp + fn).to(torch.float) # nan result score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else nan_score scores[i - bg_inv] += score_cls return reduce(scores, reduction=reduction)
def shift_image(img, shft_int = 1): """ Shifts the pixels of a grayscale image to the right by shft_int if the shft_int last columns of the tensor have zero entries only (i.e. black). If they cointain any non-zeros, the pixels are shifted to the left by shft_int. Abortion if both cases occur. Input: 'img' tuple with a torch tensor and an integer, 'shft_int' no. of cols to shift along x-axis Output: tuple with a shifted torch tensor and an unmodified integer """ no_cols = img[0].shape[1] lst_col = no_cols - 1 col_sty = no_cols - shft_int col_idx = torch.cat([torch.zeros(col_sty, dtype = torch.bool), torch.ones(shft_int, dtype = torch.bool)]) cols = torch.reshape(img[0][0,:,col_idx], (no_cols,shft_int)) cols_sum = torch.sum(cols) inval_shft = torch.is_nonzero(cols_sum) if inval_shft: col_idx = torch.cat([torch.ones(shft_int, dtype = torch.bool), torch.zeros(col_sty, dtype = torch.bool)]) cols = torch.reshape(img[0][0,:,col_idx], (no_cols,shft_int)) cols_sum = torch.sum(cols) inval_shft = torch.is_nonzero(cols_sum) if inval_shft: raise ValueError('Consider shifting along another axis.') mod_img = torch.cat([img[0][0,:,~col_idx],cols], dim = 1) mod_img = torch.reshape(mod_img, (1,mod_img.shape[0], mod_img.shape[1])) mod_img = (mod_img,img[1]) return mod_img mod_img = torch.cat([cols,img[0][0,:,~col_idx]], dim = 1) mod_img = torch.reshape(mod_img, (1,mod_img.shape[0], mod_img.shape[1])) mod_img = (mod_img,img[1]) return mod_img
def compute_stats(self, preds, target): scores = torch.zeros(self.n_class, device=preds.device, dtype=torch.float32) preds = torch.argmax(preds, dim=1) for i in range(1, self.n_class + 1): if (target != i).all(): # no foreground class scores[i - 1] += 1 if (preds != i).all() else 0 continue tp, fn, fp = self.get_stats(preds, target, i) denom = (2 * tp + fp + fn).to(torch.float) score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero( denom) else 0.0 scores[i - 1] += score_cls return scores
def compute_stats_brats(self, p, y): scores = torch.zeros(self.n_class, device=p.device, dtype=torch.float32) p = (torch.sigmoid(p) > 0.5).int() y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3 y = torch.stack([y_wt, y_tc, y_et], dim=1) for i in range(self.n_class): p_i, y_i = p[:, i], y[:, i] if (y_i != 1).all(): # no foreground class scores[i - 1] += 1 if (p_i != 1).all() else 0 continue tp, fn, fp = self.get_stats(p_i, y_i, 1) denom = (2 * tp + fp + fn).to(torch.float) score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero( denom) else 0.0 scores[i - 1] += score_cls return scores
def optimize_tabular(self, agent, trajectory_buffer, update_target=False): with torch.no_grad(): N = len(trajectory_buffer.buffer) inner_states, outer_states, actions, action_distributions, rewards, dones, next_inner_states, \ next_outer_states = trajectory_buffer.sample(None, random_sample=False) PS_s = agent.concept_architecture(inner_states, outer_states)[0] concepts = PS_s.argmax(1).detach().cpu().numpy() next_PS_s = agent.concept_architecture( next_inner_states[-1, :].view(1, -1), next_outer_states[-1, :, :, :].unsqueeze(0))[0] next_concept = next_PS_s.argmax(1).detach().cpu().numpy() next_concepts = np.concatenate([concepts[1:], next_concept]) PA_S, log_PA_S = agent.PA_S() HA_gS = -(PA_S * log_PA_S).sum(1) HA_S = (self.PS.view(-1) * HA_gS).sum() Alpha = agent.log_Alpha.exp().item() assert torch.isfinite(HA_S).all(), 'Alahuakbar' # PA_s = agent.second_level_architecture.actor(inner_states, outer_states)[0] ratios = PA_S[concepts, actions] / action_distributions[np.arange(0, N), actions] if self.clip_ratios: ratios = ratios.clip(0.0, 1.0) assert torch.isfinite(ratios).all(), 'Alahuakbar 1' Q = (1. - self.forgetting_factor) * agent.Q_table.detach().clone() C = (1. - self.forgetting_factor) * agent.C_table.detach().clone() assert torch.isfinite(Q).all(), 'Alahuakbar 2' assert torch.isfinite(C).all(), 'Alahuakbar 3' if N > 0: G = 0 WIS_trajectory = 1 for i in range(N - 1, -1, -1): S, A, R, WIS_step, nS = concepts[i], actions[i], rewards[ i], ratios[i], next_concepts[i] G = self.discount_factor * G + R if self.MC_entropy: dH = HA_gS[nS] - HA_S G += self.discount_factor * Alpha * dH C[S, A] = C[S, A] + WIS_trajectory if torch.is_nonzero(C[S, A]): assert torch.isfinite(C[S, A]), 'Infinity and beyond!' Q[S, A] = Q[S, A] + (WIS_trajectory / C[S, A]) * (G - Q[S, A]) WIS_trajectory = WIS_trajectory * WIS_step if self.clip_ratios: WIS_trajectory = WIS_trajectory.clip(0.0, 10.0) if not torch.is_nonzero(WIS_trajectory): break dQ = (Q - agent.Q_table).pow(2).mean() agent.update_Q(Q, C) if update_target: agent.update_target(self.MC_update_rate) Pi = agent.Pi_table.detach().clone() log_Pi = torch.log(Pi) HA_gS = -(Pi * log_Pi).sum(1) HA_S = (self.PS.view(-1) * HA_gS).sum() assert torch.isfinite(HA_S).all(), 'Alahuakbar' # Optimize Alpha agent.update_Alpha(HA_S) # Optimize policy Alpha = agent.log_Alpha.exp().item() duals = (1e-3) * torch.ones_like(self.PS.view(-1, 1)) found_policy = False iters_left = 8 while not found_policy and iters_left > 0: Q_adjusted = (Q + Alpha * duals * log_Pi) / (1. + duals) Pi_new = torch.exp(Q_adjusted / (Alpha + 1e-10)) Pi_new = Pi_new / Pi_new.sum(1, keepdim=True) log_Pi_new = torch.log(Pi_new + 1e-10) KL_div = (Pi_new * (log_Pi_new - log_Pi)).sum(1, keepdim=True) valid_policies = KL_div <= self.policy_divergence_limit if torch.all(valid_policies): found_policy = True else: iters_left -= 1 duals = 10**(1. - valid_policies.float()) * duals if found_policy: agent.update_policy(Pi_new) metrics = { 'Q_change': dQ.item(), 'entropy': HA_S.item(), 'Alpha': Alpha, 'found_policy': float(found_policy), 'max_dual': duals.max().item(), } return metrics
def optimize_tabular(self, agent, trajectory_buffer): with torch.no_grad(): N = len(trajectory_buffer.buffer) inner_states, outer_states, actions, rewards, dones, next_inner_states, \ next_outer_states = trajectory_buffer.sample(None, random_sample=False) PS_s = agent.concept_architecture(inner_states, outer_states)[0] concepts = PS_s.argmax(1).detach().cpu().numpy() next_PS_s = agent.concept_architecture(next_inner_states[-1,:].view(1,-1), next_outer_states[-1,:,:,:].unsqueeze(0))[0] next_concept = next_PS_s.argmax(1).detach().cpu().numpy() next_concepts = np.concatenate([concepts[1:], next_concept]) PA_S, log_PA_S = agent.PA_S() HA_gS = -(PA_S * log_PA_S).sum(1) HA_S = (self.PS.view(-1) * HA_gS).sum() Alpha = agent.log_Alpha.exp().item() assert torch.isfinite(HA_S).all(), 'Alahuakbar' PA_s = agent.second_level_architecture.actor(inner_states, outer_states)[0] ratios = PA_S[concepts, actions] / PA_s[np.arange(0,N),actions] assert torch.isfinite(ratios).all(), 'Alahuakbar 1' Q = agent.Q_table.detach().clone() C = agent.C_table.detach().clone() Q0 = Q.clone() assert torch.isfinite(Q).all(), 'Alahuakbar 2' assert torch.isfinite(C).all(), 'Alahuakbar 3' if N > 0: G = 0 WIS_trajectory = 1 for i in range(N-1, -1, -1): S, A, R, nS = concepts[i], actions[i], rewards[i], next_concepts[i] dH = HA_gS[nS] - HA_S G = self.discount_factor * G + R + self.discount_factor * Alpha * dH C[S,A] = C[S,A] + WIS_trajectory if torch.is_nonzero(C[S,A]): assert torch.isfinite(C[S,A]), 'Infinity and beyond!' Q[S,A] = Q[S,A] + (WIS_trajectory/C[S,A]) * (G - Q[S,A]) agent.update_Q(Q, C) PA_S = agent.PA_S()[0] HA_gS = -(PA_S * log_PA_S).sum(1) HA_S = (self.PS.view(-1) * HA_gS).sum() WIS_step = PA_S[S,A] / PA_s[i,A] WIS_trajectory = WIS_trajectory * WIS_step if not torch.is_nonzero(WIS_trajectory): break agent.update_Q(Q, 0.9*C) dQ = (Q - Q0).pow(2).mean() PA_S, log_PA_S = agent.PA_S() HA_gS = -(PA_S * log_PA_S).sum(1) HA_S = (self.PS.view(-1) * HA_gS).sum() assert torch.isfinite(HA_S).all(), 'Alahuakbar' # Optimize Alpha agent.update_Alpha(HA_S) metrics = { 'Q_change': dQ.item(), 'entropy': HA_S.item(), 'Alpha': Alpha, } return metrics
def tensor_general_ops(self): a = torch.randn(4) b = torch.tensor([1.5]) x = torch.ones((2, )) c = torch.randn(4, dtype=torch.cfloat) w = torch.rand(4, 4, 4, 4) v = torch.rand(4, 4, 4, 4) return len( # torch.is_tensor(a), # torch.is_storage(a), torch.is_complex(a), torch.is_conj(a), torch.is_floating_point(a), torch.is_nonzero(b), # torch.set_default_dtype(torch.float32), # torch.get_default_dtype(), # torch.set_default_tensor_type(torch.DoubleTensor), torch.numel(a), # torch.set_printoptions(), # torch.set_flush_denormal(False), # https://pytorch.org/docs/stable/tensors.html#tensor-class-reference # x.new_tensor([[0, 1], [2, 3]]), x.new_full((3, 4), 3.141592), x.new_empty((2, 3)), x.new_ones((2, 3)), x.new_zeros((2, 3)), x.is_cuda, x.is_quantized, x.is_meta, x.device, x.dim(), c.real, c.imag, # x.backward(), x.clone(), w.contiguous(), w.contiguous(memory_format=torch.channels_last), w.copy_(v), w.copy_(1), w.copy_(0.5), x.cpu(), # x.cuda(), # x.data_ptr(), x.dense_dim(), w.fill_diagonal_(0), w.element_size(), w.exponential_(), w.fill_(0), w.geometric_(0.5), a.index_fill(0, torch.tensor([0, 2]), 1), a.index_put_([torch.argmax(a)], torch.tensor(1.0)), a.index_put([torch.argmax(a)], torch.tensor(1.0)), w.is_contiguous(), c.is_complex(), w.is_conj(), w.is_floating_point(), w.is_leaf, w.is_pinned(), w.is_set_to(w), # w.is_shared, w.is_coalesced(), w.coalesce(), w.is_signed(), w.is_sparse, torch.tensor([1]).item(), x.log_normal_(), # x.masked_scatter_(), # x.masked_scatter(), # w.normal(), w.numel(), # w.pin_memory(), # w.put_(0, torch.tensor([0, 1], w)), x.repeat(4, 2), a.clamp_(0), a.clamp(0), a.clamp_min(0), a.hardsigmoid_(), a.hardsigmoid(), a.hardswish_(), a.hardswish(), a.hardtanh_(), a.hardtanh(), a.leaky_relu_(), a.leaky_relu(), a.relu_(), a.relu(), a.resize_as_(a), a.type_as(a), a._shape_as_tensor(), a.requires_grad_(False), )
def get_m(self, observations, comm_actions, prev_actions=None): #comm_rewards = K.zeros((observations.shape[0], observations.shape[1], 1), # dtype=observations.dtype, device=observations.device) if self.medium_type is 'obs_only': medium = K.zeros( (1, observations.shape[1], observations.shape[2] + 1), dtype=observations.dtype, device=observations.device) else: medium = K.zeros( (1, observations.shape[1], observations.shape[2] + prev_actions.shape[2] + 1), dtype=observations.dtype, device=observations.device) granted_agent = comm_actions.argmax(dim=0)[:, 0] for i in range(self.num_agents): #if competitive_comm: # comm_rewards[i, granted_agent == i, :] = 1 if self.medium_type is 'obs_only': medium[:, granted_agent == i, :] = K.cat([ observations[[i], ][:, granted_agent == i, :], (i + 1) * K.ones((1, (granted_agent == i).sum().item(), 1), dtype=observations.dtype, device=observations.device) ], dim=-1) else: medium[:, granted_agent == i, :] = K.cat([ observations[[i], ][:, granted_agent == i, :], prev_actions[[i], ][:, granted_agent == i, :], (i + 1) * K.ones((1, (granted_agent == i).sum().item(), 1), dtype=observations.dtype, device=observations.device) ], dim=-1) if K.is_nonzero( ((comm_actions < 0.00001).sum(dim=0) == self.num_agents)[:, 0].sum()): #comm_rewards[:,((comm_actions>0.5).sum(dim=0) == 0)[:,0],:] = -1 if self.medium_type is 'obs_only': medium[:, ((comm_actions < 0.00001).sum( dim=0) == self.num_agents)[:, 0], :] = K.cat([ K.zeros((1, 1, observations.shape[2]), dtype=observations.dtype, device=observations.device), K.zeros((1, 1, 1), dtype=observations.dtype, device=observations.device) ], dim=-1) else: medium[:, ((comm_actions < 0.00001).sum( dim=0) == self.num_agents)[:, 0], :] = K.cat([ K.zeros((1, 1, observations.shape[2]), dtype=observations.dtype, device=observations.device), K.zeros((1, 1, prev_actions.shape[2]), dtype=observations.dtype, device=observations.device), K.zeros((1, 1, 1), dtype=observations.dtype, device=observations.device) ], dim=-1) if K.is_nonzero( ((comm_actions > 0.99999).sum(dim=0) == self.num_agents)[:, 0].sum()): #comm_rewards[:,((comm_actions>0.5).sum(dim=0) > 1)[:,0],:] = -1 if self.medium_type is 'obs_only': medium[:, ((comm_actions > 0.99999).sum( dim=0) == self.num_agents)[:, 0], :] = K.cat([ K.zeros((1, 1, observations.shape[2]), dtype=observations.dtype, device=observations.device), K.ones( (1, 1, 1), dtype=observations.dtype, device=observations.device) * (self.num_agents + 1) ], dim=-1) else: medium[:, ((comm_actions > 0.99999).sum( dim=0) == self.num_agents)[:, 0], :] = K.cat([ K.zeros((1, 1, observations.shape[2]), dtype=observations.dtype, device=observations.device), K.zeros((1, 1, prev_actions.shape[2]), dtype=observations.dtype, device=observations.device), K.ones( (1, 1, 1), dtype=observations.dtype, device=observations.device) * (self.num_agents + 1) ], dim=-1) return K.tensor(medium, requires_grad=True)