示例#1
0
def main(args):
    train_loader, val_loader = load_data(args)
    # args.weight_labels = torch.tensor(calculate_weigths_labels('cityscape', train_loader, args.n_classes)).float()
    if args.cuda:
        # args.weight_labels = args.weight_labels.cuda()
        pass

    model = PSPNet()
    if args.cuda:
        model = model.cuda()
    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                    output_device=args.local_rank,
                                                    find_unused_parameters=True)

    if args.evaluation:
        checkpoint = torch.load('./checkpoint/checkpoint_model_50000.pth', map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        eval(val_loader, model, args)
    else:
        # criterion = SegmentationLosses(weight=args.weight_labels, cuda=True)
        criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_mask, weight=None).cuda()

        backbone_params = nn.ParameterList()
        decoder_params = nn.ParameterList()

        for name, param in model.named_parameters():
            if 'backbone' in name:
                backbone_params.append(param)
            else:
                decoder_params.append(param)

        params_list = [{'params': backbone_params},
                       {'params': decoder_params, 'lr': args.lr * 10}]

        optimizer = optim.SGD(params_list,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay,
                              nesterov=True)
        scheduler = PolyLr(optimizer, gamma=args.gamma,
                           max_iteration=args.max_iteration,
                           warmup_iteration=args.warmup_iteration)

        global_step = 0
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        while global_step < args.max_iteration:
            global_step = train(global_step, train_loader, model, optimizer, criterion, scheduler, args)
        eval(val_loader, model, args)
示例#2
0
import cv2
from model import PSPNet
import torch
from data import make_datapath_list, DataTransform
from PIL import Image
import numpy as np

# load trained model
net = PSPNet(n_classes=21)
#state_dict = torch.load('./pspnet50_49.pth', map_location={'cuda:0': 'cpu'}) #for cpu
state_dict = torch.load('./pspnet50_49.pth',
                        map_location={'cuda:0': 'cuda: 0'})  #for gpu
net.load_state_dict(state_dict)
net.eval()

# 1500 images for training
# 1500 images for validation
rootpath = "./data/VOCdevkit/VOC2012/"
train_img_list, train_anno_list, val_img_list, val_anno_list = make_datapath_list(
    rootpath)
color_mean = (0.485, 0.456, 0.406)
color_std = (0.229, 0.224, 0.225)

anno_file_path = val_anno_list[1]
anno_class_img = Image.open(anno_file_path)
p_palatte = anno_class_img.getpalette()

transform = DataTransform(input_size=475,
                          color_mean=color_mean,
                          color_std=color_std)
示例#3
0
import cv2
import matplotlib.pyplot as plt
from PIL import Image

from model import PSPNet
from dataset import MyDataset
from config import test_parser
from visualize import unnormalize_show, PIL_show

if __name__ == '__main__':
    arg = test_parser()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = PSPNet(n_classes=arg.n_classes)
    if arg.path is not None:
        print('load model ...')
        model.load_state_dict(torch.load(arg.path, map_location=device))

    model = model.to(device)
    test_dataset = MyDataset(img_dir=arg.img_dir,
                             anno_dir=arg.anno_dir,
                             phase='test')
    n_test_img = len(test_dataset)
    test_loader = DataLoader(test_dataset,
                             batch_size=arg.batch,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=arg.num_workers)
    model.eval()
    with torch.no_grad():
        img, anno_path_list, p_palette = iter(test_loader).next()
        # img.size(): (batch, 3, 475, 475), len(anno_path_list) = batch
示例#4
0
	best_acc = 0

	# LOSS
	# -----------------------------
	if loss_type== 'bce':
		criterion = nn.BCELoss()
	else:
		criterion = nn.MSELoss()
	optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
	print("model and criterion loaded ...")


	
	checkpoint = torch.load('./saved_models/model_arr_20190215_19_0.001_64_mse.ckpt')
	model.load_state_dict(checkpoint['state_dict'])
	optimizer.load_state_dict(checkpoint['optimizer'])
	epoch = checkpoint['epoch']

	model.eval()
	with torch.no_grad():
		count = 0
		total = 0
		prec_count = 0
		blankz_count = 0
		for i in range(num_val_iters-1):
			# x, y = read_csv_batch('../../data/final_plans/final_val.csv', i, val_batch_size)
			x,y = read_batch(rows_val, i, val_batch_size)
			x = torch.tensor(x).to(device).float()
			y = torch.tensor(y).to(device)
			y_rs = y.view(y.size(0), 8, 16)