def epoch_pass(
    loader,
    model,
    train=False,
    optimizer=None,
    scheduler=None,
    epoch=0,
    img_folder=None,
    fig=None,
    display_freq=10,
    epoch_display_freq=1,
    lr_decay_gamma=0,
    loader_nb=2,
    freeze_batchnorm=True,
    premodel=None,
):
    if train:
        prefix = "train"
    else:
        prefix = "val"
    evaluators = {
        # "joints2d_trans": EvalUtil(),
        "joints2d_base": EvalUtil(),
        "corners2d_base": EvalUtil(),
        "verts2d_base": EvalUtil(),
        "joints3d_cent": EvalUtil(),
        "joints3d": EvalUtil(),
    }
    consist_evaluators = {
        # "joints2d_trans": EvalUtil(),
        "joints2d_base": EvalUtil(),
        "corners2d_base": EvalUtil(),
        "verts2d_base": EvalUtil(),
        "joints3d_cent": EvalUtil(),
        "joints3d": EvalUtil(),
    }
    avg_meters = AverageMeters()
    consist_avg_meters = AverageMeters()
    if train and not freeze_batchnorm:
        model.train()
    else:
        model.eval()
    for batch_idx, batch in enumerate(tqdm(loader)):
        if batch_idx % loader_nb == 0:
            losses = []
        loss, all_losses, results, pair_results = premodel.forward(batch)
        losses.append(loss.flatten())
        if train and ((batch_idx % loader_nb) == (loader_nb - 1)):
            if torch.isnan(loss):
                raise ValueError(f"Loss made of {all_losses} became nan!")
            optimizer.zero_grad()
            loss = torch.stack(losses).sum()
            if loss.requires_grad:
                loss.backward()
                optimizer.step()
        if "data" in batch["supervision"]:
            for loss_name, loss_val in all_losses.items():
                if loss_val is not None:
                    avg_meters.add_loss_value(loss_name, loss_val.mean().item())
            for sample, res in zip(batch["data"], results):
                evaluate.feed_avg_meters(avg_meters, sample, res)
                evaluate.feed_evaluators(evaluators, sample, res)
        elif "consist" in batch["supervision"]:
            for loss_name, loss_val in all_losses.items():
                if loss_val is not None:
                    consist_avg_meters.add_loss_value(loss_name, loss_val.mean().item())
            # Only compute metrics for weakly supervised frames (e.g. the first in the sequence)
            evaluate.feed_avg_meters(consist_avg_meters, batch["data"][0], results[0])
            evaluate.feed_evaluators(consist_evaluators, batch["data"][0], results[0])
        else:
            raise ValueError(f"Supervision {batch['supervision']} not in [data|consist]")
        if (batch_idx % display_freq) < loader_nb and epoch % epoch_display_freq == 0:
            img_prefix = f"{prefix}_epoch{epoch:04d}_batch{batch_idx:06d}"
            save_img_path = os.path.join(img_folder, img_prefix)
            warpvis.sample_vis(batch, results, pair_results, fig=fig, save_img_prefix=save_img_path)
    if lr_decay_gamma and scheduler is not None:
        scheduler.step()
    save_dict = {}
    for loss_name, avg_meter in avg_meters.average_meters.items():
        save_dict[loss_name] = {}
        loss_val = avg_meter.avg
        save_dict[loss_name][prefix] = loss_val
    for loss_name, avg_meter in consist_avg_meters.average_meters.items():
        if loss_name not in save_dict:
            save_dict[loss_name] = {}
        loss_val = avg_meter.avg
        save_dict[loss_name]["consist"] = loss_val
    evaluator_results = evaluate.parse_evaluators(evaluators)
    show_metrics = ["epe_mean"]  # "auc"
    for eval_name, eval_res in evaluator_results.items():
        for met in show_metrics:
            loss_name = f"{eval_name}_{met}"
            # Filter nans
            if eval_res[met] == eval_res[met]:
                save_dict[loss_name] = {}
                save_dict[loss_name][prefix] = eval_res[met]
    consist_evaluator_results = evaluate.parse_evaluators(consist_evaluators)
    for eval_name, eval_res in consist_evaluator_results.items():
        for met in show_metrics:
            loss_name = f"{eval_name}_{met}"
            # Filter nans
            if eval_res[met] == eval_res[met]:
                if loss_name not in save_dict:
                    save_dict[loss_name] = {}
                save_dict[loss_name]["consist"] = eval_res[met]
    img_filepath = f"{prefix}_epoch{epoch:04d}_eval.png"
    save_img_path = os.path.join(img_folder, img_filepath)
    # Filter out Nan pck curves
    evaluator_results = {
        eval_name: res for eval_name, res in evaluator_results.items() if res["epe_mean"] == res["epe_mean"]
    }
    evalvis.eval_vis(evaluator_results, save_img_path, fig=fig)
    pickle_path = save_img_path.replace(".png", ".pkl")
    with open(pickle_path, "wb") as p_f:
        pickle.dump(evaluator_results, p_f)
    return save_dict, avg_meters, evaluator_results
Beispiel #2
0
def epoch_pass(
    loader,
    model,
    train=False,
    optimizer=None,
    scheduler=None,
    epoch=0,
    img_folder=None,
    fig=None,
    display_freq=10,
    epoch_display_freq=1,
    lr_decay_gamma=0,
    freeze_batchnorm=True,
):
    if train:
        prefix = "train"
    else:
        prefix = "val"
    evaluators = {
        # "joints2d_trans": EvalUtil(),
        "joints2d_base": EvalUtil(),
        "corners2d_base": EvalUtil(),
        "verts2d_base": EvalUtil(),
        "joints3d_cent": EvalUtil(),
        "joints3d": EvalUtil(),
    }
    if train and not freeze_batchnorm:
        model.train()
    else:
        model.eval()
    avg_meters = AverageMeters()
    render_step = 0
    # Loop over dataset
    for batch_idx, batch in enumerate(tqdm(loader)):
        # Compute outputs and losses
        if train:
            loss, results, losses = model(batch)
        else:
            with torch.no_grad():
                loss, results, losses = model(batch)
        # Optimize model if needed
        if train:
            if torch.isnan(loss):
                raise ValueError(f"Loss made of {losses} became nan!")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        for loss_name, loss_val in losses.items():
            if loss_val is not None:
                avg_meters.add_loss_value(loss_name, loss_val.mean().item())
        evaluate.feed_avg_meters(avg_meters, batch, results)

        # Visualize outputs
        if batch_idx % display_freq == 0 and epoch % epoch_display_freq == 0:
            img_filepath = f"{prefix}_epoch{epoch:04d}_batch{batch_idx:06d}.png"
            save_img_path = os.path.join(img_folder, img_filepath)
            samplevis.sample_vis(batch,
                                 results,
                                 fig=fig,
                                 save_img_path=save_img_path)
        evaluate.feed_evaluators(evaluators, batch, results)
    if lr_decay_gamma and scheduler is not None:
        scheduler.step()
    save_dict = {}
    for loss_name, avg_meter in avg_meters.average_meters.items():
        save_dict[loss_name] = {}
        loss_val = avg_meter.avg
        save_dict[loss_name][prefix] = loss_val
    evaluator_results = evaluate.parse_evaluators(evaluators)
    for eval_name, eval_res in evaluator_results.items():
        for met in ["epe_mean", "auc"]:
            loss_name = f"{eval_name}_{met}"
            # Filter nans
            if eval_res[met] == eval_res[met]:
                save_dict[loss_name] = {}
                save_dict[loss_name][prefix] = eval_res[met]
    img_filepath = f"{prefix}_epoch{epoch:04d}_eval.png"
    save_img_path = os.path.join(img_folder, img_filepath)
    # Filter out Nan pck curves
    evaluator_results = {
        eval_name: res
        for eval_name, res in evaluator_results.items()
        if res["epe_mean"] == res["epe_mean"]
    }
    evalvis.eval_vis(evaluator_results, save_img_path, fig=fig)
    pickle_path = save_img_path.replace(".png", ".pkl")
    with open(pickle_path, "wb") as p_f:
        pickle.dump(evaluator_results, p_f)
    return save_dict, avg_meters, evaluator_results