Exemplo n.º 1
0
def run():
    statistical_angular_errors = StatisticalValue()
    sub_dir = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    print('Test start.')

    with torch.no_grad():
        for idx, (images, labels, names) in enumerate(testloader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            predictions = model(images)

            angular_error = multi_angular_loss(predictions[-1], labels)
            statistical_angular_errors.update(angular_error.item(),
                                              names,
                                              sort=True)

            view_data = torch.zeros((4, *images.shape[1:]))
            view_data[0, :, :, :] = images.squeeze()
            view_data[
                1, :, :, :] = images.squeeze() / predictions[-1].squeeze()
            view_data[2, :, :, :] = predictions[-1].squeeze()
            view_data[3, :, :, :] = labels.squeeze()

            if not os.path.isdir(os.path.join(TMP_ROOT, 'test', sub_dir)):
                os.makedirs(os.path.join(TMP_ROOT, 'test', sub_dir))

            torchvision.utils.save_image(
                view_data,
                os.path.join(TMP_ROOT, 'test/%s/%s' % (sub_dir, names[0])))

            print(
                'Angular Error: mean: {errors.avg}, mid: {errors.mid}, worst: {errors.max[0]}, best: {errors.min[0]}'
                .format(errors=statistical_angular_errors))

    print('Test end.')
Exemplo n.º 2
0
def run(epoch):
    statistical_losses = StatisticalValue()
    statistical_angular_errors = StatisticalValue()

    optimizer.zero_grad()

    for idx, (images, labels, names) in enumerate(trainloader):
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        predictions = model(images)

        loss = torch.zeros(1).to(DEVICE)

        for p in predictions:
            loss += criterion(p, labels) / ITERATION_SIZE

        loss.backward()

        angular_error = multi_angular_loss(predictions[-1], labels)
        statistical_angular_errors.update(angular_error.item(), names)
        statistical_losses.update(loss.item())

        iteration_writer.add_scalar(
            'Loss/Iteration',
            statistical_losses.val[0],
            (epoch - 1) * len(trainloader) + idx + 1
        )

        if (idx + 1) % ITERATION_SIZE == 0:
            optimizer.step()
            optimizer.zero_grad()

            print_training_status(epoch, idx + 1, len(trainloader),
                                  statistical_losses.val[0], statistical_losses.avg)
            print(
                'Angular Error: mean: {errors.avg}, worst: {errors.max[0]}, best: {errors.min[0]}'.format(
                    errors=statistical_angular_errors))

    scheduler.step()

    return statistical_losses