Esempio n. 1
0
def main():
    batch_size = 8

    net = PSPNet(pretrained=False, num_classes=num_classes, input_size=(512, 1024)).cuda()
    snapshot = 'epoch_48_validation_loss_5.1326_mean_iu_0.3172_lr_0.00001000.pth'
    net.load_state_dict(torch.load(os.path.join(ckpt_path, snapshot)))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    transform = transforms.Compose([
        expanded_transform.FreeScale((512, 1024)),
        transforms.ToTensor(),
        transforms.Normalize(*mean_std)
    ])
    restore = transforms.Compose([
        expanded_transform.DeNormalize(*mean_std),
        transforms.ToPILImage()
    ])

    lsun_path = '/home/b3-542/LSUN'

    dataset = LSUN(lsun_path, ['tower_val', 'church_outdoor_val', 'bridge_val'], transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=16, shuffle=True)

    if not os.path.exists(test_results_path):
        os.mkdir(test_results_path)

    for vi, data in enumerate(dataloader, 0):
        inputs, labels = data
        inputs = Variable(inputs, volatile=True).cuda()
        outputs = net(inputs)

        prediction = outputs.cpu().data.max(1)[1].squeeze_(1).numpy()

        for idx, tensor in enumerate(zip(inputs.cpu().data, prediction)):
            pil_input = restore(tensor[0])
            pil_output = colorize_mask(tensor[1])
            pil_input.save(os.path.join(test_results_path, '%d_img.png' % (vi * batch_size + idx)))
            pil_output.save(os.path.join(test_results_path, '%d_out.png' % (vi * batch_size + idx)))
            print 'save the #%d batch, %d images' % (vi + 1, idx + 1)
Esempio n. 2
0
def validate(val_loader, net, criterion, optimizer, epoch, restore):
    net.eval()
    criterion.cpu()
    input_batches = []
    output_batches = []
    label_batches = []

    for vi, data in enumerate(val_loader, 0):
        inputs, labels = data
        inputs = Variable(inputs, volatile=True).cuda()
        labels = Variable(labels, volatile=True).cuda()

        outputs = net(inputs)

        input_batches.append(inputs.cpu().data)
        output_batches.append(outputs.cpu())
        label_batches.append(labels.cpu())

    input_batches = torch.cat(input_batches)
    output_batches = torch.cat(output_batches)
    label_batches = torch.cat(label_batches)
    val_loss = criterion(output_batches, label_batches)
    val_loss = val_loss.data[0]

    output_batches = output_batches.cpu().data[:, :num_classes - 1, :, :]
    label_batches = label_batches.cpu().data.numpy()
    prediction_batches = output_batches.max(1)[1].squeeze_(1).numpy()

    mean_iu = calculate_mean_iu(prediction_batches, label_batches, num_classes)

    writer.add_scalar('loss', val_loss, epoch + 1)
    writer.add_scalar('mean_iu', mean_iu, epoch + 1)

    if val_loss < train_record['best_val_loss']:
        train_record['best_val_loss'] = val_loss
        train_record['corr_epoch'] = epoch + 1
        train_record['corr_mean_iu'] = mean_iu
        snapshot_name = 'epoch_%d_loss_%.4f_mean_iu_%.4f_lr_%.8f' % (
            epoch + 1, val_loss, mean_iu, train_args['new_lr'])
        torch.save(net.state_dict(),
                   os.path.join(ckpt_path, exp_name, snapshot_name + '.pth'))
        torch.save(
            optimizer.state_dict(),
            os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth'))

        with open(exp_name + '.txt', 'a') as f:
            f.write(snapshot_name + '\n')

        to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch + 1))
        rmrf_mkdir(to_save_dir)

        x = []
        for idx, tensor in enumerate(
                zip(input_batches, prediction_batches, label_batches)):
            if random.random() > val_args['img_sample_rate']:
                continue
            pil_input = restore(tensor[0])
            pil_output = colorize_mask(tensor[1])
            pil_label = colorize_mask(tensor[2])
            pil_input.save(os.path.join(to_save_dir, '%d_img.png' % idx))
            pil_output.save(os.path.join(to_save_dir, '%d_out.png' % idx))
            pil_label.save(os.path.join(to_save_dir, '%d_label.png' % idx))
            x.extend([
                pil_to_tensor(pil_input.convert('RGB')),
                pil_to_tensor(pil_label.convert('RGB')),
                pil_to_tensor(pil_output.convert('RGB'))
            ])
        x = torch.stack(x, 0)
        x = vutils.make_grid(x, nrow=3, padding=5)
        writer.add_image(snapshot_name, x)

    print '--------------------------------------------------------'
    print '[val loss %.4f], [mean iu %.4f]' % (val_loss, mean_iu)
    print '[best val loss %.4f], [mean iu %.4f], [epoch %d]' % (
        train_record['best_val_loss'], train_record['corr_mean_iu'],
        train_record['corr_epoch'])
    print '--------------------------------------------------------'

    net.train()
    criterion.cuda()