Ejemplo n.º 1
0
def columnwise_clamp(
    X: Tensor,
    lower: Optional[Union[float, Tensor]] = None,
    upper: Optional[Union[float, Tensor]] = None,
) -> Tensor:
    r"""Clamp values of a Tensor in column-wise fashion (with support for t-batches).

    This function is useful in conjunction with optimizers from the torch.optim
    package, which don't natively handle constraints. If you apply this after
    a gradient step you can be fancy and call it "projected gradient descent".

    Args:
        X: The `b x n x d` input tensor. If 2-dimensional, `b` is assumed to be 1.
        lower: The column-wise lower bounds. If scalar, apply bound to all columns.
        upper: The column-wise upper bounds. If scalar, apply bound to all columns.

    Returns:
        The clamped tensor.
    """
    min_bounds = _expand_bounds(lower, X)
    max_bounds = _expand_bounds(upper, X)
    if min_bounds is not None and max_bounds is not None:
        if torch.any(min_bounds > max_bounds):
            raise ValueError("Minimum values must be <= maximum values")
    Xout = X
    if min_bounds is not None:
        Xout = Xout.max(min_bounds)
    if max_bounds is not None:
        Xout = Xout.min(max_bounds)
    return Xout
Ejemplo n.º 2
0
    def _step(
        self,
        positive_predicates: List[str],
        positive_observations: torch.FloatTensor,
        negative_predicates_list: List[List[str]],
        negative_observations_list: List[torch.FloatTensor],
    ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        """ Performs a forward pass of the model during training.

    Used in both training_step and validation_step, this function accepts a set
    of predicate names and performs a forward pass of the hypothesis generation
    training routine. This involves generating negative samples for each
    positive example and evaluating metrics that quantify the difference
    between the two.

    Args:
      positive_predicates: List of string names associated with positive batch
      positive_observations: Packed tensor containing info corresponding to
        positive_predicates. Shape: <seq_len> X <batch_size> X <dim>
      negative_predicates_list: List of predicate lists.
        negative_predicates_list[i] corresponds to the i'th negative sample
        batch. Each batch should be the same size as the positive batch.
      negative_observations_list: List of packed tensors.
        negative_observations_list[i] corresponds to the i'th negative sample.

    Returns:
      The first element is the loss tensor, used for back propagation. The
      second element is a dict containing all extra metrics.

    """
        # Do positive checks
        assert isinstance(positive_observations, torch.Tensor), \
          f"Err: positive_observations is {type(positive_observations)}"
        assert len(positive_observations.shape) == 3
        _, actual_batch_size, actual_dim = positive_observations.shape
        assert len(positive_predicates) == actual_batch_size
        assert self.hparams.positives_per_batch == actual_batch_size
        assert self.hparams.dim == actual_dim

        # no negative checks
        assert len(negative_predicates_list) == len(negative_observations_list)
        for n_preds, n_obs in zip(negative_predicates_list,
                                  negative_observations_list):
            assert isinstance(n_obs, torch.Tensor)
            assert len(n_obs.shape) == 3
            _, actual_batch_size, actual_dim = n_obs.shape
            assert len(n_preds) == actual_batch_size
            assert self.hparams.positives_per_batch == actual_batch_size
            assert self.hparams.dim == actual_dim

        positive_predictions = self.forward(positive_observations)
        # We cannot tolerate an error on a positive sample
        # An error occurs if any positive prediction is _not_ finite
        # Note that `~` is bitwise "not" for our boolean matrix
        if torch.any(~torch.isfinite(positive_predictions.detach().cpu())):
            print(positive_predicates)
            raise ValueError("Invalid positive sample")

        partial_losses = []
        for negative_predicates, negative_observations in zip(
                negative_predicates_list, negative_observations_list):
            negative_predictions = self.forward(negative_observations)
            # We CAN tolerate an error on a negative sample
            if torch.any(~torch.isfinite(negative_predictions.detach().cpu())):
                # print debug info
                print("ERROR: Encountered an issue with a negative predicate:")
                print("Negative Predicate Scores:")
                print(negative_predictions)
                print("Negative Predicates")
                print(negative_predicates)
            else:
                partial_losses.append(
                    self.loss_fn(
                        positive_predictions, negative_predictions,
                        positive_predictions.new_ones(
                            len(positive_predictions))))
        assert len(
            partial_losses) > 0, "Failure occurred on all negative batches."
        # End of batch
        loss = torch.mean(torch.stack(partial_losses))
        return (
            loss,
            dict(  # pbar metrics
            ))
Ejemplo n.º 3
0
def train_epoch(
        train_loader: torch.utils.data.DataLoader, base_model: torch.nn.Module,
        classification_layer: torch.nn.Module, forg_layer: torch.nn.Module,
        epoch: int, optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
        callback: Optional[VisdomLogger], device: torch.device, args: Any):
    """ Trains the network for one epoch

        Parameters
        ----------
        train_loader: torch.utils.data.DataLoader
            Iterable that loads the training set (x, y) tuples
        base_model: torch.nn.Module
            The model architecture that "extract features" from signatures
        classification_layer: torch.nn.Module
            The classification layer (from features to predictions of which user
            wrote the signature)
        forg_layer: torch.nn.Module
            The forgery prediction layer (from features to predictions of whether
            the signature is a forgery). Only used in args.forg = True
        epoch: int
            The current epoch (used for reporting)
        optimizer: torch.optim.Optimizer
            The optimizer (already initialized)
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler
            The learning rate scheduler
        callback: VisdomLogger (optional)
            A callback to report the training progress
        device: torch.device
            The device (CPU or GPU) to use for training
        args: Namespace
            Extra arguments used for training:
            args.forg: bool
                Whether forgeries are being used for training
            args.lamb: float
                The weight used for the forgery loss (training with forgeries only)

        Returns
        -------
        None
        """

    step = 0
    n_steps = len(train_loader)
    for batch in train_loader:
        x, y = batch[0], batch[1]
        x = torch.tensor(x, dtype=torch.float).to(device)
        y = torch.tensor(y, dtype=torch.long).to(device)
        yforg = torch.tensor(batch[2], dtype=torch.float).to(device)

        # Forward propagation
        features = base_model(x)

        if args.forg:
            if args.loss_type == 'L1':
                # Eq (3) in https://arxiv.org/abs/1705.05787
                logits = classification_layer(features)
                class_loss = F.cross_entropy(logits, y)

                forg_logits = forg_layer(features).squeeze()
                forg_loss = F.binary_cross_entropy_with_logits(
                    forg_logits, yforg)

                loss = (1 - args.lamb) * class_loss
                loss += args.lamb * forg_loss
            else:
                # Eq (4) in https://arxiv.org/abs/1705.05787
                if torch.any(yforg == 0):
                    logits = classification_layer(features[yforg == 0])
                    class_loss = F.cross_entropy(logits, y[yforg == 0])
                else:
                    class_loss = 0

                forg_logits = forg_layer(features).squeeze()
                forg_loss = F.binary_cross_entropy_with_logits(
                    forg_logits, yforg)

                loss = (1 - args.lamb) * class_loss
                loss += args.lamb * forg_loss
        else:
            # Eq (1) in https://arxiv.org/abs/1705.05787
            logits = classification_layer(features)
            loss = class_loss = F.cross_entropy(logits, y)

        # Back propagation
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(optimizer.param_groups[0]['params'],
                                        10)

        # Update weights
        optimizer.step()

        # Logging
        if callback and step % 100 == 0:
            iteration = epoch + (step / n_steps)
            callback.scalar('class_loss', iteration, class_loss.detach())

            pred = logits.argmax(1)
            if args.loss_type == 'L1': acc = y.eq(pred).float().mean()
            else: acc = y[yforg == 0].eq(pred[yforg == 0]).float().mean()
            callback.scalar('train_acc', epoch + (step / n_steps),
                            acc.detach())
            if args.forg:
                forg_pred = forg_logits > 0
                forg_acc = yforg.long().eq(forg_pred.long()).float().mean()
                callback.scalar('forg_loss', iteration, forg_loss.detach())
                callback.scalar('forg_acc', iteration, forg_acc.detach())

        step += 1
    lr_scheduler.step()
Ejemplo n.º 4
0
def infill_with_ilm(model,
                    special_tokens_to_ids,
                    x,
                    num_infills=1,
                    max_sequence_length=256,
                    nucleus=0.95):

    _sep_id = special_tokens_to_ids['<|startofinfill|>']
    _end_span_id = special_tokens_to_ids['<|endofinfill|>']
    _special_ids = special_tokens_to_ids.values()

    # Make sure example doesn't already ends with [sep]
    if x[-1] == _sep_id:
        x = x[:-1]

    # Count number of blanks
    blank_idxs = []
    for i, tok_id in enumerate(x):
        if tok_id in _special_ids:
            blank_idxs.append(i)
    k = len(blank_idxs)
    if k == 0:
        raise ValueError()

    # Decode until we have that many blanks
    with torch.no_grad():
        device = next(model.parameters()).device
        context = torch.tensor(x + [_sep_id], dtype=torch.long,
                               device=device).unsqueeze(0).repeat(
                                   num_infills, 1)

        terminated = []

        while context.shape[0] > 0:
            logits = model(context)[0][:, -1]
            next_tokens = sample_from_logits(logits, nucleus=nucleus)
            context = torch.cat((context, next_tokens), dim=1)

            num_predicted_spans = (context == _end_span_id).long().sum(dim=1)

            terminate_expected = num_predicted_spans >= k
            terminate_toolong = torch.ones_like(context).long().sum(
                dim=1) >= max_sequence_length
            terminate = terminate_expected | terminate_toolong

            if torch.any(terminate):
                terminated_seqs = context[terminate, len(x) + 1:]
                terminated.extend(
                    [list(s) for s in terminated_seqs.cpu().numpy()])
                context = context[~terminate, :]

    # Collect generated spans
    generated_spans = []
    for gen in terminated:
        spans = []
        while _end_span_id in gen:
            spans.append(gen[:gen.index(_end_span_id)])
            gen = gen[gen.index(_end_span_id) + 1:]
        while len(spans) < k:
            spans.append([])
        generated_spans.append(spans)

    # Insert into context
    generated = []
    for spans in generated_spans:
        context = copy.deepcopy(x)
        for i, j in enumerate(blank_idxs[::-1]):
            del context[j]
            context[j:j] = spans[k - 1 - i]
        spans = [item for sublist in spans for item in sublist]
        if spans not in generated:
            generated.append(spans)
        else:
            pass
    return generated
Ejemplo n.º 5
0
def main(config_path):
    """  """
    # load config
    cfg = AttrDict.from_json_path(config_path)

    # make outputs dir
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)

    # initialize seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    logger.info(f"Experiment : {cfg.exp_name}")

    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logger.info(f"Device set to {cfg.device}.")

    #-------------------------------------------
    #       Make Dataset
    #-------------------------------------------

    data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0)
    dataset = public_SegICH_Dataset2D(data_info_df, cfg.path.data,
                    augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items()],
                    output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width))

    #-------------------------------------------
    #       Load FCDD Model
    #-------------------------------------------

    cfg_fcdd = AttrDict.from_json_path(cfg.fcdd_cfg_path)
    fcdd_net = FCDD_CNN_VGG(in_shape=(cfg_fcdd.net.in_channels, 256, 256), bias=cfg_fcdd.net.bias)
    loaded_state_dict = torch.load(cfg.fcdd_model_path, map_location=cfg.device)
    fcdd_net.load_state_dict(loaded_state_dict)
    fcdd_net = fcdd_net.to(cfg.device).eval()
    logger.info(f"FCDD model succesfully loaded from {cfg.fcdd_model_path}")

    # make FCDD object
    fcdd = FCDD(fcdd_net, batch_size=cfg.batch_size, num_workers=cfg.num_workers,
                device=cfg.device, print_progress=cfg.print_progress)

    #-------------------------------------------
    #       Load Classifier Model
    #-------------------------------------------

    # Load Classifier
    if cfg.classifier_model_path is not None:
        cfg_classifier = AttrDict.from_json_path(os.path.join(cfg.classifier_model_path, 'config.json'))
        classifier = getattr(rn, cfg_classifier.net.resnet)(num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels)
        classifier_state_dict = torch.load(os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device)
        classifier.load_state_dict(classifier_state_dict)
        classifier = classifier.to(cfg.device)
        classifier.eval()
        logger.info(f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}")

    #-------------------------------------------
    #       Generate Heat-Map for each slice
    #-------------------------------------------

    with torch.no_grad():
        # make loader
        loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers,
                                             shuffle=False, worker_init_fn=lambda _: np.random.seed())
        fcdd_net.eval()

        min_val, max_val = fcdd.get_min_max(loader, **cfg.heatmap_param)

        # computing and saving heatmaps
        out = dict(id=[], slice=[], label=[], ad_map_fn=[], ad_mask_fn=[],
                   TP=[], TN=[], FP=[], FN=[], AUC=[], classifier_pred=[])
        for b, data in enumerate(loader):
            im, mask, id, slice = data
            im = im.to(cfg.device).float()
            mask = mask.to(cfg.device).float()

            # get heatmap
            heatmap = fcdd.generate_heatmap(im, reception=cfg.heatmap_param.reception, std=cfg.heatmap_param.std,
                                            cpu=cfg.heatmap_param.cpu)
            # scaling
            heatmap = ((heatmap - min_val) / (max_val - min_val)).clamp(0,1)

            # Threshold
            ad_mask = torch.where(heatmap >= cfg.heatmap_threshold, torch.ones_like(heatmap, device=heatmap.device),
                                                                    torch.zeros_like(heatmap, device=heatmap.device))

            # Compute CM
            tn, fp, fn, tp  = batch_binary_confusion_matrix(ad_mask, mask.to(heatmap.device))

            # Save heatmaps/mask
            map_fn, mask_fn = [], []
            for i in range(im.shape[0]):
                # Save AD Map
                ad_map_fn = f"{id[i]}/{slice[i]}_map_anomalies.png"
                save_path = os.path.join(out_path, 'pred/', ad_map_fn)
                if not os.path.isdir(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path), exist_ok=True)

                ad_map = heatmap[i].squeeze().cpu().numpy()
                io.imsave(save_path, img_as_ubyte(ad_map), check_contrast=False)
                # save ad_mask
                ad_mask_fn = f"{id[i]}/{slice[i]}_anomalies.bmp"
                save_path = os.path.join(out_path, 'pred/', ad_mask_fn)
                io.imsave(save_path, img_as_ubyte(ad_mask[i].squeeze().cpu().numpy()), check_contrast=False)

                map_fn.append(ad_map_fn)
                mask_fn.append(ad_mask_fn)

            # apply classifier ResNet-18
            if cfg.classifier_model_path is not None:
                pred_score = nn.functional.softmax(classifier(im), dim=1)[:,1] # take columns of softmax of positive class as score
                clss_pred = torch.where(pred_score >= cfg.classification_threshold, torch.ones_like(pred_score, device=pred_score.device),
                                                                                    torch.zeros_like(pred_score, device=pred_score.device))
            else:
                clss_pred = [None]*im.shape[0]

            # Save Values
            out['id'] += id.cpu().tolist()
            out['slice'] += slice.cpu().tolist()
            out['label'] += mask.reshape(mask.shape[0], -1).max(dim=1)[0].cpu().tolist()
            out['ad_map_fn'] += map_fn
            out['ad_mask_fn'] += mask_fn
            out['TN'] += tn.cpu().tolist()
            out['FP'] += fp.cpu().tolist()
            out['FN'] += fn.cpu().tolist()
            out['TP'] += tp.cpu().tolist()
            out['AUC'] += [roc_auc_score(mask[i].cpu().numpy().ravel(), heatmap[i].cpu().numpy().ravel()) if torch.any(mask[i]>0) else 'None' for i in range(im.shape[0])]
            out['classifier_pred'] += clss_pred.cpu().tolist()

            if cfg.print_progress:
                print_progessbar(b, len(loader), Name='Heatmap Generation Batch', Size=100, erase=True)

    # make df and save as csv
    slice_df = pd.DataFrame(out)
    volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({'label':'max', 'TP':'sum', 'TN':'sum', 'FP':'sum', 'FN':'sum'})

    slice_df['Dice'] = (2*slice_df.TP + 1) / (2*slice_df.TP + slice_df.FP + slice_df.FN + 1)
    volume_df['Dice'] = (2*volume_df.TP + 1) / (2*volume_df.TP + volume_df.FP + volume_df.FN + 1)
    logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}")

    # Save Scores and Config
    slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv'))
    logger.info(f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}")
    volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv'))
    logger.info(f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}")
    cfg.device = str(cfg.device)
    with open(os.path.join(out_path, 'config.json'), 'w') as f:
        json.dump(cfg, f)
    logger.info(f"Config file saved at {os.path.join(out_path, 'config.json')}")
Ejemplo n.º 6
0
    def forward(self, *args):  # pylint: disable=too-many-statements
        LOG.debug('loss for %s', self.field_names)

        x, t = args

        assert len(x) == 1 + 2 * self.n_vectors + self.n_scales
        x_intensity = x[0]
        x_regs = x[1:1 + self.n_vectors]
        x_spreads = x[1 + self.n_vectors:1 + 2 * self.n_vectors]
        x_scales = []
        if self.n_scales:
            x_scales = x[1 + 2 * self.n_vectors:1 + 2 * self.n_vectors + self.n_scales]

        if self.n_scales == 0:
            t = t[:-self.n_vectors]  # assume there are as many scales as vectors and remove them
        assert len(t) == 1 + self.n_vectors + self.n_scales
        target_intensity = t[0]
        target_regs = t[1:1 + self.n_vectors]
        target_scales = t[1 + self.n_vectors:]

        bce_masks = (target_intensity[:, :-1] + target_intensity[:, -1:]) > 0.5
        if not torch.any(bce_masks):
            return None, None, None

        batch_size = x_intensity.shape[0]
        LOG.debug('batch size = %d', batch_size)

        bce_x_intensity = x_intensity
        bce_target_intensity = target_intensity[:, :-1]
        if self.bce_blackout:
            bce_x_intensity = bce_x_intensity[:, self.bce_blackout]
            bce_masks = bce_masks[:, self.bce_blackout]
            bce_target_intensity = bce_target_intensity[:, self.bce_blackout]

        LOG.debug('BCE: x = %s, target = %s, mask = %s',
                  x_intensity.shape, bce_target_intensity.shape, bce_masks.shape)
        bce_target = torch.masked_select(bce_target_intensity, bce_masks)
        bce_weight = None
        if self.background_weight != 1.0:
            bce_weight = torch.ones_like(bce_target)
            bce_weight[bce_target == 0] = self.background_weight
#         print('TRIANGE loss', bce_x_intensity.shape, bce_masks.shape)
        ce_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            torch.masked_select(bce_x_intensity, bce_masks),
            bce_target,
            weight=bce_weight,
        )

        reg_losses = [None for _ in target_regs]
        reg_masks = target_intensity[:, :-1] > 0.5
        if torch.any(reg_masks):
            weight = None
            if self.multiplicity_correction:
                assert len(target_regs) == 2
                lengths = torch.norm(target_regs[0] - target_regs[1], dim=2)
                multiplicity = (lengths - 3.0) / self.independence_scale
                multiplicity = torch.clamp(multiplicity, min=1.0)
                multiplicity = torch.masked_select(multiplicity, reg_masks)
                weight = 1.0 / multiplicity

            reg_losses = []
            for i, (x_reg, x_spread, target_reg) in enumerate(zip(x_regs, x_spreads, target_regs)):
                if hasattr(self.regression_loss, 'scale'):
                    assert self.scales_to_kp is not None
                    self.regression_loss.scale = torch.masked_select(
                        torch.clamp(target_scales[i], 0.1, 1000.0),  # pylint: disable=unsubscriptable-object
                        reg_masks,
                    )

                reg_losses.append(self.regression_loss(
                    torch.masked_select(x_reg[:, :, 0], reg_masks),
                    torch.masked_select(x_reg[:, :, 1], reg_masks),
                    torch.masked_select(x_spread, reg_masks),
                    torch.masked_select(target_reg[:, :, 0], reg_masks),
                    torch.masked_select(target_reg[:, :, 1], reg_masks),
                    weight=(weight if weight is not None else 1.0) * 0.1,
                ) / 100.0 / batch_size)

        scale_losses = []
        if x_scales:
            scale_losses = [
                torch.nn.functional.l1_loss(
                    torch.masked_select(x_scale, torch.isnan(target_scale) == 0),
                    torch.masked_select(target_scale, torch.isnan(target_scale) == 0),
                    reduction='sum',
                ) / 1000.0 / batch_size
                for x_scale, target_scale in zip(x_scales, target_scales)
            ]

        margin_losses = [None for _ in target_regs] if self.margin else []
        if self.margin and torch.any(reg_masks):
            margin_losses = []
            for i, (x_reg, target_reg) in enumerate(zip(x_regs, target_regs)):
                margin_losses.append(quadrant_margin_loss(
                    torch.masked_select(x_reg[:, :, 0], reg_masks),
                    torch.masked_select(x_reg[:, :, 1], reg_masks),
                    torch.masked_select(target_reg[:, :, 0], reg_masks),
                    torch.masked_select(target_reg[:, :, 1], reg_masks),
                    torch.masked_select(target_reg[:, :, 2], reg_masks),
                    torch.masked_select(target_reg[:, :, 3], reg_masks),
                    torch.masked_select(target_reg[:, :, 4], reg_masks),
                    torch.masked_select(target_reg[:, :, 5], reg_masks),
                ) / 100.0 / batch_size)

        return [ce_loss] + reg_losses + scale_losses + margin_losses
Ejemplo n.º 7
0
    def forward(self, pyolos, targets, imgs_ts=None):
        '''

        :param pyolos: torch.Size([32, 7, 13, 13]) cls-3,box-4
        :param targets:
        :param imgs_ts:
        :return:
        '''
        cfg = self.cfg
        device = pyolos.device
        batch, c, h, w = pyolos.shape
        pyolos = pyolos.view(batch, c, -1).permute(0, 2, 1)

        # cls-num_class, txywh-4, weight-1, gltrb-4
        gdim = cfg.NUM_CLASSES + 4 + 1 + 4
        gyolos = torch.empty((batch, h, w, gdim),
                             device=device)  # 每批会整体更新这里不需要赋0

        for i, target in enumerate(targets):  # batch遍历
            gboxes_ltrb_b = target['boxes']  # ltrb
            glabels_b = target['labels']

            gyolos[i] = fmatch4yolov1(
                gboxes_ltrb_b=gboxes_ltrb_b,
                glabels_b=glabels_b,
                grid=h,  # 7
                gdim=gdim,
                device=device,
                img_ts=imgs_ts[i],
                cfg=cfg,
                use_conf=False)
            '''可视化验证'''
            # if cfg.IS_VISUAL:
            #     # conf-1, cls-1, box-4, weight-1
            #     gyolo_test = gyolos[i].clone()  # torch.Size([32, 13, 13, 9])
            #     gyolo_test = gyolo_test.view(-1, gdim)
            #     gconf_one = gyolo_test[:, 0]
            #     mask_pos = gconf_one == 1  # [169]
            #
            #     # torch.Size([169, 4])
            #     gtxywh = gyolo_test[:, 1 + cfg.NUM_CLASSES:1 + cfg.NUM_CLASSES + 4]
            #
            #     # 这里是修复是 xy
            #     _xy_grid = gtxywh[:, :2] + f_mershgrid(h, w, is_rowcol=False).to(device)
            #     hw_ts = torch.tensor((h, w), device=device)
            #     gtxywh[:, :2] = torch.true_divide(_xy_grid, hw_ts)
            #     gtxywh = gtxywh[mask_pos]
            #     gtxywh[:, 2:4] = torch.exp(gtxywh[:, 2:]) / cfg.IMAGE_SIZE[0]
            #
            #     from f_tools.pic.enhance.f_data_pretreatment4pil import f_recover_normalization4ts
            #     img_ts = f_recover_normalization4ts(imgs_ts[i])
            #     from torchvision.transforms import functional as transformsF
            #     img_pil = transformsF.to_pil_image(img_ts).convert('RGB')
            #     import numpy as np
            #     img_np = np.array(img_pil)
            #     f_show_od_np4plt(img_np, gboxes_ltrb=gboxes_ltrb_b.cpu()
            #                      , pboxes_ltrb=xywh2ltrb(gtxywh.cpu()), is_recover_size=True,
            #                      grids=(h, w))

        # [32, 13, 13, 7] -> torch.Size([32, 169, 12])
        gyolos = gyolos.view(batch, -1, gdim)  # h*w
        gcls = gyolos[:, :, 0:cfg.NUM_CLASSES]  # torch.Size([5, 169])
        mask_pos_3d = gcls > 0  # torch.Size([32, 169, 3])
        mask_neg_3d = gcls == 0
        # [32, 169, 3] -> [32, 169]
        mask_pos_2d = torch.any(mask_pos_3d, dim=-1)
        # mask_pos = gconf == 1  # yolo1 gt 写死是1

        nums_pos = (mask_pos_2d.sum(-1).to(
            torch.float)).clamp(min=torch.finfo(torch.float16).eps)
        pyolos_pos = pyolos[
            mask_pos_2d]  # torch.Size([32, 169, 13]) -> [nn, 13]
        gyolos_pos = gyolos[
            mask_pos_2d]  # torch.Size([32, 169, 13]) -> [nn, 13]
        ''' ---------------- 类别损失 ---------------- '''
        # cls-num_class, txywh-4, weight-1, gltrb-4
        pcls_sigmoid = pyolos[:, :, 0:cfg.NUM_CLASSES].sigmoid()
        gcls = gyolos[:, :, 0:cfg.NUM_CLASSES]  # torch.Size([32, 169, 3])
        # 正反比 1:169*3
        # _loss_val = x_bce(pcls_sigmoid, gcls, reduction="none")  # torch.Size([46, 3])
        # l_cls_pos = ((_loss_val * mask_pos_3d).sum(-1).sum(-1) / nums_pos).mean()
        # l_cls_neg = ((_loss_val * mask_neg_3d).sum(-1).sum(-1) / nums_pos).mean()

        # ------------ conf-mse ------------''' 666666
        # _loss_val = F.mse_loss(pconf_sigmoid, gconf, reduction="none")  # 用MSE效果更好
        # _loss_val = x_bce(pconf_sigmoid, gconf, reduction="none")
        # l_conf_pos = ((_loss_val * mask_pos_3d).sum(-1) / nums_pos).mean() * cfg.LOSS_WEIGHT[0]
        # l_conf_neg = ((_loss_val * mask_neg_3d).sum(-1) / nums_pos).mean() * cfg.LOSS_WEIGHT[1]

        # ------------ conf_ohem  ap26_26 ------------'''
        # _loss_val = x_bce(pconf_sigmoid, gconf)
        # mask_ignore = torch.logical_not(torch.logical_or(mask_pos, mask_neg))
        # mask_neg_hard = f_ohem(_loss_val, nums_pos * 3, mask_pos=mask_pos, mash_ignore=mask_ignore)
        # l_conf_pos = ((_loss_val * mask_pos).sum(-1) / nums_pos).mean() * 3  # 正例越多反例越多
        # l_conf_neg = ((_loss_val * mask_neg_hard).sum(-1) / nums_pos).mean() * 3

        # ------------ focalloss   ------------
        l_pos, l_neg = focalloss(pcls_sigmoid,
                                 gcls,
                                 mask_pos=mask_pos_2d,
                                 is_debug=True,
                                 alpha=0.75)
        l_cls_pos = (l_pos.sum(-1).sum(-1) / nums_pos).mean() * 30
        l_cls_neg = (l_neg.sum(-1).sum(-1) / nums_pos).mean() * 30
        ''' ----------------回归损失   xy采用bce wh采用mes----------------- '''
        # ------------ mse+bce   ------------ 666666
        # conf-1, cls-num_class, txywh-4, weight-1, gltrb-4
        # ptxty_sigmoid_pos = pyolos_pos[:, cfg.NUM_CLASSES:cfg.NUM_CLASSES + 2].sigmoid()  # 这个需要归一化
        # ptwth_pos = pyolos_pos[:, cfg.NUM_CLASSES + 2:cfg.NUM_CLASSES + 4]
        #
        # # cls-num_class, txywh-4, weight-1, gltrb-4
        # # id = cfg.NUM_CLASSES + 4 +1 -1
        # weight_pos = gyolos_pos[:, cfg.NUM_CLASSES + 4]  # torch.Size([32, 845])
        # gtxty_pos = gyolos_pos[:, cfg.NUM_CLASSES:cfg.NUM_CLASSES + 2]  # [nn]
        # gtwth_pos = gyolos_pos[:, cfg.NUM_CLASSES + 2:cfg.NUM_CLASSES + 4]
        #
        # _loss_val = x_bce(ptxty_sigmoid_pos, gtxty_pos, reduction="none")
        # l_txty = (_loss_val.sum(-1) * weight_pos).mean()
        # _loss_val = F.mse_loss(ptwth_pos, gtwth_pos, reduction="none")
        # l_twth = (_loss_val.sum(-1) * weight_pos).mean()

        # ------------ iou损失   ------------
        # 解码pxywh 计算预测与 GT 的 iou 作为 gconf
        ptxywh_pos = pyolos[:, :, cfg.NUM_CLASSES:cfg.NUM_CLASSES + 4]
        # 这个是批量解码 3D 故解码出来再筛选
        zltrb_pos = boxes_decode4yolo1(ptxywh_pos, h, h, cfg)[mask_pos_2d]
        gltrb_pos = gyolos_pos[:, cfg.NUM_CLASSES + 4 + 1:cfg.NUM_CLASSES + 4 +
                               1 + 4]
        iou_zg = bbox_iou4one_2d(zltrb_pos, gltrb_pos, is_ciou=True)
        l_reg = (1 - iou_zg).mean()
        ''' ---------------- loss完成 ----------------- '''
        # loss_total = l_cls_pos + l_cls_neg + l_txty + l_twth
        loss_total = l_cls_pos + l_cls_neg + l_reg

        log_dict = {}
        log_dict['l_total'] = loss_total.item()
        log_dict['l_cls_pos'] = l_cls_pos.item()
        log_dict['l_cls_neg'] = l_cls_neg.item()
        log_dict['l_reg'] = l_reg.item()
        # log_dict['l_xy'] = l_txty.item()
        # log_dict['l_wh'] = l_twth.item()

        log_dict['p_max'] = pcls_sigmoid.max().item()
        log_dict['p_min'] = pcls_sigmoid.min().item()
        log_dict['p_mean'] = pcls_sigmoid.mean().item()
        return loss_total, log_dict
Ejemplo n.º 8
0
    def forward(self,                       # forward接受所有可能用到的参数
                support_seqs, support_imgs, support_lens, support_labels,
                query_seqs, query_imgs, query_lens, query_labels,
                epoch=None, metric='euc', return_embeddings=False):

        embedded_support_seqs, embedded_query_seqs, \
        embedded_support_imgs, embedded_query_imgs = self.embed(support_seqs, query_seqs,
                                                                support_lens, query_lens,
                                                                support_imgs, query_imgs)

        k, n, qk = self.TaskParams.k, self.TaskParams.n, self.TaskParams.qk

        # 直接使用seq和img的raw output进行fuse
        support_fused_features = self._fuse(embedded_support_seqs, embedded_support_imgs, fuse_dim=1)
        query_fused_features = self._fuse(embedded_query_seqs, embedded_query_imgs, fuse_dim=1)
        dim = support_fused_features.size(1)

        nClusters = n  # 初始类簇的数量等于类数量
        nInitialClusters = nClusters

        # 此处设定batch=1
        support_labels = support_labels.unsqueeze(0)
        query_labels = query_labels.unsqueeze(0)
        support_fused_features = support_fused_features.view(n * k, -1).unsqueeze(0)
        query_fused_features = query_fused_features.view(qk, -1).unsqueeze(0)

        # create probabilities for points
        # _, idx = np.unique(batch.y_train.squeeze().data.cpu().numpy(), return_inverse=True)
        prob_support = one_hot(support_labels, nClusters).cuda()  # 将属于类簇的概率初始化为标签的one-hot

        # make initial radii for labeled clusters
        bsize = support_fused_features.size()[0]
        radii = t.ones(bsize, nClusters).cuda() # * t.exp(self.Sigma)  # 初始半径由log_sigma_l初始化(该参数可学习)

        if self.Sigma is not None:
            radii *= t.exp(self.Sigma)

        cluster_labels = t.arange(0, nClusters).cuda().long()

        # compute initial prototypes from labeled examples
        # 由于初始时,共有类别个类簇,而且类簇的分配系数是one-hot,因此初始类簇就是类中心
        # shape: [batch, cluster, dim]
        protos = self._compute_protos(support_fused_features, prob_support)

        # estimate lamda
        # lamda = self.estimate_lambda(protos.data, False)

        # loop for a given number of clustering steps
        for ii in range(self.NumClusterSteps):
            # protos = protos.data
            # iterate over labeled examples to reassign first
            for i, ex in enumerate(support_fused_features[0]):
                # 找到样本label对应的cluster的index
                idxs = t.nonzero(support_labels[0, i] == cluster_labels)[0]  # TODO: 取0?

                #****************************************************************************
                # 计算与标签对应的类簇的距离(由于其他不对应的类簇的距离都是正无穷,求min时直接可忽略)
                # distances = self._compute_distances(protos[:, idxs, :], ex.data)
                # if t.min(distances) > lamda:
                #****************************************************************************

                distances = self._compute_distances(protos,ex)
                # 如果发现离自己最近的cluster不是自己的类的cluster,就直接增加一个cluster
                if not t.any(t.min(distances,dim=1).indices==idxs).item():

                    nClusters, protos, radii = self._add_cluster(nClusters, protos, radii,
                                                                 cluster_type='labeled', ex=ex.data)
                    cluster_labels = t.cat([cluster_labels, support_labels[0, [i]].data], dim=0)  # 将样本标签设定为类簇标签

            # perform partial reassignment based on newly created labeled clusters
            if nClusters > nInitialClusters:
                support_targets = support_labels.data[0, :, None] == cluster_labels  # 找到每个样本实际对应的类簇(每一行是每个样本对应的类簇bool)
                prob_support = assign_cluster_radii_limited(protos, support_fused_features, radii,
                                                            support_targets)  # 样本属于每个类簇的概率

            nTrainClusters = nClusters
            protos = protos.cuda()
            protos = self._compute_protos(support_fused_features, prob_support)
            protos, radii, cluster_labels = self.delete_empty_clusters(protos, prob_support, radii, cluster_labels)

        # 计算query的类簇logits
        logits = compute_logits_radii(protos, query_fused_features, radii, use_sigma=self.Sigma is not None).squeeze()

        # convert class targets into indicators for supports in each class
        labels = query_labels  # batch.y_test.data
        labels[labels >= nInitialClusters] = -1

        support_targets = labels[0, :, None] == cluster_labels  # 寻找查询集样本的标签对应的类簇
        loss = self.loss(logits, support_targets,
                         cluster_labels)  # support_targets: 查询样本标签对应的类簇指示; suppott_labels: 类簇的标签

        # map support predictions back into classes to check accuracy
        _, support_preds = t.max(logits.data, dim=1)
        y_pred = cluster_labels[support_preds]

        return {
            "logits": None,
            "loss": loss,
            "predicts": y_pred
        }
Ejemplo n.º 9
0
    def decode_beam_batch(self,
                          bodies,
                          beam_size=3,
                          max_output_length=100,
                          sample=False,
                          show_each_beam=False):
        if self.mode != 'eval':
            print("BEWARE. Model is not in eval mode.")
        self.eval()  ## << Surely you are not training with beam decode?

        batch_size = len(bodies)
        N = batch_size * beam_size
        inputs = self.preprocess_input(bodies)
        next_words = torch.LongTensor([self.tokenizer.start_id] * N).to(
            self.device).unsqueeze(1)
        build_up = None
        scores = torch.zeros((N)).to(self.device)

        one_every_k = torch.FloatTensor([1] + [0] * (beam_size - 1)).repeat(
            batch_size * beam_size).to(self.device)

        # Sometimes, we process the same input, as we run it once as a sampled, and once as an argmax, in which case we should reuse the computation
        _, input_past = self.model(input_ids=inputs, past_key_values=None)
        input_past = [
            torch.repeat_interleave(p, repeats=beam_size, dim=1)
            for p in input_past
        ]

        past = input_past
        while build_up is None or (
                build_up.shape[1] < max_output_length and
                not all([self.tokenizer.end_id in build
                         for build in build_up])):
            logits, past = self.model(input_ids=next_words,
                                      past_key_values=past)
            probs = torch.nn.functional.softmax(logits, dim=2).squeeze(1)
            logprobs = torch.nn.functional.log_softmax(logits, dim=2)

            if sample:
                all_selects = torch.multinomial(probs, beam_size).unsqueeze(1)
            else:
                _, all_selects = torch.topk(logprobs, k=beam_size, dim=2)

            if build_up is not None:
                not_finished = (1 -
                                torch.any(build_up == self.tokenizer.end_id,
                                          dim=1).float()).to(self.device)
            else:
                not_finished = torch.ones_like(scores,
                                               dtype=torch.float,
                                               device=self.device)

            expanded_not_finished = torch.repeat_interleave(not_finished,
                                                            repeats=beam_size)

            expanded_score = torch.repeat_interleave(
                scores,
                repeats=beam_size)  # This should be batch_size * beam_size²
            added_score = logprobs[
                torch.repeat_interleave(torch.arange(N), repeats=beam_size), 0,
                all_selects.view(-1)]
            expanded_score += (expanded_not_finished * added_score)

            # We don't want you to select from finished beams
            expanded_score -= (1 - expanded_not_finished) * (
                1 - one_every_k) * 1000.0

            batched_scores = expanded_score.view(batch_size, -1)

            if build_up is None:
                choices = torch.arange(beam_size,
                                       device=self.device).repeat(batch_size)
                batched_choices = choices.view(batch_size, beam_size)

            else:
                _, batched_choices = torch.topk(
                    batched_scores, k=beam_size,
                    dim=1)  # Going from k² choices per element to k choices.

            batched_tracks = batched_choices / beam_size
            tracks = beam_size * torch.repeat_interleave(
                torch.arange(batch_size), repeats=beam_size).to(
                    self.device) + batched_tracks.view(-1)
            tracks = list(tracks)

            selected_scores = batched_scores[torch.repeat_interleave(
                torch.arange(batch_size), repeats=beam_size),
                                             batched_choices.view(-1)]

            # Figure out the kept words to be added to the build-up
            per_batch_selects = all_selects.view(batch_size, -1)
            next_words = per_batch_selects[torch.repeat_interleave(
                torch.arange(batch_size), repeats=beam_size),
                                           batched_choices.view(-1)]
            next_words = next_words.unsqueeze(1)

            # [BOOKKEEPING] Going from k² to k options at each time means we have to swap all the caches around: past, build-up
            if build_up is not None:
                #######
                if show_each_beam is True:
                    print(build_up)
                    building_output = []
                    for beams in build_up:
                        out_ = [
                            self.tokenizer.decode(beam.tolist()) + "END"
                            for beam in beams
                        ]
                        out_ = [S[:S.index("END")] for S in out_]
                        building_output.append(out_)
                    for output, score in zip(building_output, selected_scores):
                        print("".join(output), "|%.2f" % float(score))

                build_up = build_up[tracks, :]
            past = [p[:, tracks, :] for p in past]

            # Update the latest scores, and the current_build
            if build_up is None:
                build_up = next_words
            else:
                build_up = torch.cat((build_up, next_words), dim=1)
            scores = selected_scores.view(-1)

        batched_build_up = build_up.view(batch_size, beam_size, -1)
        batched_scores = scores.view(batch_size, -1)
        # torch.cuda.empty_cache()

        outputs = []
        for beams in batched_build_up:
            out_beams = [
                self.tokenizer.decode(beam.tolist()) + "END" for beam in beams
            ]
            out_beams = [S[:S.index("END")] for S in out_beams]
            outputs.append(out_beams)

        return outputs, batched_scores.tolist()
Ejemplo n.º 10
0
    def decode_batch(self,
                     bodies,
                     special_append=None,
                     max_output_length=100,
                     sample=False,
                     return_scores=False,
                     return_logprobs=False,
                     input_past=None):
        N = len(bodies)
        current = torch.LongTensor([self.tokenizer.start_id] * N).to(
            self.device).unsqueeze(1)
        build_up = None
        scores = torch.zeros((N)).to(self.device)
        total_logprobs = []

        # Sometimes, we process the same input, as we run it once as a sampled, and once as an argmax, in which case we should reuse the computation
        if input_past is None:
            inputs = self.preprocess_input(bodies, special_append)
            _, input_past = self.model(input_ids=inputs, past_key_values=None)

        past = input_past
        while build_up is None or (
                build_up.shape[1] < max_output_length and
                not all([self.tokenizer.end_id in build
                         for build in build_up])):
            logits, past = self.model(input_ids=current, past_key_values=past)
            probs = torch.nn.functional.softmax(logits, dim=2).squeeze(1)
            logprobs = torch.nn.functional.log_softmax(logits, dim=2)
            if sample:
                current = torch.multinomial(probs, 1)
            else:
                current = torch.argmax(logprobs, dim=2)

            if build_up is None:
                build_up = current
            else:
                build_up = torch.cat((build_up, current), dim=1)

            if return_logprobs:
                selected_logprobs = logprobs[torch.arange(N), 0,
                                             current.squeeze()].unsqueeze(1)
                total_logprobs.append(selected_logprobs)

            not_finished = (1 - torch.any(build_up == self.tokenizer.end_id,
                                          dim=1).float()).to(self.device)
            scores += not_finished * logprobs[torch.arange(N), :,
                                              current.squeeze(1)].squeeze()

        end_id = self.tokenizer.end_id
        build_up = [build.tolist() for build in build_up]
        end_indices = [
            max_output_length +
            1 if end_id not in build else build.index(end_id)
            for build in build_up
        ]
        outputs = [self.tokenizer.decode(build) + "END" for build in build_up]
        outputs = [S[:S.index("END")] for S in outputs]

        if return_logprobs:
            return outputs, torch.cat(total_logprobs,
                                      dim=1), build_up, input_past, end_indices
        elif return_scores:
            return outputs, scores.tolist()
        else:
            return outputs, end_indices
Ejemplo n.º 11
0
def __beaver_protocol(op, x, y, *args, **kwargs):
    """Performs Beaver protocol for additively secret-shared tensors x and y

    1. Obtain uniformly random sharings [a],[b] and [c] = [a * b]
    2. Additively hide [x] and [y] with appropriately sized [a] and [b]
    3. Open ([epsilon] = [x] - [a]) and ([delta] = [y] - [b])
    4. Return [z] = [c] + (epsilon * [b]) + ([a] * delta) + (epsilon * delta)
    """
    assert op in {
        "mul",
        "matmul",
        "conv1d",
        "conv2d",
        "conv_transpose1d",
        "conv_transpose2d",
    }
    if x.device != y.device:
        raise ValueError(
            f"x lives on device {x.device} but y on device {y.device}")

    provider = crypten.mpc.get_default_provider()
    a, b, c = provider.generate_additive_triple(x.size(),
                                                y.size(),
                                                op,
                                                device=x.device,
                                                *args,
                                                **kwargs)

    from .arithmetic import ArithmeticSharedTensor

    if crypten.mpc.config.active_security:
        """
        Reference: "Multiparty Computation from Somewhat Homomorphic Encryption"
        Link: https://eprint.iacr.org/2011/535.pdf
        """
        f, g, h = provider.generate_additive_triple(x.size(),
                                                    y.size(),
                                                    op,
                                                    device=x.device,
                                                    *args,
                                                    **kwargs)

        t = ArithmeticSharedTensor.PRSS(a.size(), device=x.device)
        t_plain_text = t.get_plain_text()

        rho = (t_plain_text * a - f).get_plain_text()
        sigma = (b - g).get_plain_text()
        triples_check = t_plain_text * c - h - sigma * f - rho * g - rho * sigma
        triples_check = triples_check.get_plain_text()

        if torch.any(triples_check != 0):
            raise ValueError("Beaver Triples verification failed!")

    # Vectorized reveal to reduce rounds of communication
    epsilon, delta = ArithmeticSharedTensor.reveal_batch([x - a, y - b])

    # z = c + (a * delta) + (epsilon * b) + epsilon * delta
    c._tensor += getattr(torch, op)(epsilon, b._tensor, *args, **kwargs)
    c._tensor += getattr(torch, op)(a._tensor, delta, *args, **kwargs)
    c += getattr(torch, op)(epsilon, delta, *args, **kwargs)

    return c
Ejemplo n.º 12
0
def initialize_q_batch_nonneg(X: Tensor,
                              Y: Tensor,
                              n: int,
                              eta: float = 1.0,
                              alpha: float = 1e-4) -> Tensor:
    r"""Heuristic for selecting initial conditions for non-neg. acquisition functions.

    This function is similar to `initialize_q_batch`, but designed specifically
    for acquisition functions that are non-negative and possibly zero over
    large areas of the feature space (e.g. qEI). All samples for which
    `Y < alpha * max(Y)` will be ignored (assuming that `Y` contains at least
    one positive value).

    Args:
        X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim.
            feature space. Typically, these are generated using qMC.
        Y: A tensor of `b` outcomes associated with the samples. Typically, this
            is the value of the batch acquisition function to be maximized.
        n: The number of initial condition to be generated. Must be less than `b`.
        eta: Temperature parameter for weighting samples.
        alpha: The threshold (as a fraction of the maximum observed value) under
            which to ignore samples. All input samples for which
            `Y < alpha * max(Y)` will be ignored.

    Returns:
        A `n x q x d` tensor of `n` `q`-batch initial conditions.

    Example:
        # To get `n=10` starting points of q-batch size `q=3` for model with `d=6`:
        >>> qEI = qExpectedImprovement(model, best_f=0.2)
        >>> Xrnd = torch.rand(500, 3, 6)
        >>> Xinit = initialize_q_batch(Xrnd, qEI(Xrnd), 10)
    """
    n_samples = X.shape[0]
    if n > n_samples:
        raise RuntimeError(
            "n cannot be larger than the number of provided samples")
    elif n == n_samples:
        return X

    max_val, max_idx = torch.max(Y, dim=0)
    if torch.any(max_val <= 0):
        warnings.warn(
            "All acquisition values for raw sampled points are nonpositive, so "
            "initial conditions are being selected randomly.",
            BadInitialCandidatesWarning,
        )
        return X[torch.randperm(n=n_samples, device=X.device)][:n]

    # make sure there are at least `n` points with positive acquisition values
    pos = Y > 0
    num_pos = pos.sum().item()
    if num_pos < n:
        # select all positive points and then fill remaining quota with randomly
        # selected points
        remaining_indices = (~pos).nonzero().view(-1)
        rand_indices = torch.randperm(remaining_indices.shape[0],
                                      device=Y.device)
        sampled_remaining_indices = remaining_indices[rand_indices[:n -
                                                                   num_pos]]
        pos[sampled_remaining_indices] = 1
        return X[pos]
    # select points within alpha of max_val, iteratively decreasing alpha by a
    # factor of 10 as necessary
    alpha_pos = Y >= alpha * max_val
    while alpha_pos.sum() < n:
        alpha = 0.1 * alpha
        alpha_pos = Y >= alpha * max_val
    alpha_pos_idcs = torch.arange(len(Y), device=Y.device)[alpha_pos]
    weights = torch.exp(eta * (Y[alpha_pos] / max_val - 1))
    idcs = alpha_pos_idcs[torch.multinomial(weights, n)]
    if max_idx not in idcs:
        idcs[-1] = max_idx
    return X[idcs]
Ejemplo n.º 13
0
    def forward(
        ctx,
        representation_tree,
        dtype,
        device,
        matrix_shape,
        batch_shape=torch.Size(),
        inv_quad=False,
        logdet=False,
        probe_vectors=None,
        probe_vector_norms=None,
        *args,
    ):
        """
        *args - The arguments representing the PSD matrix A (or batch of PSD matrices A)
        If self.inv_quad is true, the first entry in *args is inv_quad_rhs (Tensor)
        - the RHS of the matrix solves.

        Returns:
        - (Scalar) The inverse quadratic form (or None, if self.inv_quad is False)
        - (Scalar) The log determinant (or None, self.if logdet is False)
        """

        if not (inv_quad or logdet):
            raise RuntimeError(
                "Either inv_quad or logdet must be true (or both)")

        ctx.representation_tree = representation_tree
        ctx.dtype = dtype
        ctx.device = device
        ctx.matrix_shape = matrix_shape
        ctx.batch_shape = batch_shape
        ctx.inv_quad = inv_quad
        ctx.logdet = logdet

        matrix_args = None
        inv_quad_rhs = None
        if ctx.inv_quad:
            matrix_args = args[1:]
            inv_quad_rhs = args[0]
        else:
            matrix_args = args

        # Get closure for matmul
        lazy_tsr = ctx.representation_tree(*matrix_args)
        with torch.no_grad():
            preconditioner, precond_lt, logdet_correction = lazy_tsr._preconditioner(
            )

        ctx.preconditioner = preconditioner

        if (probe_vectors is None or probe_vector_norms is None) and logdet:
            num_random_probes = settings.num_trace_samples.value()
            if preconditioner is None:
                if settings.deterministic_probes.on():
                    warnings.warn(
                        "Deterministic probes will currently work only if you aren't training multiple independent"
                        " models simultaneously.",
                        UserWarning,
                    )
                    if settings.deterministic_probes.probe_vectors is None:
                        probe_vectors = torch.empty(matrix_shape[-1],
                                                    num_random_probes,
                                                    dtype=dtype,
                                                    device=device)
                        probe_vectors.bernoulli_().mul_(2).add_(-1)
                        settings.deterministic_probes.probe_vectors = probe_vectors
                    else:
                        probe_vectors = settings.deterministic_probes.probe_vectors
                else:
                    probe_vectors = torch.empty(matrix_shape[-1],
                                                num_random_probes,
                                                dtype=dtype,
                                                device=device)
                    probe_vectors.bernoulli_().mul_(2).add_(-1)

                probe_vector_norms = torch.norm(probe_vectors,
                                                2,
                                                dim=-2,
                                                keepdim=True)
                if batch_shape is not None:
                    probe_vectors = probe_vectors.expand(
                        *batch_shape, matrix_shape[-1], num_random_probes)
                    probe_vector_norms = probe_vector_norms.expand(
                        *batch_shape, 1, num_random_probes)
            else:  # When preconditioning, probe vectors must be drawn from N(0, P)
                if settings.deterministic_probes.on():
                    # NOTE: calling precond_lt.root_decomposition() is expensive
                    # because it requires Lanczos
                    # We don't have any other choice for when we want to use deterministic probes, however
                    if precond_lt.size()[-2:] == torch.Size([1, 1]):
                        covar_root = precond_lt.evaluate().sqrt()
                    else:
                        covar_root = precond_lt.root_decomposition().root

                    warnings.warn(
                        "Deterministic probes will currently work only if you aren't training multiple independent"
                        " models simultaneously.",
                        UserWarning,
                    )
                    base_samples = settings.deterministic_probes.probe_vectors
                    if base_samples is None or covar_root.size(
                            -1) != base_samples.size(-2):
                        base_samples = torch.randn(
                            *precond_lt.batch_shape,
                            covar_root.size(-1),
                            num_random_probes,
                            dtype=precond_lt.dtype,
                            device=precond_lt.device,
                        )
                        settings.deterministic_probes.probe_vectors = base_samples

                    probe_vectors = covar_root.matmul(base_samples).permute(
                        -1, *range(precond_lt.dim() - 1))
                else:
                    probe_vectors = precond_lt.zero_mean_mvn_samples(
                        num_random_probes)
                probe_vectors = probe_vectors.unsqueeze(-2).transpose(
                    0, -2).squeeze(0).transpose(-2, -1).contiguous()
                probe_vector_norms = torch.norm(probe_vectors,
                                                p=2,
                                                dim=-2,
                                                keepdim=True)
            probe_vectors = probe_vectors.div(probe_vector_norms)

        ctx.probe_vectors = probe_vectors
        ctx.probe_vector_norms = probe_vector_norms

        if ctx.logdet and not ctx.probe_vectors.numel():
            raise RuntimeError(
                "Probe vectors were not supplied for logdet computation")

        # Collect terms for LinearCG
        # We use LinearCG for both matrix solves and for stochastically estimating the log det
        rhs_list = []
        num_random_probes = 0
        num_inv_quad_solves = 0

        # RHS for logdet
        if ctx.logdet:
            rhs_list.append(ctx.probe_vectors)
            num_random_probes = ctx.probe_vectors.size(-1)

        # RHS for inv_quad
        ctx.is_vector = False
        if ctx.inv_quad:
            if inv_quad_rhs.ndimension() == 1:
                inv_quad_rhs = inv_quad_rhs.unsqueeze(-1)
                ctx.is_vector = True
            rhs_list.append(inv_quad_rhs)
            num_inv_quad_solves = inv_quad_rhs.size(-1)

        # Perform solves (for inv_quad) and tridiagonalization (for estimating logdet)
        rhs = torch.cat(rhs_list, -1)
        t_mat = None
        if ctx.logdet and settings.skip_logdet_forward.off():
            solves, t_mat = lazy_tsr._solve(rhs,
                                            preconditioner,
                                            num_tridiag=num_random_probes)

        else:
            solves = lazy_tsr._solve(rhs, preconditioner, num_tridiag=0)

        # Final values to return
        logdet_term = torch.zeros(lazy_tsr.batch_shape,
                                  dtype=ctx.dtype,
                                  device=ctx.device)
        inv_quad_term = torch.zeros(lazy_tsr.batch_shape,
                                    dtype=ctx.dtype,
                                    device=ctx.device)

        # Compute logdet from tridiagonalization
        if ctx.logdet and settings.skip_logdet_forward.off():
            if torch.any(torch.isnan(t_mat)).item():
                logdet_term = torch.tensor(float("nan"),
                                           dtype=ctx.dtype,
                                           device=ctx.device)
            else:
                if ctx.batch_shape is None:
                    t_mat = t_mat.unsqueeze(1)
                eigenvalues, eigenvectors = lanczos_tridiag_to_diag(t_mat)
                slq = StochasticLQ()
                (logdet_term, ) = slq.evaluate(ctx.matrix_shape, eigenvalues,
                                               eigenvectors,
                                               [lambda x: x.log()])

                # Add correction
                if logdet_correction is not None:
                    logdet_term = logdet_term + logdet_correction

        # Extract inv_quad solves from all the solves
        if ctx.inv_quad:
            inv_quad_solves = solves.narrow(-1, num_random_probes,
                                            num_inv_quad_solves)
            inv_quad_term = (inv_quad_solves * inv_quad_rhs).sum(-2)

        ctx.num_random_probes = num_random_probes
        ctx.num_inv_quad_solves = num_inv_quad_solves

        to_save = list(matrix_args) + [solves]
        ctx.save_for_backward(*to_save)

        if settings.memory_efficient.off():
            ctx._lazy_tsr = lazy_tsr

        return inv_quad_term, logdet_term
Ejemplo n.º 14
0
def sinc_inv(x):
    usetaylor = (x.abs()<thresh)
    texpand = 1+(1/6)*x**2 +(7/360)*x**4
    assert not torch.any(torch.isinf(texpand)|torch.isnan(texpand)),'sincinv texpand inf'+torch.any(torch.isinf(texpand))
    return torch.where(usetaylor,texpand,x/x.sin())
Ejemplo n.º 15
0
    def _test_no_pad_and_pad(self, no_pad_features, pad_features):
        tokenizer = BertTokenizer(self.vocab_file)
        data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
        batch = data_collator(no_pad_features)
        self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
        self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))

        batch = data_collator(pad_features)
        self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
        self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))

        data_collator = DataCollatorForLanguageModeling(tokenizer,
                                                        mlm=False,
                                                        pad_to_multiple_of=8)
        batch = data_collator(no_pad_features)
        self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
        self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))

        batch = data_collator(pad_features)
        self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
        self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))

        tokenizer._pad_token = None
        data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
        with self.assertRaises(ValueError):
            # Expect error due to padding token missing
            data_collator(pad_features)

        set_seed(42)  # For reproducibility
        tokenizer = BertTokenizer(self.vocab_file)
        data_collator = DataCollatorForLanguageModeling(tokenizer)
        batch = data_collator(no_pad_features)
        self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
        self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))

        masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
        self.assertTrue(torch.any(masked_tokens))
        self.assertTrue(
            all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))

        batch = data_collator(pad_features)
        self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
        self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))

        masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
        self.assertTrue(torch.any(masked_tokens))
        self.assertTrue(
            all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))

        data_collator = DataCollatorForLanguageModeling(tokenizer,
                                                        pad_to_multiple_of=8)
        batch = data_collator(no_pad_features)
        self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
        self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))

        masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
        self.assertTrue(torch.any(masked_tokens))
        self.assertTrue(
            all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))

        batch = data_collator(pad_features)
        self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
        self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))

        masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
        self.assertTrue(torch.any(masked_tokens))
        self.assertTrue(
            all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
def sinkhorn_knopp(a,
                   b,
                   C,
                   reg=1e-1,
                   maxIter=1000,
                   stopThr=1e-9,
                   verbose=True,
                   log=True,
                   warm_start=None,
                   eval_freq=10,
                   print_freq=200,
                   **kwargs):
    """
    Solve the entropic regularization optimal transport
    The input should be PyTorch tensors
    The function solves the following optimization problem:

    .. math::
        \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
        s.t. \gamma 1 = a
             \gamma^T 1= b
             \gamma\geq 0
    where :
    - C is the (ns,nt) metric cost matrix
    - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - a and b are target and source measures (sum to 1)
    The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1].

    Parameters
    ----------
    a : torch.tensor (na,)
        samples measure in the target domain
    b : torch.tensor (nb,)
        samples in the source domain
    C : torch.tensor (na,nb)
        loss matrix
    reg : float
        Regularization term > 0
    maxIter : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshol on error ( > 0 )
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True

    Returns
    -------
    gamma : (na x nb) torch.tensor
        Optimal transportation matrix for the given parameters
    log : dict
        log dictionary return only if log==True in parameters

    References
    ----------
    [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
    See Also
    --------

    """

    device = a.device
    na, nb = C.shape

    assert na >= 1 and nb >= 1, 'C needs to be 2d'
    assert na == a.shape[0] and nb == b.shape[
        0], "Shape of a or b does't match that of C"
    assert reg > 0, 'reg should be greater than 0'
    assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0'
    # unnecessary check for our special case
    if log:
        log = {'err': []}

    if warm_start is not None:
        u = warm_start['u']
        v = warm_start['v']
    else:
        u = torch.ones(na, dtype=a.dtype).to(device) / na
        v = torch.ones(nb, dtype=b.dtype).to(device) / nb

    K = torch.empty(C.shape, dtype=C.dtype).to(device)
    torch.div(C, -reg, out=K)
    torch.exp(K, out=K)

    b_hat = torch.empty(b.shape, dtype=C.dtype).to(device)

    it = 1
    err = 1

    # allocate memory beforehand
    KTu = torch.empty(v.shape, dtype=v.dtype).to(device)
    Kv = torch.empty(u.shape, dtype=u.dtype).to(device)

    while (err > stopThr and it <= maxIter):
        upre, vpre = u, v
        torch.matmul(u, K, out=KTu)
        v = torch.div(b, KTu + M_EPS)
        torch.matmul(K, v, out=Kv)
        u = torch.div(a, Kv + M_EPS)

        if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \
                torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)):
            print('Warning: numerical errors at iteration', it)
            u, v = upre, vpre
            break

        if log and it % eval_freq == 0:
            # we can speed up the process by checking for the error only all
            # the eval_freq iterations
            # below is equivalent to:
            # b_hat = torch.sum(u.reshape(-1, 1) * K * v.reshape(1, -1), 0)
            # but with more memory efficient
            b_hat = torch.matmul(u, K) * v
            err = (b - b_hat).pow(2).sum().item()
            # err = (b - b_hat).abs().sum().item()
            log['err'].append(err)

        if verbose and it % print_freq == 0:
            print('iteration {:5d}, constraint error {:5e}'.format(it, err))

        it += 1

    if log:
        log['u'] = u
        log['v'] = v
        log['alpha'] = reg * torch.log(u + M_EPS)
        log['beta'] = reg * torch.log(v + M_EPS)

    # transport plan
    P = u.reshape(-1, 1) * K * v.reshape(1, -1)
    if log:
        return P, log
    else:
        return P
Ejemplo n.º 17
0
for gid in gids:  # 每一张图
    mask_gid = datas_dt[:, 0] == gid
    datas_dt_gid = datas_dt[mask_gid]
    datas_gt_gid = datas_gt[datas_gt[:, 0] == gid]
    ious = calc_iou4ts(datas_dt_gid[:, 3:7], datas_gt_gid[:, 2:6])
    # print(ious)
    for i in range(len(datas_gt_gid)):
        cls = datas_gt_gid[i][1]
        # clses.add(int(cls.item()))
        # 类别cls-1d
        mask_cls = datas_dt_gid[:, 1] == cls
        mask_nomatch = datas_dt_gid[:, -1] != 1
        mask_iou = ious[:, i] > iou_std
        _mask = torch.logical_and(mask_cls, mask_nomatch)
        _mask = torch.logical_and(_mask, mask_iou)
        if torch.any(_mask):
            # 分数
            index_score = datas_dt_gid[:, 2][_mask].max(0)[1]
            # datas_dt_gid[index_score, -1] = 1
            dim0 = torch.where(mask_gid)
            datas_dt[dim0[0][index_score], -1] = 1

    if not debug:
        continue
    for c in clses:  # 每个图片的每个类的 tp情况 debug
        num_gt = (datas_gt_gid[:, 1] == c).sum()
        mask_gid_cls = datas_dt_gid[:, 1] == c
        tp = (torch.logical_and(mask_gid_cls, datas_dt_gid[:, -1] == 1)).sum()
        fp = mask_gid_cls.sum() - tp
        fn = num_gt - tp
        print('gid=%s, cls=%s, tp=%s, fp=%s, fn=%s' %
Ejemplo n.º 18
0
    def generate_batch(self,
                       batch,
                       gt_prefix=1,
                       enc_state=None,
                       whole_policy=False,
                       special_actions=None,
                       temperature=1,
                       acpt_range=None):
        # This is to ensure we can stop at EOS for stateful models
        assert batch.size == 1

        # Run the model
        policy, p_mean, p_logstd = self.model(batch.encoder_intent,
                                              batch.encoder_price,
                                              batch.encoder_pmask,
                                              batch.encoder_dianum)
        # print('policy is:', policy)

        # Get embeddings of target
        # tgt_emb = self.model.encoder.embeddings(batch.target_intent)
        # tgt_emb = torch.cat([tgt_emb, batch.target_price], )

        # policy.sub_(policy.max(1, keepdim=True)[0].expand(policy.size(0), policy.size(1)))

        # policy[batch.policy_mask == 0] = -100
        policy.sub_(policy.max(1, keepdim=True)[0].expand(-1, policy.size(1)))
        policy = policy * temperature
        # mask = batch.policy_mask
        # policy[mask == 0] = -100.
        # print(batch.policy_mask)

        # Avoid policy equal to zero ( + 1e-6 )
        p_exp = (policy.exp() + 1e-6).mul(batch.policy_mask)
        # p_exp = (policy.exp() + 1e-6)
        # if torch.sum(p_exp).item() == 0:
        #     p_exp += torch.ones_like(p_exp).mul(batch.policy_mask)

        policy = p_exp / (torch.sum(p_exp, keepdim=True, dim=1))
        if torch.any(policy < 0) or torch.any(torch.isnan(policy)):
            print('lots of errors: ', p_exp, batch.policy_mask)
        intent = torch.multinomial(policy, 1).squeeze(1)  # (batch_size,)

        # Use Normal distribution with constant variance as policy on price

        # price = p_mean + p_logstd.exp()*0.1 * torch.randn_like(p_mean)
        price = p_mean + (1.1 - temperature) * torch.randn_like(p_mean)
        # price = p_mean
        # print(torch.cat([price.view(-1,1), p_mean.view(-1,1), p_logstd.view(-1,1)], dim=1))
        # price = price + LFSampler.var_for_price * torch.randn_like(price).abs()

        # Use rule for Supervised learning agent
        if acpt_range is not None:
            assert len(acpt_range) == 2
            # Check if last action is offer
            if batch.encoder_intent[0, -1] in self.offer:
                offer_price = batch.encoder_price[0, 0, -1].item()
                policy = policy * 0

                if (acpt_range[0] <= offer_price) and (offer_price <=
                                                       acpt_range[1]):
                    # Accept
                    act_idx = self.acc_or_rej[0]
                else:
                    act_idx = self.acc_or_rej[1]

                intent[0] = act_idx
                policy[0, act_idx] = 1

        # TODO: Not correct, for multiple data.
        if intent not in self.price_actions:
            price = None

        # print('gen output: ',policy, price)
        ret = {
            "intent": intent,
            "price": price,
            "policy": policy,
            "price_mean": p_mean,
            "price_logstd": p_logstd,
        }

        ret["batch"] = batch
        # ret["policies"] = policies
        # ret["probability"] = probs
        return ret
Ejemplo n.º 19
0
    def assign_one_hot_gt_indices(self,
                                  is_bbox_in_gt_core,
                                  is_bbox_in_gt_shadow,
                                  gt_priority=None):
        """Assign only one gt index to each prior box.

        Gts with large gt_priority are more likely to be assigned.

        Args:
            is_bbox_in_gt_core (Tensor): Bool tensor indicating the bbox center
              is in the core area of a gt (e.g. 0-0.2).
              Shape: (num_prior, num_gt).
            is_bbox_in_gt_shadow (Tensor): Bool tensor indicating the bbox
              center is in the shadowed area of a gt (e.g. 0.2-0.5).
              Shape: (num_prior, num_gt).
            gt_priority (Tensor): Priorities of gts. The gt with a higher
              priority is more likely to be assigned to the bbox when the bbox
              match with multiple gts. Shape: (num_gt, ).

        Returns:
            tuple: Returns (assigned_gt_inds, shadowed_gt_inds).

                - assigned_gt_inds: The assigned gt index of each prior bbox \
                    (i.e. index from 1 to num_gts). Shape: (num_prior, ).
                - shadowed_gt_inds: shadowed gt indices. It is a tensor of \
                    shape (num_ignore, 2) with first column being the \
                    shadowed prior bbox indices and the second column the \
                    shadowed gt indices (1-based).
        """
        num_bboxes, num_gts = is_bbox_in_gt_core.shape

        if gt_priority is None:
            gt_priority = torch.arange(num_gts,
                                       device=is_bbox_in_gt_core.device)
        assert gt_priority.size(0) == num_gts
        # The bigger gt_priority, the more preferable to be assigned
        # The assigned inds are by default 0 (background)
        assigned_gt_inds = is_bbox_in_gt_core.new_zeros((num_bboxes, ),
                                                        dtype=torch.long)
        # Shadowed bboxes are assigned to be background. But the corresponding
        #   label is ignored during loss calculation, which is done through
        #   shadowed_gt_inds
        shadowed_gt_inds = torch.nonzero(is_bbox_in_gt_shadow, as_tuple=False)
        if is_bbox_in_gt_core.sum() == 0:  # No gt match
            shadowed_gt_inds[:, 1] += 1  # 1-based. For consistency issue
            return assigned_gt_inds, shadowed_gt_inds

        # The priority of each prior box and gt pair. If one prior box is
        #  matched bo multiple gts. Only the pair with the highest priority
        #  is saved
        pair_priority = is_bbox_in_gt_core.new_full((num_bboxes, num_gts),
                                                    -1,
                                                    dtype=torch.long)

        # Each bbox could match with multiple gts.
        # The following codes deal with this situation
        # Matched  bboxes (to any gt). Shape: (num_pos_anchor, )
        inds_of_match = torch.any(is_bbox_in_gt_core, dim=1)
        # The matched gt index of each positive bbox. Length >= num_pos_anchor
        #   , since one bbox could match multiple gts
        matched_bbox_gt_inds = torch.nonzero(is_bbox_in_gt_core,
                                             as_tuple=False)[:, 1]
        # Assign priority to each bbox-gt pair.
        pair_priority[is_bbox_in_gt_core] = gt_priority[matched_bbox_gt_inds]
        _, argmax_priority = pair_priority[inds_of_match].max(dim=1)
        assigned_gt_inds[inds_of_match] = argmax_priority + 1  # 1-based
        # Zero-out the assigned anchor box to filter the shadowed gt indices
        is_bbox_in_gt_core[inds_of_match, argmax_priority] = 0
        # Concat the shadowed indices due to overlapping with that out side of
        #   effective scale. shape: (total_num_ignore, 2)
        shadowed_gt_inds = torch.cat(
            (shadowed_gt_inds, torch.nonzero(is_bbox_in_gt_core,
                                             as_tuple=False)),
            dim=0)
        # `is_bbox_in_gt_core` should be changed back to keep arguments intact.
        is_bbox_in_gt_core[inds_of_match, argmax_priority] = 1
        # 1-based shadowed gt indices, to be consistent with `assigned_gt_inds`
        shadowed_gt_inds[:, 1] += 1
        return assigned_gt_inds, shadowed_gt_inds
Ejemplo n.º 20
0
 def forward(
     self,
     tokens_p: TextFieldTensors,
     tokens_h: TextFieldTensors,
     g_p: SparseAdjacencyFieldTensors,
     g_h: SparseAdjacencyFieldTensors,
     label: torch.Tensor = None,
     return_attention: bool = False,
 ) -> Dict[str, torch.Tensor]:
     """
     GMN for NLI
     let B be batch size
     let N be p length
     let M be h length
     
     input :
         tokens_p in shape [*] (B, N)
         g_p["edge_index"]
     ouput : tensor dict
     """
     # Shape: (batch_size, num_tokens, embedding_dim)
     # embedder will take out the desired entry, for ex: ["tokens"]
     embedded_p = self.embedder(tokens_p)
     embedded_h = self.embedder(tokens_h)
     # Shape: (num_batch_edges, edge_embedding_dim)
     # can use pass through if type is the information required
     # this is general for any kind of edge representation to GP2Vencoder
     # https://docs.allennlp.org/master/api/modules/token_embedders/embedding/
     # inplace will change input!!!
     g_p_embedded = deepcopy(g_p)
     g_h_embedded = deepcopy(g_h)
     g_p_embedded["edge_attr"] = self.edge_embedder(g_p["edge_attr"])
     g_h_embedded["edge_attr"] = self.edge_embedder(g_h["edge_attr"])
     assert (not torch.any(torch.isnan(embedded_p)))
     assert (not torch.any(torch.isnan(embedded_h)))
     # Shape: (batch_size, num_tokens, projected_dim)
     embedded_p = self.projector(embedded_p)
     embedded_h = self.projector(embedded_h)
     assert (not torch.any(torch.isnan(embedded_p)))
     assert (not torch.any(torch.isnan(embedded_h)))
     # Shape:
     # node_attr : (num_tokens, embedding_dim)
     # batch_id : (num_tokens)
     sparse_p = tensor_op.dense2sparse(
         embedded_p,
         tokens_p["tokens"]["mask"])  #need to overload indexer for this
     sparse_h = tensor_op.dense2sparse(embedded_h,
                                       tokens_h["tokens"]["mask"])
     assert (not torch.any(torch.isnan(sparse_p["data"])))
     assert (not torch.any(torch.isnan(sparse_h["data"])))
     # Shape: (batch_size, classifier_in_dim)
     if return_attention:
         cls_vector, attention_dict = self.encoder(
             sparse_p,
             sparse_h,
             g_p_embedded,
             g_h_embedded,
             return_attention=return_attention)
     else:
         cls_vector = self.encoder(sparse_p,
                                   sparse_h,
                                   g_p_embedded,
                                   g_h_embedded,
                                   return_attention=return_attention)
     # Shape: (batch_size, num_labels)
     logits = self.classifier(cls_vector)
     assert (not torch.any(torch.isnan(cls_vector)))
     # Shape: (batch_size, num_labels)
     probs = torch.nn.functional.softmax(logits, dim=1)
     # Shape: TensorDict
     output = {'probs': probs}
     if label is not None:
         #print(logits.size(), label.size())
         self.accuracy(logits, label)
         # the two value can be kind of different for numerical isse IMO
         self.entropy(logits, label)
         output['loss'] = torch.nn.functional.cross_entropy(logits, label)
     if return_attention is True:
         output['attentions'] = attention_dict
     return output
Ejemplo n.º 21
0
def greedy_cos_idf(ref_embedding, ref_masks, ref_idf,
                   hyp_embedding, hyp_masks, hyp_idf,
                   all_layers=False):
    """
    Compute greedy matching based on cosine similarity.

    Args:
        - :param: `ref_embedding` (torch.Tensor):
                   embeddings of reference sentences, BxKxd,
                   B: batch size, K: longest length, d: bert dimenison
        - :param: `ref_lens` (list of int): list of reference sentence length.
        - :param: `ref_masks` (torch.LongTensor): BxKxK, BERT attention mask for
                   reference sentences.
        - :param: `ref_idf` (torch.Tensor): BxK, idf score of each word
                   piece in the reference setence
        - :param: `hyp_embedding` (torch.Tensor):
                   embeddings of candidate sentences, BxKxd,
                   B: batch size, K: longest length, d: bert dimenison
        - :param: `hyp_lens` (list of int): list of candidate sentence length.
        - :param: `hyp_masks` (torch.LongTensor): BxKxK, BERT attention mask for
                   candidate sentences.
        - :param: `hyp_idf` (torch.Tensor): BxK, idf score of each word
                   piece in the candidate setence
    """
    ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1))
    hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1))

    if all_layers:
        B, _, L, D = hyp_embedding.size()
        hyp_embedding = hyp_embedding.transpose(1, 2).transpose(0, 1)\
            .contiguous().view(L*B, hyp_embedding.size(1), D)
        ref_embedding = ref_embedding.transpose(1, 2).transpose(0, 1)\
            .contiguous().view(L*B, ref_embedding.size(1), D)
    batch_size = ref_embedding.size(0)
    sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2))
    masks = torch.bmm(hyp_masks.unsqueeze(2).float(), ref_masks.unsqueeze(1).float())
    if all_layers:
        masks = masks.unsqueeze(0).expand(L, -1, -1, -1)\
                                  .contiguous().view_as(sim)
    else:
        masks = masks.expand(batch_size, -1, -1)\
                                  .contiguous().view_as(sim)

    masks = masks.float().to(sim.device)
    sim = sim * masks

    word_precision = sim.max(dim=2)[0]
    word_recall = sim.max(dim=1)[0]

    hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True))
    ref_idf.div_(ref_idf.sum(dim=1, keepdim=True))
    precision_scale = hyp_idf.to(word_precision.device)
    recall_scale = ref_idf.to(word_recall.device)
    if all_layers:
        precision_scale = precision_scale.unsqueeze(0)\
            .expand(L, B, -1).contiguous().view_as(word_precision)
        recall_scale = recall_scale.unsqueeze(0)\
            .expand(L, B, -1).contiguous().view_as(word_recall)
    P = (word_precision * precision_scale).sum(dim=1)
    R = (word_recall * recall_scale).sum(dim=1)
    F = 2 * P * R / (P + R)

    hyp_zero_mask = hyp_masks.sum(dim=1).eq(2)
    ref_zero_mask = ref_masks.sum(dim=1).eq(2)

    if all_layers:
        P = P.view(L, B)
        R = R.view(L, B)
        F = F.view(L, B)

    if torch.any(hyp_zero_mask):
        print("Warning: Empty candidate sentence; Setting precision to be 0.", file=sys.stderr)
        P = P.masked_fill(hyp_zero_mask, 0.)

    if torch.any(ref_zero_mask):
        print("Warning: Empty candidate sentence; Setting recall to be 0.", file=sys.stderr)
        R = R.masked_fill(ref_zero_mask, 0.)

    F = F.masked_fill(torch.isnan(F), 0.)

    return P, R, F
Ejemplo n.º 22
0
def nucleus_sampling(model, batch, batch_size, threshold=0.9, use_packed=True):
    model.eval()

    text_vecs = batch['text_vecs'].to(current_device)
    if use_packed:
        encoded = model.encoder(text_vecs,
                                batch['text_lens'],
                                use_packed=batch['use_packed'])
    else:
        encoded = model.encoder(text_vecs)
    encoder_output, encoder_hidden, attention_mask = encoded

    # 1 is __start__
    starts = torch.Tensor(
        [1]).long().to(model.decoder.embedding.weight.device).expand(
            batch_size, 1).long()  # expand to batch size
    decoder_hidden = encoder_hidden

    # greedy decoding here
    preds = [starts]
    scores = []

    # track if each sample in the mini batch is finished
    # if all finished, stop predicting

    finish_mask = torch.Tensor([0] * batch_size).byte().to(
        model.decoder.embedding.weight.device)
    xs = starts
    _attn_w_log = []

    for ts in range(100):
        decoder_output, decoder_hidden, attn_w_log = model.decoder(
            xs, decoder_hidden,
            encoded)  # decoder_output: [batch, time, vocab]
        _probs = torch.softmax(decoder_output, dim=-1)[0][0]
        _sorted_probs, _sorted_indices = torch.sort(_probs, descending=True)
        cumulative_probs = torch.cumsum(_sorted_probs, dim=-1)
        selected_probs = cumulative_probs < threshold
        selected_probs[1:] = selected_probs[:-1].clone()
        selected_probs[0] = True
        _sorted_probs[~selected_probs] = 0
        P = _sorted_probs.sum()
        _sorted_probs /= P
        chosen_index = torch.multinomial(_sorted_probs, 1)
        _preds = _sorted_indices[chosen_index]
        _scores = torch.log(_probs[_preds])
        _preds = _preds.unsqueeze(0)

        preds.append(_preds)
        _attn_w_log.append(attn_w_log)
        scores.append(_scores.view(-1) * (finish_mask == 0).float())

        finish_mask += (_preds == 2).byte().view(-1)

        if not (torch.any(~finish_mask.bool())):
            break

        xs = _preds

    preds = torch.cat(preds, dim=-1)

    return preds, torch.sum(torch.Tensor(scores))
Ejemplo n.º 23
0
    def forward(self, x, t):  # pylint: disable=arguments-differ
        assert len(x) == 1 + 2 * self.n_vectors + self.n_scales
        x_intensity = x[0]
        x_regs = x[1:1 + self.n_vectors]
        x_spreads = x[1 + self.n_vectors:1 + 2 * self.n_vectors]
        x_scales = []
        if self.n_scales:
            x_scales = x[1 + 2 * self.n_vectors:1 + 2 * self.n_vectors +
                         self.n_scales]

        assert len(t) == 1 + self.n_vectors + 1
        target_intensity = t[0]
        target_regs = t[1:1 + self.n_vectors]
        target_scale = t[-1]

        bce_masks = (target_intensity[:, :-1] + target_intensity[:, -1:]) > 0.5
        if not torch.any(bce_masks):
            return None, None, None

        batch_size = x_intensity.shape[0]
        LOG.debug('batch size = %d', batch_size)

        bce_x_intensity = x_intensity
        bce_target_intensity = target_intensity[:, :-1]
        if self.bce_blackout:
            bce_x_intensity = bce_x_intensity[:, self.bce_blackout]
            bce_masks = bce_masks[:, self.bce_blackout]
            bce_target_intensity = bce_target_intensity[:, self.bce_blackout]

        LOG.debug('BCE: target = %s, mask = %s', bce_target_intensity.shape,
                  bce_masks.shape)
        bce_target = torch.masked_select(bce_target_intensity, bce_masks)
        bce_weight = None
        if self.background_weight != 1.0:
            bce_weight = torch.ones_like(bce_target)
            bce_weight[bce_target == 0] = self.background_weight
        ce_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            torch.masked_select(bce_x_intensity, bce_masks),
            bce_target,
            weight=bce_weight,
        )

        reg_losses = [None for _ in target_regs]
        reg_masks = target_intensity[:, :-1] > 0.5
        if torch.any(reg_masks):
            weight = None
            if self.multiplicity_correction:
                assert len(target_regs) == 2
                lengths = torch.norm(target_regs[0] - target_regs[1], dim=2)
                multiplicity = (lengths - 3.0) / self.independence_scale
                multiplicity = torch.clamp(multiplicity, min=1.0)
                multiplicity = torch.masked_select(multiplicity, reg_masks)
                weight = 1.0 / multiplicity

            reg_losses = []
            for i, (x_reg, x_spread, target_reg) in enumerate(
                    zip(x_regs, x_spreads, target_regs)):
                if hasattr(self.regression_loss, 'scale'):
                    assert self.scales_to_kp is not None
                    self.regression_loss.scale = torch.masked_select(
                        torch.clamp(target_scale * self.scales_to_kp[i], 0.1,
                                    1000.0),  # pylint: disable=unsubscriptable-object
                        reg_masks,
                    )

                reg_losses.append(
                    self.regression_loss(
                        torch.masked_select(x_reg[:, :, 0], reg_masks),
                        torch.masked_select(x_reg[:, :, 1], reg_masks),
                        torch.masked_select(x_spread, reg_masks),
                        torch.masked_select(target_reg[:, :, 0], reg_masks),
                        torch.masked_select(target_reg[:, :, 1], reg_masks),
                        weight=weight,
                    ) / 1000.0 / batch_size)

        scale_losses = []
        if x_scales:
            scale_losses = [
                torch.nn.functional.l1_loss(
                    torch.masked_select(x_scale, reg_masks),
                    torch.masked_select(target_scale * scale_to_kp, reg_masks),
                    reduction='sum',
                ) / 1000.0 / batch_size
                for x_scale, scale_to_kp in zip(x_scales, self.scales_to_kp)
            ]

        return [ce_loss] + reg_losses + scale_losses
Ejemplo n.º 24
0
def train_with_pruning_callback(
    tmpdir,
    parameters_to_prune=False,
    use_global_unstructured=False,
    pruning_fn="l1_unstructured",
    use_lottery_ticket_hypothesis=False,
    strategy=None,
    accelerator="cpu",
    devices=1,
):
    model = TestModel()

    # Weights are random. None is 0
    assert torch.all(model.layer.mlp_2.weight != 0)

    pruning_kwargs = {
        "pruning_fn": pruning_fn,
        "amount": 0.3,
        "use_global_unstructured": use_global_unstructured,
        "use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis,
        "verbose": 1,
    }
    if parameters_to_prune:
        pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"),
                                                 (model.layer.mlp_2, "weight")]
    else:
        if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
            pruning_kwargs["parameter_names"] = ["weight"]
        else:
            pruning_kwargs["parameter_names"] = ["weight", "bias"]
    if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
        pruning_kwargs["pruning_dim"] = 0
    if pruning_fn == "ln_structured":
        pruning_kwargs["pruning_norm"] = 1

    # Misconfiguration checks
    if isinstance(pruning_fn, str) and pruning_fn.endswith(
            "_structured") and use_global_unstructured:
        with pytest.raises(
                MisconfigurationException,
                match="is supported with `use_global_unstructured=True`"):
            ModelPruning(**pruning_kwargs)
        return
    if ModelPruning._is_pruning_method(
            pruning_fn) and not use_global_unstructured:
        with pytest.raises(MisconfigurationException,
                           match="currently only supported with"):
            ModelPruning(**pruning_kwargs)
        return

    pruning = ModelPruning(**pruning_kwargs)

    trainer = Trainer(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        enable_model_summary=False,
        enable_checkpointing=False,
        logger=False,
        limit_train_batches=10,
        limit_val_batches=2,
        max_epochs=10,
        strategy=strategy,
        accelerator=accelerator,
        devices=devices,
        callbacks=pruning,
    )
    trainer.fit(model)
    trainer.test(model)

    if not strategy:
        # Check some have been pruned
        assert torch.any(model.layer.mlp_2.weight == 0)
    def _compute_perturbation(  # pylint: disable=W0221
        self, x: "torch.Tensor", y: "torch.Tensor", mask: Optional["torch.Tensor"]
    ) -> "torch.Tensor":
        """
        Compute perturbations.

        :param x: Current adversarial examples.
        :param y: Target values (class labels) one-hot-encoded of shape `(nb_samples, nb_classes)` or indices of shape
                  (nb_samples,). Only provide this parameter if you'd like to use true labels when crafting adversarial
                  samples. Otherwise, model predictions are used as labels to avoid the "label leaking" effect
                  (explained in this paper: https://arxiv.org/abs/1611.01236). Default is `None`.
        :param mask: An array with a mask broadcastable to input `x` defining where to apply adversarial perturbations.
                     Shape needs to be broadcastable to the shape of x and can also be of the same shape as `x`. Any
                     features for which the mask is zero will not be adversarially perturbed.
        :return: Perturbations.
        """
        import torch  # lgtm [py/repeated-import]

        # Pick a small scalar to avoid division by 0
        tol = 10e-8

        # Get gradient wrt loss; invert it if attack is targeted
        grad = self.estimator.loss_gradient(x=x, y=y) * (1 - 2 * int(self.targeted))

        # Write summary
        if self.summary_writer is not None:
            self.summary_writer.add_scalar(
                "gradients/norm-L1/batch-{}".format(self._batch_id),
                np.linalg.norm(grad.flatten(), ord=1),
                global_step=self._i_max_iter,
            )
            self.summary_writer.add_scalar(
                "gradients/norm-L2/batch-{}".format(self._batch_id),
                np.linalg.norm(grad.flatten(), ord=2),
                global_step=self._i_max_iter,
            )
            self.summary_writer.add_scalar(
                "gradients/norm-Linf/batch-{}".format(self._batch_id),
                np.linalg.norm(grad.flatten(), ord=np.inf),
                global_step=self._i_max_iter,
            )

            if hasattr(self.estimator, "compute_losses"):
                losses = self.estimator.compute_losses(x=x, y=y)

                for key, value in losses.items():
                    self.summary_writer.add_scalar(
                        "loss/{}/batch-{}".format(key, self._batch_id),
                        np.mean(value.detach().cpu().numpy()),
                        global_step=self._i_max_iter,
                    )

        # Check for nan before normalisation an replace with 0
        if torch.any(grad.isnan()):
            logger.warning("Elements of the loss gradient are NaN and have been replaced with 0.0.")
            grad[grad.isnan()] = 0.0

        # Apply mask
        if mask is not None:
            grad = torch.where(mask == 0.0, torch.tensor(0.0).to(self.estimator.device), grad)

        # Apply norm bound
        if self.norm in ["inf", np.inf]:
            grad = grad.sign()

        elif self.norm == 1:
            ind = tuple(range(1, len(x.shape)))
            grad = grad / (torch.sum(grad.abs(), dim=ind, keepdims=True) + tol)  # type: ignore

        elif self.norm == 2:
            ind = tuple(range(1, len(x.shape)))
            grad = grad / (torch.sqrt(torch.sum(grad * grad, axis=ind, keepdims=True)) + tol)  # type: ignore

        assert x.shape == grad.shape

        return grad
Ejemplo n.º 26
0
    def run(self):

        # ---- Subscribe ----
        with self.neuron:

            # ---- Weights ----
            self.row = self.neuron.metagraph.row

            # --- Run state ---
            self.global_step = 0
            self.best_train_loss = math.inf

            # --- Loop forever ---
            for self.epoch in range(self.config.miner.n_epochs):
                try:
                    # ---- Serve ----
                    self.neuron.axon.serve(self.model)

                    # ---- Train Model ----
                    self.train()
                    self.scheduler.step()

                    # If model has borked for some reason, we need to make sure it doesn't emit weights
                    # Instead, reload into previous version of model
                    if torch.any(
                            torch.isnan(
                                torch.cat([
                                    param.view(-1)
                                    for param in self.model.parameters()
                                ]))):
                        self.model, self.optimizer = self.model_toolbox.load_model(
                            self.config)
                        continue

                    # ---- Emit row-weights ----
                    self.neuron.metagraph.set_weights(
                        self.row, wait_for_inclusion=True
                    )  # Sets my row-weights on the chain.

                    # ---- Sync metagraph ----
                    self.neuron.metagraph.sync(
                    )  # Pulls the latest metagraph state (with my update.)
                    self.row = self.neuron.metagraph.row

                    # --- Epoch logs ----
                    print(self.neuron.axon.__full_str__())
                    print(self.neuron.dendrite.__full_str__())
                    print(self.neuron.metagraph)

                    # ---- Update Tensorboard ----
                    self.neuron.dendrite.__to_tensorboard__(
                        self.tensorboard, self.global_step)
                    self.neuron.metagraph.__to_tensorboard__(
                        self.tensorboard, self.global_step)
                    self.neuron.axon.__to_tensorboard__(
                        self.tensorboard, self.global_step)

                    # ---- Save best loss and model ----
                    if self.training_loss and self.epoch % 10 == 0:
                        if self.training_loss < self.best_train_loss:
                            self.best_train_loss = self.training_loss  # update best train loss
                            self.model_toolbox.save_model(
                                self.config.miner.full_path, {
                                    'epoch':
                                    self.epoch,
                                    'model_state_dict':
                                    self.model.state_dict(),
                                    'loss':
                                    self.best_train_loss,
                                    'optimizer_state_dict':
                                    self.optimizer.state_dict(),
                                })
                            self.tensorboard.add_scalar(
                                'Neuron/Train_loss', self.training_loss,
                                self.global_step)

                # --- Catch Errors ----
                except Exception as e:
                    logger.error('Exception in training script with error: {}',
                                 e)
                    logger.info(traceback.print_exc())
                    logger.info('Continuing to train.')
                    time.sleep(1)
Ejemplo n.º 27
0
def main():
    summary = SummaryWriter('./log')
    tr_set, dv_set, feat_dim, msg = load_dataset(args.njobs, args.gpu, args.pin_memory,
                                                 config['hparas']['curriculum'] > 0,
                                                 **config['data'])

    verbose(msg)
    # model select
    print('Model initializing\n')
    net = torch.nn.DataParallel(
        AttentionModel(120, hidden_size=args.hidden_size, dropout_p=args.dropout_p, use_attn=args.attn_use,
                       stacked_encoder=args.stacked_encoder, attn_len=args.attn_len))
    # net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = arg0s.attn_use)
    net = net.cuda()
    print(net)

    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    scheduler = ExponentialLR(optimizer, 0.5)

    # check point load
    # Check point load

    print('Trying Checkpoint Load\n')
    ckpt_dir = 'ckpt_dir'
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    best_loss = 200000.
    ckpt_path = os.path.join(ckpt_dir, args.ck_name)
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path)
        try:
            net.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            best_loss = ckpt['best_loss']

            print('checkpoint is loaded !')
            print('current best loss : %.4f' % best_loss)
        except RuntimeError as e:
            print('wrong checkpoint\n')
    else:
        print('checkpoint not exist!')
        print('current best loss : %.4f' % best_loss)

    print('Training Start!')
    # train
    iteration = 0
    train_losses = []
    test_losses = []
    for epoch in range(args.num_epochs):
        n = 0
        avg_loss = 0
        net.train()
        for input in tqdm(tr_set):
            tr_noisy_set, feat_dim = load_noisy_dataset("train", input[0], args.njobs,
                                                        args.gpu,
                                                        args.pin_memory,
                                                        config['hparas']['curriculum'] > 0,
                                                        **config['data_noisy'])
            for input_noisy in tr_noisy_set:
                train_clean_feat, feat_len = fetch_data(input)
                train_noisy_feat, feat_len = fetch_data(input_noisy)

                iteration += 1

                # feed data
                train_mixed_feat, attn_weight = net(train_noisy_feat)
                if train_mixed_feat.shape == train_clean_feat.shape:
                    loss = F.mse_loss(train_mixed_feat, train_clean_feat, True)

                    if torch.any(torch.isnan(loss)):
                        torch.save(
                            {'clean_mag': train_clean_feat, 'noisy_mag': train_noisy_feat, 'out_mag': train_mixed_feat},
                            'nan_mag')
                        raise ('loss is NaN')
                    avg_loss += loss.item()

                    n += 1
                    # gradient optimizer
                    optimizer.zero_grad()

                    loss.backward()

                    # update weight
                    optimizer.step()

        avg_loss /= n
        print('result:')
        print('[epoch: {}, iteration: {}] avg_loss : {:.4f}'.format(epoch, iteration, avg_loss))

        summary.add_scalar('Train Loss', avg_loss, iteration)

        train_losses.append(avg_loss)
        if (len(train_losses) > 2) and (train_losses[-2] < avg_loss):
            print("Learning rate Decay")
            scheduler.step()

        # test phase
        n = 0
        avg_test_loss = 0
        net.eval()
        with torch.no_grad():
            for input in tqdm(dv_set):
                dv_noisy_set, feat_dim = load_noisy_dataset("dev", input[0], args.njobs,
                                                            args.gpu,
                                                            args.pin_memory,
                                                            config['hparas']['curriculum'] > 0,
                                                            **config['data_noisy'])
                for input_noisy in dv_noisy_set:
                    test_clean_feat = input[1].to(device='cuda')
                    test_noisy_feat = input_noisy[1].to(device='cuda')

                    test_mixed_feat, logits_attn_weight = net(test_noisy_feat)
                    if test_mixed_feat.shape == test_clean_feat.shape:
                        test_loss = F.mse_loss(test_mixed_feat, test_clean_feat, True)

                        avg_test_loss += test_loss.item()
                        n += 1

            avg_test_loss /= n

            test_losses.append(avg_test_loss)
            summary.add_scalar('Test Loss', avg_test_loss, iteration)

            print('[epoch: {}, iteration: {}] test loss : {:.4f} '.format(epoch, iteration, avg_test_loss))
            if avg_test_loss < best_loss:
                best_loss = avg_test_loss
                # Note: optimizer also has states ! don't forget to save them as well.
                ckpt = {'model': net.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_loss': best_loss}
                torch.save(ckpt, ckpt_path)
                print('checkpoint is saved !')
Ejemplo n.º 28
0
        #### END CODE HERE ####

        # Make sure shapes are correct
        assert tuple(
            fake_image_and_labels.shape) == (len(real),
                                             fake.detach().shape[1] +
                                             image_one_hot_labels.shape[1], 28,
                                             28)
        assert tuple(
            real_image_and_labels.shape) == (len(real), real.shape[1] +
                                             image_one_hot_labels.shape[1], 28,
                                             28)
        # Make sure that enough predictions were made
        assert len(disc_real_pred) == len(real)
        # Make sure that the inputs are different
        assert torch.any(fake_image_and_labels != real_image_and_labels)
        # Shapes must match
        assert tuple(fake_image_and_labels.shape) == tuple(
            real_image_and_labels.shape)
        assert tuple(disc_fake_pred.shape) == tuple(disc_real_pred.shape)

        disc_fake_loss = criterion(disc_fake_pred,
                                   torch.zeros_like(disc_fake_pred))
        disc_real_loss = criterion(disc_real_pred,
                                   torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        # Keep track of the average discriminator loss
        discriminator_losses += [disc_loss.item()]
Ejemplo n.º 29
0
def test_retrieval(model, criterion, transforms_cuda, device, epoch, args):
    accuracy = [AverageMeter(),AverageMeter(),AverageMeter(),AverageMeter()]
    model.eval()
    
    def tr(x):
        B = x.size(0); assert B == 1
        test_sample = x.size(2)//(args.seq_len*args.num_seq)
        return transforms_cuda(x)\
        .view(3,test_sample,args.num_seq,args.seq_len,args.img_dim,args.img_dim).permute(1,2,0,3,4,5)

    with torch.no_grad():
        transform = transforms.Compose([
                    A.CenterCrop(size=(224,224)),
                    A.Scale(size=(args.img_dim,args.img_dim)),
                    A.ColorJitter(0.2, 0.2, 0.2, 0.1, p=0.3, consistent=True),
                    A.ToTensor()])

        if args.dataset == 'ucf101':
            d_class = UCF101LMDB
        elif args.dataset == 'ucf101-f':
            d_class = UCF101Flow_LMDB
        elif args.dataset == 'hmdb51':
            d_class = HMDB51LMDB
        elif args.dataset == 'hmdb51-f':
            d_class = HMDB51Flow_LMDB

        train_dataset = d_class(mode='train', 
                            transform=transform, 
                            num_frames=args.num_seq*args.seq_len,
                            ds=args.ds,
                            which_split=1,
                            return_label=True,
                            return_path=True)
        print('train dataset size: %d' % len(train_dataset))

        test_dataset = d_class(mode='test', 
                            transform=transform, 
                            num_frames=args.num_seq*args.seq_len,
                            ds=args.ds,
                            which_split=1,
                            return_label=True,
                            return_path=True)
        print('test dataset size: %d' % len(test_dataset))

        train_sampler = data.SequentialSampler(train_dataset)
        test_sampler = data.SequentialSampler(test_dataset)

        train_loader = data.DataLoader(train_dataset,
                                      batch_size=1,
                                      sampler=train_sampler,
                                      shuffle=False,
                                      num_workers=args.workers,
                                      pin_memory=True)
        test_loader = data.DataLoader(test_dataset,
                                      batch_size=1,
                                      sampler=test_sampler,
                                      shuffle=False,
                                      num_workers=args.workers,
                                      pin_memory=True)
        if args.dirname is None:
            dirname = 'feature'
        else:
            dirname = args.dirname

        if os.path.exists(os.path.join(os.path.dirname(args.test), dirname, '%s_test_feature.pth.tar' % args.dataset)): 
            test_feature = torch.load(os.path.join(os.path.dirname(args.test), dirname, '%s_test_feature.pth.tar' % args.dataset)).to(device)
            test_label = torch.load(os.path.join(os.path.dirname(args.test), dirname, '%s_test_label.pth.tar' % args.dataset)).to(device)
        else:
            try: os.makedirs(os.path.join(os.path.dirname(args.test), dirname))
            except: pass 

            print('Computing test set feature ... ')
            test_feature = None
            test_label = []
            test_vname = []
            sample_id = 0 
            for idx, (input_seq, target) in tqdm(enumerate(test_loader), total=len(test_loader)):
                B = 1
                input_seq = input_seq.to(device, non_blocking=True)
                input_seq = tr(input_seq)
                current_target, vname = target
                current_target = current_target.to(device, non_blocking=True)

                test_sample = input_seq.size(0)
                input_seq = input_seq.squeeze(1)
                logit, feature = model(input_seq)
                if test_feature is None:
                    test_feature = torch.zeros(len(test_dataset), feature.size(-1), device=feature.device)

                test_feature[sample_id,:] = feature.mean(0)
                test_label.append(current_target)
                test_vname.append(vname)
                sample_id += 1

            print(test_feature.size())
            # test_feature = torch.stack(test_feature, dim=0)
            test_label = torch.cat(test_label, dim=0)
            torch.save(test_feature, os.path.join(os.path.dirname(args.test), dirname, '%s_test_feature.pth.tar' % args.dataset))
            torch.save(test_label, os.path.join(os.path.dirname(args.test), dirname, '%s_test_label.pth.tar' % args.dataset))
            with open(os.path.join(os.path.dirname(args.test), dirname, '%s_test_vname.pkl' % args.dataset), 'wb') as fp:
                pickle.dump(test_vname, fp)


        if os.path.exists(os.path.join(os.path.dirname(args.test), dirname, '%s_train_feature.pth.tar' % args.dataset)): 
            train_feature = torch.load(os.path.join(os.path.dirname(args.test), dirname, '%s_train_feature.pth.tar' % args.dataset)).to(device)
            train_label = torch.load(os.path.join(os.path.dirname(args.test), dirname, '%s_train_label.pth.tar' % args.dataset)).to(device)
        else:
            print('Computing train set feature ... ')
            train_feature = None
            train_label = []
            train_vname = []
            sample_id = 0
            for idx, (input_seq, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
                B = 1
                input_seq = input_seq.to(device, non_blocking=True)
                input_seq = tr(input_seq)
                current_target, vname = target
                current_target = current_target.to(device, non_blocking=True)

                test_sample = input_seq.size(0)
                input_seq = input_seq.squeeze(1)
                logit, feature = model(input_seq)
                if train_feature is None:
                    train_feature = torch.zeros(len(train_dataset), feature.size(-1), device=feature.device)

                train_feature[sample_id,:] = feature.mean(0)
                # train_feature[sample_id,:] = feature[:,-1,:].mean(0)
                train_label.append(current_target)
                train_vname.append(vname)
                sample_id += 1
            # train_feature = torch.stack(train_feature, dim=0)
            print(train_feature.size())
            train_label = torch.cat(train_label, dim=0)
            torch.save(train_feature, os.path.join(os.path.dirname(args.test), dirname, '%s_train_feature.pth.tar' % args.dataset))
            torch.save(train_label, os.path.join(os.path.dirname(args.test), dirname, '%s_train_label.pth.tar' % args.dataset))
            with open(os.path.join(os.path.dirname(args.test), dirname, '%s_train_vname.pkl' % args.dataset), 'wb') as fp:
                pickle.dump(train_vname, fp)

        ks = [1,5,10,20,50]
        NN_acc = []

        # centering
        test_feature = test_feature - test_feature.mean(dim=0, keepdim=True)
        train_feature = train_feature - train_feature.mean(dim=0, keepdim=True)

        # normalize
        test_feature = F.normalize(test_feature, p=2, dim=1)
        train_feature = F.normalize(train_feature, p=2, dim=1)

        # dot product
        sim = test_feature.matmul(train_feature.t())

        torch.save(sim, os.path.join(os.path.dirname(args.test), dirname, '%s_sim.pth.tar' % args.dataset))

        for k in ks:
            topkval, topkidx = torch.topk(sim, k, dim=1)
            acc = torch.any(train_label[topkidx] == test_label.unsqueeze(1), dim=1).float().mean().item()
            NN_acc.append(acc)
            print('%dNN acc = %.4f' % (k, acc))

        args.logger.log('NN-Retrieval on %s:' % args.dataset)
        for k,acc in zip(ks, NN_acc):
            args.logger.log('\t%dNN acc = %.4f' % (k, acc))

        with open(os.path.join(os.path.dirname(args.test), dirname, '%s_test_vname.pkl' % args.dataset), 'rb') as fp:
            test_vname = pickle.load(fp)

        with open(os.path.join(os.path.dirname(args.test), dirname, '%s_train_vname.pkl' % args.dataset), 'rb') as fp:
            train_vname = pickle.load(fp)

        sys.exit(0)
Ejemplo n.º 30
0
    def infer(
        self,
        phonemes: torch.LongTensor,
        combine_strategy: str = "concat",
        max_len: int = 1024,
        stop_threshold: float = 0.25,
        verbose: bool = False,
        stop_at_stop_token: bool = True,
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        """
        Run inference loop to generate a new spectrogram

        :param phonemes: of shape (batch size, phonemes length)
        :param combine_strategy: "concat" to concatenate outputs from consecutive iterations,
            "replace" to replace previous output using output from current iteration
        :param max_len: maximum length of generated spectrogram
        :param stop_threshold: value in (-1 , 1) above which sigmoid(stop_pred)
            is considered to indicate end of predicted sequence
        :param verbose: if true, prints progress info every 10 steps (useful for cpu)
        :param stop_at_stop_token: if true, ignores the stop threshold and stops only at max_len
        :return: tuple (spectrogram, stop_idx) where
            spectrogram shape = (batch size, num_mel_coeffs, spectrogram length)
            stop_idx shape = (batch size), contains end index for each spectrogram in batch
        """
        assert combine_strategy in {"concat", "replace"}
        assert -1. < stop_threshold < 1.

        batch_size = phonemes.shape[0]
        zeros = torch.zeros((batch_size, 1, self.num_mel_coeffs))
        spectrogram = zeros.clone()
        stop = torch.zeros(batch_size, dtype=torch.long)

        while not torch.all(stop > 0):
            iteration = spectrogram.shape[1]
            still_running = stop == 0
            if verbose and iteration % 10 == 0:
                number_of_running_samples = sum(still_running)
                print(f"reached {iteration=}, {number_of_running_samples=}...")

            _, generated, stop_pred, _ = self.forward(phonemes, spectrogram)
            stop_pred = stop_pred.view(
                -1, generated.shape[1])  # view as (batch, len)

            if combine_strategy == "concat":
                generated_slice = generated[:, -1, :].view(
                    batch_size, 1, self.num_mel_coeffs)
                spectrogram = torch.cat([spectrogram, generated_slice], dim=1)

                if stop_at_stop_token:
                    stops_now = torch.sigmoid(stop_pred[:,
                                                        -1]) > stop_threshold
                    still_running_stops_now = still_running * stops_now
                    stop = torch.where(
                        still_running_stops_now,
                        (stops_now.to(dtype=torch.long) * iteration) + 1, stop)
            elif combine_strategy == "replace":
                spectrogram = torch.cat([zeros, generated], dim=1)

                if stop_at_stop_token:
                    stops_now = torch.any(
                        torch.sigmoid(stop_pred) > stop_threshold, dim=1)
                    still_running_stops_now = stops_now * still_running
                    stop = torch.where(still_running_stops_now,
                                       torch.argmax(stop_pred, dim=1) + 1,
                                       stop)

            if max(spectrogram.shape) > max_len:
                if verbose:
                    print(f"stopped at {max_len=}")
                break

        stop_at_end = torch.ones_like(stop, dtype=torch.long) * max_len
        stop: torch.LongTensor = torch.where(stop == 0, stop_at_end, stop)
        spectrogram: torch.Tensor = spectrogram.transpose(1, 2)
        return spectrogram[:, :, 1:], stop
Ejemplo n.º 31
0
  def forward(self, backbone_outputs, match_coords, is_train):
    # Enumerate network over pyramids.
    fpn_outputs = []
    targets = []
    classifications = []
    pyramid_output = None
    num_layers = len(backbone_outputs)
    if is_train:
      target_coords = [ME.utils.batched_coordinates([match[i][0] for match in match_coords])
                       for i in range(num_layers)]
      ambiguous_coords = [ME.utils.batched_coordinates([match[i][1] for match in match_coords])
                          for i in range(num_layers)]

    for layer_idx in reversed(range(num_layers)):
      conv_feat_layer = self.get_layer('conv_feat', layer_idx)
      conv_cls_layer = self.get_layer('conv_cls', layer_idx)
      conv_up_layer = self.get_layer('conv_up', layer_idx)

      # Current feature
      curr_feat = backbone_outputs[layer_idx]

      # Add previous layer output
      if pyramid_output is not None:
        assert pyramid_output.tensor_stride == curr_feat.tensor_stride
        curr_feat = curr_feat + pyramid_output

      # Two branches: upsample and fpn feature and classification
      # 1. FPN feature & classification
      fpn_feat = conv_feat_layer(curr_feat)
      feat_cls = conv_cls_layer(fpn_feat)
      pred_prob = F.softmax(feat_cls.F, 1)[:, 1]

      # target calculation
      target = None
      if is_train:
        target = torch.zeros(len(fpn_feat), dtype=torch.long)
        pos_ins = utils.map_coordinates(fpn_feat, torch.cat(ambiguous_coords[:layer_idx + 1]),
                                        force_stride=True)[0]
        target[pos_ins] = self.config.ignore_label
        pos_ins = utils.map_coordinates(fpn_feat, torch.cat(target_coords[:layer_idx + 1]),
                                        force_stride=True)[0]
        target[pos_ins] = 1

      # Get keep labels
      keep = (pred_prob > self.config.sfpn_min_confidence).cpu()
      if is_train:  # Force put GT labels within keep
        keep |= target == 1

      if torch.any(keep):
        # Prune and upsample
        pyramid_output = conv_up_layer(self.pruning(curr_feat, keep))
        # Generate final feature for current level
        final_pruned = self.pruning(fpn_feat, keep)
      else:
        pyramid_output = None
        final_pruned = None

      # Post processing
      classifications.insert(0, feat_cls)
      targets.insert(0, target)
      fpn_outputs.insert(0, final_pruned)

    return fpn_outputs, targets, classifications