def extract_features(self, tokens): # support passing in a single sentence torch._assert( tokens.dim() == 1 or tokens.dim() == 2, "tokens should be a 1D or 2D tensor" ) tokens = tokens.view(-1, tokens.shape[-1]) return self.transformer(tokens)
def forward(self, fmap1: Tensor, fmap2: Tensor) -> List[Tensor]: """Build the correlation pyramid from two feature maps. The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2) on the same row. The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions to build the correlation pyramid. """ torch._assert( fmap1.shape == fmap2.shape, f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)", ) batch_size, num_channels, h, w = fmap1.shape fmap1 = fmap1.view(batch_size, num_channels, h, w) fmap2 = fmap2.view(batch_size, num_channels, h, w) corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2) corr = corr.view(batch_size, h, w, 1, w) corr_volume = corr / torch.sqrt( torch.tensor(num_channels, device=corr.device)) corr_volume = corr_volume.reshape(batch_size * h * w, 1, 1, w) corr_pyramid = [corr_volume] for _ in range(self.num_levels - 1): corr_volume = F.avg_pool2d(corr_volume, kernel_size=(1, 2), stride=(1, 2)) corr_pyramid.append(corr_volume) return corr_pyramid
def index_pyramid(self, centroids_coords): """Return correlation features by indexing from the pyramid.""" neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels di = torch.linspace(-self.radius, self.radius, neighborhood_side_len) dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len) delta = torch.stack(torch.meshgrid(di, dj, indexing="ij"), dim=-1).to(centroids_coords.device) delta = delta.view(1, neighborhood_side_len, neighborhood_side_len, 2) batch_size, _, h, w = centroids_coords.shape # _ = 2 centroids_coords = centroids_coords.permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 2) indexed_pyramid = [] for corr_volume in self.corr_pyramid: sampling_coords = centroids_coords + delta # end shape is (batch_size * h * w, side_len, side_len, 2) indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view( batch_size, h, w, -1 ) indexed_pyramid.append(indexed_corr_volume) centroids_coords = centroids_coords / 2 corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous() expected_output_shape = (batch_size, self.out_channels, h, w) torch._assert( corr_features.shape == expected_output_shape, f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}", ) return corr_features
def replace_ph(x): nonlocal cnt cnt += 1 param = sig.parameters[name] default = ( ) if param.default is inspect.Parameter.empty else ( param.default, ) out = self.create_proxy('placeholder', f'{name}_{str(cnt)}', default, {}) if x == PH: return out # Union[int, bool] == bool in Python <= 3.6 if type(x) == bool or type( x) in base_types and type(x) != torch.Tensor: torch._assert( out == x, f"{name} has been specialized to have value {x} but got another value" ) elif type(x) == type(None): args = ( out, f"{name} has been specialized to have value None but got another value" ) self.create_proxy('call_function', _assert_is_none, args, {}) else: torch.warnings.warn( f"Was not able to add assertion to guarantee correct input {name} to " f"specialized function. It is up to the user to make sure that your inputs match the " f"inputs you specialized the function with.") return x
def forward(self, centroids_coords: Tensor, corr_pyramid: List[Tensor]) -> Tensor: """Return correlation features by indexing from the pyramid.""" neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels di = torch.linspace(-self.radius, self.radius, neighborhood_side_len, device=centroids_coords.device) di = di.view(1, 1, neighborhood_side_len, 1).to(centroids_coords.device) batch_size, _, h, w = centroids_coords.shape # _ = 2 but we only use the first one # We only consider 1d and take the first dim only centroids_coords = centroids_coords[:, :1].permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 1) indexed_pyramid = [] for corr_volume in corr_pyramid: x0 = centroids_coords + di # end shape is (batch_size * h * w, 1, side_len, 1) y0 = torch.zeros_like(x0) sampling_coords = torch.cat([x0, y0], dim=-1) indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view( batch_size, h, w, -1 ) indexed_pyramid.append(indexed_corr_volume) centroids_coords = centroids_coords / 2 corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous() expected_output_shape = (batch_size, self.out_channels, h, w) torch._assert( corr_features.shape == expected_output_shape, f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}", ) return corr_features
def forward(self, input: torch.Tensor): torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") x = self.ln_1(input) x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) x = self.dropout(x) x = x + input y = self.ln_2(x) y = self.mlp(y) return x + y
def forward(self, x): if x.dim() == 4: torch._assert( list(x.shape[2:]) == [1, 1], f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}", ) x = x.flatten(start_dim=1) scores = self.cls_score(x) bbox_deltas = self.bbox_pred(x) return scores, bbox_deltas
def destructure_any_list(client_batch: List[int], result_any_list: List[Any]): # -> List[List[Any]]: res_list: List[List[Any]] = [] # torch.jit.annotate(List[List[Any]], []) start = 0 for elems in client_batch: torch._assert(elems > 0, "zero or negative range error") end = start + elems res_list.append(result_any_list[start:end]) start = end return res_list
def forward(self, image1, image2, num_flow_updates: int = 12): batch_size, _, h, w = image1.shape torch._assert((h, w) == image2.shape[-2:], "input images should have the same shape") torch._assert((h % 8 == 0) and (w % 8 == 0), "input image H and W should be divisible by 8") fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) torch._assert(fmap1.shape[-2:] == (h // 8, w // 8), "The feature encoder should downsample H and W by 8") self.corr_block.build_pyramid(fmap1, fmap2) context_out = self.context_encoder(image1) torch._assert(context_out.shape[-2:] == (h // 8, w // 8), "The context encoder should downsample H and W by 8") # As in the original paper, the actual output of the context encoder is split in 2 parts: # - one part is used to initialize the hidden state of the recurent units of the update block # - the rest is the "actual" context. hidden_state_size = self.update_block.hidden_state_size out_channels_context = context_out.shape[1] - hidden_state_size torch._assert( out_channels_context > 0, f"The context encoder outputs {context_out.shape[1]} channels, but it should have at strictly more than" f"hidden_state={hidden_state_size} channels", ) hidden_state, context = torch.split(context_out, [hidden_state_size, out_channels_context], dim=1) hidden_state = torch.tanh(hidden_state) context = F.relu(context) coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) flow_predictions = [] for _ in range(num_flow_updates): coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper corr_features = self.corr_block.index_pyramid(centroids_coords=coords1) flow = coords1 - coords0 hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow) coords1 = coords1 + delta_flow up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state) upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask) flow_predictions.append(upsampled_flow) return flow_predictions
def build_pyramid(self, fmap1, fmap2): """Build the correlation pyramid from two feature maps. The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2) The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions to build the correlation pyramid. """ torch._assert(fmap1.shape == fmap2.shape, "Input feature maps should have the same shapes") corr_volume = self._compute_corr_volume(fmap1, fmap2) batch_size, h, w, num_channels, _, _ = corr_volume.shape # _, _ = h, w corr_volume = corr_volume.reshape(batch_size * h * w, num_channels, h, w) self.corr_pyramid = [corr_volume] for _ in range(self.num_levels - 1): corr_volume = F.avg_pool2d(corr_volume, kernel_size=2, stride=2) self.corr_pyramid.append(corr_volume)
def replace_ph(x): nonlocal cnt cnt += 1 out = self.create_proxy('placeholder', f'{name}_{str(cnt)}', (), {}) if x == PH: return out # Union[int, bool] == bool in Python <= 3.6 if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor: torch._assert(out == x, f"{name} has been specialized to have value {x}") else: torch.warnings.warn( "Was not able to add assertion to guarantee correct inputs to " "specialized function. It is up to the user to make sure that your inputs match the " "inputs you specialized the function with." ) return x
def test_assert_true(self): # verify assertions work as expected # bool argument torch._assert(True, "foo") with self.assertRaisesRegex(AssertionError, "bar"): torch._assert(False, "bar") # tensor argument torch._assert(torch.tensor([True], dtype=torch.bool), "foo") with self.assertRaisesRegex(AssertionError, "bar"): torch._assert(torch.tensor([False], dtype=torch.bool), "bar")
def __call__(self, match_quality_matrix: Tensor) -> Tensor: """ Args: match_quality_matrix (Tensor[float]): an MxN tensor, containing the pairwise quality between M ground-truth elements and N predicted elements. Returns: matches (Tensor[int64]): an N tensor where N[i] is a matched gt in [0, M - 1] or a negative value indicating that prediction i could not be matched. """ if match_quality_matrix.numel() == 0: # empty targets or proposals not supported during training if match_quality_matrix.shape[0] == 0: raise ValueError( "No ground-truth boxes available for one of the images during training" ) else: raise ValueError( "No proposal boxes available for one of the images during training" ) # match_quality_matrix is M (gt) x N (predicted) # Max over gt elements (dim 0) to find best gt candidate for each prediction matched_vals, matches = match_quality_matrix.max(dim=0) if self.allow_low_quality_matches: all_matches = matches.clone() else: all_matches = None # type: ignore[assignment] # Assign candidate matches with low quality to negative (unassigned) values below_low_threshold = matched_vals < self.low_threshold between_thresholds = (matched_vals >= self.low_threshold) & ( matched_vals < self.high_threshold) matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD matches[between_thresholds] = self.BETWEEN_THRESHOLDS if self.allow_low_quality_matches: if all_matches is None: torch._assert(False, "all_matches should not be None") else: self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) return matches
def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int): backbone = backbone.features # Gather the indices of maxpools. These are the locations of output blocks. stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1] num_stages = len(stage_indices) # find the index of the layer from which we wont freeze torch._assert( 0 <= trainable_layers <= num_stages, f"trainable_layers should be in the range [0, {num_stages}]. Instead got {trainable_layers}", ) freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] for b in backbone[:freeze_before]: for parameter in b.parameters(): parameter.requires_grad_(False) return SSDFeatureExtractorVGG(backbone, highres)
def forward( self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]: images = [img for img in images] if targets is not None: # make a copy of targets to avoid modifying it in-place # once torchscript supports dict comprehension # this can be simplified as follows # targets = [{k: v for k,v in t.items()} for t in targets] targets_copy: List[Dict[str, Tensor]] = [] for t in targets: data: Dict[str, Tensor] = {} for k, v in t.items(): data[k] = v targets_copy.append(data) targets = targets_copy for i in range(len(images)): image = images[i] target_index = targets[i] if targets is not None else None if image.dim() != 3: raise ValueError( f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}" ) image = self.normalize(image) image, target_index = self.resize(image, target_index) images[i] = image if targets is not None and target_index is not None: targets[i] = target_index image_sizes = [img.shape[-2:] for img in images] images = self.batch_images(images, size_divisible=self.size_divisible) image_sizes_list: List[Tuple[int, int]] = [] for image_size in image_sizes: torch._assert( len(image_size) == 2, f"Input tensors expected to have in the last two elements H and W, instead got {image_size}", ) image_sizes_list.append((image_size[0], image_size[1])) image_list = ImageList(images, image_sizes_list) return image_list, targets
def _box_loss( type: str, box_coder: BoxCoder, anchors_per_image: Tensor, matched_gt_boxes_per_image: Tensor, bbox_regression_per_image: Tensor, cnf: Optional[Dict[str, float]] = None, ) -> Tensor: torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}") if type == "l1": target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum") elif type == "smooth_l1": target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0 return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta) else: bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image) eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7 if type == "ciou": return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) if type == "diou": return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) # otherwise giou return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
def _process_input(self, x: torch.Tensor) -> torch.Tensor: n, c, h, w = x.shape p = self.patch_size torch._assert(h == self.image_size, "Wrong image height!") torch._assert(w == self.image_size, "Wrong image width!") n_h = h // p n_w = w // p # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) x = self.conv_proj(x) # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) x = x.reshape(n, self.hidden_dim, n_h * n_w) # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) # The self attention layer expects inputs in the format (N, S, E) # where S is the source sequence length, N is the batch size, E is the # embedding dimension x = x.permute(0, 2, 1) return x
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: torch._assert( isinstance(boxes, (list, tuple)), "This function expects boxes of type list or tuple.", ) torch._assert( isinstance(rel_codes, torch.Tensor), "This function expects rel_codes of type torch.Tensor.", ) boxes_per_image = [b.size(0) for b in boxes] concat_boxes = torch.cat(boxes, dim=0) box_sum = 0 for val in boxes_per_image: box_sum += val if box_sum > 0: rel_codes = rel_codes.reshape(box_sum, -1) pred_boxes = self.decode_single(rel_codes, concat_boxes) if box_sum > 0: pred_boxes = pred_boxes.reshape(box_sum, -1, 4) return pred_boxes
def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]: anchors = [] cell_anchors = self.cell_anchors torch._assert(cell_anchors is not None, "cell_anchors should not be None") torch._assert( len(grid_sizes) == len(strides) == len(cell_anchors), "Anchors should be Tuple[Tuple[int]] because each feature " "map could potentially have different sizes and aspect ratios. " "There needs to be a match between the number of " "feature maps passed and the number of sizes / aspect ratios specified.", ) for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors): grid_height, grid_width = size stride_height, stride_width = stride device = base_anchors.device # For output anchor, compute [x_center, y_center, x_center, y_center] shifts_x = torch.arange( 0, grid_width, dtype=torch.int32, device=device) * stride_width shifts_y = torch.arange( 0, grid_height, dtype=torch.int32, device=device) * stride_height shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) # For every (base anchor, output anchor) pair, # offset each zero-centered base anchor by the center of the output anchor. anchors.append( (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape( -1, 4)) return anchors
def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None: """ Args: high_threshold (float): quality values greater than or equal to this value are candidate matches. low_threshold (float): a lower quality threshold used to stratify matches into three levels: 1) matches >= high_threshold 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold) 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold) allow_low_quality_matches (bool): if True, produce additional matches for predictions that have only low-quality match candidates. See set_low_quality_matches_ for more details. """ self.BELOW_LOW_THRESHOLD = -1 self.BETWEEN_THRESHOLDS = -2 torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold") self.high_threshold = high_threshold self.low_threshold = low_threshold self.allow_low_quality_matches = allow_low_quality_matches
def forward( self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]] ) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]: torch._assert( isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]), "images should be a list of tensors", ) torch._assert( isinstance(targets, (list, tuple)) and len(images) == len(targets), "targets should be a list of the same size as images", ) for target in targets: # Can not check for instance type dict with inside torch.jit.script # torch._assert(isinstance(target, dict), "targets item should be a dict") for k in ["masks", "boxes", "labels"]: torch._assert(k in target, f"Key {k} should be present in targets") torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor") # images = [t1, t2, ..., tN] # Let's define paste_images as shifted list of input images # paste_images = [t2, t3, ..., tN, t1] # FYI: in TF they mix data on the dataset level images_rolled = images[-1:] + images[:-1] targets_rolled = targets[-1:] + targets[:-1] output_images: List[torch.Tensor] = [] output_targets: List[Dict[str, Tensor]] = [] for image, target, paste_image, paste_target in zip( images, targets, images_rolled, targets_rolled): output_image, output_data = _copy_paste( image, target, paste_image, paste_target, blending=self.blending, resize_interpolation=self.resize_interpolation, ) output_images.append(output_image) output_targets.append(output_data) return output_images, output_targets
def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]): if isinstance(boxes, (list, tuple)): for _tensor in boxes: torch._assert( _tensor.size(1) == 4, "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]" ) elif isinstance(boxes, torch.Tensor): torch._assert( boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]") else: torch._assert( False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]") return
def forward(self, images, targets=None): # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] """ Args: images (list[Tensor]): images to be processed targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional) Returns: result (list[BoxList] or dict[Tensor]): the output from the model. During training, it returns a dict[Tensor] which contains the losses. During testing, it returns list[BoxList] contains additional fields like `scores`, `labels` and `mask` (for Mask R-CNN models). """ if self.training: if targets is None: torch._assert( False, "targets should not be none when in training mode") else: for target in targets: boxes = target["boxes"] if isinstance(boxes, torch.Tensor): torch._assert( len(boxes.shape) == 2 and boxes.shape[-1] == 4, f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.", ) else: torch._assert( False, f"Expected target boxes to be of type Tensor, got {type(boxes)}." ) original_image_sizes: List[Tuple[int, int]] = [] for img in images: val = img.shape[-2:] torch._assert( len(val) == 2, f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", ) original_image_sizes.append((val[0], val[1])) images, targets = self.transform(images, targets) # Check for degenerate boxes # TODO: Move this to a function if targets is not None: for target_idx, target in enumerate(targets): boxes = target["boxes"] degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] if degenerate_boxes.any(): # print the first degenerate box bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] degen_bb: List[float] = boxes[bb_idx].tolist() torch._assert( False, "All bounding boxes should have positive height and width." f" Found invalid box {degen_bb} for target at index {target_idx}.", ) features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): features = OrderedDict([("0", features)]) proposals, proposal_losses = self.rpn(images, features, targets) detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) detections = self.transform.postprocess( detections, images.image_sizes, original_image_sizes) # type: ignore[operator] losses = {} losses.update(detector_losses) losses.update(proposal_losses) if torch.jit.is_scripting(): if not self._has_warned: warnings.warn( "RCNN always returns a (Losses, Detections) tuple in scripting" ) self._has_warned = True return losses, detections else: return self.eager_outputs(losses, detections)
def forward(self, input: torch.Tensor): torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") input = input + self.pos_embedding return self.ln(self.layers(self.dropout(input)))
def __init__( self, image_size: int, patch_size: int, num_layers: int, num_heads: int, hidden_dim: int, mlp_dim: int, dropout: float = 0.0, attention_dropout: float = 0.0, num_classes: int = 1000, representation_size: Optional[int] = None, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), conv_stem_configs: Optional[List[ConvStemConfig]] = None, ): super().__init__() _log_api_usage_once(self) torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") self.image_size = image_size self.patch_size = patch_size self.hidden_dim = hidden_dim self.mlp_dim = mlp_dim self.attention_dropout = attention_dropout self.dropout = dropout self.num_classes = num_classes self.representation_size = representation_size self.norm_layer = norm_layer if conv_stem_configs is not None: # As per https://arxiv.org/abs/2106.14881 seq_proj = nn.Sequential() prev_channels = 3 for i, conv_stem_layer_config in enumerate(conv_stem_configs): seq_proj.add_module( f"conv_bn_relu_{i}", Conv2dNormActivation( in_channels=prev_channels, out_channels=conv_stem_layer_config.out_channels, kernel_size=conv_stem_layer_config.kernel_size, stride=conv_stem_layer_config.stride, norm_layer=conv_stem_layer_config.norm_layer, activation_layer=conv_stem_layer_config.activation_layer, ), ) prev_channels = conv_stem_layer_config.out_channels seq_proj.add_module( "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1) ) self.conv_proj: nn.Module = seq_proj else: self.conv_proj = nn.Conv2d( in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size ) seq_length = (image_size // patch_size) ** 2 # Add a class token self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) seq_length += 1 self.encoder = Encoder( seq_length, num_layers, num_heads, hidden_dim, mlp_dim, dropout, attention_dropout, norm_layer, ) self.seq_length = seq_length heads_layers: OrderedDict[str, nn.Module] = OrderedDict() if representation_size is None: heads_layers["head"] = nn.Linear(hidden_dim, num_classes) else: heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) heads_layers["act"] = nn.Tanh() heads_layers["head"] = nn.Linear(representation_size, num_classes) self.heads = nn.Sequential(heads_layers) if isinstance(self.conv_proj, nn.Conv2d): # Init the patchify stem fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) if self.conv_proj.bias is not None: nn.init.zeros_(self.conv_proj.bias) elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d): # Init the last 1x1 conv of the conv stem nn.init.normal_( self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels) ) if self.conv_proj.conv_last.bias is not None: nn.init.zeros_(self.conv_proj.conv_last.bias) if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): fan_in = self.heads.pre_logits.in_features nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) nn.init.zeros_(self.heads.pre_logits.bias) if isinstance(self.heads.head, nn.Linear): nn.init.zeros_(self.heads.head.weight) nn.init.zeros_(self.heads.head.bias)
def __init__( self, image_size: int, patch_size: int, num_layers: int, num_heads: int, hidden_dim: int, mlp_dim: int, dropout: float = 0.0, attention_dropout: float = 0.0, num_classes: int = 1000, representation_size: Optional[int] = None, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): super().__init__() _log_api_usage_once(self) torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") self.image_size = image_size self.patch_size = patch_size self.hidden_dim = hidden_dim self.mlp_dim = mlp_dim self.attention_dropout = attention_dropout self.dropout = dropout self.num_classes = num_classes self.representation_size = representation_size self.norm_layer = norm_layer input_channels = 3 # The conv_proj is a more efficient version of reshaping, permuting # and projecting the input self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size) seq_length = (image_size // patch_size) ** 2 # Add a class token self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) seq_length += 1 self.encoder = Encoder( seq_length, num_layers, num_heads, hidden_dim, mlp_dim, dropout, attention_dropout, norm_layer, ) self.seq_length = seq_length heads_layers: OrderedDict[str, nn.Module] = OrderedDict() if representation_size is None: heads_layers["head"] = nn.Linear(hidden_dim, num_classes) else: heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) heads_layers["act"] = nn.Tanh() heads_layers["head"] = nn.Linear(representation_size, num_classes) self.heads = nn.Sequential(heads_layers) self._init_weights()
def forward(self, x): torch._assert(x.sum() > 0, "foo") return x
def interpolate_embeddings( image_size: int, patch_size: int, model_state: "OrderedDict[str, torch.Tensor]", interpolation_mode: str = "bicubic", reset_heads: bool = False, ) -> "OrderedDict[str, torch.Tensor]": """This function helps interpolating positional embeddings during checkpoint loading, especially when you want to apply a pre-trained model on images with different resolution. Args: image_size (int): Image size of the new model. patch_size (int): Patch size of the new model. model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. reset_heads (bool): If true, not copying the state of heads. Default: False. Returns: OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. """ # Shape of pos_embedding is (1, seq_length, hidden_dim) pos_embedding = model_state["encoder.pos_embedding"] n, seq_length, hidden_dim = pos_embedding.shape if n != 1: raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") new_seq_length = (image_size // patch_size) ** 2 + 1 # Need to interpolate the weights for the position embedding. # We do this by reshaping the positions embeddings to a 2d grid, performing # an interpolation in the (h, w) space and then reshaping back to a 1d grid. if new_seq_length != seq_length: # The class token embedding shouldn't be interpolated so we split it up. seq_length -= 1 new_seq_length -= 1 pos_embedding_token = pos_embedding[:, :1, :] pos_embedding_img = pos_embedding[:, 1:, :] # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) pos_embedding_img = pos_embedding_img.permute(0, 2, 1) seq_length_1d = int(math.sqrt(seq_length)) torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!") # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) new_seq_length_1d = image_size // patch_size # Perform interpolation. # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) new_pos_embedding_img = nn.functional.interpolate( pos_embedding_img, size=new_seq_length_1d, mode=interpolation_mode, align_corners=True, ) # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length) new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim) new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1) model_state["encoder.pos_embedding"] = new_pos_embedding if reset_heads: model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict() for k, v in model_state.items(): if not k.startswith("heads"): model_state_copy[k] = v model_state = model_state_copy return model_state
def forward(self, x): torch._assert(x.shape[1] > 4, "assert_foobar") return x
def forward( self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None, ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: """ Args: images (list[Tensor]): images to be processed targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) Returns: result (list[BoxList] or dict[Tensor]): the output from the model. During training, it returns a dict[Tensor] which contains the losses. During testing, it returns list[BoxList] contains additional fields like `scores`, `labels` and `mask` (for Mask R-CNN models). """ if self.training: if targets is None: torch._assert(False, "targets should not be none when in training mode") else: for target in targets: boxes = target["boxes"] torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.") torch._assert( len(boxes.shape) == 2 and boxes.shape[-1] == 4, f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.", ) original_image_sizes: List[Tuple[int, int]] = [] for img in images: val = img.shape[-2:] torch._assert( len(val) == 2, f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", ) original_image_sizes.append((val[0], val[1])) # transform the input images, targets = self.transform(images, targets) # Check for degenerate boxes if targets is not None: for target_idx, target in enumerate(targets): boxes = target["boxes"] degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] if degenerate_boxes.any(): # print the first degenerate box bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] degen_bb: List[float] = boxes[bb_idx].tolist() torch._assert( False, f"All bounding boxes should have positive height and width. Found invalid box {degen_bb} for target at index {target_idx}.", ) # get the features from the backbone features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): features = OrderedDict([("0", features)]) features = list(features.values()) # compute the fcos heads outputs using the features head_outputs = self.head(features) # create the set of anchors anchors = self.anchor_generator(images, features) # recover level sizes num_anchors_per_level = [x.size(2) * x.size(3) for x in features] losses = {} detections: List[Dict[str, Tensor]] = [] if self.training: if targets is None: torch._assert(False, "targets should not be none when in training mode") else: # compute the losses losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level) else: # split outputs per level split_head_outputs: Dict[str, List[Tensor]] = {} for k in head_outputs: split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1)) split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors] # compute the detections detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes) detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) if torch.jit.is_scripting(): if not self._has_warned: warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting") self._has_warned = True return losses, detections return self.eager_outputs(losses, detections)