def compute_loss(y, gt_map): # 1. segmentation loss target_labels = torch.argmax(gt_map[:, MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND + 1], dim=1) loss_seg = F.cross_entropy(y[:, MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND + 1], target_labels) # 2. geometry loss # distances to all sides of aabb t = torch.minimum(y[:, MapOrdering.GEO_TOP], gt_map[:, MapOrdering.GEO_TOP]) b = torch.minimum(y[:, MapOrdering.GEO_BOTTOM], gt_map[:, MapOrdering.GEO_BOTTOM]) l = torch.minimum(y[:, MapOrdering.GEO_LEFT], gt_map[:, MapOrdering.GEO_LEFT]) r = torch.minimum(y[:, MapOrdering.GEO_RIGHT], gt_map[:, MapOrdering.GEO_RIGHT]) # area of predicted aabb y_width = y[:, MapOrdering.GEO_LEFT, ...] + y[:, MapOrdering.GEO_RIGHT, ...] y_height = y[:, MapOrdering.GEO_TOP, ...] + y[:, MapOrdering.GEO_BOTTOM, ...] area1 = y_width * y_height # area of gt aabb gt_width = gt_map[:, MapOrdering.GEO_LEFT, ...] + gt_map[:, MapOrdering.GEO_RIGHT, ...] gt_height = gt_map[:, MapOrdering.GEO_TOP, ...] + gt_map[:, MapOrdering.GEO_BOTTOM, ...] area2 = gt_width * gt_height # compute intersection over union intersection = (r + l) * (b + t) union = area1 + area2 - intersection eps = 0.01 # avoid division by 0 iou = intersection / (union + eps) iou = iou[gt_map[:, MapOrdering.SEG_WORD] > 0] loss_aabb = -torch.log(torch.mean(iou)) # total loss is simply the sum of both losses loss = loss_seg + loss_aabb return loss
def small_IoU(box_i, box_j): device = box_i.device dtype = box_i.dtype pw_a = box_i[2] - box_i[0] ph_a = box_i[3] - box_i[1] area_p = (pw_a * ph_a) bw_a = box_j[2] - box_j[0] bh_a = box_j[3] - box_j[1] area_b = (bw_a * bh_a) area_u = area_p + area_b x_val = torch.minimum(box_i[2], box_j[2]) - torch.maximum( box_i[0], box_j[0]) x_val_zero = torch.zeros(x_val.shape, device=device, dtype=dtype) y_val = torch.minimum(box_i[3], box_j[3]) - torch.maximum( box_i[1], box_j[1]) y_val_zero = torch.zeros(y_val.shape, device=device, dtype=dtype) area_i = torch.maximum(x_val, x_val_zero) * torch.maximum( y_val, y_val_zero) area_u -= area_i return area_i / area_u
def compute_iou(pred, gt): """ Calculates IoU (Jaccard index) of two sets of bboxes: IOU = pred ∩ gt / (area(pred) + area(gt) - pred ∩ gt) Parameters: Coordinates of bboxes are supposed to be in the following form: [x1, y1, x2, y2] pred (torch.tensor): predicted bboxes gt (torch.tensor): ground truth bboxes Return value: iou (torch.tensor): intersection over union """ def get_box_area(box): return (box[:, 2] - box[:, 0] + 1.) * (box[:, 3] - box[:, 1] + 1.) #_gt = torch.tile(gt, (pred.shape[0], 1)) _gt = gt.repeat(pred.shape[0], 1) _pred = torch.repeat_interleave(pred, gt.shape[0], dim=0) ixmin = torch.maximum(_gt[:, 0], _pred[:, 0]) iymin = torch.maximum(_gt[:, 1], _pred[:, 1]) ixmax = torch.minimum(_gt[:, 2], _pred[:, 2]) iymax = torch.minimum(_gt[:, 3], _pred[:, 3]) width = torch.maximum(ixmax - ixmin + 1., torch.tensor(0)) height = torch.maximum(iymax - iymin + 1., torch.tensor(0)) intersection_area = width * height union_area = get_box_area(_gt) + get_box_area(_pred) - intersection_area iou = (intersection_area / union_area).reshape(pred.shape[0], gt.shape[0]) return iou
def find_active_constraints(x, lb, ub, rtol=1e-10): """Determine which constraints are active in a given point. The threshold is computed using `rtol` and the absolute value of the closest bound. Returns ------- active : ndarray of int with shape of x Each component shows whether the corresponding constraint is active: * 0 - a constraint is not active. * -1 - a lower bound is active. * 1 - a upper bound is active. """ active = torch.zeros_like(x, dtype=torch.long) if rtol == 0: active[x <= lb] = -1 active[x >= ub] = 1 return active lower_dist = x - lb upper_dist = ub - x lower_threshold = rtol * lb.abs().clamp(1, None) upper_threshold = rtol * ub.abs().clamp(1, None) lower_active = (lb.isfinite() & (lower_dist <= torch.minimum(upper_dist, lower_threshold))) active[lower_active] = -1 upper_active = (ub.isfinite() & (upper_dist <= torch.minimum(lower_dist, upper_threshold))) active[upper_active] = 1 return active
def rbox_loss(pred_geometry_map: Tensor, target_geometry_map: Tensor, train_ignore_mask: Tensor, train_boundary_mask: Tensor, angle_lambda: int = 10) -> Tensor: pred_top, pred_right, pred_bottom, pred_left, pred_angle = torch.split( pred_geometry_map, split_size_or_sections=[1, 1, 1, 1, 1], dim=2) target_top, target_right, target_bottom, target_left, target_angle = torch.split( target_geometry_map, split_size_or_sections=[1, 1, 1, 1, 1], dim=2) pred_area = (pred_top + pred_bottom) * (pred_left + pred_right) target_area = (target_top + target_bottom) * (target_left + target_right) h_inter = torch.minimum(pred_right, target_right) + torch.minimum( pred_left, target_left) w_inter = torch.minimum(pred_top, target_top) + torch.minimum( pred_bottom, target_bottom) intersection = h_inter * w_inter union = pred_area + target_area - intersection box_loss = -torch.log((intersection + 1e-6) / (union + 1e-6)) angle_loss = 1 - torch.cos(pred_angle - target_angle) rbox_loss = box_loss + angle_lambda * angle_loss train_mask = torch.unsqueeze(train_ignore_mask * train_boundary_mask, dim=2) rbox_loss = torch.mean( torch.sum(rbox_loss * target_geometry_map * train_mask)) return rbox_loss
def bbox_iou(bboxes1, bboxes2): eps = 0.001 EPS = 1.0e+20 x1, y1 = bboxes1[..., 0], bboxes1[..., 1] w1, h1 = bboxes1[..., 2], bboxes1[..., 3] x2, y2 = bboxes2[..., 0], bboxes2[..., 1] w2, h2 = bboxes2[..., 2], bboxes2[..., 3] xmin1 = x1 - (w1 * 0.5) xmax1 = x1 + (w1 * 0.5) ymin1 = y1 - (h1 * 0.5) ymax1 = y1 + (h1 * 0.5) xmin2 = x2 - (w2 * 0.5) xmax2 = x2 + (w2 * 0.5) ymin2 = y2 - (h2 * 0.5) ymax2 = y2 + (h2 * 0.5) xmini = torch.maximum(xmin1, xmin2) ymini = torch.maximum(ymin1, ymin2) xmaxi = torch.minimum(xmax1, xmax2) ymaxi = torch.minimum(ymax1, ymax2) wi = torch.clamp(xmaxi - xmini, min=0.0) hi = torch.clamp(ymaxi - ymini, min=0.0) area1 = torch.clamp(w1 * h1, max=EPS) area2 = torch.clamp(w2 * h2, max=EPS) areai = torch.clamp(wi * hi, max=EPS) areau = area1 + area2 - areai iou = areai / (areau + eps) return iou
def multiplex(self, subframes: torch.Tensor): """ Assuming that the number of subframes is divisible by the number of pixles in a neighborhood (n_pixels), bucket 1 is active in the following cases: * pixel 1 of each neighborhood in the first S / n_pixels * pixel 2 of each neighborhood in the second S / n_pixels * pixel 3 of each neighborhood in the third S / n_pixels ... * last pixel of each neighborhood in the last S / n_pixels subframes: intensities of subframes of a frame as a numpy array with shape (S, H, W) nbhd_size: tuple of width and height of the neighborhood (For now W must be divisible by the width of the nbhd and H must be divisible by the height of the neighborhood) """ # S, height, width = raw_subframes.shape # scheme_ = torch.tile(self.scheme, (height // self.nbhd_height, width // self.nbhd_width)).to(raw_subframes.device) # shape: (n_pixels, height, width) # scheme_ = scheme_.unsqueeze(1).repeat(1, S // n_pixels, 1, 1).view(S, height, width) # shape: (S, height, width) c2b_frame_bucket0 = torch.minimum(torch.sum(subframes * self.scheme_.unsqueeze(1), dim=1), self.max_intensity.to(self.device)) c2b_frame_bucket1 = torch.minimum(torch.sum(subframes * (1 - self.scheme_.unsqueeze(1)), dim=1), self.max_intensity.to(self.device)) if self.noise_std is not None: c2b_frame_bucket0 += torch.randn_like(c2b_frame_bucket0) * self.noise_std c2b_frame_bucket1 += torch.randn_like(c2b_frame_bucket1) * self.noise_std return c2b_frame_bucket0, c2b_frame_bucket1
def non_maximum_suppress(confidences, coordinates, over_thres=.4): """ Params TODO: wrapping the non_maximum_suppress and channel nms as an integral module of the neural network TODO: Considering is it possible to archive different overlap threshold on rbc cells and platelet. ------ confidences: {Tensor: (n_posi_m,)}, coordinates: {Tensor: (n_posi_m, x-y-w-h)}, class_scores: {Tensor: (n_posi_m, n_class)}, over_thresh: float """ x1, y1, x2, y2 = coordinates.permute(1, 0) areas = (x2 - x1) * (y2 - y1) order = confidences.argsort() pick = [] while len(order): pick.append(idx := order[-1]) xx1 = torch.maximum(x1[order[:-1]], x1[idx]) yy1 = torch.maximum(y1[order[:-1]], y1[idx]) xx2 = torch.minimum(x2[order[:-1]], x2[idx]) yy2 = torch.minimum(y2[order[:-1]], y2[idx]) # inter: w * h inter = torch.clamp(xx2 - xx1, min=0.) * torch.clamp(yy2 - yy1, min=0.) over = inter / (areas[idx] + areas[order[:-1]] - inter) order = order[:-1][over < over_thres] return pick
def compute_crop_pad_image_location( bbox_tight: "torch.Tensor", image: "torch.Tensor" ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: """ Get the valid image coordinates for the context region in target or search region in full image :param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2]. :param image: Frame to be cropped and padded. :return: x-coordinate of the bounding box center. """ # Center of the bounding box # bbox_center_x = bbox_tight.get_center_x() # bbox_center_y = bbox_tight.get_center_y() bbox_center_x = get_center_x_f(bbox_tight) bbox_center_y = get_center_y_f(bbox_tight) image_height = image.shape[0] image_width = image.shape[1] # Padded output width and height # output_width = bbox_tight.compute_output_width() # output_height = bbox_tight.compute_output_height() output_width = compute_output_width_f(bbox_tight) output_height = compute_output_height_f(bbox_tight) roi_left = torch.maximum( torch.tensor(0.0).to(self.device), bbox_center_x - (output_width / 2.0)) roi_bottom = torch.maximum( torch.tensor(0.0).to(self.device), bbox_center_y - (output_height / 2.0)) # New ROI width # ------------- # 1. left_half should not go out of bound on the left side of the # image # 2. right_half should not go out of bound on the right side of the # image left_half = torch.minimum(output_width / 2.0, bbox_center_x) right_half = torch.minimum(output_width / 2.0, image_width - bbox_center_x) roi_width = torch.maximum( torch.tensor(1.0).to(self.device), left_half + right_half) # New ROI height # Similar logic applied that is applied for 'New ROI width' top_half = torch.minimum(output_height / 2.0, bbox_center_y) bottom_half = torch.minimum(output_height / 2.0, image_height - bbox_center_y) roi_height = torch.maximum( torch.tensor(1.0).to(self.device), top_half + bottom_half) # Padded image location in the original image # objPadImageLocation = BoundingBox(roi_left, roi_bottom, roi_left + roi_width, roi_bottom + roi_height) # # return objPadImageLocation return roi_left, roi_bottom, roi_left + roi_width, roi_bottom + roi_height
def forward(self, x): start, end = self._logits(x) _, start_idx = start.max(dim=1) _, end_idx = end.max(dim=1) _, ctxs_len = (x[:, 0] == self.info.sep_token).max(dim=1) start_idx = torch.minimum(start_idx, ctxs_len) end_idx = torch.minimum(end_idx, ctxs_len) return start_idx, end_idx
def limiter(cr): return torch.maximum( torch.tensor([0.0], device=cr.device), torch.maximum( torch.minimum(torch.tensor([1.0], device=cr.device), 2 * cr), torch.minimum(torch.tensor([2.0], device=cr.device), cr), ), )
def layout_bbox(self, final_pred, batch_size, num_bboxes, num_classes, output_height, output_width): # 5, 188, 20 final_pred = torch.reshape(final_pred, [batch_size, num_bboxes, 4 + num_classes]) #print('Final pred:',final_pred.shape) return self.rectangle_render(final_pred) 0 / 0 final_pred = torch.reshape(final_pred, [batch_size, 4 + num_classes, num_bboxes]) print('Final pred requires grad:', final_pred.requires_grad) bbox_reg = final_pred[:, :4, :] cls_prob = final_pred[:, 4:, :] print('bbox requires grad:', bbox_reg.requires_grad) bbox_reg = torch.reshape(bbox_reg, [batch_size, num_bboxes, 4]) x_c = bbox_reg[:, :, 0] * output_width y_c = bbox_reg[:, :, 1] * output_height w = bbox_reg[:, :, 2] * output_width h = bbox_reg[:, :, 3] * output_height x1 = x_c - 0.5 * w x2 = x_c + 0.5 * w y1 = y_c - 0.5 * h y2 = y_c + 0.5 * h xt = torch.reshape( torch.range(start=0, end=output_width, dtype=torch.float32), [1, 1, 1, -1]) xt = torch.reshape( torch.tile(xt, [batch_size, num_bboxes, output_height, 1]), [batch_size, num_bboxes, -1]) yt = torch.reshape( torch.range(start=0, end=output_height, dtype=torch.float32), [1, 1, 1, -1]) yt = torch.reshape( torch.tile(yt, [batch_size, num_bboxes, 1, output_width]), [batch_size, num_bboxes, -1]) x1_diff = torch.reshape( xt - x1, [batch_size, num_bboxes, output_height, output_width, 1]) y1_diff = torch.reshape( yt - y1, [batch_size, num_bboxes, output_height, output_width, 1]) x2_diff = torch.reshape( x2 - xt, [batch_size, num_bboxes, output_height, output_width, 1]) y2_diff = torch.reshape( y2 - yt, [batch_size, num_bboxes, output_height, output_width, 1]) x1_line = self.relu(1.0 - torch.abs(x1_diff)) * torch.minimum( self.relu(y1_diff), 1.0) * torch.minimum(self.relu(y2_diff), 1.0) print(x1_line.shape) print(x1_line) 0 / 0
def hinge_adv_loss(real_fake_logits_real, real_fake_logits_fake, device): """ the hinge version of the adversarial loss :param real_fake_logits_real: ``Tensor([1, 5, 5])`` :param real_fake_logits_fake: ``Tensor([1, 5, 5])`` :param device: torch device :return: ``float``, discriminator loss """ threshold = torch.Tensor([0.0]).to(device) real_loss = -1 * torch.mean(torch.minimum(threshold, -1 + real_fake_logits_real)) fake_loss = -1 * torch.mean(torch.minimum(threshold, -1 - real_fake_logits_fake)) return real_loss + fake_loss
def bbox_overlaps_ciou(bboxes1, bboxes2): bboxes1 = convert_box(bboxes1) bboxes2 = convert_box(bboxes2) rows = bboxes1.shape[0] cols = bboxes2.shape[0] cious = torch.zeros((rows, cols)) if rows * cols == 0: return cious exchange = False if bboxes1.shape[0] > bboxes2.shape[0]: bboxes1, bboxes2 = bboxes2, bboxes1 cious = torch.zeros((cols, rows)) exchange = True w1 = bboxes1[:, 2] - bboxes1[:, 0] h1 = bboxes1[:, 3] - bboxes1[:, 1] w2 = bboxes2[:, 2] - bboxes2[:, 0] h2 = bboxes2[:, 3] - bboxes2[:, 1] area1 = w1 * h1 area2 = w2 * h2 center_x1 = (bboxes1[:, 2] + bboxes1[:, 0]) / 2 center_y1 = (bboxes1[:, 3] + bboxes1[:, 1]) / 2 center_x2 = (bboxes2[:, 2] + bboxes2[:, 0]) / 2 center_y2 = (bboxes2[:, 3] + bboxes2[:, 1]) / 2 inter_max_xy = torch.minimum(bboxes1[:, 2:], bboxes2[:, 2:]) inter_min_xy = torch.maximum(bboxes1[:, :2], bboxes2[:, :2]) out_max_xy = torch.maximum(bboxes1[:, 2:], bboxes2[:, 2:]) out_min_xy = torch.minimum(bboxes1[:, :2], bboxes2[:, :2]) inter = torch.clamp((inter_max_xy - inter_min_xy), min=0) inter_area = inter[:, 0] * inter[:, 1] inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2 outer = torch.clamp((out_max_xy - out_min_xy), min=0) outer_diag = (outer[:, 0]**2) + (outer[:, 1]**2) union = area1 + area2 - inter_area u = (inter_diag) / outer_diag iou = inter_area / union with torch.no_grad(): arctan = torch.atan(w2 / h2) - torch.atan(w1 / h1) v = (4 / (math.pi**2)) * torch.pow( (torch.atan(w2 / h2) - torch.atan(w1 / h1)), 2) S = 1 - iou alpha = v / (S + v) w_temp = 2 * w1 ar = (8 / (math.pi**2)) * arctan * ((w1 - w_temp) * h1) cious = iou - (u + alpha * ar) cious = torch.clamp(cious, min=-1.0, max=1.0) if exchange: cious = cious.T return cious
def forward(self, discriminator_prediction_real: torch.Tensor, discriminator_prediction_fake: torch.Tensor, **kwargs) -> torch.Tensor: """ Forward pass. :param discriminator_prediction_real: (torch.Tensor) Raw discriminator prediction for real samples :param discriminator_prediction_fake: (torch.Tensor) Raw discriminator predictions for fake samples :return: (torch.Tensor) Hinge discriminator GAN loss """ return - torch.minimum(torch.tensor(0., dtype=torch.float, device=discriminator_prediction_real.device), discriminator_prediction_real - 1.).mean() \ - torch.minimum(torch.tensor(0., dtype=torch.float, device=discriminator_prediction_fake.device), - discriminator_prediction_fake - 1.).mean()
def forward(ctx, X, rank: int = 100): U, S, V = torch.svd(X, compute_uv=True, some=False) S = torch.diag(S[0:(rank - 1)]) U = torch.matmul(U[:, 0:(rank - 1)], S) V = torch.transpose(V, 0, 1)[0:(rank - 1), :] x, y = X.shape Unew = U[:, 0] Vnew = V[0, :] __U = torch.where(torch.less(torch.min(V[0, :]), torch.min(-V[0, :])), -(Unew.view(x, 1)), Unew.view(x, 1)) __V = torch.where(torch.less(torch.min(V[0, :]), torch.min(-V[0, :])), -(Vnew.view(1, y)), Vnew.view(1, y)) if rank > 2: for i in range(1, rank - 1): Unew = Unew.view(x, 1) Vnew = Vnew.view(1, y) __U = torch.where( torch.less(torch.min(V[0, :]), torch.min(-V[0, :])), torch.cat((__U, -Unew), dim=1), torch.cat((__U, Unew), dim=1)) __V = torch.where( torch.less(torch.min(V[0, :]), torch.min(-V[0, :])), torch.cat((__V, -Vnew), dim=0), torch.cat((__V, Vnew), dim=0)) if rank == 2: A = torch.cat((U, -U), dim=1) else: Un = torch.transpose(-(torch.sum(U, dim=1)), 0, -1).view(x, 1) A = torch.cat((U, Un), dim=1) B = torch.cat((V, torch.zeros((1, y))), dim=0) if rank >= 3: b, _ = torch.min(V, dim=0) B = torch.subtract(B, torch.minimum(torch.tensor(0.), b)) else: B = torch.subtract(B, torch.minimum(torch.tensor(0.), V)) x = torch.tensor(x) y = torch.tensor(y) normalize = torch.sqrt(torch.multiply(x, y).type(torch.FloatTensor)) norm = torch.norm(A) return torch.multiply(torch.div(A, norm), normalize), torch.div(torch.multiply(B, norm), normalize)
def _train(self, BATCH): q1 = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2 = self.critic2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q1_eval = (q1 * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q2_eval = (q2 * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q1_log_probs = (q1 / (self.alpha + th.finfo().eps)).log_softmax(-1) # [T, B, A] q1_entropy = -(q1_log_probs.exp() * q1_log_probs).sum(-1, keepdim=True).mean() # 1 q1_target = self.critic.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q2_target = self.critic2.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q1_target_max = q1_target.max(-1, keepdim=True)[0] # [T, B, 1] q1_target_log_probs = (q1_target / (self.alpha + th.finfo().eps)).log_softmax(-1) # [T, B, A] q1_target_entropy = -(q1_target_log_probs.exp() * q1_target_log_probs).sum(-1, keepdim=True) # [T, B, 1] q2_target_max = q2_target.max(-1, keepdim=True)[0] # [T, B, 1] # q2_target_log_probs = q2_target.log_softmax(-1) # q2_target_log_max = q2_target_log_probs.max(1, keepdim=True)[0] q_target = th.minimum(q1_target_max, q2_target_max) + self.alpha * q1_target_entropy # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1_eval - dc_r # [T, B, 1] td_error2 = q2_eval - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 loss = 0.5 * (q1_loss + q2_loss) self.critic_oplr.optimize(loss) summaries = { 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/loss': loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/q1_entropy': q1_entropy, 'Statistics/q_min': th.minimum(q1, q2).mean(), 'Statistics/q_mean': q1.mean(), 'Statistics/q_max': th.maximum(q1, q2).mean() } if self.auto_adaption: alpha_loss = -(self.alpha * (self.target_entropy - q1_entropy).detach()).mean() self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries
def calculate_iou(best_pred, preds, areas): # Remove bboxes with IOU >= Thresh max_min_x = torch.maximum(best_pred[1], preds[1:, 1]) max_min_y = torch.maximum(best_pred[2], preds[1:, 2]) min_max_x = torch.minimum(best_pred[3], preds[1:, 3]) min_max_y = torch.minimum(best_pred[4], preds[1:, 4]) intersection_x = min_max_x - max_min_x intersection_y = min_max_y - max_min_y intersection_area = intersection_x * intersection_y iou = intersection_area / (areas[0] + areas[1:] - intersection_area) return iou
def reduce_relu(self, nodes): w = torch.exp(self.w) R = torch.clamp(self.R, 0.000001, 0.999999) msg = w * nodes.mailbox['m'] + self.b fsum = torch.sum(torch.maximum(msg, R * msg), dim=1) out_h = (torch.minimum(fsum, fsum / R) - self.b) / w return {'sum_sigma_h': out_h}
def comparison_ops(self): a = torch.randn(4) b = torch.randn(4) return ( torch.allclose(a, b), torch.argsort(a), torch.eq(a, b), torch.equal(a, b), torch.ge(a, b), torch.greater_equal(a, b), torch.gt(a, b), torch.greater(a, b), torch.isclose(a, b), torch.isfinite(a), torch.isin(a, b), torch.isinf(a), torch.isposinf(a), torch.isneginf(a), torch.isnan(a), torch.isreal(a), torch.kthvalue(a, 1), torch.le(a, b), torch.less_equal(a, b), torch.lt(a, b), torch.less(a, b), torch.maximum(a, b), torch.minimum(a, b), torch.fmax(a, b), torch.fmin(a, b), torch.ne(a, b), torch.not_equal(a, b), torch.sort(a), torch.topk(a, 1), torch.msort(a), )
def orthogonalized_raised_cosines(cls, dt, last_time_peak, n, b, a=1e0, weight=None): range_locs = torch.log(torch.tensor([0, last_time_peak]) + b) delta = (range_locs[1] - range_locs[0]) / (n - 1) locs = torch.linspace(range_locs[0], range_locs[1], n) last_time = torch.exp(range_locs[1] + 2 * delta / a) - b t = torch.arange(0, last_time, dt) support = torch.tensor([t[0], t[-1] + dt]) pi_torch = torch.tensor([pi]) raised_cosines = torch.minimum( a * (torch.log(t[:, None] + b) - locs[None, :]) * pi / delta / 2, pi_torch) raised_cosines = ( 1 + torch.cos(torch.maximum(-pi_torch, raised_cosines))) / 2 raised_cosines = raised_cosines / torch.sqrt( torch.sum(raised_cosines**2, 0)) u, s, v = torch.linalg.svd(raised_cosines) basis = u[:, :n] return cls(basis=basis, support=support, weight=weight)
def _get_slate_size(self, state: rlt.FeatureData) -> torch.Tensor: """Get the actual size (ignore all padded items) of each slate by summing item masks.""" mask = self._get_item_mask(state) return torch.minimum( mask.sum(1, keepdim=True), torch.tensor([self.slate_size], device=mask.device), )
def step(self, closure=None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() device = self.param_groups[0]['params'][0].device one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] nesterov = group['nesterov'] trust_coeff = group['trust_coeff'] eps = group['eps'] for p in group['params']: if p.grad is None: continue grad = p.grad # apply LARS LR adaptation, LARC clipping, weight decay # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py if weight_decay != 0 or group['always_adapt']: w_norm = p.norm(2.0) g_norm = grad.norm(2.0) trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps) # FIXME nested where required since logical and/or not working in PT XLA trust_ratio = torch.where( w_norm > 0, torch.where(g_norm > 0, trust_ratio, one_tensor), one_tensor, ) if group['trust_clip']: trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor) grad.add(p, alpha=weight_decay) grad.mul_(trust_ratio) # apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = torch.clone(grad).detach() else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(grad, alpha=1. - dampening) if nesterov: grad = grad.add(buf, alpha=momentum) else: grad = buf p.add_(grad, alpha=-group['lr']) return loss
def boundaries_detection(self, inputs): y_fwd, y_bwd, seq_len_y, *_ = self.forward(inputs) seq_mask = compute_mask(y_fwd, seq_len_y, batch_axis=0, sequence_axis=-1) return torch.minimum(y_fwd * seq_mask, y_bwd * seq_mask), seq_len_y
def cdf_tail(self, z: torch.Tensor, left_tail: bool = True) -> torch.Tensor: r""" Computes the quantile level alpha_tilde such that alpha_tilde = q^{-1}(z) if z is in the tail region = qk_x_l or qk_x_r if z is in the non-tail region Parameters ---------- z Observation, shape = (*batch_shape,) left_tail If True, compute alpha_tilde for the left tail Otherwise, compute alpha_tilde for the right tail Returns ------- alpha_tilde Corresponding quantile level, shape = (*batch_shape,) """ if left_tail: tail_a, tail_b, qk_x = self.tail_al, self.tail_bl, self.qk_x_l else: tail_a, tail_b, qk_x = self.tail_ar, self.tail_br, 1 - self.qk_x_r log_alpha_tilde = torch.minimum((z - tail_b) / tail_a, torch.log(qk_x)) alpha_tilde = torch.exp(log_alpha_tilde) return alpha_tilde if left_tail else 1 - alpha_tilde
def quantile_spline( self, alpha: torch.Tensor, dim: Optional[int] = None, ) -> torch.Tensor: # Refer to the description in quantile_internal qk_y = self.qk_y sk_x, delta_sk_x, delta_sk_y = ( self.sk_x, self.delta_sk_x, self.delta_sk_y, ) if dim is not None: qk_y = qk_y.unsqueeze(dim=0 if dim == 0 else -1) sk_x = sk_x.unsqueeze(dim=dim) delta_sk_x = delta_sk_x.unsqueeze(dim=dim) delta_sk_y = delta_sk_y.unsqueeze(dim=dim) if dim is None or dim == 0: alpha = alpha.unsqueeze(dim=-1) alpha = alpha.unsqueeze(dim=-1) spline_val = (alpha - sk_x) / delta_sk_x spline_val = torch.maximum( torch.minimum(spline_val, torch.ones_like(spline_val)), torch.zeros_like(spline_val), ) return qk_y + torch.sum(spline_val * delta_sk_y, dim=-1)
def streaming_perception(self, d, cd=None, ece=None, ce=None, x=None, rgb=None, **args): if self.training: w_pow_d, w_pow_e = self.prepare_weights() else: w_pow_d, w_pow_e = self.weights if cd is None: cd = (d > 0).float() outs1 = self.inner_model.streaming_perception(d, cd, ece, ce, x, rgb, **args) mng = plt.get_current_fig_manager() mng.window.state('zoomed') plt.tight_layout() cd = cd * torch.pow( torch.minimum(outs1['d'], d) / (d + self.eps), w_pow_d) e1 = outs1['e'] ce1 = torch.pow(outs1['ce'], w_pow_e) return self.inner_model.streaming_perception(d, cd, e1 * ce1, ce1, x, rgb, **args)
def _policy_update(self, advantage): for epoch in range(self._iteration_op_policy): loss = torch.zeros(1) len_data = 0 for n, traj in enumerate(self._data_buffer): len_data += len(traj) # per-batch normalization advantage_ = (advantage[n, :] - torch.mean( advantage[n, :])) / torch.std(advantage[n, :]) for t, experience in enumerate(traj): log_p = self._policy_network.log_prob( Tensor(experience.observation), Tensor(experience.action)) log_p_old = self._policy_old.log_prob( Tensor(experience.observation), Tensor(experience.action)).detach() ratio = torch.exp(log_p - log_p_old) tmp = torch.minimum( ratio, torch.clip(ratio, max=1 + self._eps_policy_clip, min=1 - self._eps_policy_clip)) loss += -tmp * advantage_[t] loss /= len_data print( f"policy loss {epoch + 1}/{self._iteration_op_policy}: {len_data * loss.detach()}" ) self._optimizer_policy.zero_grad() loss.backward() self._optimizer_policy.step() # copy old policy self._policy_old.load_state_dict(self._policy_network.state_dict())
def sample_cdf(z_vals, weights, det=False): bins = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) weights = weights + 1e-5 pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) if det: u = torch.linspace(0., 1., config.N_samples) u = torch.broadcast_to(u, list(cdf.shape[:-1]) + [config.N_samples]) else: u = torch.rand(cdf.shape[0], config.N_samples) u = u.cuda(cdf.device) idxs = torch.searchsorted(cdf, u, right=True) below = torch.maximum(torch.zeros_like(idxs), idxs - 1).long() above = torch.minimum(torch.ones_like(idxs) * (cdf.shape[-1] - 1), idxs).long() # idxs_g = torch.stack([below, above], -1) cdf_below = torch.gather(cdf, dim=1, index=below) cdf_above = torch.gather(cdf, dim=1, index=above) bin_below = torch.gather(bins, dim=1, index=below) bin_above = torch.gather(bins, dim=1, index=above) denom = cdf_above - cdf_below denom = torch.clamp(denom, 1e-5, 99999) t = (u - cdf_below) / denom samples = bin_below + t * (bin_above - bin_below) return samples
def forward(self, x, *args, **kwargs): b, device = x.shape[0], x.device D = np.cumprod(x.shape[1:])[-1] shape = x.shape t = torch.randint(1, self.T - self.tao, (b, ), device=device).long() # print(t.device, self.sqrt_cumprod_alpha.device, self.betas.device) alpha_n = extract(self.sqrt_cumprod_alpha, t, shape) beta_np1 = 1. - ( extract(self.sqrt_cumprod_alpha, t + self.tao, shape) / alpha_n)**2 delta_n = torch.sqrt(1. - alpha_n**2) noise = torch.randn_like(x) xn = alpha_n * x + delta_n * noise denoise_par = self.get_denoise_par(alpha_n, *args, **kwargs) epsilon_theta = self.denoise_fn(xn, *denoise_par) beta_n = torch.minimum(delta_n**2, beta_np1) * reshape( self.sigma_phi(xn), delta_n.shape) Cn = 0.25 * torch.log( delta_n**2 / beta_n) + 0.5 * (beta_n / delta_n**2 - 1.) # print(delta_n.shape, beta_n.shape, noise.shape, epsilon_theta.shape, beta_np1.shape, self.sigma_phi(xn).shape) term1 = 0.5 / (delta_n**2 - beta_n) * ( delta_n * noise - beta_n / delta_n * epsilon_theta)**2 # print(Cn.shape, term1.shape) loss1 = torch.mean(Cn.squeeze()) # loss2 = torch.mean(torch.sum(term1, dim=[-i for i in range(1, len(shape))])) loss2 = torch.mean(term1) # return torch.mean(Cn.squeeze() + torch.sum(term1, dim=[-i for i in range(1, len(shape))])) return loss1, loss2