def train_epoch(self, model, optimizer, epoch, num_epochs, batch_size=256, num_workers=0, aum_calculator=None, aum_wtr=False, rand_weight=False): stats = ["error", "loss"] meters = [util.AverageMeter() for _ in stats] result_class = util.result_class(stats) # Weighting - set up from GMM # NOTE: This is only used when removing indicator samples # TODO: some of this probably needs to be changed? if aum_wtr: counts = torch.zeros(len(self.train_set)) bad_probs = torch.zeros(len(self.train_set)) if isinstance(aum_wtr, str): aum_wtr = aum_wtr.split(",") for sub_aum_wtr in aum_wtr: aums_path = os.path.join(sub_aum_wtr, "aum_details.csv") if not os.path.exists(aums_path): self.generate_aum_details(load=sub_aum_wtr) aums_data = pd.read_csv(aums_path).drop( ["True Target", "Observed Target", "Label Flipped"], axis=1) counts += torch.tensor( ~aums_data["Is Indicator Sample"].values).float() bad_probs += torch.tensor( aums_data["AUM_WTR"].values * ~aums_data["Is Indicator Sample"].values).float() counts.clamp_min_(1) good_probs = (1 - bad_probs / counts).to( next(model.parameters()).dtype).ceil() if torch.cuda.is_available(): good_probs = good_probs.cuda() logging.info(f"AUM WTR Score") logging.info( f"(Num samples removed: {good_probs.ne(1.).sum().item()})") elif rand_weight: logging.info("Rectified Normal Random Weighting") else: logging.info("Standard weighting") # Setup loader train_set = self.train_set loader = tqdm.tqdm(torch.utils.data.DataLoader( train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers), desc=f"Train (Epoch {epoch + 1}/{num_epochs})") # Model on train mode model.train() for inputs, targets, indices in loader: optimizer.zero_grad() # Get types right if torch.cuda.is_available(): inputs = inputs.cuda() targets = targets.cuda() # Compute output and losses outputs = model(inputs) losses = self.loss_func(outputs, targets, reduction="none") preds = outputs.argmax(dim=-1) # Compute loss weights if aum_wtr: weights = good_probs[indices.to(good_probs.device)] weights = weights.div(weights.sum()) elif rand_weight: weights = torch.randn(targets.size(), dtype=outputs.dtype, device=outputs.device).clamp_min_(0) weights = weights.div(weights.sum().clamp_min_(1e-10)) else: weights = torch.ones(targets.size(), dtype=outputs.dtype, device=outputs.device).div_( targets.numel()) # Backward through model loss = torch.dot(weights, losses) error = torch.ne(targets, preds).float().mean() loss.backward() # Update the model optimizer.step() # Update AUM values (after the first epoch due to variability of random initialization) if aum_calculator and epoch > 0: aum_calculator.update( logits=outputs.detach().cpu().half().float(), targets=targets.detach().cpu(), sample_ids=indices.tolist()) # measure and record stats batch_size = outputs.size(0) stat_vals = [error.item(), loss.item()] for stat_val, meter in zip(stat_vals, meters): meter.update(stat_val, batch_size) # log stats res = dict((name, f"{meter.val:.3f} ({meter.avg:.3f})") for name, meter in zip(stats, meters)) loader.set_postfix(**res) # Return summary statistics return result_class(*[meter.avg for meter in meters])
def test(self, model=None, split="test", batch_size=512, dataset=None, epoch=None, num_workers=0): """ Testing script """ stats = ['error', 'top5_error', 'loss'] meters = [util.AverageMeter() for _ in stats] result_class = util.result_class(stats) # Get model if model is None: model = self.model # Model on cuda if torch.cuda.is_available(): model = model.cuda() if torch.cuda.is_available() and torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model).cuda() # Get dataset/loader if dataset is None: try: dataset = getattr(self, f"{split}_set") except Exception: raise ValueError(f"Invalid split '{split}'") loader = tqdm.tqdm(torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers), desc=split.title()) # For storing results all_losses = [] all_confs = [] all_preds = [] all_targets = [] # Model on train mode model.eval() with torch.no_grad(): for inputs, targets, indices in loader: # Get types right if torch.cuda.is_available(): inputs = inputs.cuda() targets = targets.cuda() # Calculate loss outputs = model(inputs) losses = self.loss_func(outputs, targets, reduction="none") confs, preds = outputs.topk(5, dim=-1, largest=True, sorted=True) is_correct = preds.eq(targets.unsqueeze(-1)).float() loss = losses.mean() error = 1 - is_correct[:, 0].mean() top5_error = 1 - is_correct.sum(dim=-1).mean() # measure and record stats batch_size = inputs.size(0) stat_vals = [error.item(), top5_error.item(), loss.item()] for stat_val, meter in zip(stat_vals, meters): meter.update(stat_val, batch_size) # Record losses all_losses.append(losses.cpu()) all_confs.append(confs[:, 0].cpu()) all_preds.append(preds[:, 0].cpu()) all_targets.append(targets.cpu()) # log stats res = dict((name, f"{meter.val:.3f} ({meter.avg:.3f})") for name, meter in zip(stats, meters)) loader.set_postfix(**res) # Save the outputs pd.DataFrame({ "Loss": torch.cat(all_losses).numpy(), "Prediction": torch.cat(all_preds).numpy(), "Confidence": torch.cat(all_confs).numpy(), "Label": torch.cat(all_targets).numpy(), }).to_csv(os.path.join(self.savedir, f"results_{split}.csv"), index_label="index") # Return summary statistics and outputs return result_class(*[meter.avg for meter in meters])