def forward(self, scores, features, image_size): num_targets = scores.shape[0] num_fmaps = features[next(iter(self.ft_channels))].shape[0] if num_targets > num_fmaps: multi_targets = True else: multi_targets = False x = None for i, L in enumerate(self.ft_channels): ft = features[L] s = interpolate(scores, ft.shape[-2:]) # Resample scores to match features size if multi_targets: h, hpool = self.TSE[L](ft.repeat(num_targets, 1, 1, 1), s, x) else: h, hpool = self.TSE[L](ft, s, x) h = self.RRB1[L](h) h = self.CAB[L](hpool, h) x = self.RRB2[L](h) x = self.project(x, image_size) return x
def forward(self, deeper, shallower): shallow_pool = F.adaptive_avg_pool2d(shallower, (1, 1)) deeper_pool = deeper if self.deepest else F.adaptive_avg_pool2d(deeper, (1, 1)) global_pool = torch.cat((shallow_pool, deeper_pool), dim=1) conv_1x1 = self.convreluconv(global_pool) inputs = shallower * torch.sigmoid(conv_1x1) out = inputs + interpolate(deeper, inputs.shape[-2:]) return out
def _align_time(self): attr_list = ['ppg', 'acc', 'hr', 'activity'] self._get_global_timestamps().astype('datetime64[ms]').astype(int) for attr_name in attr_list: attr = getattr(self, attr_name) try: time_attr = getattr(self, f'time_{attr_name}').astype('datetime64[ms]').astype(int) except AttributeError: time_attr = getattr(self, 'time_sensors').astype('datetime64[ms]').astype(int) if attr is not None: time_for_interp = self.time.astype(int) attr_interp = interpolate(time_for_interp, time_attr, attr) setattr(self, attr_name, attr_interp)
def forward(self, scores, features, image_size): x = None for i, L in enumerate(self.ft_channels): ft = features[L] s = interpolate(scores, ft.shape[-2:]) # Resample scores to match features size h, hpool = self.TSE[L](ft, s, x) h = self.RRB1[L](h) h = self.CAB[L](hpool, h) x = self.RRB2[L](h) x = self.project(x, image_size) return x
def _forward(self, image): batch_size = image.shape[0] features = self.feature_extractor(image) scores = [] ft = features[self.tmodels[0].discriminator.layer] for i, tmdl in zip(range(batch_size), self.tmodels): x = ft[i, None] s = tmdl.classify(x) scores.append(s) scores = torch.cat(scores, dim=0) y = self.refiner(scores, features, image.shape) y = interpolate(y, image.shape[-2:]) return torch.sigmoid(y)