コード例 #1
0
    torch.cuda.set_device(device=gpu_id)
    net.cuda()
    net = torch.nn.DataParallel(net,
                                device_ids=range(torch.cuda.device_count()))

if resume_epoch != nEpochs:
    # Logging into Tensorboard
    log_dir = os.path.join(
        save_dir, 'models',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())

    optimizer = optim.Adam(net.parameters(), lr=p['lr'], weight_decay=p['wd'])
    p['optimizer'] = str(optimizer)

    composed_transforms_tr = transforms.Compose([
        tr.RandomSized(512),
        tr.RandomRotate(15),
        tr.RandomHorizontalFlip(),
        tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        tr.ToTensor()
    ])

    composed_transforms_ts = transforms.Compose([
        tr.FixedResize(size=(512, 512)),
        tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        tr.ToTensor()
    ])

    # voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr)
    # voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
    ROOT = 'dataset/ORIGA'
コード例 #2
0
		sample = {'image': _img, 'label': _target}
		if self.transform is not None: 
			sample = self.transform(sample)
		return sample 

	def _make_img_gt_point_pair(self, index):
		#Read image and target 
		_img = Image.open(self.images[index].convert('RGB'))
		_target = Image.open(self.categories[index])
		return _img, _target

	def __str__(self):
		return 'VOC2012(split=' + str(self.split) + ')'

if __name__ == '__main__':
	composed_transforms_tr = transforms.Compose([tr.RandomHorizontalFlip(), tr.RandomSized(512), tr.RandomRotate(15), tr.ToTensor()])

	voc_train = PascalVOC(split='train', transform=composed_transforms_tr)
	dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=2)

	for ii, sample in tqdm(enumberate(dataloader)):
		for jj in tqdm(range(sample["image"].size()[0])):
			img = sample['image'].numpy()
			gt = sample['label'].numpy()
			tmp = np.array(get[jj]).astype(np.uint8)
			tmp = np.squeeze(tmp, axis=0)
			segmap = decode_segmap(tmp, dataset = 'pascal')
			img_tmp = np.transpose(img[jj], axes=[1,2,0]).astype(np.uint8)
			plt.figure()
			plt.title('display')
			plt.subplot(211)