Пример #1
0
def main():
    """Create the model and start the training."""

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    logger = util.set_logger(args.snapshot_dir, args.log_file, args.debug)
    logger.info('start with arguments %s', args)

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    lscale, hscale = map(float, args.train_scale.split(','))
    train_scale = (lscale, hscale)

    cudnn.enabled = True

    # Create network.
    model = Res_Deeplab(num_classes=args.num_classes)

    #saved_state_dict = torch.load(args.restore_from)
    #new_params = model.state_dict().copy()
    #for i in saved_state_dict:
    #    #Scale.layer5.conv2d_list.3.weight
    #    i_parts = i.split('.')
    #    # print i_parts
    #    if not args.num_classes == 21 or not i_parts[1]=='layer5':
    #        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    model_urls = {
        'resnet18':
        'https://download.pytorch.org/models/resnet18-5c106cde.pth',
        'resnet34':
        'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
        'resnet50':
        'https://download.pytorch.org/models/resnet50-19c8e357.pth',
        'resnet101':
        'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
        'resnet152':
        'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    }
    saved_state_dict = torch.utils.model_zoo.load_url(model_urls['resnet101'])
    # coco pretrained parameters:
    # saved_state_dict = torch.load(args.restore_from)
    new_params = model.state_dict().copy()
    for i in saved_state_dict:
        #Scale.layer5.conv2d_list.3.weight
        i_parts = str(i).split('.')
        # print i_parts
        if not i_parts[0] == 'fc':
            new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
    model.load_state_dict(new_params)
    #model.float()
    model.eval()  # use_global_stats = True
    #model.train()
    device = torch.device("cuda:" + str(args.gpu))
    model.to(device)

    cudnn.benchmark = True

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.batch_size,
                                              crop_size=input_size,
                                              train_scale=train_scale,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN,
                                              std=IMG_STD),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=5,
                                  pin_memory=args.pin_memory)
    optimizer = optim.SGD([{
        'params': get_1x_lr_params_NOscale(model),
        'lr': args.learning_rate
    }, {
        'params': get_10x_lr_params(model),
        'lr': 10 * args.learning_rate
    }],
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    interp = nn.Upsample(size=input_size, mode='bilinear', align_corners=True)

    for i_iter, batch in enumerate(trainloader):
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        pred = interp(model(images))
        loss = loss_calc(pred, labels)
        loss.backward()
        optimizer.step()

        # print('iter = ', i_iter, 'of', args.num_steps,'completed, loss = ', loss.data.cpu().numpy())
        logger.info('iter = {} of {} completed, loss = {:.4f}'.format(
            i_iter, args.num_steps,
            loss.data.cpu().numpy()))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC12_scenes_' + str(args.num_steps) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC12_scenes_' + str(i_iter) + '.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Пример #2
0
def main():
    """Create the model and start the evaluation process."""
    device = torch.device("cuda:" + str(args.gpu))

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    logger = util.set_logger(args.save, args.log_file, args.debug)
    logger.info('start with arguments %s', args)

    x_num = 0

    with open(args.data_list) as f:
        for _ in f.readlines():
            x_num = x_num + 1

    sys.path.insert(0, 'dataset/helpers')
    if args.data_src == 'gta' or args.data_src == 'cityscapes':
        from labels import id2label, trainId2label
    elif args.data_src == 'synthia':
        from labels_cityscapes_synthia import id2label, trainId2label
    #
    label_2_id = 255 * np.ones((256, ))
    for l in id2label:
        if l in (-1, 255):
            continue
        label_2_id[l] = id2label[l].trainId
    id_2_label = np.array(
        [trainId2label[_].id for _ in trainId2label if _ not in (-1, 255)])
    valid_labels = sorted(set(id_2_label.ravel()))
    scorer = ScoreUpdater(valid_labels, args.num_classes, x_num, logger)
    scorer.reset()

    if args.model == 'DeeplabRes':
        model = Res_Deeplab(num_classes=args.num_classes)
    # elif args.model == 'DeeplabVGG':
    #     model = DeeplabVGG(num_classes=args.num_classes)
    #     if args.restore_from == RESTORE_FROM:
    #         args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = str(i).split('.')
            # print i_parts
            if not i_parts[0] == 'fc':
                new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
    else:
        loc = "cuda:" + str(args.gpu)
        saved_state_dict = torch.load(args.restore_from, map_location=loc)
        new_params = saved_state_dict.copy()
    model.load_state_dict(new_params)
    #model.train()
    model.eval()
    model.to(device)

    testloader = data.DataLoader(GTA5TestDataSet(args.data_dir,
                                                 args.data_list,
                                                 test_scale=1.0,
                                                 test_size=(1024, 512),
                                                 mean=IMG_MEAN,
                                                 std=IMG_STD,
                                                 scale=False,
                                                 mirror=False),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    test_scales = [float(_) for _ in str(args.test_scale).split(',')]

    h, w = map(int, args.test_image_size.split(','))
    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(h, w), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(h, w), mode='bilinear')

    test_image_size = (h, w)
    mean_rgb = IMG_MEAN[::-1].copy()
    std_rgb = IMG_STD[::-1].copy()
    with torch.no_grad():
        for index, batch in enumerate(testloader):
            image, label, _, name = batch
            img = image.clone()
            num_scales = len(test_scales)
            # output_dict = {k: [] for k in range(num_scales)}
            for scale_idx in range(num_scales):
                if version.parse(torch.__version__) > version.parse('0.4.0'):
                    image = F.interpolate(image,
                                          scale_factor=test_scales[scale_idx],
                                          mode='bilinear',
                                          align_corners=True)
                else:
                    test_size = (int(h * test_scales[scale_idx]),
                                 int(w * test_scales[scale_idx]))
                    interp_tmp = nn.Upsample(size=test_size,
                                             mode='bilinear',
                                             align_corners=True)
                    image = interp_tmp(img)
                if args.model == 'DeeplabRes':
                    output2 = model(image.to(device))
                    coutput = interp(output2).cpu().data[0].numpy()
                if args.test_flipping:
                    output2 = model(
                        torch.from_numpy(
                            image.numpy()[:, :, :, ::-1].copy()).to(device))
                    coutput = 0.5 * (
                        coutput +
                        interp(output2).cpu().data[0].numpy()[:, :, ::-1])
                if scale_idx == 0:
                    output = coutput.copy()
                else:
                    output += coutput

            output = output / num_scales
            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
            pred_label = output.copy()
            label = label_2_id[np.asarray(label.numpy(), dtype=np.uint8)]
            scorer.update(pred_label.flatten(), label.flatten(), index)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            output.save('%s/%s' % (args.save, name))
            output_col.save('%s/%s_color.png' %
                            (args.save, name.split('.')[0]))