def columnwise_clamp( X: Tensor, lower: Optional[Union[float, Tensor]] = None, upper: Optional[Union[float, Tensor]] = None, ) -> Tensor: r"""Clamp values of a Tensor in column-wise fashion (with support for t-batches). This function is useful in conjunction with optimizers from the torch.optim package, which don't natively handle constraints. If you apply this after a gradient step you can be fancy and call it "projected gradient descent". Args: X: The `b x n x d` input tensor. If 2-dimensional, `b` is assumed to be 1. lower: The column-wise lower bounds. If scalar, apply bound to all columns. upper: The column-wise upper bounds. If scalar, apply bound to all columns. Returns: The clamped tensor. """ min_bounds = _expand_bounds(lower, X) max_bounds = _expand_bounds(upper, X) if min_bounds is not None and max_bounds is not None: if torch.any(min_bounds > max_bounds): raise ValueError("Minimum values must be <= maximum values") Xout = X if min_bounds is not None: Xout = Xout.max(min_bounds) if max_bounds is not None: Xout = Xout.min(max_bounds) return Xout
def _step( self, positive_predicates: List[str], positive_observations: torch.FloatTensor, negative_predicates_list: List[List[str]], negative_observations_list: List[torch.FloatTensor], ) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Performs a forward pass of the model during training. Used in both training_step and validation_step, this function accepts a set of predicate names and performs a forward pass of the hypothesis generation training routine. This involves generating negative samples for each positive example and evaluating metrics that quantify the difference between the two. Args: positive_predicates: List of string names associated with positive batch positive_observations: Packed tensor containing info corresponding to positive_predicates. Shape: <seq_len> X <batch_size> X <dim> negative_predicates_list: List of predicate lists. negative_predicates_list[i] corresponds to the i'th negative sample batch. Each batch should be the same size as the positive batch. negative_observations_list: List of packed tensors. negative_observations_list[i] corresponds to the i'th negative sample. Returns: The first element is the loss tensor, used for back propagation. The second element is a dict containing all extra metrics. """ # Do positive checks assert isinstance(positive_observations, torch.Tensor), \ f"Err: positive_observations is {type(positive_observations)}" assert len(positive_observations.shape) == 3 _, actual_batch_size, actual_dim = positive_observations.shape assert len(positive_predicates) == actual_batch_size assert self.hparams.positives_per_batch == actual_batch_size assert self.hparams.dim == actual_dim # no negative checks assert len(negative_predicates_list) == len(negative_observations_list) for n_preds, n_obs in zip(negative_predicates_list, negative_observations_list): assert isinstance(n_obs, torch.Tensor) assert len(n_obs.shape) == 3 _, actual_batch_size, actual_dim = n_obs.shape assert len(n_preds) == actual_batch_size assert self.hparams.positives_per_batch == actual_batch_size assert self.hparams.dim == actual_dim positive_predictions = self.forward(positive_observations) # We cannot tolerate an error on a positive sample # An error occurs if any positive prediction is _not_ finite # Note that `~` is bitwise "not" for our boolean matrix if torch.any(~torch.isfinite(positive_predictions.detach().cpu())): print(positive_predicates) raise ValueError("Invalid positive sample") partial_losses = [] for negative_predicates, negative_observations in zip( negative_predicates_list, negative_observations_list): negative_predictions = self.forward(negative_observations) # We CAN tolerate an error on a negative sample if torch.any(~torch.isfinite(negative_predictions.detach().cpu())): # print debug info print("ERROR: Encountered an issue with a negative predicate:") print("Negative Predicate Scores:") print(negative_predictions) print("Negative Predicates") print(negative_predicates) else: partial_losses.append( self.loss_fn( positive_predictions, negative_predictions, positive_predictions.new_ones( len(positive_predictions)))) assert len( partial_losses) > 0, "Failure occurred on all negative batches." # End of batch loss = torch.mean(torch.stack(partial_losses)) return ( loss, dict( # pbar metrics ))
def train_epoch( train_loader: torch.utils.data.DataLoader, base_model: torch.nn.Module, classification_layer: torch.nn.Module, forg_layer: torch.nn.Module, epoch: int, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, callback: Optional[VisdomLogger], device: torch.device, args: Any): """ Trains the network for one epoch Parameters ---------- train_loader: torch.utils.data.DataLoader Iterable that loads the training set (x, y) tuples base_model: torch.nn.Module The model architecture that "extract features" from signatures classification_layer: torch.nn.Module The classification layer (from features to predictions of which user wrote the signature) forg_layer: torch.nn.Module The forgery prediction layer (from features to predictions of whether the signature is a forgery). Only used in args.forg = True epoch: int The current epoch (used for reporting) optimizer: torch.optim.Optimizer The optimizer (already initialized) lr_scheduler: torch.optim.lr_scheduler._LRScheduler The learning rate scheduler callback: VisdomLogger (optional) A callback to report the training progress device: torch.device The device (CPU or GPU) to use for training args: Namespace Extra arguments used for training: args.forg: bool Whether forgeries are being used for training args.lamb: float The weight used for the forgery loss (training with forgeries only) Returns ------- None """ step = 0 n_steps = len(train_loader) for batch in train_loader: x, y = batch[0], batch[1] x = torch.tensor(x, dtype=torch.float).to(device) y = torch.tensor(y, dtype=torch.long).to(device) yforg = torch.tensor(batch[2], dtype=torch.float).to(device) # Forward propagation features = base_model(x) if args.forg: if args.loss_type == 'L1': # Eq (3) in https://arxiv.org/abs/1705.05787 logits = classification_layer(features) class_loss = F.cross_entropy(logits, y) forg_logits = forg_layer(features).squeeze() forg_loss = F.binary_cross_entropy_with_logits( forg_logits, yforg) loss = (1 - args.lamb) * class_loss loss += args.lamb * forg_loss else: # Eq (4) in https://arxiv.org/abs/1705.05787 if torch.any(yforg == 0): logits = classification_layer(features[yforg == 0]) class_loss = F.cross_entropy(logits, y[yforg == 0]) else: class_loss = 0 forg_logits = forg_layer(features).squeeze() forg_loss = F.binary_cross_entropy_with_logits( forg_logits, yforg) loss = (1 - args.lamb) * class_loss loss += args.lamb * forg_loss else: # Eq (1) in https://arxiv.org/abs/1705.05787 logits = classification_layer(features) loss = class_loss = F.cross_entropy(logits, y) # Back propagation optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_value_(optimizer.param_groups[0]['params'], 10) # Update weights optimizer.step() # Logging if callback and step % 100 == 0: iteration = epoch + (step / n_steps) callback.scalar('class_loss', iteration, class_loss.detach()) pred = logits.argmax(1) if args.loss_type == 'L1': acc = y.eq(pred).float().mean() else: acc = y[yforg == 0].eq(pred[yforg == 0]).float().mean() callback.scalar('train_acc', epoch + (step / n_steps), acc.detach()) if args.forg: forg_pred = forg_logits > 0 forg_acc = yforg.long().eq(forg_pred.long()).float().mean() callback.scalar('forg_loss', iteration, forg_loss.detach()) callback.scalar('forg_acc', iteration, forg_acc.detach()) step += 1 lr_scheduler.step()
def infill_with_ilm(model, special_tokens_to_ids, x, num_infills=1, max_sequence_length=256, nucleus=0.95): _sep_id = special_tokens_to_ids['<|startofinfill|>'] _end_span_id = special_tokens_to_ids['<|endofinfill|>'] _special_ids = special_tokens_to_ids.values() # Make sure example doesn't already ends with [sep] if x[-1] == _sep_id: x = x[:-1] # Count number of blanks blank_idxs = [] for i, tok_id in enumerate(x): if tok_id in _special_ids: blank_idxs.append(i) k = len(blank_idxs) if k == 0: raise ValueError() # Decode until we have that many blanks with torch.no_grad(): device = next(model.parameters()).device context = torch.tensor(x + [_sep_id], dtype=torch.long, device=device).unsqueeze(0).repeat( num_infills, 1) terminated = [] while context.shape[0] > 0: logits = model(context)[0][:, -1] next_tokens = sample_from_logits(logits, nucleus=nucleus) context = torch.cat((context, next_tokens), dim=1) num_predicted_spans = (context == _end_span_id).long().sum(dim=1) terminate_expected = num_predicted_spans >= k terminate_toolong = torch.ones_like(context).long().sum( dim=1) >= max_sequence_length terminate = terminate_expected | terminate_toolong if torch.any(terminate): terminated_seqs = context[terminate, len(x) + 1:] terminated.extend( [list(s) for s in terminated_seqs.cpu().numpy()]) context = context[~terminate, :] # Collect generated spans generated_spans = [] for gen in terminated: spans = [] while _end_span_id in gen: spans.append(gen[:gen.index(_end_span_id)]) gen = gen[gen.index(_end_span_id) + 1:] while len(spans) < k: spans.append([]) generated_spans.append(spans) # Insert into context generated = [] for spans in generated_spans: context = copy.deepcopy(x) for i, j in enumerate(blank_idxs[::-1]): del context[j] context[j:j] = spans[k - 1 - i] spans = [item for sublist in spans for item in sublist] if spans not in generated: generated.append(spans) else: pass return generated
def main(config_path): """ """ # load config cfg = AttrDict.from_json_path(config_path) # make outputs dir out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) # initialize seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) logger.info(f"Experiment : {cfg.exp_name}") # set device if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') logger.info(f"Device set to {cfg.device}.") #------------------------------------------- # Make Dataset #------------------------------------------- data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0) dataset = public_SegICH_Dataset2D(data_info_df, cfg.path.data, augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items()], output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width)) #------------------------------------------- # Load FCDD Model #------------------------------------------- cfg_fcdd = AttrDict.from_json_path(cfg.fcdd_cfg_path) fcdd_net = FCDD_CNN_VGG(in_shape=(cfg_fcdd.net.in_channels, 256, 256), bias=cfg_fcdd.net.bias) loaded_state_dict = torch.load(cfg.fcdd_model_path, map_location=cfg.device) fcdd_net.load_state_dict(loaded_state_dict) fcdd_net = fcdd_net.to(cfg.device).eval() logger.info(f"FCDD model succesfully loaded from {cfg.fcdd_model_path}") # make FCDD object fcdd = FCDD(fcdd_net, batch_size=cfg.batch_size, num_workers=cfg.num_workers, device=cfg.device, print_progress=cfg.print_progress) #------------------------------------------- # Load Classifier Model #------------------------------------------- # Load Classifier if cfg.classifier_model_path is not None: cfg_classifier = AttrDict.from_json_path(os.path.join(cfg.classifier_model_path, 'config.json')) classifier = getattr(rn, cfg_classifier.net.resnet)(num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels) classifier_state_dict = torch.load(os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device) classifier.load_state_dict(classifier_state_dict) classifier = classifier.to(cfg.device) classifier.eval() logger.info(f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}") #------------------------------------------- # Generate Heat-Map for each slice #------------------------------------------- with torch.no_grad(): # make loader loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=False, worker_init_fn=lambda _: np.random.seed()) fcdd_net.eval() min_val, max_val = fcdd.get_min_max(loader, **cfg.heatmap_param) # computing and saving heatmaps out = dict(id=[], slice=[], label=[], ad_map_fn=[], ad_mask_fn=[], TP=[], TN=[], FP=[], FN=[], AUC=[], classifier_pred=[]) for b, data in enumerate(loader): im, mask, id, slice = data im = im.to(cfg.device).float() mask = mask.to(cfg.device).float() # get heatmap heatmap = fcdd.generate_heatmap(im, reception=cfg.heatmap_param.reception, std=cfg.heatmap_param.std, cpu=cfg.heatmap_param.cpu) # scaling heatmap = ((heatmap - min_val) / (max_val - min_val)).clamp(0,1) # Threshold ad_mask = torch.where(heatmap >= cfg.heatmap_threshold, torch.ones_like(heatmap, device=heatmap.device), torch.zeros_like(heatmap, device=heatmap.device)) # Compute CM tn, fp, fn, tp = batch_binary_confusion_matrix(ad_mask, mask.to(heatmap.device)) # Save heatmaps/mask map_fn, mask_fn = [], [] for i in range(im.shape[0]): # Save AD Map ad_map_fn = f"{id[i]}/{slice[i]}_map_anomalies.png" save_path = os.path.join(out_path, 'pred/', ad_map_fn) if not os.path.isdir(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) ad_map = heatmap[i].squeeze().cpu().numpy() io.imsave(save_path, img_as_ubyte(ad_map), check_contrast=False) # save ad_mask ad_mask_fn = f"{id[i]}/{slice[i]}_anomalies.bmp" save_path = os.path.join(out_path, 'pred/', ad_mask_fn) io.imsave(save_path, img_as_ubyte(ad_mask[i].squeeze().cpu().numpy()), check_contrast=False) map_fn.append(ad_map_fn) mask_fn.append(ad_mask_fn) # apply classifier ResNet-18 if cfg.classifier_model_path is not None: pred_score = nn.functional.softmax(classifier(im), dim=1)[:,1] # take columns of softmax of positive class as score clss_pred = torch.where(pred_score >= cfg.classification_threshold, torch.ones_like(pred_score, device=pred_score.device), torch.zeros_like(pred_score, device=pred_score.device)) else: clss_pred = [None]*im.shape[0] # Save Values out['id'] += id.cpu().tolist() out['slice'] += slice.cpu().tolist() out['label'] += mask.reshape(mask.shape[0], -1).max(dim=1)[0].cpu().tolist() out['ad_map_fn'] += map_fn out['ad_mask_fn'] += mask_fn out['TN'] += tn.cpu().tolist() out['FP'] += fp.cpu().tolist() out['FN'] += fn.cpu().tolist() out['TP'] += tp.cpu().tolist() out['AUC'] += [roc_auc_score(mask[i].cpu().numpy().ravel(), heatmap[i].cpu().numpy().ravel()) if torch.any(mask[i]>0) else 'None' for i in range(im.shape[0])] out['classifier_pred'] += clss_pred.cpu().tolist() if cfg.print_progress: print_progessbar(b, len(loader), Name='Heatmap Generation Batch', Size=100, erase=True) # make df and save as csv slice_df = pd.DataFrame(out) volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({'label':'max', 'TP':'sum', 'TN':'sum', 'FP':'sum', 'FN':'sum'}) slice_df['Dice'] = (2*slice_df.TP + 1) / (2*slice_df.TP + slice_df.FP + slice_df.FN + 1) volume_df['Dice'] = (2*volume_df.TP + 1) / (2*volume_df.TP + volume_df.FP + volume_df.FN + 1) logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}") # Save Scores and Config slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv')) logger.info(f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}") volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv')) logger.info(f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}") cfg.device = str(cfg.device) with open(os.path.join(out_path, 'config.json'), 'w') as f: json.dump(cfg, f) logger.info(f"Config file saved at {os.path.join(out_path, 'config.json')}")
def forward(self, *args): # pylint: disable=too-many-statements LOG.debug('loss for %s', self.field_names) x, t = args assert len(x) == 1 + 2 * self.n_vectors + self.n_scales x_intensity = x[0] x_regs = x[1:1 + self.n_vectors] x_spreads = x[1 + self.n_vectors:1 + 2 * self.n_vectors] x_scales = [] if self.n_scales: x_scales = x[1 + 2 * self.n_vectors:1 + 2 * self.n_vectors + self.n_scales] if self.n_scales == 0: t = t[:-self.n_vectors] # assume there are as many scales as vectors and remove them assert len(t) == 1 + self.n_vectors + self.n_scales target_intensity = t[0] target_regs = t[1:1 + self.n_vectors] target_scales = t[1 + self.n_vectors:] bce_masks = (target_intensity[:, :-1] + target_intensity[:, -1:]) > 0.5 if not torch.any(bce_masks): return None, None, None batch_size = x_intensity.shape[0] LOG.debug('batch size = %d', batch_size) bce_x_intensity = x_intensity bce_target_intensity = target_intensity[:, :-1] if self.bce_blackout: bce_x_intensity = bce_x_intensity[:, self.bce_blackout] bce_masks = bce_masks[:, self.bce_blackout] bce_target_intensity = bce_target_intensity[:, self.bce_blackout] LOG.debug('BCE: x = %s, target = %s, mask = %s', x_intensity.shape, bce_target_intensity.shape, bce_masks.shape) bce_target = torch.masked_select(bce_target_intensity, bce_masks) bce_weight = None if self.background_weight != 1.0: bce_weight = torch.ones_like(bce_target) bce_weight[bce_target == 0] = self.background_weight # print('TRIANGE loss', bce_x_intensity.shape, bce_masks.shape) ce_loss = torch.nn.functional.binary_cross_entropy_with_logits( torch.masked_select(bce_x_intensity, bce_masks), bce_target, weight=bce_weight, ) reg_losses = [None for _ in target_regs] reg_masks = target_intensity[:, :-1] > 0.5 if torch.any(reg_masks): weight = None if self.multiplicity_correction: assert len(target_regs) == 2 lengths = torch.norm(target_regs[0] - target_regs[1], dim=2) multiplicity = (lengths - 3.0) / self.independence_scale multiplicity = torch.clamp(multiplicity, min=1.0) multiplicity = torch.masked_select(multiplicity, reg_masks) weight = 1.0 / multiplicity reg_losses = [] for i, (x_reg, x_spread, target_reg) in enumerate(zip(x_regs, x_spreads, target_regs)): if hasattr(self.regression_loss, 'scale'): assert self.scales_to_kp is not None self.regression_loss.scale = torch.masked_select( torch.clamp(target_scales[i], 0.1, 1000.0), # pylint: disable=unsubscriptable-object reg_masks, ) reg_losses.append(self.regression_loss( torch.masked_select(x_reg[:, :, 0], reg_masks), torch.masked_select(x_reg[:, :, 1], reg_masks), torch.masked_select(x_spread, reg_masks), torch.masked_select(target_reg[:, :, 0], reg_masks), torch.masked_select(target_reg[:, :, 1], reg_masks), weight=(weight if weight is not None else 1.0) * 0.1, ) / 100.0 / batch_size) scale_losses = [] if x_scales: scale_losses = [ torch.nn.functional.l1_loss( torch.masked_select(x_scale, torch.isnan(target_scale) == 0), torch.masked_select(target_scale, torch.isnan(target_scale) == 0), reduction='sum', ) / 1000.0 / batch_size for x_scale, target_scale in zip(x_scales, target_scales) ] margin_losses = [None for _ in target_regs] if self.margin else [] if self.margin and torch.any(reg_masks): margin_losses = [] for i, (x_reg, target_reg) in enumerate(zip(x_regs, target_regs)): margin_losses.append(quadrant_margin_loss( torch.masked_select(x_reg[:, :, 0], reg_masks), torch.masked_select(x_reg[:, :, 1], reg_masks), torch.masked_select(target_reg[:, :, 0], reg_masks), torch.masked_select(target_reg[:, :, 1], reg_masks), torch.masked_select(target_reg[:, :, 2], reg_masks), torch.masked_select(target_reg[:, :, 3], reg_masks), torch.masked_select(target_reg[:, :, 4], reg_masks), torch.masked_select(target_reg[:, :, 5], reg_masks), ) / 100.0 / batch_size) return [ce_loss] + reg_losses + scale_losses + margin_losses
def forward(self, pyolos, targets, imgs_ts=None): ''' :param pyolos: torch.Size([32, 7, 13, 13]) cls-3,box-4 :param targets: :param imgs_ts: :return: ''' cfg = self.cfg device = pyolos.device batch, c, h, w = pyolos.shape pyolos = pyolos.view(batch, c, -1).permute(0, 2, 1) # cls-num_class, txywh-4, weight-1, gltrb-4 gdim = cfg.NUM_CLASSES + 4 + 1 + 4 gyolos = torch.empty((batch, h, w, gdim), device=device) # 每批会整体更新这里不需要赋0 for i, target in enumerate(targets): # batch遍历 gboxes_ltrb_b = target['boxes'] # ltrb glabels_b = target['labels'] gyolos[i] = fmatch4yolov1( gboxes_ltrb_b=gboxes_ltrb_b, glabels_b=glabels_b, grid=h, # 7 gdim=gdim, device=device, img_ts=imgs_ts[i], cfg=cfg, use_conf=False) '''可视化验证''' # if cfg.IS_VISUAL: # # conf-1, cls-1, box-4, weight-1 # gyolo_test = gyolos[i].clone() # torch.Size([32, 13, 13, 9]) # gyolo_test = gyolo_test.view(-1, gdim) # gconf_one = gyolo_test[:, 0] # mask_pos = gconf_one == 1 # [169] # # # torch.Size([169, 4]) # gtxywh = gyolo_test[:, 1 + cfg.NUM_CLASSES:1 + cfg.NUM_CLASSES + 4] # # # 这里是修复是 xy # _xy_grid = gtxywh[:, :2] + f_mershgrid(h, w, is_rowcol=False).to(device) # hw_ts = torch.tensor((h, w), device=device) # gtxywh[:, :2] = torch.true_divide(_xy_grid, hw_ts) # gtxywh = gtxywh[mask_pos] # gtxywh[:, 2:4] = torch.exp(gtxywh[:, 2:]) / cfg.IMAGE_SIZE[0] # # from f_tools.pic.enhance.f_data_pretreatment4pil import f_recover_normalization4ts # img_ts = f_recover_normalization4ts(imgs_ts[i]) # from torchvision.transforms import functional as transformsF # img_pil = transformsF.to_pil_image(img_ts).convert('RGB') # import numpy as np # img_np = np.array(img_pil) # f_show_od_np4plt(img_np, gboxes_ltrb=gboxes_ltrb_b.cpu() # , pboxes_ltrb=xywh2ltrb(gtxywh.cpu()), is_recover_size=True, # grids=(h, w)) # [32, 13, 13, 7] -> torch.Size([32, 169, 12]) gyolos = gyolos.view(batch, -1, gdim) # h*w gcls = gyolos[:, :, 0:cfg.NUM_CLASSES] # torch.Size([5, 169]) mask_pos_3d = gcls > 0 # torch.Size([32, 169, 3]) mask_neg_3d = gcls == 0 # [32, 169, 3] -> [32, 169] mask_pos_2d = torch.any(mask_pos_3d, dim=-1) # mask_pos = gconf == 1 # yolo1 gt 写死是1 nums_pos = (mask_pos_2d.sum(-1).to( torch.float)).clamp(min=torch.finfo(torch.float16).eps) pyolos_pos = pyolos[ mask_pos_2d] # torch.Size([32, 169, 13]) -> [nn, 13] gyolos_pos = gyolos[ mask_pos_2d] # torch.Size([32, 169, 13]) -> [nn, 13] ''' ---------------- 类别损失 ---------------- ''' # cls-num_class, txywh-4, weight-1, gltrb-4 pcls_sigmoid = pyolos[:, :, 0:cfg.NUM_CLASSES].sigmoid() gcls = gyolos[:, :, 0:cfg.NUM_CLASSES] # torch.Size([32, 169, 3]) # 正反比 1:169*3 # _loss_val = x_bce(pcls_sigmoid, gcls, reduction="none") # torch.Size([46, 3]) # l_cls_pos = ((_loss_val * mask_pos_3d).sum(-1).sum(-1) / nums_pos).mean() # l_cls_neg = ((_loss_val * mask_neg_3d).sum(-1).sum(-1) / nums_pos).mean() # ------------ conf-mse ------------''' 666666 # _loss_val = F.mse_loss(pconf_sigmoid, gconf, reduction="none") # 用MSE效果更好 # _loss_val = x_bce(pconf_sigmoid, gconf, reduction="none") # l_conf_pos = ((_loss_val * mask_pos_3d).sum(-1) / nums_pos).mean() * cfg.LOSS_WEIGHT[0] # l_conf_neg = ((_loss_val * mask_neg_3d).sum(-1) / nums_pos).mean() * cfg.LOSS_WEIGHT[1] # ------------ conf_ohem ap26_26 ------------''' # _loss_val = x_bce(pconf_sigmoid, gconf) # mask_ignore = torch.logical_not(torch.logical_or(mask_pos, mask_neg)) # mask_neg_hard = f_ohem(_loss_val, nums_pos * 3, mask_pos=mask_pos, mash_ignore=mask_ignore) # l_conf_pos = ((_loss_val * mask_pos).sum(-1) / nums_pos).mean() * 3 # 正例越多反例越多 # l_conf_neg = ((_loss_val * mask_neg_hard).sum(-1) / nums_pos).mean() * 3 # ------------ focalloss ------------ l_pos, l_neg = focalloss(pcls_sigmoid, gcls, mask_pos=mask_pos_2d, is_debug=True, alpha=0.75) l_cls_pos = (l_pos.sum(-1).sum(-1) / nums_pos).mean() * 30 l_cls_neg = (l_neg.sum(-1).sum(-1) / nums_pos).mean() * 30 ''' ----------------回归损失 xy采用bce wh采用mes----------------- ''' # ------------ mse+bce ------------ 666666 # conf-1, cls-num_class, txywh-4, weight-1, gltrb-4 # ptxty_sigmoid_pos = pyolos_pos[:, cfg.NUM_CLASSES:cfg.NUM_CLASSES + 2].sigmoid() # 这个需要归一化 # ptwth_pos = pyolos_pos[:, cfg.NUM_CLASSES + 2:cfg.NUM_CLASSES + 4] # # # cls-num_class, txywh-4, weight-1, gltrb-4 # # id = cfg.NUM_CLASSES + 4 +1 -1 # weight_pos = gyolos_pos[:, cfg.NUM_CLASSES + 4] # torch.Size([32, 845]) # gtxty_pos = gyolos_pos[:, cfg.NUM_CLASSES:cfg.NUM_CLASSES + 2] # [nn] # gtwth_pos = gyolos_pos[:, cfg.NUM_CLASSES + 2:cfg.NUM_CLASSES + 4] # # _loss_val = x_bce(ptxty_sigmoid_pos, gtxty_pos, reduction="none") # l_txty = (_loss_val.sum(-1) * weight_pos).mean() # _loss_val = F.mse_loss(ptwth_pos, gtwth_pos, reduction="none") # l_twth = (_loss_val.sum(-1) * weight_pos).mean() # ------------ iou损失 ------------ # 解码pxywh 计算预测与 GT 的 iou 作为 gconf ptxywh_pos = pyolos[:, :, cfg.NUM_CLASSES:cfg.NUM_CLASSES + 4] # 这个是批量解码 3D 故解码出来再筛选 zltrb_pos = boxes_decode4yolo1(ptxywh_pos, h, h, cfg)[mask_pos_2d] gltrb_pos = gyolos_pos[:, cfg.NUM_CLASSES + 4 + 1:cfg.NUM_CLASSES + 4 + 1 + 4] iou_zg = bbox_iou4one_2d(zltrb_pos, gltrb_pos, is_ciou=True) l_reg = (1 - iou_zg).mean() ''' ---------------- loss完成 ----------------- ''' # loss_total = l_cls_pos + l_cls_neg + l_txty + l_twth loss_total = l_cls_pos + l_cls_neg + l_reg log_dict = {} log_dict['l_total'] = loss_total.item() log_dict['l_cls_pos'] = l_cls_pos.item() log_dict['l_cls_neg'] = l_cls_neg.item() log_dict['l_reg'] = l_reg.item() # log_dict['l_xy'] = l_txty.item() # log_dict['l_wh'] = l_twth.item() log_dict['p_max'] = pcls_sigmoid.max().item() log_dict['p_min'] = pcls_sigmoid.min().item() log_dict['p_mean'] = pcls_sigmoid.mean().item() return loss_total, log_dict
def forward(self, # forward接受所有可能用到的参数 support_seqs, support_imgs, support_lens, support_labels, query_seqs, query_imgs, query_lens, query_labels, epoch=None, metric='euc', return_embeddings=False): embedded_support_seqs, embedded_query_seqs, \ embedded_support_imgs, embedded_query_imgs = self.embed(support_seqs, query_seqs, support_lens, query_lens, support_imgs, query_imgs) k, n, qk = self.TaskParams.k, self.TaskParams.n, self.TaskParams.qk # 直接使用seq和img的raw output进行fuse support_fused_features = self._fuse(embedded_support_seqs, embedded_support_imgs, fuse_dim=1) query_fused_features = self._fuse(embedded_query_seqs, embedded_query_imgs, fuse_dim=1) dim = support_fused_features.size(1) nClusters = n # 初始类簇的数量等于类数量 nInitialClusters = nClusters # 此处设定batch=1 support_labels = support_labels.unsqueeze(0) query_labels = query_labels.unsqueeze(0) support_fused_features = support_fused_features.view(n * k, -1).unsqueeze(0) query_fused_features = query_fused_features.view(qk, -1).unsqueeze(0) # create probabilities for points # _, idx = np.unique(batch.y_train.squeeze().data.cpu().numpy(), return_inverse=True) prob_support = one_hot(support_labels, nClusters).cuda() # 将属于类簇的概率初始化为标签的one-hot # make initial radii for labeled clusters bsize = support_fused_features.size()[0] radii = t.ones(bsize, nClusters).cuda() # * t.exp(self.Sigma) # 初始半径由log_sigma_l初始化(该参数可学习) if self.Sigma is not None: radii *= t.exp(self.Sigma) cluster_labels = t.arange(0, nClusters).cuda().long() # compute initial prototypes from labeled examples # 由于初始时,共有类别个类簇,而且类簇的分配系数是one-hot,因此初始类簇就是类中心 # shape: [batch, cluster, dim] protos = self._compute_protos(support_fused_features, prob_support) # estimate lamda # lamda = self.estimate_lambda(protos.data, False) # loop for a given number of clustering steps for ii in range(self.NumClusterSteps): # protos = protos.data # iterate over labeled examples to reassign first for i, ex in enumerate(support_fused_features[0]): # 找到样本label对应的cluster的index idxs = t.nonzero(support_labels[0, i] == cluster_labels)[0] # TODO: 取0? #**************************************************************************** # 计算与标签对应的类簇的距离(由于其他不对应的类簇的距离都是正无穷,求min时直接可忽略) # distances = self._compute_distances(protos[:, idxs, :], ex.data) # if t.min(distances) > lamda: #**************************************************************************** distances = self._compute_distances(protos,ex) # 如果发现离自己最近的cluster不是自己的类的cluster,就直接增加一个cluster if not t.any(t.min(distances,dim=1).indices==idxs).item(): nClusters, protos, radii = self._add_cluster(nClusters, protos, radii, cluster_type='labeled', ex=ex.data) cluster_labels = t.cat([cluster_labels, support_labels[0, [i]].data], dim=0) # 将样本标签设定为类簇标签 # perform partial reassignment based on newly created labeled clusters if nClusters > nInitialClusters: support_targets = support_labels.data[0, :, None] == cluster_labels # 找到每个样本实际对应的类簇(每一行是每个样本对应的类簇bool) prob_support = assign_cluster_radii_limited(protos, support_fused_features, radii, support_targets) # 样本属于每个类簇的概率 nTrainClusters = nClusters protos = protos.cuda() protos = self._compute_protos(support_fused_features, prob_support) protos, radii, cluster_labels = self.delete_empty_clusters(protos, prob_support, radii, cluster_labels) # 计算query的类簇logits logits = compute_logits_radii(protos, query_fused_features, radii, use_sigma=self.Sigma is not None).squeeze() # convert class targets into indicators for supports in each class labels = query_labels # batch.y_test.data labels[labels >= nInitialClusters] = -1 support_targets = labels[0, :, None] == cluster_labels # 寻找查询集样本的标签对应的类簇 loss = self.loss(logits, support_targets, cluster_labels) # support_targets: 查询样本标签对应的类簇指示; suppott_labels: 类簇的标签 # map support predictions back into classes to check accuracy _, support_preds = t.max(logits.data, dim=1) y_pred = cluster_labels[support_preds] return { "logits": None, "loss": loss, "predicts": y_pred }
def decode_beam_batch(self, bodies, beam_size=3, max_output_length=100, sample=False, show_each_beam=False): if self.mode != 'eval': print("BEWARE. Model is not in eval mode.") self.eval() ## << Surely you are not training with beam decode? batch_size = len(bodies) N = batch_size * beam_size inputs = self.preprocess_input(bodies) next_words = torch.LongTensor([self.tokenizer.start_id] * N).to( self.device).unsqueeze(1) build_up = None scores = torch.zeros((N)).to(self.device) one_every_k = torch.FloatTensor([1] + [0] * (beam_size - 1)).repeat( batch_size * beam_size).to(self.device) # Sometimes, we process the same input, as we run it once as a sampled, and once as an argmax, in which case we should reuse the computation _, input_past = self.model(input_ids=inputs, past_key_values=None) input_past = [ torch.repeat_interleave(p, repeats=beam_size, dim=1) for p in input_past ] past = input_past while build_up is None or ( build_up.shape[1] < max_output_length and not all([self.tokenizer.end_id in build for build in build_up])): logits, past = self.model(input_ids=next_words, past_key_values=past) probs = torch.nn.functional.softmax(logits, dim=2).squeeze(1) logprobs = torch.nn.functional.log_softmax(logits, dim=2) if sample: all_selects = torch.multinomial(probs, beam_size).unsqueeze(1) else: _, all_selects = torch.topk(logprobs, k=beam_size, dim=2) if build_up is not None: not_finished = (1 - torch.any(build_up == self.tokenizer.end_id, dim=1).float()).to(self.device) else: not_finished = torch.ones_like(scores, dtype=torch.float, device=self.device) expanded_not_finished = torch.repeat_interleave(not_finished, repeats=beam_size) expanded_score = torch.repeat_interleave( scores, repeats=beam_size) # This should be batch_size * beam_size² added_score = logprobs[ torch.repeat_interleave(torch.arange(N), repeats=beam_size), 0, all_selects.view(-1)] expanded_score += (expanded_not_finished * added_score) # We don't want you to select from finished beams expanded_score -= (1 - expanded_not_finished) * ( 1 - one_every_k) * 1000.0 batched_scores = expanded_score.view(batch_size, -1) if build_up is None: choices = torch.arange(beam_size, device=self.device).repeat(batch_size) batched_choices = choices.view(batch_size, beam_size) else: _, batched_choices = torch.topk( batched_scores, k=beam_size, dim=1) # Going from k² choices per element to k choices. batched_tracks = batched_choices / beam_size tracks = beam_size * torch.repeat_interleave( torch.arange(batch_size), repeats=beam_size).to( self.device) + batched_tracks.view(-1) tracks = list(tracks) selected_scores = batched_scores[torch.repeat_interleave( torch.arange(batch_size), repeats=beam_size), batched_choices.view(-1)] # Figure out the kept words to be added to the build-up per_batch_selects = all_selects.view(batch_size, -1) next_words = per_batch_selects[torch.repeat_interleave( torch.arange(batch_size), repeats=beam_size), batched_choices.view(-1)] next_words = next_words.unsqueeze(1) # [BOOKKEEPING] Going from k² to k options at each time means we have to swap all the caches around: past, build-up if build_up is not None: ####### if show_each_beam is True: print(build_up) building_output = [] for beams in build_up: out_ = [ self.tokenizer.decode(beam.tolist()) + "END" for beam in beams ] out_ = [S[:S.index("END")] for S in out_] building_output.append(out_) for output, score in zip(building_output, selected_scores): print("".join(output), "|%.2f" % float(score)) build_up = build_up[tracks, :] past = [p[:, tracks, :] for p in past] # Update the latest scores, and the current_build if build_up is None: build_up = next_words else: build_up = torch.cat((build_up, next_words), dim=1) scores = selected_scores.view(-1) batched_build_up = build_up.view(batch_size, beam_size, -1) batched_scores = scores.view(batch_size, -1) # torch.cuda.empty_cache() outputs = [] for beams in batched_build_up: out_beams = [ self.tokenizer.decode(beam.tolist()) + "END" for beam in beams ] out_beams = [S[:S.index("END")] for S in out_beams] outputs.append(out_beams) return outputs, batched_scores.tolist()
def decode_batch(self, bodies, special_append=None, max_output_length=100, sample=False, return_scores=False, return_logprobs=False, input_past=None): N = len(bodies) current = torch.LongTensor([self.tokenizer.start_id] * N).to( self.device).unsqueeze(1) build_up = None scores = torch.zeros((N)).to(self.device) total_logprobs = [] # Sometimes, we process the same input, as we run it once as a sampled, and once as an argmax, in which case we should reuse the computation if input_past is None: inputs = self.preprocess_input(bodies, special_append) _, input_past = self.model(input_ids=inputs, past_key_values=None) past = input_past while build_up is None or ( build_up.shape[1] < max_output_length and not all([self.tokenizer.end_id in build for build in build_up])): logits, past = self.model(input_ids=current, past_key_values=past) probs = torch.nn.functional.softmax(logits, dim=2).squeeze(1) logprobs = torch.nn.functional.log_softmax(logits, dim=2) if sample: current = torch.multinomial(probs, 1) else: current = torch.argmax(logprobs, dim=2) if build_up is None: build_up = current else: build_up = torch.cat((build_up, current), dim=1) if return_logprobs: selected_logprobs = logprobs[torch.arange(N), 0, current.squeeze()].unsqueeze(1) total_logprobs.append(selected_logprobs) not_finished = (1 - torch.any(build_up == self.tokenizer.end_id, dim=1).float()).to(self.device) scores += not_finished * logprobs[torch.arange(N), :, current.squeeze(1)].squeeze() end_id = self.tokenizer.end_id build_up = [build.tolist() for build in build_up] end_indices = [ max_output_length + 1 if end_id not in build else build.index(end_id) for build in build_up ] outputs = [self.tokenizer.decode(build) + "END" for build in build_up] outputs = [S[:S.index("END")] for S in outputs] if return_logprobs: return outputs, torch.cat(total_logprobs, dim=1), build_up, input_past, end_indices elif return_scores: return outputs, scores.tolist() else: return outputs, end_indices
def __beaver_protocol(op, x, y, *args, **kwargs): """Performs Beaver protocol for additively secret-shared tensors x and y 1. Obtain uniformly random sharings [a],[b] and [c] = [a * b] 2. Additively hide [x] and [y] with appropriately sized [a] and [b] 3. Open ([epsilon] = [x] - [a]) and ([delta] = [y] - [b]) 4. Return [z] = [c] + (epsilon * [b]) + ([a] * delta) + (epsilon * delta) """ assert op in { "mul", "matmul", "conv1d", "conv2d", "conv_transpose1d", "conv_transpose2d", } if x.device != y.device: raise ValueError( f"x lives on device {x.device} but y on device {y.device}") provider = crypten.mpc.get_default_provider() a, b, c = provider.generate_additive_triple(x.size(), y.size(), op, device=x.device, *args, **kwargs) from .arithmetic import ArithmeticSharedTensor if crypten.mpc.config.active_security: """ Reference: "Multiparty Computation from Somewhat Homomorphic Encryption" Link: https://eprint.iacr.org/2011/535.pdf """ f, g, h = provider.generate_additive_triple(x.size(), y.size(), op, device=x.device, *args, **kwargs) t = ArithmeticSharedTensor.PRSS(a.size(), device=x.device) t_plain_text = t.get_plain_text() rho = (t_plain_text * a - f).get_plain_text() sigma = (b - g).get_plain_text() triples_check = t_plain_text * c - h - sigma * f - rho * g - rho * sigma triples_check = triples_check.get_plain_text() if torch.any(triples_check != 0): raise ValueError("Beaver Triples verification failed!") # Vectorized reveal to reduce rounds of communication epsilon, delta = ArithmeticSharedTensor.reveal_batch([x - a, y - b]) # z = c + (a * delta) + (epsilon * b) + epsilon * delta c._tensor += getattr(torch, op)(epsilon, b._tensor, *args, **kwargs) c._tensor += getattr(torch, op)(a._tensor, delta, *args, **kwargs) c += getattr(torch, op)(epsilon, delta, *args, **kwargs) return c
def initialize_q_batch_nonneg(X: Tensor, Y: Tensor, n: int, eta: float = 1.0, alpha: float = 1e-4) -> Tensor: r"""Heuristic for selecting initial conditions for non-neg. acquisition functions. This function is similar to `initialize_q_batch`, but designed specifically for acquisition functions that are non-negative and possibly zero over large areas of the feature space (e.g. qEI). All samples for which `Y < alpha * max(Y)` will be ignored (assuming that `Y` contains at least one positive value). Args: X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim. feature space. Typically, these are generated using qMC. Y: A tensor of `b` outcomes associated with the samples. Typically, this is the value of the batch acquisition function to be maximized. n: The number of initial condition to be generated. Must be less than `b`. eta: Temperature parameter for weighting samples. alpha: The threshold (as a fraction of the maximum observed value) under which to ignore samples. All input samples for which `Y < alpha * max(Y)` will be ignored. Returns: A `n x q x d` tensor of `n` `q`-batch initial conditions. Example: # To get `n=10` starting points of q-batch size `q=3` for model with `d=6`: >>> qEI = qExpectedImprovement(model, best_f=0.2) >>> Xrnd = torch.rand(500, 3, 6) >>> Xinit = initialize_q_batch(Xrnd, qEI(Xrnd), 10) """ n_samples = X.shape[0] if n > n_samples: raise RuntimeError( "n cannot be larger than the number of provided samples") elif n == n_samples: return X max_val, max_idx = torch.max(Y, dim=0) if torch.any(max_val <= 0): warnings.warn( "All acquisition values for raw sampled points are nonpositive, so " "initial conditions are being selected randomly.", BadInitialCandidatesWarning, ) return X[torch.randperm(n=n_samples, device=X.device)][:n] # make sure there are at least `n` points with positive acquisition values pos = Y > 0 num_pos = pos.sum().item() if num_pos < n: # select all positive points and then fill remaining quota with randomly # selected points remaining_indices = (~pos).nonzero().view(-1) rand_indices = torch.randperm(remaining_indices.shape[0], device=Y.device) sampled_remaining_indices = remaining_indices[rand_indices[:n - num_pos]] pos[sampled_remaining_indices] = 1 return X[pos] # select points within alpha of max_val, iteratively decreasing alpha by a # factor of 10 as necessary alpha_pos = Y >= alpha * max_val while alpha_pos.sum() < n: alpha = 0.1 * alpha alpha_pos = Y >= alpha * max_val alpha_pos_idcs = torch.arange(len(Y), device=Y.device)[alpha_pos] weights = torch.exp(eta * (Y[alpha_pos] / max_val - 1)) idcs = alpha_pos_idcs[torch.multinomial(weights, n)] if max_idx not in idcs: idcs[-1] = max_idx return X[idcs]
def forward( ctx, representation_tree, dtype, device, matrix_shape, batch_shape=torch.Size(), inv_quad=False, logdet=False, probe_vectors=None, probe_vector_norms=None, *args, ): """ *args - The arguments representing the PSD matrix A (or batch of PSD matrices A) If self.inv_quad is true, the first entry in *args is inv_quad_rhs (Tensor) - the RHS of the matrix solves. Returns: - (Scalar) The inverse quadratic form (or None, if self.inv_quad is False) - (Scalar) The log determinant (or None, self.if logdet is False) """ if not (inv_quad or logdet): raise RuntimeError( "Either inv_quad or logdet must be true (or both)") ctx.representation_tree = representation_tree ctx.dtype = dtype ctx.device = device ctx.matrix_shape = matrix_shape ctx.batch_shape = batch_shape ctx.inv_quad = inv_quad ctx.logdet = logdet matrix_args = None inv_quad_rhs = None if ctx.inv_quad: matrix_args = args[1:] inv_quad_rhs = args[0] else: matrix_args = args # Get closure for matmul lazy_tsr = ctx.representation_tree(*matrix_args) with torch.no_grad(): preconditioner, precond_lt, logdet_correction = lazy_tsr._preconditioner( ) ctx.preconditioner = preconditioner if (probe_vectors is None or probe_vector_norms is None) and logdet: num_random_probes = settings.num_trace_samples.value() if preconditioner is None: if settings.deterministic_probes.on(): warnings.warn( "Deterministic probes will currently work only if you aren't training multiple independent" " models simultaneously.", UserWarning, ) if settings.deterministic_probes.probe_vectors is None: probe_vectors = torch.empty(matrix_shape[-1], num_random_probes, dtype=dtype, device=device) probe_vectors.bernoulli_().mul_(2).add_(-1) settings.deterministic_probes.probe_vectors = probe_vectors else: probe_vectors = settings.deterministic_probes.probe_vectors else: probe_vectors = torch.empty(matrix_shape[-1], num_random_probes, dtype=dtype, device=device) probe_vectors.bernoulli_().mul_(2).add_(-1) probe_vector_norms = torch.norm(probe_vectors, 2, dim=-2, keepdim=True) if batch_shape is not None: probe_vectors = probe_vectors.expand( *batch_shape, matrix_shape[-1], num_random_probes) probe_vector_norms = probe_vector_norms.expand( *batch_shape, 1, num_random_probes) else: # When preconditioning, probe vectors must be drawn from N(0, P) if settings.deterministic_probes.on(): # NOTE: calling precond_lt.root_decomposition() is expensive # because it requires Lanczos # We don't have any other choice for when we want to use deterministic probes, however if precond_lt.size()[-2:] == torch.Size([1, 1]): covar_root = precond_lt.evaluate().sqrt() else: covar_root = precond_lt.root_decomposition().root warnings.warn( "Deterministic probes will currently work only if you aren't training multiple independent" " models simultaneously.", UserWarning, ) base_samples = settings.deterministic_probes.probe_vectors if base_samples is None or covar_root.size( -1) != base_samples.size(-2): base_samples = torch.randn( *precond_lt.batch_shape, covar_root.size(-1), num_random_probes, dtype=precond_lt.dtype, device=precond_lt.device, ) settings.deterministic_probes.probe_vectors = base_samples probe_vectors = covar_root.matmul(base_samples).permute( -1, *range(precond_lt.dim() - 1)) else: probe_vectors = precond_lt.zero_mean_mvn_samples( num_random_probes) probe_vectors = probe_vectors.unsqueeze(-2).transpose( 0, -2).squeeze(0).transpose(-2, -1).contiguous() probe_vector_norms = torch.norm(probe_vectors, p=2, dim=-2, keepdim=True) probe_vectors = probe_vectors.div(probe_vector_norms) ctx.probe_vectors = probe_vectors ctx.probe_vector_norms = probe_vector_norms if ctx.logdet and not ctx.probe_vectors.numel(): raise RuntimeError( "Probe vectors were not supplied for logdet computation") # Collect terms for LinearCG # We use LinearCG for both matrix solves and for stochastically estimating the log det rhs_list = [] num_random_probes = 0 num_inv_quad_solves = 0 # RHS for logdet if ctx.logdet: rhs_list.append(ctx.probe_vectors) num_random_probes = ctx.probe_vectors.size(-1) # RHS for inv_quad ctx.is_vector = False if ctx.inv_quad: if inv_quad_rhs.ndimension() == 1: inv_quad_rhs = inv_quad_rhs.unsqueeze(-1) ctx.is_vector = True rhs_list.append(inv_quad_rhs) num_inv_quad_solves = inv_quad_rhs.size(-1) # Perform solves (for inv_quad) and tridiagonalization (for estimating logdet) rhs = torch.cat(rhs_list, -1) t_mat = None if ctx.logdet and settings.skip_logdet_forward.off(): solves, t_mat = lazy_tsr._solve(rhs, preconditioner, num_tridiag=num_random_probes) else: solves = lazy_tsr._solve(rhs, preconditioner, num_tridiag=0) # Final values to return logdet_term = torch.zeros(lazy_tsr.batch_shape, dtype=ctx.dtype, device=ctx.device) inv_quad_term = torch.zeros(lazy_tsr.batch_shape, dtype=ctx.dtype, device=ctx.device) # Compute logdet from tridiagonalization if ctx.logdet and settings.skip_logdet_forward.off(): if torch.any(torch.isnan(t_mat)).item(): logdet_term = torch.tensor(float("nan"), dtype=ctx.dtype, device=ctx.device) else: if ctx.batch_shape is None: t_mat = t_mat.unsqueeze(1) eigenvalues, eigenvectors = lanczos_tridiag_to_diag(t_mat) slq = StochasticLQ() (logdet_term, ) = slq.evaluate(ctx.matrix_shape, eigenvalues, eigenvectors, [lambda x: x.log()]) # Add correction if logdet_correction is not None: logdet_term = logdet_term + logdet_correction # Extract inv_quad solves from all the solves if ctx.inv_quad: inv_quad_solves = solves.narrow(-1, num_random_probes, num_inv_quad_solves) inv_quad_term = (inv_quad_solves * inv_quad_rhs).sum(-2) ctx.num_random_probes = num_random_probes ctx.num_inv_quad_solves = num_inv_quad_solves to_save = list(matrix_args) + [solves] ctx.save_for_backward(*to_save) if settings.memory_efficient.off(): ctx._lazy_tsr = lazy_tsr return inv_quad_term, logdet_term
def sinc_inv(x): usetaylor = (x.abs()<thresh) texpand = 1+(1/6)*x**2 +(7/360)*x**4 assert not torch.any(torch.isinf(texpand)|torch.isnan(texpand)),'sincinv texpand inf'+torch.any(torch.isinf(texpand)) return torch.where(usetaylor,texpand,x/x.sin())
def _test_no_pad_and_pad(self, no_pad_features, pad_features): tokenizer = BertTokenizer(self.vocab_file) data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) batch = data_collator(no_pad_features) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) batch = data_collator(pad_features) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8) batch = data_collator(no_pad_features) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16))) self.assertEqual(batch["labels"].shape, torch.Size((2, 16))) batch = data_collator(pad_features) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16))) self.assertEqual(batch["labels"].shape, torch.Size((2, 16))) tokenizer._pad_token = None data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) with self.assertRaises(ValueError): # Expect error due to padding token missing data_collator(pad_features) set_seed(42) # For reproducibility tokenizer = BertTokenizer(self.vocab_file) data_collator = DataCollatorForLanguageModeling(tokenizer) batch = data_collator(no_pad_features) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) masked_tokens = batch["input_ids"] == tokenizer.mask_token_id self.assertTrue(torch.any(masked_tokens)) self.assertTrue( all(x == -100 for x in batch["labels"][~masked_tokens].tolist())) batch = data_collator(pad_features) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) masked_tokens = batch["input_ids"] == tokenizer.mask_token_id self.assertTrue(torch.any(masked_tokens)) self.assertTrue( all(x == -100 for x in batch["labels"][~masked_tokens].tolist())) data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8) batch = data_collator(no_pad_features) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16))) self.assertEqual(batch["labels"].shape, torch.Size((2, 16))) masked_tokens = batch["input_ids"] == tokenizer.mask_token_id self.assertTrue(torch.any(masked_tokens)) self.assertTrue( all(x == -100 for x in batch["labels"][~masked_tokens].tolist())) batch = data_collator(pad_features) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16))) self.assertEqual(batch["labels"].shape, torch.Size((2, 16))) masked_tokens = batch["input_ids"] == tokenizer.mask_token_id self.assertTrue(torch.any(masked_tokens)) self.assertTrue( all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
def sinkhorn_knopp(a, b, C, reg=1e-1, maxIter=1000, stopThr=1e-9, verbose=True, log=True, warm_start=None, eval_freq=10, print_freq=200, **kwargs): """ Solve the entropic regularization optimal transport The input should be PyTorch tensors The function solves the following optimization problem: .. math:: \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) s.t. \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - C is the (ns,nt) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are target and source measures (sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]. Parameters ---------- a : torch.tensor (na,) samples measure in the target domain b : torch.tensor (nb,) samples in the source domain C : torch.tensor (na,nb) loss matrix reg : float Regularization term > 0 maxIter : int, optional Max number of iterations stopThr : float, optional Stop threshol on error ( > 0 ) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- gamma : (na x nb) torch.tensor Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters References ---------- [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 See Also -------- """ device = a.device na, nb = C.shape assert na >= 1 and nb >= 1, 'C needs to be 2d' assert na == a.shape[0] and nb == b.shape[ 0], "Shape of a or b does't match that of C" assert reg > 0, 'reg should be greater than 0' assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0' # unnecessary check for our special case if log: log = {'err': []} if warm_start is not None: u = warm_start['u'] v = warm_start['v'] else: u = torch.ones(na, dtype=a.dtype).to(device) / na v = torch.ones(nb, dtype=b.dtype).to(device) / nb K = torch.empty(C.shape, dtype=C.dtype).to(device) torch.div(C, -reg, out=K) torch.exp(K, out=K) b_hat = torch.empty(b.shape, dtype=C.dtype).to(device) it = 1 err = 1 # allocate memory beforehand KTu = torch.empty(v.shape, dtype=v.dtype).to(device) Kv = torch.empty(u.shape, dtype=u.dtype).to(device) while (err > stopThr and it <= maxIter): upre, vpre = u, v torch.matmul(u, K, out=KTu) v = torch.div(b, KTu + M_EPS) torch.matmul(K, v, out=Kv) u = torch.div(a, Kv + M_EPS) if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \ torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)): print('Warning: numerical errors at iteration', it) u, v = upre, vpre break if log and it % eval_freq == 0: # we can speed up the process by checking for the error only all # the eval_freq iterations # below is equivalent to: # b_hat = torch.sum(u.reshape(-1, 1) * K * v.reshape(1, -1), 0) # but with more memory efficient b_hat = torch.matmul(u, K) * v err = (b - b_hat).pow(2).sum().item() # err = (b - b_hat).abs().sum().item() log['err'].append(err) if verbose and it % print_freq == 0: print('iteration {:5d}, constraint error {:5e}'.format(it, err)) it += 1 if log: log['u'] = u log['v'] = v log['alpha'] = reg * torch.log(u + M_EPS) log['beta'] = reg * torch.log(v + M_EPS) # transport plan P = u.reshape(-1, 1) * K * v.reshape(1, -1) if log: return P, log else: return P
for gid in gids: # 每一张图 mask_gid = datas_dt[:, 0] == gid datas_dt_gid = datas_dt[mask_gid] datas_gt_gid = datas_gt[datas_gt[:, 0] == gid] ious = calc_iou4ts(datas_dt_gid[:, 3:7], datas_gt_gid[:, 2:6]) # print(ious) for i in range(len(datas_gt_gid)): cls = datas_gt_gid[i][1] # clses.add(int(cls.item())) # 类别cls-1d mask_cls = datas_dt_gid[:, 1] == cls mask_nomatch = datas_dt_gid[:, -1] != 1 mask_iou = ious[:, i] > iou_std _mask = torch.logical_and(mask_cls, mask_nomatch) _mask = torch.logical_and(_mask, mask_iou) if torch.any(_mask): # 分数 index_score = datas_dt_gid[:, 2][_mask].max(0)[1] # datas_dt_gid[index_score, -1] = 1 dim0 = torch.where(mask_gid) datas_dt[dim0[0][index_score], -1] = 1 if not debug: continue for c in clses: # 每个图片的每个类的 tp情况 debug num_gt = (datas_gt_gid[:, 1] == c).sum() mask_gid_cls = datas_dt_gid[:, 1] == c tp = (torch.logical_and(mask_gid_cls, datas_dt_gid[:, -1] == 1)).sum() fp = mask_gid_cls.sum() - tp fn = num_gt - tp print('gid=%s, cls=%s, tp=%s, fp=%s, fn=%s' %
def generate_batch(self, batch, gt_prefix=1, enc_state=None, whole_policy=False, special_actions=None, temperature=1, acpt_range=None): # This is to ensure we can stop at EOS for stateful models assert batch.size == 1 # Run the model policy, p_mean, p_logstd = self.model(batch.encoder_intent, batch.encoder_price, batch.encoder_pmask, batch.encoder_dianum) # print('policy is:', policy) # Get embeddings of target # tgt_emb = self.model.encoder.embeddings(batch.target_intent) # tgt_emb = torch.cat([tgt_emb, batch.target_price], ) # policy.sub_(policy.max(1, keepdim=True)[0].expand(policy.size(0), policy.size(1))) # policy[batch.policy_mask == 0] = -100 policy.sub_(policy.max(1, keepdim=True)[0].expand(-1, policy.size(1))) policy = policy * temperature # mask = batch.policy_mask # policy[mask == 0] = -100. # print(batch.policy_mask) # Avoid policy equal to zero ( + 1e-6 ) p_exp = (policy.exp() + 1e-6).mul(batch.policy_mask) # p_exp = (policy.exp() + 1e-6) # if torch.sum(p_exp).item() == 0: # p_exp += torch.ones_like(p_exp).mul(batch.policy_mask) policy = p_exp / (torch.sum(p_exp, keepdim=True, dim=1)) if torch.any(policy < 0) or torch.any(torch.isnan(policy)): print('lots of errors: ', p_exp, batch.policy_mask) intent = torch.multinomial(policy, 1).squeeze(1) # (batch_size,) # Use Normal distribution with constant variance as policy on price # price = p_mean + p_logstd.exp()*0.1 * torch.randn_like(p_mean) price = p_mean + (1.1 - temperature) * torch.randn_like(p_mean) # price = p_mean # print(torch.cat([price.view(-1,1), p_mean.view(-1,1), p_logstd.view(-1,1)], dim=1)) # price = price + LFSampler.var_for_price * torch.randn_like(price).abs() # Use rule for Supervised learning agent if acpt_range is not None: assert len(acpt_range) == 2 # Check if last action is offer if batch.encoder_intent[0, -1] in self.offer: offer_price = batch.encoder_price[0, 0, -1].item() policy = policy * 0 if (acpt_range[0] <= offer_price) and (offer_price <= acpt_range[1]): # Accept act_idx = self.acc_or_rej[0] else: act_idx = self.acc_or_rej[1] intent[0] = act_idx policy[0, act_idx] = 1 # TODO: Not correct, for multiple data. if intent not in self.price_actions: price = None # print('gen output: ',policy, price) ret = { "intent": intent, "price": price, "policy": policy, "price_mean": p_mean, "price_logstd": p_logstd, } ret["batch"] = batch # ret["policies"] = policies # ret["probability"] = probs return ret
def assign_one_hot_gt_indices(self, is_bbox_in_gt_core, is_bbox_in_gt_shadow, gt_priority=None): """Assign only one gt index to each prior box. Gts with large gt_priority are more likely to be assigned. Args: is_bbox_in_gt_core (Tensor): Bool tensor indicating the bbox center is in the core area of a gt (e.g. 0-0.2). Shape: (num_prior, num_gt). is_bbox_in_gt_shadow (Tensor): Bool tensor indicating the bbox center is in the shadowed area of a gt (e.g. 0.2-0.5). Shape: (num_prior, num_gt). gt_priority (Tensor): Priorities of gts. The gt with a higher priority is more likely to be assigned to the bbox when the bbox match with multiple gts. Shape: (num_gt, ). Returns: tuple: Returns (assigned_gt_inds, shadowed_gt_inds). - assigned_gt_inds: The assigned gt index of each prior bbox \ (i.e. index from 1 to num_gts). Shape: (num_prior, ). - shadowed_gt_inds: shadowed gt indices. It is a tensor of \ shape (num_ignore, 2) with first column being the \ shadowed prior bbox indices and the second column the \ shadowed gt indices (1-based). """ num_bboxes, num_gts = is_bbox_in_gt_core.shape if gt_priority is None: gt_priority = torch.arange(num_gts, device=is_bbox_in_gt_core.device) assert gt_priority.size(0) == num_gts # The bigger gt_priority, the more preferable to be assigned # The assigned inds are by default 0 (background) assigned_gt_inds = is_bbox_in_gt_core.new_zeros((num_bboxes, ), dtype=torch.long) # Shadowed bboxes are assigned to be background. But the corresponding # label is ignored during loss calculation, which is done through # shadowed_gt_inds shadowed_gt_inds = torch.nonzero(is_bbox_in_gt_shadow, as_tuple=False) if is_bbox_in_gt_core.sum() == 0: # No gt match shadowed_gt_inds[:, 1] += 1 # 1-based. For consistency issue return assigned_gt_inds, shadowed_gt_inds # The priority of each prior box and gt pair. If one prior box is # matched bo multiple gts. Only the pair with the highest priority # is saved pair_priority = is_bbox_in_gt_core.new_full((num_bboxes, num_gts), -1, dtype=torch.long) # Each bbox could match with multiple gts. # The following codes deal with this situation # Matched bboxes (to any gt). Shape: (num_pos_anchor, ) inds_of_match = torch.any(is_bbox_in_gt_core, dim=1) # The matched gt index of each positive bbox. Length >= num_pos_anchor # , since one bbox could match multiple gts matched_bbox_gt_inds = torch.nonzero(is_bbox_in_gt_core, as_tuple=False)[:, 1] # Assign priority to each bbox-gt pair. pair_priority[is_bbox_in_gt_core] = gt_priority[matched_bbox_gt_inds] _, argmax_priority = pair_priority[inds_of_match].max(dim=1) assigned_gt_inds[inds_of_match] = argmax_priority + 1 # 1-based # Zero-out the assigned anchor box to filter the shadowed gt indices is_bbox_in_gt_core[inds_of_match, argmax_priority] = 0 # Concat the shadowed indices due to overlapping with that out side of # effective scale. shape: (total_num_ignore, 2) shadowed_gt_inds = torch.cat( (shadowed_gt_inds, torch.nonzero(is_bbox_in_gt_core, as_tuple=False)), dim=0) # `is_bbox_in_gt_core` should be changed back to keep arguments intact. is_bbox_in_gt_core[inds_of_match, argmax_priority] = 1 # 1-based shadowed gt indices, to be consistent with `assigned_gt_inds` shadowed_gt_inds[:, 1] += 1 return assigned_gt_inds, shadowed_gt_inds
def forward( self, tokens_p: TextFieldTensors, tokens_h: TextFieldTensors, g_p: SparseAdjacencyFieldTensors, g_h: SparseAdjacencyFieldTensors, label: torch.Tensor = None, return_attention: bool = False, ) -> Dict[str, torch.Tensor]: """ GMN for NLI let B be batch size let N be p length let M be h length input : tokens_p in shape [*] (B, N) g_p["edge_index"] ouput : tensor dict """ # Shape: (batch_size, num_tokens, embedding_dim) # embedder will take out the desired entry, for ex: ["tokens"] embedded_p = self.embedder(tokens_p) embedded_h = self.embedder(tokens_h) # Shape: (num_batch_edges, edge_embedding_dim) # can use pass through if type is the information required # this is general for any kind of edge representation to GP2Vencoder # https://docs.allennlp.org/master/api/modules/token_embedders/embedding/ # inplace will change input!!! g_p_embedded = deepcopy(g_p) g_h_embedded = deepcopy(g_h) g_p_embedded["edge_attr"] = self.edge_embedder(g_p["edge_attr"]) g_h_embedded["edge_attr"] = self.edge_embedder(g_h["edge_attr"]) assert (not torch.any(torch.isnan(embedded_p))) assert (not torch.any(torch.isnan(embedded_h))) # Shape: (batch_size, num_tokens, projected_dim) embedded_p = self.projector(embedded_p) embedded_h = self.projector(embedded_h) assert (not torch.any(torch.isnan(embedded_p))) assert (not torch.any(torch.isnan(embedded_h))) # Shape: # node_attr : (num_tokens, embedding_dim) # batch_id : (num_tokens) sparse_p = tensor_op.dense2sparse( embedded_p, tokens_p["tokens"]["mask"]) #need to overload indexer for this sparse_h = tensor_op.dense2sparse(embedded_h, tokens_h["tokens"]["mask"]) assert (not torch.any(torch.isnan(sparse_p["data"]))) assert (not torch.any(torch.isnan(sparse_h["data"]))) # Shape: (batch_size, classifier_in_dim) if return_attention: cls_vector, attention_dict = self.encoder( sparse_p, sparse_h, g_p_embedded, g_h_embedded, return_attention=return_attention) else: cls_vector = self.encoder(sparse_p, sparse_h, g_p_embedded, g_h_embedded, return_attention=return_attention) # Shape: (batch_size, num_labels) logits = self.classifier(cls_vector) assert (not torch.any(torch.isnan(cls_vector))) # Shape: (batch_size, num_labels) probs = torch.nn.functional.softmax(logits, dim=1) # Shape: TensorDict output = {'probs': probs} if label is not None: #print(logits.size(), label.size()) self.accuracy(logits, label) # the two value can be kind of different for numerical isse IMO self.entropy(logits, label) output['loss'] = torch.nn.functional.cross_entropy(logits, label) if return_attention is True: output['attentions'] = attention_dict return output
def greedy_cos_idf(ref_embedding, ref_masks, ref_idf, hyp_embedding, hyp_masks, hyp_idf, all_layers=False): """ Compute greedy matching based on cosine similarity. Args: - :param: `ref_embedding` (torch.Tensor): embeddings of reference sentences, BxKxd, B: batch size, K: longest length, d: bert dimenison - :param: `ref_lens` (list of int): list of reference sentence length. - :param: `ref_masks` (torch.LongTensor): BxKxK, BERT attention mask for reference sentences. - :param: `ref_idf` (torch.Tensor): BxK, idf score of each word piece in the reference setence - :param: `hyp_embedding` (torch.Tensor): embeddings of candidate sentences, BxKxd, B: batch size, K: longest length, d: bert dimenison - :param: `hyp_lens` (list of int): list of candidate sentence length. - :param: `hyp_masks` (torch.LongTensor): BxKxK, BERT attention mask for candidate sentences. - :param: `hyp_idf` (torch.Tensor): BxK, idf score of each word piece in the candidate setence """ ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1)) hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1)) if all_layers: B, _, L, D = hyp_embedding.size() hyp_embedding = hyp_embedding.transpose(1, 2).transpose(0, 1)\ .contiguous().view(L*B, hyp_embedding.size(1), D) ref_embedding = ref_embedding.transpose(1, 2).transpose(0, 1)\ .contiguous().view(L*B, ref_embedding.size(1), D) batch_size = ref_embedding.size(0) sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)) masks = torch.bmm(hyp_masks.unsqueeze(2).float(), ref_masks.unsqueeze(1).float()) if all_layers: masks = masks.unsqueeze(0).expand(L, -1, -1, -1)\ .contiguous().view_as(sim) else: masks = masks.expand(batch_size, -1, -1)\ .contiguous().view_as(sim) masks = masks.float().to(sim.device) sim = sim * masks word_precision = sim.max(dim=2)[0] word_recall = sim.max(dim=1)[0] hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True)) ref_idf.div_(ref_idf.sum(dim=1, keepdim=True)) precision_scale = hyp_idf.to(word_precision.device) recall_scale = ref_idf.to(word_recall.device) if all_layers: precision_scale = precision_scale.unsqueeze(0)\ .expand(L, B, -1).contiguous().view_as(word_precision) recall_scale = recall_scale.unsqueeze(0)\ .expand(L, B, -1).contiguous().view_as(word_recall) P = (word_precision * precision_scale).sum(dim=1) R = (word_recall * recall_scale).sum(dim=1) F = 2 * P * R / (P + R) hyp_zero_mask = hyp_masks.sum(dim=1).eq(2) ref_zero_mask = ref_masks.sum(dim=1).eq(2) if all_layers: P = P.view(L, B) R = R.view(L, B) F = F.view(L, B) if torch.any(hyp_zero_mask): print("Warning: Empty candidate sentence; Setting precision to be 0.", file=sys.stderr) P = P.masked_fill(hyp_zero_mask, 0.) if torch.any(ref_zero_mask): print("Warning: Empty candidate sentence; Setting recall to be 0.", file=sys.stderr) R = R.masked_fill(ref_zero_mask, 0.) F = F.masked_fill(torch.isnan(F), 0.) return P, R, F
def nucleus_sampling(model, batch, batch_size, threshold=0.9, use_packed=True): model.eval() text_vecs = batch['text_vecs'].to(current_device) if use_packed: encoded = model.encoder(text_vecs, batch['text_lens'], use_packed=batch['use_packed']) else: encoded = model.encoder(text_vecs) encoder_output, encoder_hidden, attention_mask = encoded # 1 is __start__ starts = torch.Tensor( [1]).long().to(model.decoder.embedding.weight.device).expand( batch_size, 1).long() # expand to batch size decoder_hidden = encoder_hidden # greedy decoding here preds = [starts] scores = [] # track if each sample in the mini batch is finished # if all finished, stop predicting finish_mask = torch.Tensor([0] * batch_size).byte().to( model.decoder.embedding.weight.device) xs = starts _attn_w_log = [] for ts in range(100): decoder_output, decoder_hidden, attn_w_log = model.decoder( xs, decoder_hidden, encoded) # decoder_output: [batch, time, vocab] _probs = torch.softmax(decoder_output, dim=-1)[0][0] _sorted_probs, _sorted_indices = torch.sort(_probs, descending=True) cumulative_probs = torch.cumsum(_sorted_probs, dim=-1) selected_probs = cumulative_probs < threshold selected_probs[1:] = selected_probs[:-1].clone() selected_probs[0] = True _sorted_probs[~selected_probs] = 0 P = _sorted_probs.sum() _sorted_probs /= P chosen_index = torch.multinomial(_sorted_probs, 1) _preds = _sorted_indices[chosen_index] _scores = torch.log(_probs[_preds]) _preds = _preds.unsqueeze(0) preds.append(_preds) _attn_w_log.append(attn_w_log) scores.append(_scores.view(-1) * (finish_mask == 0).float()) finish_mask += (_preds == 2).byte().view(-1) if not (torch.any(~finish_mask.bool())): break xs = _preds preds = torch.cat(preds, dim=-1) return preds, torch.sum(torch.Tensor(scores))
def forward(self, x, t): # pylint: disable=arguments-differ assert len(x) == 1 + 2 * self.n_vectors + self.n_scales x_intensity = x[0] x_regs = x[1:1 + self.n_vectors] x_spreads = x[1 + self.n_vectors:1 + 2 * self.n_vectors] x_scales = [] if self.n_scales: x_scales = x[1 + 2 * self.n_vectors:1 + 2 * self.n_vectors + self.n_scales] assert len(t) == 1 + self.n_vectors + 1 target_intensity = t[0] target_regs = t[1:1 + self.n_vectors] target_scale = t[-1] bce_masks = (target_intensity[:, :-1] + target_intensity[:, -1:]) > 0.5 if not torch.any(bce_masks): return None, None, None batch_size = x_intensity.shape[0] LOG.debug('batch size = %d', batch_size) bce_x_intensity = x_intensity bce_target_intensity = target_intensity[:, :-1] if self.bce_blackout: bce_x_intensity = bce_x_intensity[:, self.bce_blackout] bce_masks = bce_masks[:, self.bce_blackout] bce_target_intensity = bce_target_intensity[:, self.bce_blackout] LOG.debug('BCE: target = %s, mask = %s', bce_target_intensity.shape, bce_masks.shape) bce_target = torch.masked_select(bce_target_intensity, bce_masks) bce_weight = None if self.background_weight != 1.0: bce_weight = torch.ones_like(bce_target) bce_weight[bce_target == 0] = self.background_weight ce_loss = torch.nn.functional.binary_cross_entropy_with_logits( torch.masked_select(bce_x_intensity, bce_masks), bce_target, weight=bce_weight, ) reg_losses = [None for _ in target_regs] reg_masks = target_intensity[:, :-1] > 0.5 if torch.any(reg_masks): weight = None if self.multiplicity_correction: assert len(target_regs) == 2 lengths = torch.norm(target_regs[0] - target_regs[1], dim=2) multiplicity = (lengths - 3.0) / self.independence_scale multiplicity = torch.clamp(multiplicity, min=1.0) multiplicity = torch.masked_select(multiplicity, reg_masks) weight = 1.0 / multiplicity reg_losses = [] for i, (x_reg, x_spread, target_reg) in enumerate( zip(x_regs, x_spreads, target_regs)): if hasattr(self.regression_loss, 'scale'): assert self.scales_to_kp is not None self.regression_loss.scale = torch.masked_select( torch.clamp(target_scale * self.scales_to_kp[i], 0.1, 1000.0), # pylint: disable=unsubscriptable-object reg_masks, ) reg_losses.append( self.regression_loss( torch.masked_select(x_reg[:, :, 0], reg_masks), torch.masked_select(x_reg[:, :, 1], reg_masks), torch.masked_select(x_spread, reg_masks), torch.masked_select(target_reg[:, :, 0], reg_masks), torch.masked_select(target_reg[:, :, 1], reg_masks), weight=weight, ) / 1000.0 / batch_size) scale_losses = [] if x_scales: scale_losses = [ torch.nn.functional.l1_loss( torch.masked_select(x_scale, reg_masks), torch.masked_select(target_scale * scale_to_kp, reg_masks), reduction='sum', ) / 1000.0 / batch_size for x_scale, scale_to_kp in zip(x_scales, self.scales_to_kp) ] return [ce_loss] + reg_losses + scale_losses
def train_with_pruning_callback( tmpdir, parameters_to_prune=False, use_global_unstructured=False, pruning_fn="l1_unstructured", use_lottery_ticket_hypothesis=False, strategy=None, accelerator="cpu", devices=1, ): model = TestModel() # Weights are random. None is 0 assert torch.all(model.layer.mlp_2.weight != 0) pruning_kwargs = { "pruning_fn": pruning_fn, "amount": 0.3, "use_global_unstructured": use_global_unstructured, "use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis, "verbose": 1, } if parameters_to_prune: pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")] else: if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"): pruning_kwargs["parameter_names"] = ["weight"] else: pruning_kwargs["parameter_names"] = ["weight", "bias"] if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"): pruning_kwargs["pruning_dim"] = 0 if pruning_fn == "ln_structured": pruning_kwargs["pruning_norm"] = 1 # Misconfiguration checks if isinstance(pruning_fn, str) and pruning_fn.endswith( "_structured") and use_global_unstructured: with pytest.raises( MisconfigurationException, match="is supported with `use_global_unstructured=True`"): ModelPruning(**pruning_kwargs) return if ModelPruning._is_pruning_method( pruning_fn) and not use_global_unstructured: with pytest.raises(MisconfigurationException, match="currently only supported with"): ModelPruning(**pruning_kwargs) return pruning = ModelPruning(**pruning_kwargs) trainer = Trainer( default_root_dir=tmpdir, enable_progress_bar=False, enable_model_summary=False, enable_checkpointing=False, logger=False, limit_train_batches=10, limit_val_batches=2, max_epochs=10, strategy=strategy, accelerator=accelerator, devices=devices, callbacks=pruning, ) trainer.fit(model) trainer.test(model) if not strategy: # Check some have been pruned assert torch.any(model.layer.mlp_2.weight == 0)
def _compute_perturbation( # pylint: disable=W0221 self, x: "torch.Tensor", y: "torch.Tensor", mask: Optional["torch.Tensor"] ) -> "torch.Tensor": """ Compute perturbations. :param x: Current adversarial examples. :param y: Target values (class labels) one-hot-encoded of shape `(nb_samples, nb_classes)` or indices of shape (nb_samples,). Only provide this parameter if you'd like to use true labels when crafting adversarial samples. Otherwise, model predictions are used as labels to avoid the "label leaking" effect (explained in this paper: https://arxiv.org/abs/1611.01236). Default is `None`. :param mask: An array with a mask broadcastable to input `x` defining where to apply adversarial perturbations. Shape needs to be broadcastable to the shape of x and can also be of the same shape as `x`. Any features for which the mask is zero will not be adversarially perturbed. :return: Perturbations. """ import torch # lgtm [py/repeated-import] # Pick a small scalar to avoid division by 0 tol = 10e-8 # Get gradient wrt loss; invert it if attack is targeted grad = self.estimator.loss_gradient(x=x, y=y) * (1 - 2 * int(self.targeted)) # Write summary if self.summary_writer is not None: self.summary_writer.add_scalar( "gradients/norm-L1/batch-{}".format(self._batch_id), np.linalg.norm(grad.flatten(), ord=1), global_step=self._i_max_iter, ) self.summary_writer.add_scalar( "gradients/norm-L2/batch-{}".format(self._batch_id), np.linalg.norm(grad.flatten(), ord=2), global_step=self._i_max_iter, ) self.summary_writer.add_scalar( "gradients/norm-Linf/batch-{}".format(self._batch_id), np.linalg.norm(grad.flatten(), ord=np.inf), global_step=self._i_max_iter, ) if hasattr(self.estimator, "compute_losses"): losses = self.estimator.compute_losses(x=x, y=y) for key, value in losses.items(): self.summary_writer.add_scalar( "loss/{}/batch-{}".format(key, self._batch_id), np.mean(value.detach().cpu().numpy()), global_step=self._i_max_iter, ) # Check for nan before normalisation an replace with 0 if torch.any(grad.isnan()): logger.warning("Elements of the loss gradient are NaN and have been replaced with 0.0.") grad[grad.isnan()] = 0.0 # Apply mask if mask is not None: grad = torch.where(mask == 0.0, torch.tensor(0.0).to(self.estimator.device), grad) # Apply norm bound if self.norm in ["inf", np.inf]: grad = grad.sign() elif self.norm == 1: ind = tuple(range(1, len(x.shape))) grad = grad / (torch.sum(grad.abs(), dim=ind, keepdims=True) + tol) # type: ignore elif self.norm == 2: ind = tuple(range(1, len(x.shape))) grad = grad / (torch.sqrt(torch.sum(grad * grad, axis=ind, keepdims=True)) + tol) # type: ignore assert x.shape == grad.shape return grad
def run(self): # ---- Subscribe ---- with self.neuron: # ---- Weights ---- self.row = self.neuron.metagraph.row # --- Run state --- self.global_step = 0 self.best_train_loss = math.inf # --- Loop forever --- for self.epoch in range(self.config.miner.n_epochs): try: # ---- Serve ---- self.neuron.axon.serve(self.model) # ---- Train Model ---- self.train() self.scheduler.step() # If model has borked for some reason, we need to make sure it doesn't emit weights # Instead, reload into previous version of model if torch.any( torch.isnan( torch.cat([ param.view(-1) for param in self.model.parameters() ]))): self.model, self.optimizer = self.model_toolbox.load_model( self.config) continue # ---- Emit row-weights ---- self.neuron.metagraph.set_weights( self.row, wait_for_inclusion=True ) # Sets my row-weights on the chain. # ---- Sync metagraph ---- self.neuron.metagraph.sync( ) # Pulls the latest metagraph state (with my update.) self.row = self.neuron.metagraph.row # --- Epoch logs ---- print(self.neuron.axon.__full_str__()) print(self.neuron.dendrite.__full_str__()) print(self.neuron.metagraph) # ---- Update Tensorboard ---- self.neuron.dendrite.__to_tensorboard__( self.tensorboard, self.global_step) self.neuron.metagraph.__to_tensorboard__( self.tensorboard, self.global_step) self.neuron.axon.__to_tensorboard__( self.tensorboard, self.global_step) # ---- Save best loss and model ---- if self.training_loss and self.epoch % 10 == 0: if self.training_loss < self.best_train_loss: self.best_train_loss = self.training_loss # update best train loss self.model_toolbox.save_model( self.config.miner.full_path, { 'epoch': self.epoch, 'model_state_dict': self.model.state_dict(), 'loss': self.best_train_loss, 'optimizer_state_dict': self.optimizer.state_dict(), }) self.tensorboard.add_scalar( 'Neuron/Train_loss', self.training_loss, self.global_step) # --- Catch Errors ---- except Exception as e: logger.error('Exception in training script with error: {}', e) logger.info(traceback.print_exc()) logger.info('Continuing to train.') time.sleep(1)
def main(): summary = SummaryWriter('./log') tr_set, dv_set, feat_dim, msg = load_dataset(args.njobs, args.gpu, args.pin_memory, config['hparas']['curriculum'] > 0, **config['data']) verbose(msg) # model select print('Model initializing\n') net = torch.nn.DataParallel( AttentionModel(120, hidden_size=args.hidden_size, dropout_p=args.dropout_p, use_attn=args.attn_use, stacked_encoder=args.stacked_encoder, attn_len=args.attn_len)) # net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = arg0s.attn_use) net = net.cuda() print(net) optimizer = optim.Adam(net.parameters(), lr=args.learning_rate) scheduler = ExponentialLR(optimizer, 0.5) # check point load # Check point load print('Trying Checkpoint Load\n') ckpt_dir = 'ckpt_dir' if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) best_loss = 200000. ckpt_path = os.path.join(ckpt_dir, args.ck_name) if os.path.exists(ckpt_path): ckpt = torch.load(ckpt_path) try: net.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) best_loss = ckpt['best_loss'] print('checkpoint is loaded !') print('current best loss : %.4f' % best_loss) except RuntimeError as e: print('wrong checkpoint\n') else: print('checkpoint not exist!') print('current best loss : %.4f' % best_loss) print('Training Start!') # train iteration = 0 train_losses = [] test_losses = [] for epoch in range(args.num_epochs): n = 0 avg_loss = 0 net.train() for input in tqdm(tr_set): tr_noisy_set, feat_dim = load_noisy_dataset("train", input[0], args.njobs, args.gpu, args.pin_memory, config['hparas']['curriculum'] > 0, **config['data_noisy']) for input_noisy in tr_noisy_set: train_clean_feat, feat_len = fetch_data(input) train_noisy_feat, feat_len = fetch_data(input_noisy) iteration += 1 # feed data train_mixed_feat, attn_weight = net(train_noisy_feat) if train_mixed_feat.shape == train_clean_feat.shape: loss = F.mse_loss(train_mixed_feat, train_clean_feat, True) if torch.any(torch.isnan(loss)): torch.save( {'clean_mag': train_clean_feat, 'noisy_mag': train_noisy_feat, 'out_mag': train_mixed_feat}, 'nan_mag') raise ('loss is NaN') avg_loss += loss.item() n += 1 # gradient optimizer optimizer.zero_grad() loss.backward() # update weight optimizer.step() avg_loss /= n print('result:') print('[epoch: {}, iteration: {}] avg_loss : {:.4f}'.format(epoch, iteration, avg_loss)) summary.add_scalar('Train Loss', avg_loss, iteration) train_losses.append(avg_loss) if (len(train_losses) > 2) and (train_losses[-2] < avg_loss): print("Learning rate Decay") scheduler.step() # test phase n = 0 avg_test_loss = 0 net.eval() with torch.no_grad(): for input in tqdm(dv_set): dv_noisy_set, feat_dim = load_noisy_dataset("dev", input[0], args.njobs, args.gpu, args.pin_memory, config['hparas']['curriculum'] > 0, **config['data_noisy']) for input_noisy in dv_noisy_set: test_clean_feat = input[1].to(device='cuda') test_noisy_feat = input_noisy[1].to(device='cuda') test_mixed_feat, logits_attn_weight = net(test_noisy_feat) if test_mixed_feat.shape == test_clean_feat.shape: test_loss = F.mse_loss(test_mixed_feat, test_clean_feat, True) avg_test_loss += test_loss.item() n += 1 avg_test_loss /= n test_losses.append(avg_test_loss) summary.add_scalar('Test Loss', avg_test_loss, iteration) print('[epoch: {}, iteration: {}] test loss : {:.4f} '.format(epoch, iteration, avg_test_loss)) if avg_test_loss < best_loss: best_loss = avg_test_loss # Note: optimizer also has states ! don't forget to save them as well. ckpt = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'best_loss': best_loss} torch.save(ckpt, ckpt_path) print('checkpoint is saved !')
#### END CODE HERE #### # Make sure shapes are correct assert tuple( fake_image_and_labels.shape) == (len(real), fake.detach().shape[1] + image_one_hot_labels.shape[1], 28, 28) assert tuple( real_image_and_labels.shape) == (len(real), real.shape[1] + image_one_hot_labels.shape[1], 28, 28) # Make sure that enough predictions were made assert len(disc_real_pred) == len(real) # Make sure that the inputs are different assert torch.any(fake_image_and_labels != real_image_and_labels) # Shapes must match assert tuple(fake_image_and_labels.shape) == tuple( real_image_and_labels.shape) assert tuple(disc_fake_pred.shape) == tuple(disc_real_pred.shape) disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred)) disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred)) disc_loss = (disc_fake_loss + disc_real_loss) / 2 disc_loss.backward(retain_graph=True) disc_opt.step() # Keep track of the average discriminator loss discriminator_losses += [disc_loss.item()]
def test_retrieval(model, criterion, transforms_cuda, device, epoch, args): accuracy = [AverageMeter(),AverageMeter(),AverageMeter(),AverageMeter()] model.eval() def tr(x): B = x.size(0); assert B == 1 test_sample = x.size(2)//(args.seq_len*args.num_seq) return transforms_cuda(x)\ .view(3,test_sample,args.num_seq,args.seq_len,args.img_dim,args.img_dim).permute(1,2,0,3,4,5) with torch.no_grad(): transform = transforms.Compose([ A.CenterCrop(size=(224,224)), A.Scale(size=(args.img_dim,args.img_dim)), A.ColorJitter(0.2, 0.2, 0.2, 0.1, p=0.3, consistent=True), A.ToTensor()]) if args.dataset == 'ucf101': d_class = UCF101LMDB elif args.dataset == 'ucf101-f': d_class = UCF101Flow_LMDB elif args.dataset == 'hmdb51': d_class = HMDB51LMDB elif args.dataset == 'hmdb51-f': d_class = HMDB51Flow_LMDB train_dataset = d_class(mode='train', transform=transform, num_frames=args.num_seq*args.seq_len, ds=args.ds, which_split=1, return_label=True, return_path=True) print('train dataset size: %d' % len(train_dataset)) test_dataset = d_class(mode='test', transform=transform, num_frames=args.num_seq*args.seq_len, ds=args.ds, which_split=1, return_label=True, return_path=True) print('test dataset size: %d' % len(test_dataset)) train_sampler = data.SequentialSampler(train_dataset) test_sampler = data.SequentialSampler(test_dataset) train_loader = data.DataLoader(train_dataset, batch_size=1, sampler=train_sampler, shuffle=False, num_workers=args.workers, pin_memory=True) test_loader = data.DataLoader(test_dataset, batch_size=1, sampler=test_sampler, shuffle=False, num_workers=args.workers, pin_memory=True) if args.dirname is None: dirname = 'feature' else: dirname = args.dirname if os.path.exists(os.path.join(os.path.dirname(args.test), dirname, '%s_test_feature.pth.tar' % args.dataset)): test_feature = torch.load(os.path.join(os.path.dirname(args.test), dirname, '%s_test_feature.pth.tar' % args.dataset)).to(device) test_label = torch.load(os.path.join(os.path.dirname(args.test), dirname, '%s_test_label.pth.tar' % args.dataset)).to(device) else: try: os.makedirs(os.path.join(os.path.dirname(args.test), dirname)) except: pass print('Computing test set feature ... ') test_feature = None test_label = [] test_vname = [] sample_id = 0 for idx, (input_seq, target) in tqdm(enumerate(test_loader), total=len(test_loader)): B = 1 input_seq = input_seq.to(device, non_blocking=True) input_seq = tr(input_seq) current_target, vname = target current_target = current_target.to(device, non_blocking=True) test_sample = input_seq.size(0) input_seq = input_seq.squeeze(1) logit, feature = model(input_seq) if test_feature is None: test_feature = torch.zeros(len(test_dataset), feature.size(-1), device=feature.device) test_feature[sample_id,:] = feature.mean(0) test_label.append(current_target) test_vname.append(vname) sample_id += 1 print(test_feature.size()) # test_feature = torch.stack(test_feature, dim=0) test_label = torch.cat(test_label, dim=0) torch.save(test_feature, os.path.join(os.path.dirname(args.test), dirname, '%s_test_feature.pth.tar' % args.dataset)) torch.save(test_label, os.path.join(os.path.dirname(args.test), dirname, '%s_test_label.pth.tar' % args.dataset)) with open(os.path.join(os.path.dirname(args.test), dirname, '%s_test_vname.pkl' % args.dataset), 'wb') as fp: pickle.dump(test_vname, fp) if os.path.exists(os.path.join(os.path.dirname(args.test), dirname, '%s_train_feature.pth.tar' % args.dataset)): train_feature = torch.load(os.path.join(os.path.dirname(args.test), dirname, '%s_train_feature.pth.tar' % args.dataset)).to(device) train_label = torch.load(os.path.join(os.path.dirname(args.test), dirname, '%s_train_label.pth.tar' % args.dataset)).to(device) else: print('Computing train set feature ... ') train_feature = None train_label = [] train_vname = [] sample_id = 0 for idx, (input_seq, target) in tqdm(enumerate(train_loader), total=len(train_loader)): B = 1 input_seq = input_seq.to(device, non_blocking=True) input_seq = tr(input_seq) current_target, vname = target current_target = current_target.to(device, non_blocking=True) test_sample = input_seq.size(0) input_seq = input_seq.squeeze(1) logit, feature = model(input_seq) if train_feature is None: train_feature = torch.zeros(len(train_dataset), feature.size(-1), device=feature.device) train_feature[sample_id,:] = feature.mean(0) # train_feature[sample_id,:] = feature[:,-1,:].mean(0) train_label.append(current_target) train_vname.append(vname) sample_id += 1 # train_feature = torch.stack(train_feature, dim=0) print(train_feature.size()) train_label = torch.cat(train_label, dim=0) torch.save(train_feature, os.path.join(os.path.dirname(args.test), dirname, '%s_train_feature.pth.tar' % args.dataset)) torch.save(train_label, os.path.join(os.path.dirname(args.test), dirname, '%s_train_label.pth.tar' % args.dataset)) with open(os.path.join(os.path.dirname(args.test), dirname, '%s_train_vname.pkl' % args.dataset), 'wb') as fp: pickle.dump(train_vname, fp) ks = [1,5,10,20,50] NN_acc = [] # centering test_feature = test_feature - test_feature.mean(dim=0, keepdim=True) train_feature = train_feature - train_feature.mean(dim=0, keepdim=True) # normalize test_feature = F.normalize(test_feature, p=2, dim=1) train_feature = F.normalize(train_feature, p=2, dim=1) # dot product sim = test_feature.matmul(train_feature.t()) torch.save(sim, os.path.join(os.path.dirname(args.test), dirname, '%s_sim.pth.tar' % args.dataset)) for k in ks: topkval, topkidx = torch.topk(sim, k, dim=1) acc = torch.any(train_label[topkidx] == test_label.unsqueeze(1), dim=1).float().mean().item() NN_acc.append(acc) print('%dNN acc = %.4f' % (k, acc)) args.logger.log('NN-Retrieval on %s:' % args.dataset) for k,acc in zip(ks, NN_acc): args.logger.log('\t%dNN acc = %.4f' % (k, acc)) with open(os.path.join(os.path.dirname(args.test), dirname, '%s_test_vname.pkl' % args.dataset), 'rb') as fp: test_vname = pickle.load(fp) with open(os.path.join(os.path.dirname(args.test), dirname, '%s_train_vname.pkl' % args.dataset), 'rb') as fp: train_vname = pickle.load(fp) sys.exit(0)
def infer( self, phonemes: torch.LongTensor, combine_strategy: str = "concat", max_len: int = 1024, stop_threshold: float = 0.25, verbose: bool = False, stop_at_stop_token: bool = True, ) -> Tuple[torch.Tensor, torch.LongTensor]: """ Run inference loop to generate a new spectrogram :param phonemes: of shape (batch size, phonemes length) :param combine_strategy: "concat" to concatenate outputs from consecutive iterations, "replace" to replace previous output using output from current iteration :param max_len: maximum length of generated spectrogram :param stop_threshold: value in (-1 , 1) above which sigmoid(stop_pred) is considered to indicate end of predicted sequence :param verbose: if true, prints progress info every 10 steps (useful for cpu) :param stop_at_stop_token: if true, ignores the stop threshold and stops only at max_len :return: tuple (spectrogram, stop_idx) where spectrogram shape = (batch size, num_mel_coeffs, spectrogram length) stop_idx shape = (batch size), contains end index for each spectrogram in batch """ assert combine_strategy in {"concat", "replace"} assert -1. < stop_threshold < 1. batch_size = phonemes.shape[0] zeros = torch.zeros((batch_size, 1, self.num_mel_coeffs)) spectrogram = zeros.clone() stop = torch.zeros(batch_size, dtype=torch.long) while not torch.all(stop > 0): iteration = spectrogram.shape[1] still_running = stop == 0 if verbose and iteration % 10 == 0: number_of_running_samples = sum(still_running) print(f"reached {iteration=}, {number_of_running_samples=}...") _, generated, stop_pred, _ = self.forward(phonemes, spectrogram) stop_pred = stop_pred.view( -1, generated.shape[1]) # view as (batch, len) if combine_strategy == "concat": generated_slice = generated[:, -1, :].view( batch_size, 1, self.num_mel_coeffs) spectrogram = torch.cat([spectrogram, generated_slice], dim=1) if stop_at_stop_token: stops_now = torch.sigmoid(stop_pred[:, -1]) > stop_threshold still_running_stops_now = still_running * stops_now stop = torch.where( still_running_stops_now, (stops_now.to(dtype=torch.long) * iteration) + 1, stop) elif combine_strategy == "replace": spectrogram = torch.cat([zeros, generated], dim=1) if stop_at_stop_token: stops_now = torch.any( torch.sigmoid(stop_pred) > stop_threshold, dim=1) still_running_stops_now = stops_now * still_running stop = torch.where(still_running_stops_now, torch.argmax(stop_pred, dim=1) + 1, stop) if max(spectrogram.shape) > max_len: if verbose: print(f"stopped at {max_len=}") break stop_at_end = torch.ones_like(stop, dtype=torch.long) * max_len stop: torch.LongTensor = torch.where(stop == 0, stop_at_end, stop) spectrogram: torch.Tensor = spectrogram.transpose(1, 2) return spectrogram[:, :, 1:], stop
def forward(self, backbone_outputs, match_coords, is_train): # Enumerate network over pyramids. fpn_outputs = [] targets = [] classifications = [] pyramid_output = None num_layers = len(backbone_outputs) if is_train: target_coords = [ME.utils.batched_coordinates([match[i][0] for match in match_coords]) for i in range(num_layers)] ambiguous_coords = [ME.utils.batched_coordinates([match[i][1] for match in match_coords]) for i in range(num_layers)] for layer_idx in reversed(range(num_layers)): conv_feat_layer = self.get_layer('conv_feat', layer_idx) conv_cls_layer = self.get_layer('conv_cls', layer_idx) conv_up_layer = self.get_layer('conv_up', layer_idx) # Current feature curr_feat = backbone_outputs[layer_idx] # Add previous layer output if pyramid_output is not None: assert pyramid_output.tensor_stride == curr_feat.tensor_stride curr_feat = curr_feat + pyramid_output # Two branches: upsample and fpn feature and classification # 1. FPN feature & classification fpn_feat = conv_feat_layer(curr_feat) feat_cls = conv_cls_layer(fpn_feat) pred_prob = F.softmax(feat_cls.F, 1)[:, 1] # target calculation target = None if is_train: target = torch.zeros(len(fpn_feat), dtype=torch.long) pos_ins = utils.map_coordinates(fpn_feat, torch.cat(ambiguous_coords[:layer_idx + 1]), force_stride=True)[0] target[pos_ins] = self.config.ignore_label pos_ins = utils.map_coordinates(fpn_feat, torch.cat(target_coords[:layer_idx + 1]), force_stride=True)[0] target[pos_ins] = 1 # Get keep labels keep = (pred_prob > self.config.sfpn_min_confidence).cpu() if is_train: # Force put GT labels within keep keep |= target == 1 if torch.any(keep): # Prune and upsample pyramid_output = conv_up_layer(self.pruning(curr_feat, keep)) # Generate final feature for current level final_pruned = self.pruning(fpn_feat, keep) else: pyramid_output = None final_pruned = None # Post processing classifications.insert(0, feat_cls) targets.insert(0, target) fpn_outputs.insert(0, final_pruned) return fpn_outputs, targets, classifications