示例#1
0
def do_validation_step(model, input, target, target_weight=None, flip=False):
    assert not model.training, 'model must be in evaluation mode.'
    assert len(input) == len(
        target), 'input and target must contain the same number of examples.'

    # Forward pass and loss calculation.
    output = model(input)
    loss = sum(joints_mse_loss(o, target, target_weight) for o in output)

    # Get the heatmaps.
    if flip:
        # If `flip` is true, perform horizontally flipped inference as well. This should
        # result in more robust predictions at the expense of additional compute.
        flip_input = fliplr(input.clone().cpu().numpy())
        flip_input = torch.as_tensor(flip_input,
                                     dtype=torch.float32,
                                     device=device)
        flip_output = model(flip_input)
        flip_output = flip_output[-1].cpu()
        flip_output = flip_back(flip_output)
        heatmaps = (output[-1].cpu() + flip_output) / 2
    else:
        heatmaps = output[-1].cpu()

    return heatmaps, loss.item()
def do_validation_step(model,
                       input,
                       target,
                       data_info,
                       target_weight=None,
                       flip=False):
    assert not model.training, 'model must be in evaluation mode.'
    assert len(input) == len(
        target), 'input and target must contain the same number of examples.'

    # Forward pass and loss calculation.
    start = time.time()
    output = model(input)
    inference_time = (time.time() - start) * 1000
    loss = sum(joints_mse_loss(o, target, target_weight) for o in output)

    # Get the heatmaps.
    if flip:
        # If `flip` is true, perform horizontally flipped inference as well. This should
        # result in more robust predictions at the expense of additional compute.
        flip_input = fliplr(input)
        flip_output = model(flip_input)
        flip_output = flip_output[-1].cpu()
        flip_output = flip_back(flip_output.detach(), data_info.hflip_indices)
        heatmaps = (output[-1].cpu() + flip_output) / 2
    else:
        heatmaps = output[-1].cpu()

    return heatmaps, loss.item(), inference_time
示例#3
0
def do_training_step(model, optimiser, input, target, target_weight=None):
    assert model.training, 'model must be in training mode.'
    assert len(input) == len(
        target), 'input and target must contain the same number of examples.'

    with torch.enable_grad():
        # Forward pass and loss calculation.
        output = model(input)
        loss = sum(joints_mse_loss(o, target, target_weight) for o in output)

        # Backward pass and parameter update.
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

    return output[-1], loss.item()