def step(self, batch_idx: torch.Tensor, emb_t: torch.Tensor, is_shift: torch.Tensor, is_reduce: torch.Tensor, stack: torch.Tensor, stack_ptr: torch.Tensor): # stack_ptr_ = stack_ptr # stack_ptr = stack_ptr_.clone() # 2. Batched shift and reduce operations # shift if is_shift.any(): shift_stack = stack[is_shift] shift_stack_ptr = stack_ptr[is_shift] idx = torch.arange(shift_stack.size(0), dtype=shift_stack_ptr.dtype, device=shift_stack_ptr.device) shift_stack[idx, shift_stack_ptr] = emb_t[is_shift] stack[is_shift] = shift_stack stack_ptr[is_shift] = shift_stack_ptr + 1 # reduce if is_reduce.any(): reduce_stack = stack[is_reduce] reduce_stack_ptr = stack_ptr[is_reduce] idx = torch.arange(reduce_stack.size(0), dtype=reduce_stack_ptr.dtype, device=reduce_stack_ptr.device) r_child = reduce_stack[idx, reduce_stack_ptr - 1] l_child = reduce_stack[idx, reduce_stack_ptr - 2] parent = self.op(l_child, r_child) reduce_stack[idx, reduce_stack_ptr - 2] = parent stack[is_reduce] = reduce_stack stack_ptr[is_reduce] = reduce_stack_ptr - 1 return stack, stack_ptr
def _recall_at(self, predictions: Batch, targets: Batch, matches: torch.Tensor) -> Dict[int, float]: # matches.shape = [num_relations_pred, num_relations_gt] # matches.argmax(dim=0) will return the last index if no True value is found. # We can use matches.any(dim=0) to ignore those cases. # Also, we must account for the row offset in the matches matrix. gt_retrieved = matches.any(dim=0) offset = (predictions.n_edges.cumsum(dim=0).repeat_interleave( targets.n_edges) - predictions.n_edges[0]) gt_retrieved_rank = matches.int().argmax(dim=0) - offset # [K, E_t] gt_retrieved_at = (gt_retrieved_rank[None, :] < self.steps[:, None]) & gt_retrieved[None, :] # [K, num_graphs] gt_relation_to_graph_assignment = targets.batch[ targets.relation_indexes[0]] recall_at_per_graph = scatter_( "mean", gt_retrieved_at.float(), index=gt_relation_to_graph_assignment, dim=1, dim_size=targets.num_graphs, ) # [K] recall_at = recall_at_per_graph.mean(dim=1) return {k: v for k, v in zip(self.steps.numpy(), recall_at.numpy())}
def forward(self, q: Tensor, sequence_mask: Tensor) -> Tensor: """ Produce a single classification output for a sequence of vectors. Parameters ---------- q : [T, B, D] Hidden activations after central encoder. sequence_mask : [T, B, 1] Positive mask for zeroing out padded vectors between operations. Returns ------- classification: [B] Probability of this particle existing in the data. """ # ------------------------------------------------------------ # Collapse the sequence vectors into a single vector as a sum. # hidden_dim : [1, B, D] # sequence_mask : [1, B, 1] # ------------------------------------------------------------ hidden = (q * sequence_mask).sum(0, keepdim=True) / sequence_mask.sum( 0, keepdim=True) sequence_mask = sequence_mask.any(dim=0, keepdim=True) # ------------------------------------------------------------ # Run through the linear layer stack and output the result # classification : [B] # ------------------------------------------------------------ hidden = self.hidden_layers(hidden, sequence_mask).squeeze() classification = self.output_layer(hidden).squeeze() return classification
def rescale( self, tensor: torch.Tensor, mask: torch.Tensor, image_name: str, ) -> torch.Tensor: # The tensor is cloned as in-place operations will be used array = tensor.clone().float().numpy() mask = mask.numpy() if not mask.any(): message = (f'Rescaling image "{image_name}" not possible' ' because the mask to compute the statistics is empty') warnings.warn(message, RuntimeWarning) return tensor values = array[mask] cutoff = np.percentile(values, self.percentiles) np.clip(array, *cutoff, out=array) if self.in_min_max is None: in_min, in_max = array.min(), array.max() else: in_min, in_max = self.in_min_max in_range = in_max - in_min if in_range == 0: # should this be compared using a tolerance? message = (f'Rescaling image "{image_name}" not possible' ' because all the intensity values are the same') warnings.warn(message, RuntimeWarning) return tensor array -= in_min array /= in_range out_range = self.out_max - self.out_min array *= out_range array += self.out_min return torch.as_tensor(array)
def forward(self, prev_output_tokens: Tensor, encoder_out: Tensor, encoder_padding_mask: Tensor, incremental_state: Optional[Dict[str, Dict[str, Tensor]]] = None): # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C # The tensor needs to copy transposed because # fused dropout is not capable of handing strided data if self.fuse_dropout_add: x = x.transpose(0, 1).contiguous() else: x = x.transpose(0, 1) attn = None # decoder layers for layer in self.layers: x, attn = layer( x, encoder_out, encoder_padding_mask if encoder_padding_mask.any() else None, incremental_state, ) if self.normalize: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.adaptive_softmax is None: # project back to size of vocabulary x = F.linear(x, self.embed_out) return x, attn
def forward(self, # pylint: disable=arguments-differ source: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], reset: torch.Tensor, metadata: List[Dict[str, Any]], mention_type: torch.Tensor = None, raw_entity_ids: Dict[str, torch.Tensor] = None, entity_ids: Dict[str, torch.Tensor] = None, parent_ids: Dict[str, torch.Tensor] = None, relations: Dict[str, torch.Tensor] = None, shortlist: Dict[str, torch.Tensor] = None, shortlist_inds: torch.Tensor = None, alias_copy_inds: torch.Tensor = None) -> Dict[str, torch.Tensor]: # Tensorize the alias_database - this will only perform the operation once. alias_database = metadata[0]['alias_database'] alias_database.tensorize(vocab=self.vocab) # Reset the model if needed if reset.any() and (self._state is not None): for layer in range(self._num_layers): h, c = self._state['layer_%i' % layer] h[:, reset, :] = torch.zeros_like(h[:, reset, :]) c[:, reset, :] = torch.zeros_like(c[:, reset, :]) self._state['layer_%i' % layer] = (h, c) self._recent_entities.reset(reset) if entity_ids is not None: output_dict = self._forward_loop( source=source, target=target, alias_database=alias_database, mention_type=mention_type, raw_entity_ids=raw_entity_ids, entity_ids=entity_ids, parent_ids=parent_ids, relations=relations, shortlist=shortlist, shortlist_inds=shortlist_inds, alias_copy_inds=alias_copy_inds) else: # TODO: Figure out what we want here - probably to do some king of inference on # entities / mention types. output_dict = {} return output_dict
def forward(self, prev_output_tokens: Tensor, encoder_out: Tensor, encoder_padding_mask: Tensor, incremental_state: Optional[Dict[str, Dict[str, Tensor]]] = None): positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if positions is not None: x += positions if self.training: if self.platform == "npu": x, _, _ = torch.dropoutV2(x, self.seed, p=self.prob) elif self.platform == "gpu": x = self.dropout(x) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None # decoder layers for layer in self.layers: x, attn = layer( x, encoder_out, encoder_padding_mask if encoder_padding_mask.any() else None, incremental_state, ) # T x B x C -> B x T x C x = x.transpose(0, 1) x = F.linear(x, self.embed_out) return x, attn
def offset_loss( preds: torch.Tensor, labels: torch.Tensor, is_center: torch.Tensor, n_objects: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ :param preds: N2HW (float32) :param labels: N2HW (float32) :param is_center: NKHW (torch) for K-way classification """ if n_objects is None: n_objects = is_center.sum() if not n_objects: return 0 is_center = is_center.any(1).unsqueeze(1) # we could omit `* is_center` for labels if we assume they're 0 except for centers l1_loss = torch.nn.functional.l1_loss(preds * is_center, labels * is_center, reduction="sum") return l1_loss / n_objects
def forward(self, # pylint: disable=arguments-differ source: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], reset: torch.Tensor, entity_ids: torch.Tensor = None, shortlist: Dict[str, torch.Tensor] = None, shortlist_inds: torch.Tensor = None, alias_copy_inds: torch.Tensor = None, alias_tokens: torch.Tensor = None, alias_inds: torch.Tensor = None) -> Dict[str, torch.Tensor]: alias_tokens = alias_tokens['tokens'] # Inds have fixed size and don't get truncated on split so truncate # now. alias_inds = alias_inds[:, :, :alias_tokens.shape[2]] # Reset the model if needed if reset.any() and (self._state is not None): for layer in range(self._num_layers): h, c = self._state['layer_%i' % layer] h[:, reset, :] = torch.zeros_like(h[:, reset, :]) c[:, reset, :] = torch.zeros_like(c[:, reset, :]) self._state['layer_%i' % layer] = (h, c) if entity_ids is not None: output_dict = self._forward_loop( source=source, target=target, alias_copy_inds=alias_copy_inds, alias_tokens=alias_tokens, alias_inds=alias_inds) else: # TODO: Figure out what we want here - probably to do some king of inference on # entities / mention types. output_dict = {} return output_dict
def sample(self, source: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], reset: torch.Tensor, metadata: Dict[str, Any], alias_copy_inds: torch.Tensor, shortlist: Dict[str, torch.Tensor] = None, **kwargs) -> Dict[str, Any]: # **kwargs intended to eat the other fields if they are provided. """ Sampling annotations for the generative model. Note that unlike forward, this function expects inputs from a **generative** dataset reader, not a **discriminative** one. """ # Tensorize the alias_database - this will only perform the operation once. alias_database = metadata[0]['alias_database'] alias_database.tensorize(vocab=self.vocab) # Reset the model if needed if reset.any() and (self._state is not None): for layer in range(self._num_layers): h, c = self._state['layer_%i' % layer] h[:, reset, :] = torch.zeros_like(h[:, reset, :]) c[:, reset, :] = torch.zeros_like(c[:, reset, :]) self._state['layer_%i' % layer] = (h, c) self._recent_entities.reset(reset) logp = 0.0 mask = get_text_field_mask(target).byte() # We encode the target tokens (**not** source) since the discriminative model makes # predictions on the current token, but the generative model expects labels for the # **next** (e.g. target) token! encoded, *_ = self._encode_source(target['tokens']) splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2 encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1) # Compute new mention logits mention_logits = self._fc_mention_type(encoded_token) mention_probs = F.softmax(mention_logits, dim=-1) mention_type = parallel_sample(mention_probs) mention_logp = mention_probs.gather(-1, mention_type.unsqueeze(-1)).log() mention_logp[~mask] = 0 mention_logp = mention_logp.sum() # Compute entity logits new_entity_mask = mention_type.eq(1) new_entity_logits = self._new_entity_logits(encoded_head + encoded_relation, shortlist) if self._use_shortlist: # If using shortlist, then samples are indexed w.r.t the shortlist and entity_ids must be looked up shortlist_mask = get_text_field_mask(shortlist) new_entity_probs = masked_softmax(new_entity_logits, shortlist_mask) shortlist_inds = torch.zeros_like(mention_type) # Some sequences may be full of padding in which case the shortlist # is empty not_just_padding = shortlist_mask.byte().any(-1) shortlist_inds[not_just_padding] = parallel_sample(new_entity_probs[not_just_padding]) shortlist_inds[~new_entity_mask] = 0 _new_entity_logp = new_entity_probs.gather(-1, shortlist_inds.unsqueeze(-1)).log() new_entity_samples = shortlist['entity_ids'].gather(1, shortlist_inds) else: new_entity_logits = new_entity_logits # If not using shortlist, then samples are indexed w.r.t to the global vocab new_entity_probs = F.softmax(new_entity_logits, dim=-1) new_entity_samples = parallel_sample(new_entity_probs) _new_entity_logp = new_entity_probs.gather(-1, new_entity_samples.unsqueeze(-1)).log() shortlist_inds = None # Zero out masked tokens and non-new entity predictions _new_entity_logp[~mask] = 0 _new_entity_logp[~new_entity_mask] = 0 new_entity_logp = _new_entity_logp.sum() # Start filling in the entity ids entity_ids = torch.zeros_like(target['tokens']) entity_ids[new_entity_mask] = new_entity_samples[new_entity_mask] # ...UGH we also need the raw ids - remapping time raw_entity_ids = torch.zeros_like(target['tokens']) for *index, entity_id in nested_enumerate(entity_ids.tolist()): token = self.vocab.get_token_from_index(entity_id, 'entity_ids') raw_entity_id = self.vocab.get_token_index(token, 'raw_entity_ids') raw_entity_ids[tuple(index)] = raw_entity_id # Derived mentions need to be computed sequentially. parent_ids = torch.zeros_like(target['tokens']).unsqueeze(-1) derived_entity_mask = mention_type.eq(2) derived_entity_logp = 0.0 sequence_length = target['tokens'].shape[1] for i in range(sequence_length): current_mask = derived_entity_mask[:, i] & mask[:, i] # ------------------- SAMPLE PARENTS --------------------- # Update recent entities with **current** entity only current_entity_id = entity_ids[:, i].unsqueeze(1) candidate_ids, candidate_mask = self._recent_entities(current_entity_id) # If no mentions are derived, there is no point continuing after entities have been updated. if not current_mask.any(): continue # Otherwise we proceed candidate_embeddings = self._entity_embedder(candidate_ids) # Compute logits w.r.t **current** hidden state only current_head_encoding = encoded_head[:, i].unsqueeze(1) selection_logits = torch.bmm(current_head_encoding, candidate_embeddings.transpose(1, 2)) selection_probs = masked_softmax(selection_logits, candidate_mask) # Only sample if there is at least one viable candidate (e.g. if a sampling distribution # has no probability mass we cannot sample from it). Return zero as the parent for # non-viable distributions. viable_candidate_mask = candidate_mask.any(-1).squeeze() _parent_ids = torch.zeros_like(current_entity_id) parent_logp = torch.zeros_like(current_entity_id, dtype=torch.float32) if viable_candidate_mask.any(): viable_candidate_ids = candidate_ids[viable_candidate_mask] viable_candidate_probs = selection_probs[viable_candidate_mask] viable_parent_samples = parallel_sample(viable_candidate_probs) viable_logp = viable_candidate_probs.gather(-1, viable_parent_samples.unsqueeze(-1)).log() viable_parent_ids = viable_candidate_ids.gather(-1, viable_parent_samples) _parent_ids[viable_candidate_mask] = viable_parent_ids parent_logp[viable_candidate_mask] = viable_logp.squeeze(-1) parent_ids[current_mask, i] = _parent_ids[current_mask] # TODO: Double-check derived_entity_logp += parent_logp[current_mask].sum() # ---------------------- SAMPLE RELATION ----------------------------- # Lookup sampled parent ids in the knowledge graph indices, parent_ids_list, relations_list, tail_ids_list = self._knowledge_graph_lookup(_parent_ids) relation_embeddings = [self._relation_embedder(r) for r in relations_list] # Sample tail ids current_relation_encoding = encoded_relation[:, i].unsqueeze(1) _raw_tail_ids = torch.zeros_like(_parent_ids).squeeze(-1) _tail_ids = torch.zeros_like(_parent_ids).squeeze(-1) for index, relation_embedding, tail_id_lookup in zip(indices, relation_embeddings, tail_ids_list): # Compute the score for each relation w.r.t the current encoding. NOTE: In the loss # code index has a slice. We don't need that here since there is always a # **single** parent. logits = torch.mv(relation_embedding, current_relation_encoding[index]) # Convert to probability tail_probs = F.softmax(logits, dim=-1) # Sample tail_sample = torch.multinomial(tail_probs, 1) # Get logp. Ignoring the current_mask here is **super** dodgy, but since we forced # null parents to zero we shouldn't be accumulating probabilities for unused predictions. tail_logp = tail_probs.gather(-1, tail_sample).log() derived_entity_logp += tail_logp.sum() # Sum is redundant, just need it to make logp a scalar # Map back to raw id raw_tail_id = tail_id_lookup[tail_sample] # Convert raw id to id tail_id_string = self.vocab.get_token_from_index(raw_tail_id.item(), 'raw_entity_ids') tail_id = self.vocab.get_token_index(tail_id_string, 'entity_ids') _raw_tail_ids[index[:-1]] = raw_tail_id _tail_ids[index[:-1]] = tail_id raw_entity_ids[current_mask, i] = _raw_tail_ids[current_mask] # TODO: Double-check entity_ids[current_mask, i] = _tail_ids[current_mask] # TODO: Double-check self._recent_entities.insert(_tail_ids, current_mask) # --------------------- CONTINUE MENTIONS --------------------------------------- continue_mask = mention_type[:, i].eq(3) & mask[:, i] if not current_mask.any() or i == 0: continue raw_entity_ids[continue_mask, i] = raw_entity_ids[continue_mask, i-1] entity_ids[continue_mask, i] = entity_ids[continue_mask, i-1] entity_ids[continue_mask, i] = entity_ids[continue_mask, i-1] parent_ids[continue_mask, i] = parent_ids[continue_mask, i-1] if self._use_shortlist: shortlist_inds[continue_mask, i] = shortlist_inds[continue_mask, i-1] alias_copy_inds[continue_mask, i] = alias_copy_inds[continue_mask, i-1] # Lastly, because entities won't always match the true entity ids, # we need to zero out any alias copy ids that won't be valid. if 'raw_entity_ids' in kwargs: true_raw_entity_ids = kwargs['raw_entity_ids']['raw_entity_ids'] invalid_id_mask = ~true_raw_entity_ids.eq(raw_entity_ids) alias_copy_inds[invalid_id_mask] = 0 # Pass denotes fields that are passed directly from input to output. sample = { 'source': source, # Pass 'target': target, # Pass 'reset': reset, # Pass 'metadata': metadata, # Pass 'mention_type': mention_type, 'raw_entity_ids': {'raw_entity_ids': raw_entity_ids}, 'entity_ids': {'entity_ids': entity_ids}, 'parent_ids': {'entity_ids': parent_ids}, 'relations': {'relations': None}, # We aren't using them - eventually should remove entirely 'shortlist': shortlist, # Pass 'shortlist_inds': shortlist_inds, 'alias_copy_inds': alias_copy_inds } logp = mention_logp + new_entity_logp + derived_entity_logp return {'sample': sample, 'logp': logp}
def mask_to_instances(mask: Tensor) -> Tensor: r"""Given a binary mask, extract a new mask indicating contiguous instances. The resultant mask will use ``0`` to indicate background classes, and will begin numbering instance regions with ``1``. This method identifies connected components via nearest-neighbor message passing. The runtime is a function of the diameter of the largest connected component. Args: mask (:class:`torch.Tensor`): Binary mask to extract instances Shapes: * ``mask`` - :math:`(H, W)` * Output - :math:`(H, W)` """ H, W = mask.shape[-2:] mask = mask > 0 if not mask.any(): return mask # assign each positive location a unique instance label _ = torch.arange(0, H * W, device=mask.device) with torch.random.fork_rng(devices=(mask.device, )): torch.random.manual_seed(42) instances = (torch.multinomial(torch.ones(H * W, dtype=torch.float, device=mask.device), H * W, replacement=False).view(H, W).long()) instances[~mask] = 0 # get locations of each positive label nodes = mask.nonzero() N = nodes.shape[0] node_idx = torch.split(nodes, [1, 1], dim=-1) # build adjacency tensor delta = torch.tensor( [ [-1, -1], [-1, 0], [-1, 1], [0, -1], [0, 0], [0, 1], [1, -1], [1, 0], [1, 1], ], device=mask.device, ) A = delta.shape[-2] adjacency = (nodes.view(-1, 1, 2) + delta.view(1, -1, 2)).reshape(-1, 2) adjacency[..., 0].clamp_(min=0, max=H - 1) adjacency[..., 1].clamp_(min=0, max=W - 1) adjacency_idx = torch.split(adjacency, [1, 1], dim=-1) passed_messages = instances[adjacency_idx].view(N, A) # iteratively pass instance label to neighbors # adopt instance label of max(self, neighbors) # NOTE: try to buffer things and operate in place where possible for speed # NOTE: something below fails to script old_instances = instances new_instances = instances.clone() adjacency_buffer = adjacency.new_empty(N) adjacency.new_empty(N) while True: passed_messages = old_instances[adjacency_idx].view(N, A) torch.amax(passed_messages, dim=-1, out=adjacency_buffer) new_instances[node_idx] = adjacency_buffer.view(-1, 1) # if nothing was updated, we're done diff = new_instances[mask] != old_instances[mask] if not diff.any(): break _ = new_instances new_instances = old_instances old_instances = _ # convert unique instance labels into a 1-indexed set of consecutive labels unique_instances = torch.unique(instances) for new_ins, old_ins in enumerate(unique_instances): if old_ins == 0: continue instances[instances == old_ins] = new_ins return instances.long()
def forward(self, # pylint: disable=arguments-differ source: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], reset: torch.Tensor = None) -> Dict[str, torch.Tensor]: # THE BELOW ONLY NEEDS TO BE SATISFIED FOR THE FANCY ITERATOR, MERITY # ET AL JUST PROPOGATE THE HIDDEN STATE NO MATTER WHAT # To make life easier when evaluating the model we use a BasicIterator # so that we do not need to worry about the sequence truncation # performed by our splitting iterators. To accomodate this, we assume # that if reset is not given, then everything gets reset. if reset is None: self._state = None elif reset.all() and (self._state is not None): logger.debug('RESET') self._state = None elif reset.any() and (self._state is not None): for layer in range(self.num_layers): h, c = self._state['layer_%i' % layer] h[:, reset, :] = torch.zeros_like(h[:, reset, :]) c[:, reset, :] = torch.zeros_like(c[:, reset, :]) self._state['layer_%i' % layer] = (h, c) target_mask = get_text_field_mask(target) source = source['tokens'] target = target['tokens'] embeddings = embedded_dropout(self.embedder, source, dropout=self.dropoute if self.training else 0) embeddings = self.locked_dropout(embeddings, self.dropouti) # Iterate through RNN layers current_input = embeddings current_hidden = [] outputs = [] dropped_outputs = [] for layer, rnn in enumerate(self.rnns): # Bookkeeping if self._state is not None: prev_hidden = self._state['layer_%i' % layer] else: prev_hidden = None # Forward-pass output, hidden = rnn(current_input, prev_hidden) # More bookkeeping output = output.contiguous() outputs.append(output) hidden = tuple(h.detach() for h in hidden) current_hidden.append(hidden) # Apply dropout if layer == self.num_layers - 1: current_input = self.locked_dropout(output, self.dropout) dropped_outputs.append(output) else: current_input = self.locked_dropout(output, self.dropouth) dropped_outputs.append(current_input) # Compute logits and loss logits = self.decoder(current_input) loss = sequence_cross_entropy_with_logits(logits, target.contiguous(), target_mask, average="token") num_tokens = target_mask.float().sum() + 1e-13 # Activation regularization if self.alpha: loss = loss + self.alpha * current_input.pow(2).mean() # Temporal activation regularization (slowness) if self.beta: loss = loss + self.beta * (output[:, 1:] - output[:, :-1]).pow(2).mean() # Update metrics and state unks = target.eq(self._unk_index) unk_penalty = self._unk_penalty * unks.float().sum() self.ppl(loss * num_tokens, num_tokens) self.upp(loss * num_tokens + unk_penalty, num_tokens) self._state = {'layer_%i' % l: h for l, h in enumerate(current_hidden)} return {'loss': loss}
def _has_any(mask: Tensor) -> bool: r""" Checks if the mask has any set to \p True """ assert mask.dtype == torch.bool, "Mask should be a Boolean Tensor" return bool(mask.any().item())
def pgd(net: nn.Module, x: torch.Tensor, nb_iter: int = 10, eps: float = 0.3, eps_iter: float = 0.05, rand_minmax: float = 0.3, clip_min=None, clip_max=None, y=None, ordr=np.inf, rand_init=None, targeted=False) -> torch.Tensor: """ This class implements either the Basic Iterative Method (Kuarkin et al. 2016) when rand_init is set to 0. or the Madry et al. (2017) method when rand_minmax is larger than 0. Paper link (Kuarkin et al. 2016): https://arxiv.org/pdf/1607.02533.pdf Paper link (Madry et al. 2017): https://arxiv.org/pdf/1706.06083.pdf # TODO FIX the below arguments style Arguments --------- model: Model dtype: dtype of the data default_rand_init: whether to use random initialization by default kwargs: passed through to super constructor Returns ------- adv_x : torch.Tensor The adversarial Example. """ # TODO Check params # If a data range was specified, check that the input was in that range if clip_min is not None: asserts.append(x.any() >= clip_min) if clip_max is not None: asserts.append(x.any() <= clip_max) # Initialize loop variables if rand_init: eta = torch.FloatTensor(*x.shape).uniform_(-minmax, minmax) else: eta = torch.zeros_like(x) # Clip eta eta = clip_eta(eta, ordr, eps) adv_x = x + eta if clip_min is not None or clip_max is not None: adv_x = torch.clamp(adv_x, clip_max, clip_min) if y is None: # Use ground truth labels to avoid label leaking _, y = torch.max(net(x), dim=1) else: targeted = True if ord == 1: raise NotImplementedError( "It's not clear that FGM is a good inner loop" " step for PGD when ord=1, because ord=1 FGM " " changes only one pixel at a time. We need " " to rigoursly test a strong ord=1 PGD " " before enabling this feature.") i = 0 while i < nb_iter: """ Do a projected gradient step. """ adv_x = fgm(net, adv_x, eps=eps_iter, ordr=ordr, clip_min=clip_min, clip_max=clip_max, y=y, targeted=targeted) # Clipping perturbation eta to ord norm ball eta = adv_x - x eta = clip_eta(eta, ordr, eps) adv_x = x + eta # Redo the clipping. # FGM alread already did it, but subtracting and re-adding eta can add some # small numerical error if clip_min is not None or clip_max is not None: adv_x = torch.clamp(adv_x, clip_min, clip_max) i += 1 # Asserts run only on CPU # When multi-GPU eval code tries to force all PGD ops onto GPU, this # can cause an error. # The 1e-6 is needed to compensate for numerical error. # Without the 1e-6 this fails when e.g. eps=.2 clip_max=.5 clip_min=.7 if ordr == np.inf and clip_min is not None: assert (eps <= (1e6 + clip_max - clip_min)) return adv_x