예제 #1
0
def make_data_loader(args, verbose=True):

	if args.dataset == 'pascal':
		train_set = pascal.VOCSegmentation(args, split='train_aug', csplit='seen', verbose=verbose)
		val_set = pascal.VOCSegmentation(args, split='val_aug', csplit=args.test_set, verbose=verbose)
		test_set = pascal.VOCSegmentation(args, split='test', csplit=args.test_set, verbose=verbose)

		train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
		val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False)
		test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)
		
		return train_loader, val_loader, test_loader, {'train': train_set.NUM_CLASSES, 'val': val_set.NUM_CLASSES, 'test': test_set.NUM_CLASSES}
	else:
		print("Dataloader for {} is not implemented".format(args.dataset))
		raise NotImplementedError
예제 #2
0
def eval_model(net, save_dir, batch_size=10):
    # Setting parameters
    relax_crop = 50  # Enlarge the bounding box by relax_crop pixels
    zero_pad_crop = True  # Insert zero padding when cropping the image

    net.eval()
    composed_transforms_ts = transforms.Compose([
        tr.CropFromMask(crop_elems=('image', 'gt'),
                        relax=relax_crop,
                        zero_pad=zero_pad_crop),
        tr.FixedResize(resolutions={
            'gt': None,
            'crop_image': (512, 512),
            'crop_gt': (512, 512)
        }),
        tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'),
        tr.ToImage(norm_elem='extreme_points'),
        tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
        tr.ToTensor()
    ])
    db_test = pascal.VOCSegmentation(split='val',
                                     transform=composed_transforms_ts,
                                     retname=True)
    testloader = DataLoader(db_test,
                            batch_size=1,
                            shuffle=False,
                            num_workers=2)

    save_dir.mkdir(exist_ok=True)

    with torch.no_grad():
        test(net, testloader, save_dir)
예제 #3
0
    composed_transforms_tr = transforms.Compose([
        tr.RandomSized(512),
        tr.RandomRotate(15),
        tr.RandomHorizontalFlip(),
        tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        tr.ToTensor()
    ])

    composed_transforms_ts = transforms.Compose([
        tr.FixedResize(size=(512, 512)),
        tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        tr.ToTensor()
    ])

    voc_train = pascal.VOCSegmentation(split='train',
                                       transform=composed_transforms_tr)
    voc_val = pascal.VOCSegmentation(split='val',
                                     transform=composed_transforms_ts)

    if use_sbd:
        print("Using SBD dataset")
        sbd_train = sbd.SBDSegmentation(split=['train', 'val'],
                                        transform=composed_transforms_tr)
        db_train = combine_dbs.CombineDBs([voc_train, sbd_train],
                                          excluded=[voc_val])
    else:
        db_train = voc_train

    trainloader = DataLoader(db_train,
                             batch_size=p['trainBatch'],
                             shuffle=True,
예제 #4
0
        tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
        tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
        tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}),
        tr.ExtremePoints(sigma=10, pert=5, elem='crop_gt'),
        tr.ToImage(norm_elem='extreme_points'),
        tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
        tr.ToTensor()])
    composed_transforms_ts = transforms.Compose([
        tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
        tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}),
        tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'),
        tr.ToImage(norm_elem='extreme_points'),
        tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
        tr.ToTensor()])

    voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr)
    voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)

    if use_sbd:
        sbd = sbd.SBDSegmentation(split=['train', 'val'], transform=composed_transforms_tr, retname=True)
        db_train = combine_dbs([voc_train, sbd], excluded=[voc_val])
    else:
        db_train = voc_train

    p['dataset_train'] = str(db_train)
    p['transformations_train'] = [str(tran) for tran in composed_transforms_tr.transforms]
    p['dataset_test'] = str(db_train)
    p['transformations_test'] = [str(tran) for tran in composed_transforms_ts.transforms]

    trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=2)
    testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=2)
예제 #5
0
    if resume_epoch != nEpochs:
        # Logging into Tensorboard
        time_now = datetime.now().strftime('%b%d_%H-%M-%S')
        hostname = socket.gethostname()
        log_dir = save_dir / 'models' / '{}_{}'.format(time_now, hostname)
        writer = SummaryWriter(log_dir=str(log_dir))

        # Use the following optimizer
        #optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
        optimizer = optim.Adam(train_params, lr=p['lr'], weight_decay=p['wd'])
        p['optimizer'] = str(optimizer)

        # Preparation of the data loaders
        train_tf, test_tf = create_transforms(relax_crop, zero_pad_crop)
        voc_train = pascal.VOCSegmentation(split='train',
                                           download=True,
                                           transform=train_tf)
        voc_val = pascal.VOCSegmentation(split='val',
                                         download=True,
                                         transform=test_tf)

        if use_sbd:
            sbd = sbd.SBDSegmentation(split=['train', 'val'],
                                      retname=True,
                                      transform=train_tf)
            db_train = combine_dbs([voc_train, sbd], excluded=[voc_val])
        else:
            db_train = voc_train

        p['dataset_train'] = str(db_train)
        p['transformations_train'] = [str(t) for t in train_tf.transforms]
예제 #6
0
pretrain_dict = torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'))
print("Initializing weights from: {}".format(
    os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')))
net.load_state_dict(pretrain_dict)
net.to(device)

# Generate result of the validation images
net.eval()
composed_transforms_ts = transforms.Compose([
    tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True),
    tr.FixedResize(resolutions={'gt': None, 'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'gt':cv2.INTER_LINEAR,'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}),
    tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10),
    tr.ToImage(norm_elem='IOG_points'),
    tr.ConcatInputs(elems=('crop_image', 'IOG_points')),
    tr.ToTensor()])
db_test = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts, retname=True)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)

save_dir_res = os.path.join(save_dir, 'Results')
if not os.path.exists(save_dir_res):
    os.makedirs(save_dir_res)
save_dir_res_list=[save_dir_res]
print('Testing Network')
with torch.no_grad():
    for ii, sample_batched in enumerate(testloader):       
        inputs, gts, metas = sample_batched['concat'], sample_batched['gt'], sample_batched['meta']
        inputs = inputs.to(device)
        coarse_outs1,coarse_outs2,coarse_outs3,coarse_outs4,fine_out = net.forward(inputs)
        outputs = fine_out.to(torch.device('cpu'))
        pred = np.transpose(outputs.data.numpy()[0, :, :, :], (1, 2, 0))
        pred = 1 / (1 + np.exp(-pred))
예제 #7
0
import os.path

from torch.utils.data import DataLoader
from evaluation.eval import eval_one_result
import dataloaders.pascal as pascal

exp_root_dir = './'

method_names = []
method_names.append('run_0')

if __name__ == '__main__':

    # Dataloader
    dataset = pascal.VOCSegmentation(transform=None, retname=True)
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=0)

    # Iterate through all the different methods
    for method in method_names:
        for ii in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
            results_folder = os.path.join(exp_root_dir, method, 'Results')

            filename = os.path.join(exp_root_dir, 'eval_results',
                                    method.replace('/', '-') + '.txt')
            if not os.path.exists(os.path.join(exp_root_dir, 'eval_results')):
                os.makedirs(os.path.join(exp_root_dir, 'eval_results'))

            jaccards = eval_one_result(dataloader,
예제 #8
0
def main(opts):
    adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
    adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda()

    adj1_ = Variable(
        torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
    adj1_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda()

    cihp_adj = graph.preprocess_adj(graph.cihp_graph)
    adj3_ = Variable(torch.from_numpy(cihp_adj).float())
    adj3_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda()

    p = OrderedDict()  # Parameters to include in report
    p['trainBatch'] = opts.batch  # Training batch size
    p['nAveGrad'] = 1  # Average the gradient of several iterations
    p['lr'] = opts.lr  # Learning rate
    p['lrFtr'] = 1e-5
    p['lraspp'] = 1e-5
    p['lrpro'] = 1e-5
    p['lrdecoder'] = 1e-5
    p['lrother'] = 1e-5
    p['wd'] = 5e-4  # Weight decay
    p['momentum'] = 0.9  # Momentum
    p['epoch_size'] = 10  # How many epochs to change learning rate
    p['num_workers'] = opts.numworker
    backbone = 'xception'  # Use xception or resnet as feature extractor,

    with open(opts.txt_file, 'r') as f:
        img_list = f.readlines()

    max_id = 0
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
    runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*'))
    for r in runs:
        run_id = int(r.split('_')[-1])
        if run_id >= max_id:
            max_id = run_id + 1
    # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0

    # Network definition
    if backbone == 'xception':
        net = deeplab_xception_transfer.deeplab_xception_transfer_projection(
            n_classes=opts.classes,
            os=16,
            hidden_layers=opts.hidden_layers,
            source_classes=20,
        )
    elif backbone == 'resnet':
        # net = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
        raise NotImplementedError
    else:
        raise NotImplementedError

    if gpu_id >= 0:
        net.cuda()

    # net load weights
    if not opts.loadmodel == '':
        x = torch.load(opts.loadmodel)
        net.load_source_model(x)
        print('load model:', opts.loadmodel)
    else:
        print('no model load !!!!!!!!')

    ## multi scale
    scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
    testloader_list = []
    testloader_flip_list = []
    for pv in scale_list:
        composed_transforms_ts = transforms.Compose(
            [tr.Scale_(pv),
             tr.Normalize_xception_tf(),
             tr.ToTensor_()])

        composed_transforms_ts_flip = transforms.Compose([
            tr.Scale_(pv),
            tr.HorizontalFlip(),
            tr.Normalize_xception_tf(),
            tr.ToTensor_()
        ])

        voc_val = pascal.VOCSegmentation(split='val',
                                         transform=composed_transforms_ts)
        voc_val_f = pascal.VOCSegmentation(
            split='val', transform=composed_transforms_ts_flip)

        testloader = DataLoader(voc_val,
                                batch_size=1,
                                shuffle=False,
                                num_workers=p['num_workers'])
        testloader_flip = DataLoader(voc_val_f,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=p['num_workers'])

        testloader_list.append(copy.deepcopy(testloader))
        testloader_flip_list.append(copy.deepcopy(testloader_flip))

    print("Eval Network")

    if not os.path.exists(opts.output_path + 'pascal_output_vis/'):
        os.makedirs(opts.output_path + 'pascal_output_vis/')
    if not os.path.exists(opts.output_path + 'pascal_output/'):
        os.makedirs(opts.output_path + 'pascal_output/')

    start_time = timeit.default_timer()
    # One testing epoch
    total_iou = 0.0
    net.eval()
    for ii, large_sample_batched in enumerate(
            zip(*testloader_list, *testloader_flip_list)):
        print(ii)
        #1 0.5 0.75 1.25 1.5 1.75 ; flip:
        sample1 = large_sample_batched[:6]
        sample2 = large_sample_batched[6:]
        for iii, sample_batched in enumerate(zip(sample1, sample2)):
            inputs, labels = sample_batched[0]['image'], sample_batched[0][
                'label']
            inputs_f, _ = sample_batched[1]['image'], sample_batched[1][
                'label']
            inputs = torch.cat((inputs, inputs_f), dim=0)
            if iii == 0:
                _, _, h, w = inputs.size()
            # assert inputs.size() == inputs_f.size()

            # Forward pass of the mini-batch
            inputs, labels = Variable(inputs,
                                      requires_grad=False), Variable(labels)

            with torch.no_grad():
                if gpu_id >= 0:
                    inputs, labels = inputs.cuda(), labels.cuda()
                # outputs = net.forward(inputs)
                # pdb.set_trace()
                outputs = net.forward(inputs, adj1_test.cuda(),
                                      adj3_test.cuda(), adj2_test.cuda())
                outputs = (outputs[0] + flip(outputs[1], dim=-1)) / 2
                outputs = outputs.unsqueeze(0)

                if iii > 0:
                    outputs = F.upsample(outputs,
                                         size=(h, w),
                                         mode='bilinear',
                                         align_corners=True)
                    outputs_final = outputs_final + outputs
                else:
                    outputs_final = outputs.clone()
        ################ plot pic
        predictions = torch.max(outputs_final, 1)[1]
        prob_predictions = torch.max(outputs_final, 1)[0]
        results = predictions.cpu().numpy()
        prob_results = prob_predictions.cpu().numpy()
        vis_res = decode_labels(results)

        parsing_im = Image.fromarray(vis_res[0])
        parsing_im.save(opts.output_path +
                        'pascal_output_vis/{}.png'.format(img_list[ii][:-1]))
        cv2.imwrite(
            opts.output_path +
            'pascal_output/{}.png'.format(img_list[ii][:-1]), results[0, :, :])
        # np.save('../../cihp_prob_output/{}.npy'.format(img_list[ii][:-1]), prob_results[0, :, :])
        # pred_list.append(predictions.cpu())
        # label_list.append(labels.squeeze(1).cpu())
        # loss = criterion(outputs, labels, batch_average=True)
        # running_loss_ts += loss.item()

        # total_iou += utils.get_iou(predictions, labels)
    end_time = timeit.default_timer()
    print('time use for ' + str(ii) + ' is :' + str(end_time - start_time))

    # Eval
    pred_path = opts.output_path + 'pascal_output/'
    eval_(pred_path=pred_path,
          gt_path=opts.gt_path,
          classes=opts.classes,
          txt_file=opts.txt_file)
예제 #9
0
composed_transforms_tr = standard_transforms.Compose([
    tr.RandomHorizontalFlip(),
    tr.ScaleNRotate(rots=(-15, 15), scales=(.75, 1.5)),
    tr.RandomResizedCrop(img_size),
    tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    tr.ToTensor()
])

composed_transforms_ts = standard_transforms.Compose([
    tr.RandomResizedCrop(img_size),
    tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    tr.ToTensor()
])

voc_train = pascal.VOCSegmentation(base_dir=data_file,
                                   split='train',
                                   transform=composed_transforms_tr)
trainloader = DataLoader(voc_train,
                         batch_size=opt.b,
                         shuffle=True,
                         num_workers=1)
model_dir = './pth/'


def find_new_file(dir):
    if os.path.exists(dir) is False:
        os.mkdir(model_dir)
        dir = model_dir

    file_lists = os.listdir(dir)
    file_lists.sort(key=lambda fn: os.path.getmtime(dir + fn)
예제 #10
0
def main(opts):

	# Some of the settings are not used
	p = OrderedDict()  # Parameters to include in report
	p['trainBatch'] = opts.batch  # Training batch size
	testBatch = 1  # Testing batch size
	useTest = True  # See evolution of the test set when training
	nTestInterval = opts.testInterval  # Run on test set every nTestInterval epochs
	snapshot = 1  # Store a model every snapshot epochs
	p['nAveGrad'] = 1  # Average the gradient of several iterations
	p['lr'] = opts.lr  # Learning rate
	p['lrFtr'] = 1e-5
	p['lraspp'] = 1e-5
	p['lrpro'] = 1e-5
	p['lrdecoder'] = 1e-5
	p['lrother'] = 1e-5
	p['wd'] = 5e-4  # Weight decay
	p['momentum'] = 0.9  # Momentum
	p['epoch_size'] = opts.step  # How many epochs to change learning rate
	p['num_workers'] = opts.numworker
	backbone = 'xception'  # Use xception or resnet as feature extractor,
	nEpochs = opts.epochs

	resume_epoch = opts.resume_epoch

	max_id = 0
	save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
	exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
	runs = glob.glob(os.path.join(save_dir_root, 'run_cihp', 'run_*'))
	for r in runs:
		run_id = int(r.split('_')[-1])
		if run_id >= max_id:
			max_id = run_id + 1
	save_dir = os.path.join(save_dir_root, 'run_cihp', 'run_' + str(max_id))

	print(save_dir)

	# Network definition
	net_ = grapy_net.GrapyMutualLearning(os=16, hidden_layers=opts.hidden_graph_layers)

	modelName = 'deeplabv3plus-' + backbone + '-voc' + datetime.now().strftime('%b%d_%H-%M-%S')
	criterion = util.cross_entropy2d

	log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
	writer = SummaryWriter(log_dir=log_dir)
	writer.add_text('load model', opts.loadmodel, 1)
	writer.add_text('setting', sys.argv[0], 1)

	# Use the following optimizer
	optimizer = optim.SGD(net_.parameters(), lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])

	composed_transforms_tr = transforms.Compose([
		tr.RandomSized_new(512),
		tr.Normalize_xception_tf(),
		tr.ToTensor_()])

	composed_transforms_ts = transforms.Compose([
		tr.Normalize_xception_tf(),
		tr.ToTensor_()])

	composed_transforms_ts_flip = transforms.Compose([
		tr.HorizontalFlip(),
		tr.Normalize_xception_tf(),
		tr.ToTensor_()])

	if opts.train_mode == 'cihp_pascal_atr':
		all_train = cihp_pascal_atr.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
		num_cihp, num_pascal, num_atr = all_train.get_class_num()

		voc_val = atr.VOCSegmentation(split='val', transform=composed_transforms_ts)
		voc_val_flip = atr.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)

		ss = sam.Sampler_uni(num_cihp, num_pascal, num_atr, opts.batch)

		trainloader = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=18, sampler=ss, drop_last=True)

	elif opts.train_mode == 'cihp_pascal_atr_1_1_1':
		all_train = cihp_pascal_atr.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
		num_cihp, num_pascal, num_atr = all_train.get_class_num()

		voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
		voc_val_flip = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)

		ss_uni = sam.Sampler_uni(num_cihp, num_pascal, num_atr, opts.batch, balance_id=1)

		trainloader = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=1, sampler=ss_uni, drop_last=True)

	elif opts.train_mode == 'cihp':
		voc_train = cihp.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
		voc_val = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts)
		voc_val_flip = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)

		trainloader = DataLoader(voc_train, batch_size=p['trainBatch'], shuffle=True, num_workers=18, drop_last=True)

	elif opts.train_mode == 'pascal':

		# here we train without flip but test with flip
		voc_train = pascal_flip.VOCSegmentation(split='train', transform=composed_transforms_tr)
		voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
		voc_val_flip = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)

		trainloader = DataLoader(voc_train, batch_size=p['trainBatch'], shuffle=True, num_workers=18, drop_last=True)

	elif opts.train_mode == 'atr':

		# here we train without flip but test with flip
		voc_train = atr.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
		voc_val = atr.VOCSegmentation(split='val', transform=composed_transforms_ts)
		voc_val_flip = atr.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)

		trainloader = DataLoader(voc_train, batch_size=p['trainBatch'], shuffle=True, num_workers=18, drop_last=True)

	else:
		raise NotImplementedError

	if not opts.loadmodel == '':
		x = torch.load(opts.loadmodel)
		net_.load_state_dict_new(x, strict=False)
		print('load model:', opts.loadmodel)
	else:
		print('no model load !!!!!!!!')

	if not opts.resume_model == '':
		x = torch.load(opts.resume_model)
		net_.load_state_dict(x)
		print('resume model:', opts.resume_model)

	else:
		print('we are not resuming from any model')

	# We only validate on pascal dataset to save time
	testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=3)
	testloader_flip = DataLoader(voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=3)

	num_img_tr = len(trainloader)
	num_img_ts = len(testloader)

	# Set the category relations
	c1, c2, p1, p2, a1, a2 = [[0], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]],\
							 [[0], [1, 2, 4, 13], [5, 6, 7, 10, 11, 12], [3, 14, 15], [8, 9, 16, 17, 18, 19]], \
							 [[0], [1, 2, 3, 4, 5, 6]], [[0], [1], [2], [3, 4], [5, 6]], [[0], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]],\
							 [[0], [1, 2, 3, 11], [4, 5, 7, 8, 16, 17], [14, 15], [6, 9, 10, 12, 13]]

	net_.set_category_list(c1, c2, p1, p2, a1, a2)
	if gpu_id >= 0:
		# torch.cuda.set_device(device=gpu_id)
		net_.cuda()

	running_loss_tr = 0.0
	running_loss_ts = 0.0

	running_loss_tr_main = 0.0
	running_loss_tr_aux = 0.0
	aveGrad = 0
	global_step = 0
	miou = 0
	cur_miou = 0
	print("Training Network")

	net = torch.nn.DataParallel(net_)

	# Main Training and Testing Loop
	for epoch in range(resume_epoch, nEpochs):
		start_time = timeit.default_timer()

		if opts.poly:
			if epoch % p['epoch_size'] == p['epoch_size'] - 1:
				lr_ = util.lr_poly(p['lr'], epoch, nEpochs, 0.9)
				optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd'])
				writer.add_scalar('data/lr_', lr_, epoch)
				print('(poly lr policy) learning rate: ', lr_)

		net.train()
		for ii, sample_batched in enumerate(trainloader):

			inputs, labels = sample_batched['image'], sample_batched['label']
			# Forward-Backward of the mini-batch
			inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
			global_step += inputs.data.shape[0]

			if gpu_id >= 0:
				inputs, labels = inputs.cuda(), labels.cuda()

			if opts.train_mode == 'cihp_pascal_atr' or opts.train_mode == 'cihp_pascal_atr_1_1_1':
				num_dataset_lbl = sample_batched['pascal'][0].item()

			elif opts.train_mode == 'cihp':
				num_dataset_lbl = 0

			elif opts.train_mode == 'pascal':
				num_dataset_lbl = 1

			else:
				num_dataset_lbl = 2

			outputs, outputs_aux = net.forward((inputs, num_dataset_lbl))

			# print(inputs.shape, labels.shape, outputs.shape, outputs_aux.shape)

			loss_main = criterion(outputs, labels, batch_average=True)
			loss_aux = criterion(outputs_aux, labels, batch_average=True)

			loss = opts.beta_main * loss_main + opts.beta_aux * loss_aux

			running_loss_tr_main += loss_main.item()
			running_loss_tr_aux += loss_aux.item()
			running_loss_tr += loss.item()

			# Print stuff
			if ii % num_img_tr == (num_img_tr - 1):
				running_loss_tr = running_loss_tr / num_img_tr
				running_loss_tr_aux = running_loss_tr_aux / num_img_tr
				running_loss_tr_main = running_loss_tr_main / num_img_tr

				writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)

				writer.add_scalars('data/scalar_group', {'loss': running_loss_tr_main,
														 'loss_aux': running_loss_tr_aux}, epoch)

				print('[Epoch: %d, numImages: %5d]' % (epoch, ii * p['trainBatch'] + inputs.data.shape[0]))
				print('Loss: %f' % running_loss_tr)
				running_loss_tr = 0
				stop_time = timeit.default_timer()
				print("Execution time: " + str(stop_time - start_time) + "\n")

			# Backward the averaged gradient
			loss /= p['nAveGrad']
			loss.backward()
			aveGrad += 1

			# Update the weights once in p['nAveGrad'] forward passes
			if aveGrad % p['nAveGrad'] == 0:
				writer.add_scalar('data/total_loss_iter', loss.item(), ii + num_img_tr * epoch)

				if num_dataset_lbl == 0:
					writer.add_scalar('data/total_loss_iter_cihp', loss.item(), global_step)
				if num_dataset_lbl == 1:
					writer.add_scalar('data/total_loss_iter_pascal', loss.item(), global_step)
				if num_dataset_lbl == 2:
					writer.add_scalar('data/total_loss_iter_atr', loss.item(), global_step)

				optimizer.step()
				optimizer.zero_grad()
				aveGrad = 0

			# Show 10 * 3 images results each
			# print(ii, (num_img_tr * 10), (ii % (num_img_tr * 10) == 0))
			if ii % (num_img_tr * 10) == 0:
				grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
				writer.add_image('Image', grid_image, global_step)
				grid_image = make_grid(
					util.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3,
					normalize=False,
					range=(0, 255))
				writer.add_image('Predicted label', grid_image, global_step)
				grid_image = make_grid(
					util.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3,
					normalize=False, range=(0, 255))
				writer.add_image('Groundtruth label', grid_image, global_step)
			print('loss is ', loss.cpu().item(), flush=True)

		# Save the model
		# One testing epoch
		if useTest and epoch % nTestInterval == (nTestInterval - 1):

			cur_miou = validation(net_, testloader=testloader, testloader_flip=testloader_flip, classes=opts.classes,
								epoch=epoch, writer=writer, criterion=criterion, dataset=opts.train_mode)

		torch.cuda.empty_cache()

		if (epoch % snapshot) == snapshot - 1:

			torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch' + '_current' + '.pth'))
			print("Save model at {}\n".format(
				os.path.join(save_dir, 'models', modelName + str(epoch) + '_epoch-' + str(epoch) + '.pth as our current model')))

			if cur_miou > miou:
				miou = cur_miou
				torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_best' + '.pth'))
				print("Save model at {}\n".format(
					os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth as our best model')))

		torch.cuda.empty_cache()
def main(opts):
    # Set parameters
    p = OrderedDict()  # Parameters to include in report
    p['trainBatch'] = opts.batch  # Training batch size
    testBatch = 1  # Testing batch size
    useTest = True  # See evolution of the test set when training
    nTestInterval = opts.testInterval # Run on test set every nTestInterval epochs
    snapshot = 1  # Store a model every snapshot epochs
    p['nAveGrad'] = 1  # Average the gradient of several iterations
    p['lr'] = opts.lr  # Learning rate
    p['wd'] = 5e-4  # Weight decay
    p['momentum'] = 0.9  # Momentum
    p['epoch_size'] = opts.step  # How many epochs to change learning rate
    p['num_workers'] = opts.numworker
    model_path = opts.pretrainedModel
    backbone = 'xception' # Use xception or resnet as feature extractor
    nEpochs = opts.epochs

    max_id = 0
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
    runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*'))
    for r in runs:
        run_id = int(r.split('_')[-1])
        if run_id >= max_id:
            max_id = run_id + 1
    # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
    save_dir = os.path.join(save_dir_root, 'run', 'run_' + str(max_id))

    # Network definition
    if backbone == 'xception':
        net_ = deeplab_xception_universal.deeplab_xception_end2end_3d(n_classes=20, os=16,
                                                                      hidden_layers=opts.hidden_layers,
                                                                      source_classes=7,
                                                                      middle_classes=18, )
    elif backbone == 'resnet':
        # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
        raise NotImplementedError
    else:
        raise NotImplementedError

    modelName = 'deeplabv3plus-' + backbone + '-voc'+datetime.now().strftime('%b%d_%H-%M-%S')
    criterion = ut.cross_entropy2d

    if gpu_id >= 0:
        # torch.cuda.set_device(device=gpu_id)
        net_.cuda()

    # net load weights
    if not model_path == '':
        x = torch.load(model_path)
        net_.load_state_dict_new(x)
        print('load pretrainedModel.')
    else:
        print('no pretrainedModel.')

    if not opts.loadmodel =='':
        x = torch.load(opts.loadmodel)
        net_.load_source_model(x)
        print('load model:' ,opts.loadmodel)
    else:
        print('no trained model load !!!!!!!!')

    log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    writer = SummaryWriter(log_dir=log_dir)
    writer.add_text('load model',opts.loadmodel,1)
    writer.add_text('setting',sys.argv[0],1)

    # Use the following optimizer
    optimizer = optim.SGD(net_.parameters(), lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])

    composed_transforms_tr = transforms.Compose([
            tr.RandomSized_new(512),
            tr.Normalize_xception_tf(),
            tr.ToTensor_()])

    composed_transforms_ts = transforms.Compose([
        tr.Normalize_xception_tf(),
        tr.ToTensor_()])

    composed_transforms_ts_flip = transforms.Compose([
        tr.HorizontalFlip(),
        tr.Normalize_xception_tf(),
        tr.ToTensor_()])

    all_train = cihp_pascal_atr.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
    voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
    voc_val_flip = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)

    num_cihp,num_pascal,num_atr = all_train.get_class_num()
    ss = sam.Sampler_uni(num_cihp,num_pascal,num_atr,opts.batch)
    # balance datasets based pascal
    ss_balanced = sam.Sampler_uni(num_cihp,num_pascal,num_atr,opts.batch, balance_id=1)

    trainloader = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=p['num_workers'],
                             sampler=ss, drop_last=True)
    trainloader_balanced = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=p['num_workers'],
                             sampler=ss_balanced, drop_last=True)
    testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
    testloader_flip = DataLoader(voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])

    num_img_tr = len(trainloader)
    num_img_balanced = len(trainloader_balanced)
    num_img_ts = len(testloader)
    running_loss_tr = 0.0
    running_loss_tr_atr = 0.0
    running_loss_ts = 0.0
    aveGrad = 0
    global_step = 0
    print("Training Network")
    net = torch.nn.DataParallel(net_)

    id_list = torch.LongTensor(range(opts.batch))
    pascal_iter = int(num_img_tr//opts.batch)

    # Get graphs
    train_graph, test_graph = get_graphs(opts)
    adj1, adj2, adj3, adj4, adj5, adj6 = train_graph
    adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test = test_graph

    # Main Training and Testing Loop
    for epoch in range(resume_epoch, int(1.5*nEpochs)):
        start_time = timeit.default_timer()

        if epoch % p['epoch_size'] == p['epoch_size'] - 1 and epoch<nEpochs:
            lr_ = ut.lr_poly(p['lr'], epoch, nEpochs, 0.9)
            optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd'])
            print('(poly lr policy) learning rate: ', lr_)
            writer.add_scalar('data/lr_',lr_,epoch)
        elif epoch % p['epoch_size'] == p['epoch_size'] - 1 and epoch > nEpochs:
            lr_ = ut.lr_poly(p['lr'], epoch-nEpochs, int(0.5*nEpochs), 0.9)
            optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd'])
            print('(poly lr policy) learning rate: ', lr_)
            writer.add_scalar('data/lr_', lr_, epoch)

        net_.train()
        if epoch < nEpochs:
            for ii, sample_batched in enumerate(trainloader):
                inputs, labels = sample_batched['image'], sample_batched['label']
                dataset_lbl = sample_batched['pascal'][0].item()
                # Forward-Backward of the mini-batch
                inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
                global_step += 1

                if gpu_id >= 0:
                    inputs, labels = inputs.cuda(), labels.cuda()

                if dataset_lbl == 0:
                    # 0 is cihp -- target
                    _, outputs,_ = net.forward(None, input_target=inputs, input_middle=None, adj1_target=adj1, adj2_source=adj2,
                        adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2,3), adj4_middle=adj4,adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),adj5_transfer_m2s=adj5,adj6_transfer_m2t=adj6,)
                elif dataset_lbl == 1:
                    # pascal is source
                    outputs, _, _ = net.forward(inputs, input_target=None, input_middle=None, adj1_target=adj1,
                                                adj2_source=adj2,
                                                adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
                                                adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
                                                adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
                                                adj6_transfer_m2t=adj6, )
                else:
                    # atr
                    _, _, outputs = net.forward(None, input_target=None, input_middle=inputs, adj1_target=adj1,
                                                adj2_source=adj2,
                                                adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
                                                adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
                                                adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
                                                adj6_transfer_m2t=adj6, )
                # print(sample_batched['pascal'])
                # print(outputs.size(),)
                # print(labels)
                loss = criterion(outputs, labels,  batch_average=True)
                running_loss_tr += loss.item()

                # Print stuff
                if ii % num_img_tr == (num_img_tr - 1):
                    running_loss_tr = running_loss_tr / num_img_tr
                    writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)
                    print('[Epoch: %d, numImages: %5d]' % (epoch, epoch))
                    print('Loss: %f' % running_loss_tr)
                    running_loss_tr = 0
                    stop_time = timeit.default_timer()
                    print("Execution time: " + str(stop_time - start_time) + "\n")

                # Backward the averaged gradient
                loss /= p['nAveGrad']
                loss.backward()
                aveGrad += 1

                # Update the weights once in p['nAveGrad'] forward passes
                if aveGrad % p['nAveGrad'] == 0:
                    writer.add_scalar('data/total_loss_iter', loss.item(), global_step)
                    if dataset_lbl == 0:
                        writer.add_scalar('data/total_loss_iter_cihp', loss.item(), global_step)
                    if dataset_lbl == 1:
                        writer.add_scalar('data/total_loss_iter_pascal', loss.item(), global_step)
                    if dataset_lbl == 2:
                        writer.add_scalar('data/total_loss_iter_atr', loss.item(), global_step)
                    optimizer.step()
                    optimizer.zero_grad()
                    # optimizer_gcn.step()
                    # optimizer_gcn.zero_grad()
                    aveGrad = 0

                # Show 10 * 3 images results each epoch
                if ii % (num_img_tr // 10) == 0:
                    grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
                    writer.add_image('Image', grid_image, global_step)
                    grid_image = make_grid(ut.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3, normalize=False,
                                           range=(0, 255))
                    writer.add_image('Predicted label', grid_image, global_step)
                    grid_image = make_grid(ut.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3, normalize=False, range=(0, 255))
                    writer.add_image('Groundtruth label', grid_image, global_step)

                print('loss is ',loss.cpu().item(),flush=True)
        else:
            # Balanced the number of datasets
            for ii, sample_batched in enumerate(trainloader_balanced):
                inputs, labels = sample_batched['image'], sample_batched['label']
                dataset_lbl = sample_batched['pascal'][0].item()
                # Forward-Backward of the mini-batch
                inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
                global_step += 1

                if gpu_id >= 0:
                    inputs, labels = inputs.cuda(), labels.cuda()

                if dataset_lbl == 0:
                    # 0 is cihp -- target
                    _, outputs, _ = net.forward(None, input_target=inputs, input_middle=None, adj1_target=adj1,
                                                adj2_source=adj2,
                                                adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
                                                adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
                                                adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
                                                adj6_transfer_m2t=adj6, )
                elif dataset_lbl == 1:
                    # pascal is source
                    outputs, _, _ = net.forward(inputs, input_target=None, input_middle=None, adj1_target=adj1,
                                                adj2_source=adj2,
                                                adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
                                                adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
                                                adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
                                                adj6_transfer_m2t=adj6, )
                else:
                    # atr
                    _, _, outputs = net.forward(None, input_target=None, input_middle=inputs, adj1_target=adj1,
                                                adj2_source=adj2,
                                                adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
                                                adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
                                                adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
                                                adj6_transfer_m2t=adj6, )
                # print(sample_batched['pascal'])
                # print(outputs.size(),)
                # print(labels)
                loss = criterion(outputs, labels, batch_average=True)
                running_loss_tr += loss.item()

                # Print stuff
                if ii % num_img_balanced == (num_img_balanced - 1):
                    running_loss_tr = running_loss_tr / num_img_balanced
                    writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)
                    print('[Epoch: %d, numImages: %5d]' % (epoch, epoch))
                    print('Loss: %f' % running_loss_tr)
                    running_loss_tr = 0
                    stop_time = timeit.default_timer()
                    print("Execution time: " + str(stop_time - start_time) + "\n")

                # Backward the averaged gradient
                loss /= p['nAveGrad']
                loss.backward()
                aveGrad += 1

                # Update the weights once in p['nAveGrad'] forward passes
                if aveGrad % p['nAveGrad'] == 0:
                    writer.add_scalar('data/total_loss_iter', loss.item(), global_step)
                    if dataset_lbl == 0:
                        writer.add_scalar('data/total_loss_iter_cihp', loss.item(), global_step)
                    if dataset_lbl == 1:
                        writer.add_scalar('data/total_loss_iter_pascal', loss.item(), global_step)
                    if dataset_lbl == 2:
                        writer.add_scalar('data/total_loss_iter_atr', loss.item(), global_step)
                    optimizer.step()
                    optimizer.zero_grad()

                    aveGrad = 0

                # Show 10 * 3 images results each epoch
                if ii % (num_img_balanced // 10) == 0:
                    grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
                    writer.add_image('Image', grid_image, global_step)
                    grid_image = make_grid(
                        ut.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3,
                        normalize=False,
                        range=(0, 255))
                    writer.add_image('Predicted label', grid_image, global_step)
                    grid_image = make_grid(
                        ut.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3,
                        normalize=False, range=(0, 255))
                    writer.add_image('Groundtruth label', grid_image, global_step)

                print('loss is ', loss.cpu().item(), flush=True)

        # Save the model
        if (epoch % snapshot) == snapshot - 1:
            torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))
            print("Save model at {}\n".format(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')))

        # One testing epoch
        if useTest and epoch % nTestInterval == (nTestInterval - 1):
            val_pascal(net_=net_, testloader=testloader, testloader_flip=testloader_flip, test_graph=test_graph,
                       criterion=criterion, epoch=epoch, writer=writer)
예제 #12
0
def train(model_name, gpu_id, learning_rate):
    # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu
    device = torch.device("cuda:%d" %
                          gpu_id if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        tqdm.write('Using GPU: {} '.format(gpu_id))

    # Setting parameters
    use_sbd = False
    nEpochs = 100  # Number of epochs for training
    resume_epoch = 0  # Default is 0, change if want to resume

    p = OrderedDict()  # Parameters to include in report
    classifier = 'psp'  # Head classifier to use
    p['trainBatch'] = 5  # Training batch size
    testBatch = 5  # Testing batch size
    useTest = True  # See evolution of the test set when training?
    nTestInterval = 10  # Run on test set every nTestInterval epochs
    snapshot = 10  # Store a model every snapshot epochs
    relax_crop = 50  # Enlarge the bounding box by relax_crop pixels
    nInputChannels = 4  # Number of input channels (RGB + heatmap of extreme points)
    zero_pad_crop = True  # Insert zero padding when cropping the image
    p['nAveGrad'] = 1  # Average the gradient of several iterations
    p['lr'] = learning_rate  #1e-4  # Learning rate
    p['wd'] = 0.0005  # Weight decay
    p['momentum'] = 0.9  # Momentum

    # Results and model directories (a new directory is generated for every run)
    package_path = Path(__file__).resolve()
    folder_name = 'runs-{}-{:f}'.format(model_name, learning_rate)
    save_dir_root = package_path.parent / 'RUNS'
    exp_name = model_name
    save_dir = save_dir_root / folder_name
    (save_dir / 'models').mkdir(parents=True, exist_ok=True)
    with (save_dir / 'log.csv').open('wt') as csv:
        csv.write('train_loss,test_loss\n')
    tqdm.write(str(save_dir))

    # Network definition
    modelName = model_name
    net = load_model(model_name, nInputChannels)
    #if resume_epoch == 0:
    #    print("Initializing from pretrained Deeplab-v2 model")
    #else:
    #    weights_path = save_dir / 'models'
    #    weights_path /= '%s_epoch-%d.pth' % (modelName, resume_epoch-1)
    #    print("Initializing weights from: ", weights_path)
    #    net.load_state_dict(torch.load(weights_path,
    #                                   map_location=lambda s, _: s))
    train_params = [{'params': net.parameters(), 'lr': p['lr']}]

    net.to(device)

    # Training the network
    if resume_epoch != nEpochs:
        # Logging into Tensorboard
        time_now = datetime.now().strftime('%b%d_%H-%M-%S')
        hostname = socket.gethostname()
        log_dir = save_dir / 'models' / '{}_{}'.format(time_now, hostname)
        writer = SummaryWriter(log_dir=str(log_dir))

        # Use the following optimizer
        #optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
        optimizer = optim.Adam(train_params, lr=p['lr'], weight_decay=p['wd'])
        p['optimizer'] = str(optimizer)

        # Preparation of the data loaders
        train_tf, test_tf = create_transforms(relax_crop, zero_pad_crop)
        voc_train = pascal.VOCSegmentation(split='train',
                                           download=True,
                                           transform=train_tf)
        voc_val = pascal.VOCSegmentation(split='val',
                                         download=True,
                                         transform=test_tf)

        if use_sbd:
            sbd = sbd.SBDSegmentation(split=['train', 'val'],
                                      retname=True,
                                      transform=train_tf)
            db_train = combine_dbs([voc_train, sbd], excluded=[voc_val])
        else:
            db_train = voc_train

        p['dataset_train'] = str(db_train)
        p['transformations_train'] = [str(t) for t in train_tf.transforms]
        p['dataset_test'] = str(db_train)
        p['transformations_test'] = [str(t) for t in test_tf.transforms]

        trainloader = DataLoader(db_train,
                                 batch_size=p['trainBatch'],
                                 shuffle=True,
                                 num_workers=2)
        testloader = DataLoader(voc_val,
                                batch_size=testBatch,
                                shuffle=False,
                                num_workers=2)
        generate_param_report((save_dir / exp_name).with_suffix('.txt'), p)

        # Train variables
        num_img_tr = len(trainloader)
        num_img_ts = len(testloader)
        running_loss_tr = 0.0
        running_loss_ts = 0.0
        aveGrad = 0
        #print("Training Network")
        # Main Training and Testing Loop
        for epoch in trange(resume_epoch, nEpochs):
            start_time = timeit.default_timer()

            net.train()
            for ii, sample_batched in enumerate(tqdm(trainloader,
                                                     leave=False)):
                inputs = sample_batched['concat'].to(device)
                gts = sample_batched['crop_gt'].to(device)

                # Forward-Backward of the mini-batch
                inputs.requires_grad_()

                output = net.forward(inputs)  #.cpu()
                #print(output.shape)
                #exit()
                output = interpolate(output,
                                     size=(512, 512),
                                     mode='bilinear',
                                     align_corners=True).to(device)
                # Compute the losses, side outputs and fuse
                loss = class_balanced_cross_entropy_loss(output,
                                                         gts,
                                                         size_average=False,
                                                         batch_average=True)
                running_loss_tr += loss.item()
                #print(loss.item())

                # Backward the averaged gradient
                loss /= p['nAveGrad']
                loss.backward()
                aveGrad += 1

                # Update the weights once in p['nAveGrad'] forward passes
                if aveGrad % p['nAveGrad'] == 0:
                    writer.add_scalar('data/total_loss_iter', loss.item(),
                                      ii + num_img_tr * epoch)
                    optimizer.step()
                    optimizer.zero_grad()
                    aveGrad = 0

            # Save the model
            if (epoch + 1) % snapshot == 0:
                weights_path = save_dir / 'models'
                weights_path /= '{}_epoch-{:d}.pth'.format(modelName, epoch)
                torch.save(net.state_dict(), weights_path)

            # One testing epoch
            if useTest and (epoch + 1) % nTestInterval == 0:
                msg = 'Test Loss: {:.3f}'
                test_loss = test(net, testloader, device)
                tqdm.write(msg.format(test_loss))
                running_loss_tr = running_loss_tr / num_img_tr
                with (save_dir / 'log.csv').open('at') as csv:
                    csv.write(','.join([str(running_loss_tr), str(test_loss)]))
                    csv.write('\n')
                writer.add_scalar('data/total_loss_epoch', running_loss_tr,
                                  epoch)
                num_images = ii * p['trainBatch'] + inputs.data.shape[0]
                msg = '[Epoch: {:d}, numImages: {:5d}]'
                tqdm.write(msg.format(epoch, num_images))
                tqdm.write('Loss: %f' % running_loss_tr)
                running_loss_tr = 0
                stop_time = timeit.default_timer()
                msg = "Execution time: {:.3f}"
                tqdm.write(msg.format(stop_time - start_time))

        writer.close()