def main(): batch_size = 8 net = PSPNet(pretrained=False, num_classes=num_classes, input_size=(512, 1024)).cuda() snapshot = 'epoch_48_validation_loss_5.1326_mean_iu_0.3172_lr_0.00001000.pth' net.load_state_dict(torch.load(os.path.join(ckpt_path, snapshot))) net.eval() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) transform = transforms.Compose([ expanded_transform.FreeScale((512, 1024)), transforms.ToTensor(), transforms.Normalize(*mean_std) ]) restore = transforms.Compose([ expanded_transform.DeNormalize(*mean_std), transforms.ToPILImage() ]) lsun_path = '/home/b3-542/LSUN' dataset = LSUN(lsun_path, ['tower_val', 'church_outdoor_val', 'bridge_val'], transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=16, shuffle=True) if not os.path.exists(test_results_path): os.mkdir(test_results_path) for vi, data in enumerate(dataloader, 0): inputs, labels = data inputs = Variable(inputs, volatile=True).cuda() outputs = net(inputs) prediction = outputs.cpu().data.max(1)[1].squeeze_(1).numpy() for idx, tensor in enumerate(zip(inputs.cpu().data, prediction)): pil_input = restore(tensor[0]) pil_output = colorize_mask(tensor[1]) pil_input.save(os.path.join(test_results_path, '%d_img.png' % (vi * batch_size + idx))) pil_output.save(os.path.join(test_results_path, '%d_out.png' % (vi * batch_size + idx))) print 'save the #%d batch, %d images' % (vi + 1, idx + 1)
def validate(val_loader, net, criterion, optimizer, epoch, restore): net.eval() criterion.cpu() input_batches = [] output_batches = [] label_batches = [] for vi, data in enumerate(val_loader, 0): inputs, labels = data inputs = Variable(inputs, volatile=True).cuda() labels = Variable(labels, volatile=True).cuda() outputs = net(inputs) input_batches.append(inputs.cpu().data) output_batches.append(outputs.cpu()) label_batches.append(labels.cpu()) input_batches = torch.cat(input_batches) output_batches = torch.cat(output_batches) label_batches = torch.cat(label_batches) val_loss = criterion(output_batches, label_batches) val_loss = val_loss.data[0] output_batches = output_batches.cpu().data[:, :num_classes - 1, :, :] label_batches = label_batches.cpu().data.numpy() prediction_batches = output_batches.max(1)[1].squeeze_(1).numpy() mean_iu = calculate_mean_iu(prediction_batches, label_batches, num_classes) writer.add_scalar('loss', val_loss, epoch + 1) writer.add_scalar('mean_iu', mean_iu, epoch + 1) if val_loss < train_record['best_val_loss']: train_record['best_val_loss'] = val_loss train_record['corr_epoch'] = epoch + 1 train_record['corr_mean_iu'] = mean_iu snapshot_name = 'epoch_%d_loss_%.4f_mean_iu_%.4f_lr_%.8f' % ( epoch + 1, val_loss, mean_iu, train_args['new_lr']) torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth')) torch.save( optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth')) with open(exp_name + '.txt', 'a') as f: f.write(snapshot_name + '\n') to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch + 1)) rmrf_mkdir(to_save_dir) x = [] for idx, tensor in enumerate( zip(input_batches, prediction_batches, label_batches)): if random.random() > val_args['img_sample_rate']: continue pil_input = restore(tensor[0]) pil_output = colorize_mask(tensor[1]) pil_label = colorize_mask(tensor[2]) pil_input.save(os.path.join(to_save_dir, '%d_img.png' % idx)) pil_output.save(os.path.join(to_save_dir, '%d_out.png' % idx)) pil_label.save(os.path.join(to_save_dir, '%d_label.png' % idx)) x.extend([ pil_to_tensor(pil_input.convert('RGB')), pil_to_tensor(pil_label.convert('RGB')), pil_to_tensor(pil_output.convert('RGB')) ]) x = torch.stack(x, 0) x = vutils.make_grid(x, nrow=3, padding=5) writer.add_image(snapshot_name, x) print '--------------------------------------------------------' print '[val loss %.4f], [mean iu %.4f]' % (val_loss, mean_iu) print '[best val loss %.4f], [mean iu %.4f], [epoch %d]' % ( train_record['best_val_loss'], train_record['corr_mean_iu'], train_record['corr_epoch']) print '--------------------------------------------------------' net.train() criterion.cuda()