Beispiel #1
0
def do_validation_pass(epoch, model, tel, loader):
    vis_images = None

    model.eval()
    with torch.no_grad():
        for batch in progress_iter(loader, 'Validation'):
            in_var = batch['input'].to(global_opts['device'], torch.float32)
            target_var = batch['target'].to(global_opts['device'],
                                            torch.float32)
            mask_var = batch['joint_mask'].to(global_opts['device'],
                                              torch.float32)

            # Calculate predictions and loss
            out_var = model(in_var)
            loss = forward_loss(model, out_var, target_var, mask_var,
                                batch['valid_depth'])
            tel['val_loss'].add(loss.sum().item())

            calculate_performance_metrics(
                batch, loader.dataset,
                ensure_homogeneous(out_var.to(CPU, torch.float64).detach(),
                                   d=3), tel['val_mpjpe'], tel['val_pck'])

            if vis_images is None:
                preds = out_var.to(CPU, torch.float64).detach()
                vis_images = visualise_predictions(preds, batch,
                                                   loader.dataset)

    tel['val_examples'].set_value(vis_images[:8])
Beispiel #2
0
def do_training_pass(epoch, model, tel, loader, scheduler, on_progress):
    if hasattr(scheduler, 'step'):
        scheduler.step(epoch)
    optimiser = scheduler.optimizer

    vis_images = None
    samples_processed = 0

    model.train()
    for batch in generator_timer(progress_iter(loader, 'Training'),
                                 tel['data_load_time']):
        if hasattr(scheduler, 'batch_step'):
            scheduler.batch_step()

        with timer(tel['data_transfer_time']):
            in_var = batch['input'].to(global_opts['device'], torch.float32)
            target_var = batch['target'].to(global_opts['device'],
                                            torch.float32)
            mask_var = batch['joint_mask'].to(global_opts['device'],
                                              torch.float32)

        # Calculate predictions and loss
        with timer(tel['forward_time']):
            out_var = model(in_var)
            loss = forward_loss(model, out_var, target_var, mask_var,
                                batch['valid_depth'])
            tel['train_loss'].add(loss.sum().item())

        # Calculate accuracy metrics
        with timer(tel['eval_time']):
            calculate_performance_metrics(
                batch, loader.dataset,
                ensure_homogeneous(out_var.to(CPU, torch.float64).detach(),
                                   d=3), tel['train_mpjpe'], tel['train_pck'])

        # Calculate gradients
        with timer(tel['backward_time']):
            optimiser.zero_grad()
            loss.backward()

        # Update parameters
        with timer(tel['optim_time']):
            optimiser.step()

        # Update progress
        samples_processed += len(batch['input'])
        on_progress(samples_processed)

        if vis_images is None:
            preds = out_var.to(CPU, torch.float64).detach()
            vis_images = visualise_predictions(preds, batch, loader.dataset)

    tel['train_examples'].set_value(vis_images[:8])