def compute_scale( self, F, data: Tensor, observed_indicator: Tensor # shapes (N, T, C) ) -> Tensor: # these will have shape (N, C) num_observed = F.sum(observed_indicator, axis=1) sum_observed = (data.abs() * observed_indicator).sum(axis=1) # first compute a global scale per-dimension total_observed = num_observed.sum(axis=0) denominator = F.maximum(total_observed, 1.0) default_scale = sum_observed.sum(axis=0) / denominator # shape (C, ) # then compute a per-item, per-dimension scale denominator = F.maximum(num_observed, 1.0) scale = sum_observed / denominator # shape (N, C) # use per-batch scale when no element is observed # or when the sequence contains only zeros cond = F.broadcast_greater(sum_observed, F.zeros_like(sum_observed)) scale = F.where( cond, scale, F.broadcast_mul(default_scale, F.ones_like(num_observed)), ) return F.maximum(scale, self.scale_min)
def compute_scale( self, F, data: Tensor, observed_indicator: Tensor, # shapes (N, T, C) or (N, C, T) ) -> Tensor: """ Parameters ---------- F A module that can either refer to the Symbol API or the NDArray API in MXNet. data tensor containing the data to be scaled. observed_indicator observed_indicator: binary tensor with the same shape as ``data``, that has 1 in correspondence of observed data points, and 0 in correspondence of missing data points. Returns ------- Tensor shape (N, C), computed according to the average absolute value over time of the observed values. """ # these will have shape (N, C) num_observed = F.sum(observed_indicator, axis=self.axis) sum_observed = (data.abs() * observed_indicator).sum(axis=self.axis) # first compute a global scale per-dimension total_observed = num_observed.sum(axis=0) denominator = F.maximum(total_observed, 1.0) if self.default_scale is not None: default_scale = self.default_scale * F.ones_like(num_observed) else: # shape (C, ) default_scale = sum_observed.sum(axis=0) / denominator # then compute a per-item, per-dimension scale denominator = F.maximum(num_observed, 1.0) scale = sum_observed / denominator # shape (N, C) # use per-batch scale when no element is observed # or when the sequence contains only zeros cond = F.broadcast_greater(sum_observed, F.zeros_like(sum_observed)) scale = F.where( cond, scale, F.broadcast_mul(default_scale, F.ones_like(num_observed)), ) return F.maximum(scale, self.minimum_scale)