Exemplo n.º 1
0
def train_one_epoch(model, optim, loss_fn, loader, epoch, steps, device,
                    writer, global_i, writer_interval=200, normalize=None):
    model.train()
    status_col = TextColumn("")
    running_loss = 0
    lr = optim.param_groups[0]['lr']

    if normalize is not None:
        assert len(normalize) == 2, "mean and std values should be provided to use data normalization"
        fetcher = DataPrefetcher(loader, mean=normalize[0], std=normalize[1], device=device)  # modified behavior - w/ input normalization
    else:
        fetcher = DataPrefetcher(loader, mean=None, std=None, device=device)                  # original behavior - no input normalization
    samples, targets = fetcher.next()

    with Progress("[progress.description]{task.description}",
                  "[{task.completed}/{task.total}]",
                  BarColumn(),
                  "[progress.percentage]{task.percentage:>3.0f}%",
                  TimeRemainingColumn(),
                  TextColumn("/"),
                  TimeElapsedColumn(),
                  status_col,
                  expand=False, console=CONSOLE, refresh_per_second=5) as progress:
        task = progress.add_task(description=f'[Epoch {epoch}]', total=steps)
        i = 0  # counter
        t_start = time.time()

        while samples is not None:
            # zero the parameter gradients
            optim.zero_grad()
            # forward + backward + optimize
            out = model(samples)
            loss = loss_fn(out, targets)
            loss.backward()
            optim.step()

            # collect running loss
            running_loss += loss.item()
            i += 1
            global_i += 1

            # update tensorboard
            if i % writer_interval == 0:
                writer.add_scalar('Loss/Train', running_loss/i, global_i)

            # pre-fetch next samples
            samples, targets = fetcher.next()

            # update trackbar
            if not progress.finished:
                status_col.text_format = f"Loss: {running_loss/i:.06f} " \
                                         f"speed: {(time.time() - t_start)/i:.4f}s/it " \
                                         f"lr: {lr}"
                progress.update(task, advance=1)

    return running_loss / i, global_i
Exemplo n.º 2
0
def evaluate(model, loss_fn, loader, epoch, steps, device, normalize=None):
    model.eval()
    status_col = TextColumn("")
    running_loss = 0

    if normalize is not None:
        assert len(normalize) == 2, "mean and std values should be provided to use data normalization"
        fetcher = DataPrefetcher(loader, mean=normalize[0], std=normalize[1], device=device)  # modified behavior - w/ input normalization
    else:
        fetcher = DataPrefetcher(loader, mean=None, std=None, device=device)                  # original behavior - no input normalization
    samples, targets = fetcher.next()

    with Progress("[progress.description]{task.description}",
                  "[{task.completed}/{task.total}]",
                  BarColumn(),
                  "[progress.percentage]{task.percentage:>3.0f}%",
                  TimeRemainingColumn(),
                  TextColumn("/"),
                  TimeElapsedColumn(),
                  status_col,
                  expand=False, console=CONSOLE, refresh_per_second=5) as progress:
        task = progress.add_task(description=f'[Eval  {epoch}]', total=steps)
        i = 0  # counter
        t_start = time.time()

        with torch.no_grad():
            while samples is not None:

                # forward only
                out = model(samples)
                val_loss = loss_fn(out, targets)

                # collect running loss
                running_loss += val_loss.item()
                i += 1
                # pre-fetch next samples
                samples, targets = fetcher.next()

                if not progress.finished:
                    status_col.text_format = f"Val loss: {running_loss/i:.06f} " \
                                             f"speed: {(time.time() - t_start)/i:.4f}s/it"
                    progress.update(task, advance=1)
    return running_loss / i
Exemplo n.º 3
0
def test_on_model(args):
    device = args.device
    if device == 'cpu':
        raise NotImplementedError("CPU training is not implemented.")
    device = torch.device(args.device)
    torch.cuda.set_device(device)

    # build model
    model = build_model(args)
    model.to(device)

    # output dir
    p_out = Path(
        args.p_out).joinpath(f"{model.name}-{args.tensorboard_exp_name}")
    if not p_out.exists():
        p_out.mkdir(exist_ok=True, parents=True)

    # dataset & loader
    test_dataset = MTTDataset(path=args.p_data, split='test')
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.n_workers,
                             pin_memory=True,
                             drop_last=False)  # not dropping last in testing
    test_steps = test_dataset.calc_steps(
        args.batch_size, drop_last=False)  # not dropping last in testing
    LOG.info(f"Total testing steps: {test_steps}")
    LOG.info(f"Testing data size: {len(test_dataset)}")

    # create loss
    loss_fn = get_loss(args.loss)
    # create metric
    metric = AUCMetric()

    # load checkpoint OR init state_dict
    if args.checkpoint is not None:
        state_dict = load_ckpt(args.checkpoint,
                               reset_epoch=args.ckpt_epoch,
                               no_scheduler=args.ckpt_no_scheduler,
                               no_optimizer=args.ckpt_no_optimizer,
                               no_loss_fn=args.ckpt_no_loss_fn,
                               map_values=args.ckpt_map_values)
        model_dict = {'model': model} if 'model' in state_dict else None
        apply_state_dict(state_dict, model=model_dict)
        best_val_loss = state_dict['val_loss']
        epoch = state_dict['epoch']
        global_i = state_dict['global_i']
        LOG.info(
            f"Checkpoint loaded. Epoch trained {epoch}, global_i {global_i}, best val {best_val_loss:.6f}"
        )
    else:
        raise AssertionError("Pre-trained checkpoint must be provided.")

    # summary writer
    writer = SummaryWriter(log_dir=p_out.as_posix(), filename_suffix='-test')

    # start testing
    model.eval()
    sigmoid = Sigmoid().to(device)
    status_col = TextColumn("")
    running_loss = 0
    if args.data_normalization:
        fetcher = DataPrefetcher(test_loader,
                                 mean=MTT_MEAN,
                                 std=MTT_STD,
                                 device=device)
    else:
        fetcher = DataPrefetcher(test_loader,
                                 mean=None,
                                 std=None,
                                 device=device)
    samples, targets = fetcher.next()

    with Progress("[progress.description]{task.description}",
                  "[{task.completed}/{task.total}]",
                  BarColumn(),
                  "[progress.percentage]{task.percentage:>3.0f}%",
                  TimeRemainingColumn(),
                  TextColumn("/"),
                  TimeElapsedColumn(),
                  status_col,
                  expand=False,
                  console=CONSOLE,
                  refresh_per_second=5) as progress:
        task = progress.add_task(description=f'[Test]', total=test_steps)
        i = 0  # counter
        t_start = time.time()

        with torch.no_grad():
            while samples is not None:
                # forward model
                logits = model(samples)
                out = sigmoid(logits)
                test_loss = loss_fn(logits, targets)

                # collect running loss
                running_loss += test_loss.item()
                i += 1
                writer.add_scalar('Test/Loss', running_loss / i, i)

                # auc metric
                metric.step(targets.cpu().numpy(), out.cpu().numpy())

                # pre-fetch next samples
                samples, targets = fetcher.next()

                if not progress.finished:
                    status_col.text_format = f"Test loss: {running_loss/i:.06f}"
                    progress.update(task, advance=1)

    auc_tag, auc_sample, ap_tag, ap_sample = metric.auc_ap_score
    LOG.info(f"Testing speed: {(time.time() - t_start)/i:.4f}s/it, "
             f"auc_tag: {auc_tag:.04f}, "
             f"auc_sample: {auc_sample:.04f}, "
             f"ap_tag: {ap_tag:.04f}, "
             f"ap_sample: {ap_sample:.04f}")
    writer.close()
    return