Esempio n. 1
0
File: runner.py Progetto: pgsrv/aum
    def train_epoch(self,
                    model,
                    optimizer,
                    epoch,
                    num_epochs,
                    batch_size=256,
                    num_workers=0,
                    aum_calculator=None,
                    aum_wtr=False,
                    rand_weight=False):
        stats = ["error", "loss"]
        meters = [util.AverageMeter() for _ in stats]
        result_class = util.result_class(stats)

        # Weighting - set up from GMM
        # NOTE: This is only used when removing indicator samples
        # TODO: some of this probably needs to be changed?
        if aum_wtr:
            counts = torch.zeros(len(self.train_set))
            bad_probs = torch.zeros(len(self.train_set))
            if isinstance(aum_wtr, str):
                aum_wtr = aum_wtr.split(",")
            for sub_aum_wtr in aum_wtr:
                aums_path = os.path.join(sub_aum_wtr, "aum_details.csv")
                if not os.path.exists(aums_path):
                    self.generate_aum_details(load=sub_aum_wtr)
                aums_data = pd.read_csv(aums_path).drop(
                    ["True Target", "Observed Target", "Label Flipped"],
                    axis=1)
                counts += torch.tensor(
                    ~aums_data["Is Indicator Sample"].values).float()
                bad_probs += torch.tensor(
                    aums_data["AUM_WTR"].values *
                    ~aums_data["Is Indicator Sample"].values).float()
            counts.clamp_min_(1)
            good_probs = (1 - bad_probs / counts).to(
                next(model.parameters()).dtype).ceil()
            if torch.cuda.is_available():
                good_probs = good_probs.cuda()
            logging.info(f"AUM WTR Score")
            logging.info(
                f"(Num samples removed: {good_probs.ne(1.).sum().item()})")
        elif rand_weight:
            logging.info("Rectified Normal Random Weighting")
        else:
            logging.info("Standard weighting")

        # Setup loader
        train_set = self.train_set
        loader = tqdm.tqdm(torch.utils.data.DataLoader(
            train_set,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers),
                           desc=f"Train (Epoch {epoch + 1}/{num_epochs})")

        # Model on train mode
        model.train()
        for inputs, targets, indices in loader:
            optimizer.zero_grad()

            # Get types right
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                targets = targets.cuda()

            # Compute output and losses
            outputs = model(inputs)
            losses = self.loss_func(outputs, targets, reduction="none")
            preds = outputs.argmax(dim=-1)

            # Compute loss weights
            if aum_wtr:
                weights = good_probs[indices.to(good_probs.device)]
                weights = weights.div(weights.sum())
            elif rand_weight:
                weights = torch.randn(targets.size(),
                                      dtype=outputs.dtype,
                                      device=outputs.device).clamp_min_(0)
                weights = weights.div(weights.sum().clamp_min_(1e-10))
            else:
                weights = torch.ones(targets.size(),
                                     dtype=outputs.dtype,
                                     device=outputs.device).div_(
                                         targets.numel())

            # Backward through model
            loss = torch.dot(weights, losses)
            error = torch.ne(targets, preds).float().mean()
            loss.backward()

            # Update the model
            optimizer.step()

            # Update AUM values (after the first epoch due to variability of random initialization)
            if aum_calculator and epoch > 0:
                aum_calculator.update(
                    logits=outputs.detach().cpu().half().float(),
                    targets=targets.detach().cpu(),
                    sample_ids=indices.tolist())

            # measure and record stats
            batch_size = outputs.size(0)
            stat_vals = [error.item(), loss.item()]
            for stat_val, meter in zip(stat_vals, meters):
                meter.update(stat_val, batch_size)

            # log stats
            res = dict((name, f"{meter.val:.3f} ({meter.avg:.3f})")
                       for name, meter in zip(stats, meters))
            loader.set_postfix(**res)

        # Return summary statistics
        return result_class(*[meter.avg for meter in meters])
Esempio n. 2
0
File: runner.py Progetto: pgsrv/aum
    def test(self,
             model=None,
             split="test",
             batch_size=512,
             dataset=None,
             epoch=None,
             num_workers=0):
        """
        Testing script
        """
        stats = ['error', 'top5_error', 'loss']
        meters = [util.AverageMeter() for _ in stats]
        result_class = util.result_class(stats)

        # Get model
        if model is None:
            model = self.model
            # Model on cuda
            if torch.cuda.is_available():
                model = model.cuda()
                if torch.cuda.is_available() and torch.cuda.device_count() > 1:
                    model = torch.nn.DataParallel(model).cuda()

        # Get dataset/loader
        if dataset is None:
            try:
                dataset = getattr(self, f"{split}_set")
            except Exception:
                raise ValueError(f"Invalid split '{split}'")
        loader = tqdm.tqdm(torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers),
                           desc=split.title())

        # For storing results
        all_losses = []
        all_confs = []
        all_preds = []
        all_targets = []

        # Model on train mode
        model.eval()
        with torch.no_grad():
            for inputs, targets, indices in loader:
                # Get types right
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    targets = targets.cuda()

                # Calculate loss
                outputs = model(inputs)
                losses = self.loss_func(outputs, targets, reduction="none")
                confs, preds = outputs.topk(5,
                                            dim=-1,
                                            largest=True,
                                            sorted=True)
                is_correct = preds.eq(targets.unsqueeze(-1)).float()
                loss = losses.mean()
                error = 1 - is_correct[:, 0].mean()
                top5_error = 1 - is_correct.sum(dim=-1).mean()

                # measure and record stats
                batch_size = inputs.size(0)
                stat_vals = [error.item(), top5_error.item(), loss.item()]
                for stat_val, meter in zip(stat_vals, meters):
                    meter.update(stat_val, batch_size)

                # Record losses
                all_losses.append(losses.cpu())
                all_confs.append(confs[:, 0].cpu())
                all_preds.append(preds[:, 0].cpu())
                all_targets.append(targets.cpu())

                # log stats
                res = dict((name, f"{meter.val:.3f} ({meter.avg:.3f})")
                           for name, meter in zip(stats, meters))
                loader.set_postfix(**res)

        # Save the outputs
        pd.DataFrame({
            "Loss": torch.cat(all_losses).numpy(),
            "Prediction": torch.cat(all_preds).numpy(),
            "Confidence": torch.cat(all_confs).numpy(),
            "Label": torch.cat(all_targets).numpy(),
        }).to_csv(os.path.join(self.savedir, f"results_{split}.csv"),
                  index_label="index")

        # Return summary statistics and outputs
        return result_class(*[meter.avg for meter in meters])