def main():

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    if not os.path.exists(args.test_debug_vis_dir):
        os.makedirs(args.test_debug_vis_dir)

    model = SegNet(model='resnet50')
    model.load_state_dict(torch.load(args.snapshot_dir + '150000.pth'))

    # freeze bn statics
    model.eval()
    model.cuda()

    dataloader = DataLoader(SegDataset(mode='test'),
                            batch_size=1,
                            shuffle=False,
                            num_workers=4)

    for i_iter, batch_data in enumerate(dataloader):

        Input_image, vis_image, gt_mask, weight_matrix, dataset_length, image_name = batch_data

        pred_mask = model(Input_image.cuda())

        print('i_iter/total {}/{}'.format(\
               i_iter, int(dataset_length[0].data)))

        if not os.path.exists(args.test_debug_vis_dir +
                              image_name[0].split('/')[0]):
            os.makedirs(args.test_debug_vis_dir + image_name[0].split('/')[0])

        vis_pred_result(vis_image, gt_mask, pred_mask,
                        args.test_debug_vis_dir + image_name[0] + '.png')
Beispiel #2
0
            loss_n=cost[6],
            ds=dist_loss_save[0].val,
            dd=dist_loss_save[1].val,
            dn=dist_loss_save[2].val,
        )
        bar.next()
    bar.finish()

    scheduler.step()

    loss_index = (avg_cost[index, 0] + avg_cost[index, 3] +
                  avg_cost[index, 6]) / 3.0
    isbest = loss_index < best_loss

    # evaluating test data
    model.eval()
    with torch.no_grad():  # operations inside don't track history
        nyuv2_test_dataset = iter(nyuv2_test_loader)
        for k in range(test_batch):
            test_data, test_label, test_depth, test_normal = nyuv2_test_dataset.next(
            )
            test_data, test_label = test_data.cuda(), test_label.type(
                torch.LongTensor).cuda()
            test_depth, test_normal = test_depth.cuda(), test_normal.cuda()

            test_pred, _, _ = model(test_data)
            test_loss = model.model_fit(test_pred[0], test_label, test_pred[1],
                                        test_depth, test_pred[2], test_normal)

            cost[12] = test_loss[0].item()
            cost[13] = model.compute_miou(test_pred[0], test_label).item()