def train_one_epoch(model, optim, loss_fn, loader, epoch, steps, device, writer, global_i, writer_interval=200, normalize=None): model.train() status_col = TextColumn("") running_loss = 0 lr = optim.param_groups[0]['lr'] if normalize is not None: assert len(normalize) == 2, "mean and std values should be provided to use data normalization" fetcher = DataPrefetcher(loader, mean=normalize[0], std=normalize[1], device=device) # modified behavior - w/ input normalization else: fetcher = DataPrefetcher(loader, mean=None, std=None, device=device) # original behavior - no input normalization samples, targets = fetcher.next() with Progress("[progress.description]{task.description}", "[{task.completed}/{task.total}]", BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), TextColumn("/"), TimeElapsedColumn(), status_col, expand=False, console=CONSOLE, refresh_per_second=5) as progress: task = progress.add_task(description=f'[Epoch {epoch}]', total=steps) i = 0 # counter t_start = time.time() while samples is not None: # zero the parameter gradients optim.zero_grad() # forward + backward + optimize out = model(samples) loss = loss_fn(out, targets) loss.backward() optim.step() # collect running loss running_loss += loss.item() i += 1 global_i += 1 # update tensorboard if i % writer_interval == 0: writer.add_scalar('Loss/Train', running_loss/i, global_i) # pre-fetch next samples samples, targets = fetcher.next() # update trackbar if not progress.finished: status_col.text_format = f"Loss: {running_loss/i:.06f} " \ f"speed: {(time.time() - t_start)/i:.4f}s/it " \ f"lr: {lr}" progress.update(task, advance=1) return running_loss / i, global_i
def evaluate(model, loss_fn, loader, epoch, steps, device, normalize=None): model.eval() status_col = TextColumn("") running_loss = 0 if normalize is not None: assert len(normalize) == 2, "mean and std values should be provided to use data normalization" fetcher = DataPrefetcher(loader, mean=normalize[0], std=normalize[1], device=device) # modified behavior - w/ input normalization else: fetcher = DataPrefetcher(loader, mean=None, std=None, device=device) # original behavior - no input normalization samples, targets = fetcher.next() with Progress("[progress.description]{task.description}", "[{task.completed}/{task.total}]", BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), TextColumn("/"), TimeElapsedColumn(), status_col, expand=False, console=CONSOLE, refresh_per_second=5) as progress: task = progress.add_task(description=f'[Eval {epoch}]', total=steps) i = 0 # counter t_start = time.time() with torch.no_grad(): while samples is not None: # forward only out = model(samples) val_loss = loss_fn(out, targets) # collect running loss running_loss += val_loss.item() i += 1 # pre-fetch next samples samples, targets = fetcher.next() if not progress.finished: status_col.text_format = f"Val loss: {running_loss/i:.06f} " \ f"speed: {(time.time() - t_start)/i:.4f}s/it" progress.update(task, advance=1) return running_loss / i
def test_on_model(args): device = args.device if device == 'cpu': raise NotImplementedError("CPU training is not implemented.") device = torch.device(args.device) torch.cuda.set_device(device) # build model model = build_model(args) model.to(device) # output dir p_out = Path( args.p_out).joinpath(f"{model.name}-{args.tensorboard_exp_name}") if not p_out.exists(): p_out.mkdir(exist_ok=True, parents=True) # dataset & loader test_dataset = MTTDataset(path=args.p_data, split='test') test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=True, drop_last=False) # not dropping last in testing test_steps = test_dataset.calc_steps( args.batch_size, drop_last=False) # not dropping last in testing LOG.info(f"Total testing steps: {test_steps}") LOG.info(f"Testing data size: {len(test_dataset)}") # create loss loss_fn = get_loss(args.loss) # create metric metric = AUCMetric() # load checkpoint OR init state_dict if args.checkpoint is not None: state_dict = load_ckpt(args.checkpoint, reset_epoch=args.ckpt_epoch, no_scheduler=args.ckpt_no_scheduler, no_optimizer=args.ckpt_no_optimizer, no_loss_fn=args.ckpt_no_loss_fn, map_values=args.ckpt_map_values) model_dict = {'model': model} if 'model' in state_dict else None apply_state_dict(state_dict, model=model_dict) best_val_loss = state_dict['val_loss'] epoch = state_dict['epoch'] global_i = state_dict['global_i'] LOG.info( f"Checkpoint loaded. Epoch trained {epoch}, global_i {global_i}, best val {best_val_loss:.6f}" ) else: raise AssertionError("Pre-trained checkpoint must be provided.") # summary writer writer = SummaryWriter(log_dir=p_out.as_posix(), filename_suffix='-test') # start testing model.eval() sigmoid = Sigmoid().to(device) status_col = TextColumn("") running_loss = 0 if args.data_normalization: fetcher = DataPrefetcher(test_loader, mean=MTT_MEAN, std=MTT_STD, device=device) else: fetcher = DataPrefetcher(test_loader, mean=None, std=None, device=device) samples, targets = fetcher.next() with Progress("[progress.description]{task.description}", "[{task.completed}/{task.total}]", BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), TextColumn("/"), TimeElapsedColumn(), status_col, expand=False, console=CONSOLE, refresh_per_second=5) as progress: task = progress.add_task(description=f'[Test]', total=test_steps) i = 0 # counter t_start = time.time() with torch.no_grad(): while samples is not None: # forward model logits = model(samples) out = sigmoid(logits) test_loss = loss_fn(logits, targets) # collect running loss running_loss += test_loss.item() i += 1 writer.add_scalar('Test/Loss', running_loss / i, i) # auc metric metric.step(targets.cpu().numpy(), out.cpu().numpy()) # pre-fetch next samples samples, targets = fetcher.next() if not progress.finished: status_col.text_format = f"Test loss: {running_loss/i:.06f}" progress.update(task, advance=1) auc_tag, auc_sample, ap_tag, ap_sample = metric.auc_ap_score LOG.info(f"Testing speed: {(time.time() - t_start)/i:.4f}s/it, " f"auc_tag: {auc_tag:.04f}, " f"auc_sample: {auc_sample:.04f}, " f"ap_tag: {ap_tag:.04f}, " f"ap_sample: {ap_sample:.04f}") writer.close() return