Esempio n. 1
0
    def __init__(self, args):
        """initialize the Trainer"""
        # about gpus
        self.cuda = args.cuda
        self.gpu_ids = args.gpu_ids
        self.num_gpus = len(self.gpu_ids)
        self.crf_iter_steps = args.crf_iter_steps
        self.output_dir = args.output_dir
        self.model = 'test'

        # define dataloader
        self.val_loader = factory.get_dataset(args.data_dir,
                                              batch_size=1,
                                              dataset=args.dataset,
                                              split=args.train_split)
        self.nclass = self.val_loader.NUM_CLASSES
        # define network
        assert args.seg_model in seg_model_obj_dict.keys()
        self.seg_model = args.seg_model
        self.seg_model = seg_model_obj_dict[self.seg_model](
            num_classes=self.nclass,
            backbone=args.backbone,
            output_stride=args.out_stride,
            norm_layer=torch.nn.BatchNorm2d,
            bn_mom=args.bn_mom,
            freeze_bn=True)

        # define criterion
        #self.criterion = torch.nn.CrossEntropyLoss(weight=None, ignore_index=255, reduction='mean')
        self.model = full_model.FullModel(seg_model=self.seg_model,
                                          model=self.model)
        # define evaluator
        self.evaluator = Evaluator(self.nclass)

        # using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.gpu_ids)
            #patch_replication_callback(self.model)
            self.model = self.model.cuda()
            #self.criterion = self.criterion.cuda()

        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            print('Restore parameters from the {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            self.global_step = checkpoint['global_step']

            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
Esempio n. 2
0
class Trainer(object):
	def __init__(self, args):
		"""initialize the Trainer"""
		# about gpus
		self.cuda = args.cuda
		self.gpu_ids = args.gpu_ids
		self.num_gpus = len(self.gpu_ids)
		self.crf_iter_steps = args.crf_iter_steps
		self.output_dir = args.output_dir
		# define dataloader
		self.val_loader = factory.get_dataset(args.data_dir,
												batch_size=1,
												dataset=args.dataset,
												split=args.train_split)
		self.nclass = self.val_loader.NUM_CLASSES
		# define network
		assert args.seg_model in seg_model_obj_dict.keys()
		self.seg_model = args.seg_model
		self.model = seg_model_obj_dict[self.seg_model](num_classes=self.nclass,
														backbone=args.backbone,
														output_stride=args.out_stride,
														norm_layer=torch.nn.BatchNorm2d,
														bn_mom=args.bn_mom,
														freeze_bn=True)

		# define evaluator
		self.evaluator = Evaluator(self.nclass)

		# using cuda
		if args.cuda:
			self.model = torch.nn.DataParallel(self.model, device_ids=self.gpu_ids)
			#patch_replication_callback(self.model)
			self.model = self.model.cuda()

		# resuming checkpoint
		if args.resume is not None:
			if not os.path.isfile(args.resume):
				raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
			print('Restore parameters from the {}'.format(args.resume))
			checkpoint = torch.load(args.resume)
			self.global_step = checkpoint['global_step']

			if args.cuda:
				self.model.module.load_state_dict(checkpoint['state_dict'])
			else:
				self.model.load_state_dict(checkpoint['state_dict'])

	def validation(self):
		"""validation procedure
		"""
		# set validation mode
		self.model.eval()
		self.evaluator.reset()
		crf_100_steps = 0.0
		start = timeit.default_timer()
		for i in range(len(self.val_loader)):
			#for i, sample in enumerate(self.val_loader):
			sample = self.val_loader[i]
			image = sample['image']
			#image = image.repeat(self.num_gpus, 1, 1, 1)
			#print("{}-th sample, Image shape {}, label shape {}".format(i + 1, image.size(), target.size()))
			if self.cuda:
				image = image.cuda()
			image = image.unsqueeze(dim=0)
			# forward
			with torch.no_grad():
				output = self.model(image)
			# the output of the pspnet is a tuple
			if self.seg_model == 'pspnet':
				output = output[0]

			# get probs, shape [N, C, H, W] --> [N, H, W, C]
			probs = F.softmax(output, dim=1).permute(0, 2, 3, 1).squeeze_()
			probs_np = probs.data.cpu().numpy()
			#pred = output.data.cpu().numpy()

			# CRF post-processing
			image_name = self.val_loader.image_lists[i]
			#real_image = cv2.cvtColor(cv2.imread(image_name), cv2.COLOR_BGR2RGB)

			real_image = Image.open(image_name).convert('RGB')
			real_image = np.array(real_image).astype(np.uint8)

			crf_start = timeit.default_timer()
			pred = crf.dense_crf(real_image=real_image, probs=probs_np, iter_steps=self.crf_iter_steps)
			crf_end = timeit.default_timer()
			crf_100_steps += (crf_end - crf_start)
			# save output
			path_to_output = os.path.join(self.output_dir, self.val_loader.image_ids[i] + '.png')
			result = Image.fromarray(pred.astype(np.uint8))
			result.save(path_to_output)
			#cv2.imwrite(path_to_output, pred)
			# report time of CRF
			if not i % 100:
				stop = timeit.default_timer()
				print("current step = {} ({:.3f} sec), crf time {:.3f} sec".
									format(i, stop - start, crf_100_steps))
				crf_100_steps = 0.0
				start = timeit.default_timer()
Esempio n. 3
0
    def __init__(self, args):
        """initialize the Trainer"""
        # about gpus
        self.cuda = args.cuda
        self.gpu_ids = args.gpu_ids
        self.num_gpus = len(self.gpu_ids)
        self.no_val = args.no_val

        # about training schedule
        self.init_global_step = args.init_global_step
        self.start_epoch = args.start_epoch
        self.train_epochs = args.epochs

        # about the learning rate
        self.init_lr = args.init_lr
        self.lr_scheduler = args.lr_scheduler
        self.slow_start_lr = args.slow_start_lr
        self.lr_multiplier = args.lr_multiplier
        self.accumulation_steps = args.accumulation_steps

        # about the model_dir and checkpoint
        self.model_dir = args.model_dir
        self.save_ckpt_steps = args.save_ckpt_steps
        self.max_ckpt_nums = args.max_ckpt_nums
        self.saved_ckpt_filenames = []
        self.checkname = args.checkname

        # define global setp
        self.global_step = 0
        self.main_gpu = args.main_gpu

        # sync bn, both can be used.
        self.norm_layer = syncbn.BatchNorm2d if args.sync_bn else nn.BatchNorm2d
        #self.norm_layer = SynchronizedBatchNorm2d if args.sync_bn else nn.BatchNorm2d

        # define tensorboard summary
        self.train_writer = SummaryWriter(log_dir=self.model_dir)
        self.val_writer = SummaryWriter(
            log_dir=os.path.join(self.model_dir, 'eval'))

        # define dataloader
        self.train_loader, self.nclass = factory.get_data_loader(
            args.data_dir,
            batch_size=args.batch_size,
            crop_size=args.crop_size,
            dataset=args.dataset,
            split=args.train_split,
            num_workers=args.workers,
            pin_memory=True)
        self.val_loader, _ = factory.get_data_loader(
            args.data_dir,
            dataset=args.dataset,
            split="test" if 'camvid' in args.dataset else "val")

        # max iters
        self.steps_per_epochs = len(self.train_loader)
        self.max_iter = self.steps_per_epochs * self.train_epochs

        # define network
        assert args.seg_model in seg_model_obj_dict.keys()
        self.seg_model = args.seg_model
        self.model = seg_model_obj_dict[self.seg_model](
            num_classes=self.nclass,
            backbone=args.backbone,
            output_stride=args.out_stride,
            norm_layer=self.norm_layer,
            bn_mom=args.bn_mom,
            freeze_bn=args.freeze_bn)

        # define criterion
        self.loss_type = args.loss_type
        self.criterion = loss_factory.criterion_choose(
            self.nclass,
            weight=None,
            loss_type=args.loss_type,
            ignore_index=255,
            reduction='mean',
            max_iter=self.max_iter,
            args=args)
        self.model_with_loss = full_model.FullModel(seg_model=self.seg_model,
                                                    model=self.model,
                                                    loss_type=self.loss_type,
                                                    criterion=self.criterion)

        # define evaluator
        self.evaluator = Evaluator(self.nclass)

        # using cuda
        # If you need to move a model to GPU via .cuda(), please do so before constructing optimizers for it.
        # Parameters of a model after .cuda() will be different objects with those before the call.
        # In general, you should make sure that optimized parameters
        # live in consistent locations when optimizers are constructed and used.
        if args.cuda:
            self.model_with_loss = torch.nn.DataParallel(
                self.model_with_loss, device_ids=self.gpu_ids)
            if self.norm_layer is SynchronizedBatchNorm2d:
                patch_replication_callback(self.model_with_loss)
                print("INFO:PyTorch: The batch norm layer is {}".format(
                    self.norm_layer))
            elif self.norm_layer is syncbn.BatchNorm2d:
                parallel.patch_replication_callback(self.model_with_loss)
                print("INFO:PyTorch: The batch norm layer is Hang Zhang's {}".
                      format(self.norm_layer))
            self.model_with_loss = self.model_with_loss.cuda(self.main_gpu)

        # optimizer parameters, construct optim after module
        self.params_list = []
        self.params_list = model_init.seg_model_get_optim_params(
            self.params_list,
            self.model_with_loss.module.model,
            norm_layer=self.norm_layer,
            seg_model=args.seg_model,
            base_lr=args.init_lr,
            lr_multiplier=self.lr_multiplier,
            weight_decay=args.weight_decay)
        self.optimizer = torch.optim.SGD(self.params_list,
                                         momentum=args.momentum,
                                         nesterov=args.nesterov)

        # define learning rate scheduler.
        # Be careful about the learning rate for different params list, check
        # the `params_list` and the `lr_scheduler` to ensure the strategy is right.
        self.scheduler = train_utils.lr_scheduler(
            init_lr=self.init_lr,
            mode=self.lr_scheduler,
            num_epochs=self.train_epochs,
            max_iter=self.max_iter,
            slow_start_steps=args.slow_start_steps,
            slow_start_lr=args.slow_start_lr,
            multiplier=self.lr_multiplier)
        # resuming checkpoint
        #self.best_pred = 0.0
        if args.resume is not None:
            if os.path.isfile(args.resume):
                #raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
                print("INFO:PyTorch: Restore checkpoint from {}".format(
                    args.resume))
                checkpoint = torch.load(args.resume)
                self.global_step = checkpoint['global_step']
                if args.cuda:
                    self.model_with_loss.module.load_state_dict(
                        checkpoint['state_dict'])
                else:
                    self.model_with_loss.load_state_dict(
                        checkpoint['state_dict'])
                self.start_epoch = (self.global_step +
                                    1) // self.steps_per_epochs
Esempio n. 4
0
class Trainer(object):
    def __init__(self, args):
        """initialize the Trainer"""
        # about gpus
        self.cuda = args.cuda
        self.gpu_ids = args.gpu_ids
        self.num_gpus = len(self.gpu_ids)
        self.no_val = args.no_val

        # about training schedule
        self.init_global_step = args.init_global_step
        self.start_epoch = args.start_epoch
        self.train_epochs = args.epochs

        # about the learning rate
        self.init_lr = args.init_lr
        self.lr_scheduler = args.lr_scheduler
        self.slow_start_lr = args.slow_start_lr
        self.lr_multiplier = args.lr_multiplier
        self.accumulation_steps = args.accumulation_steps

        # about the model_dir and checkpoint
        self.model_dir = args.model_dir
        self.save_ckpt_steps = args.save_ckpt_steps
        self.max_ckpt_nums = args.max_ckpt_nums
        self.saved_ckpt_filenames = []
        self.checkname = args.checkname

        # define global setp
        self.global_step = 0
        self.main_gpu = args.main_gpu

        # sync bn, both can be used.
        self.norm_layer = syncbn.BatchNorm2d if args.sync_bn else nn.BatchNorm2d
        #self.norm_layer = SynchronizedBatchNorm2d if args.sync_bn else nn.BatchNorm2d

        # define tensorboard summary
        self.train_writer = SummaryWriter(log_dir=self.model_dir)
        self.val_writer = SummaryWriter(
            log_dir=os.path.join(self.model_dir, 'eval'))

        # define dataloader
        self.train_loader, self.nclass = factory.get_data_loader(
            args.data_dir,
            batch_size=args.batch_size,
            crop_size=args.crop_size,
            dataset=args.dataset,
            split=args.train_split,
            num_workers=args.workers,
            pin_memory=True)
        self.val_loader, _ = factory.get_data_loader(
            args.data_dir,
            dataset=args.dataset,
            split="test" if 'camvid' in args.dataset else "val")

        # max iters
        self.steps_per_epochs = len(self.train_loader)
        self.max_iter = self.steps_per_epochs * self.train_epochs

        # define network
        assert args.seg_model in seg_model_obj_dict.keys()
        self.seg_model = args.seg_model
        self.model = seg_model_obj_dict[self.seg_model](
            num_classes=self.nclass,
            backbone=args.backbone,
            output_stride=args.out_stride,
            norm_layer=self.norm_layer,
            bn_mom=args.bn_mom,
            freeze_bn=args.freeze_bn)

        # define criterion
        self.loss_type = args.loss_type
        self.criterion = loss_factory.criterion_choose(
            self.nclass,
            weight=None,
            loss_type=args.loss_type,
            ignore_index=255,
            reduction='mean',
            max_iter=self.max_iter,
            args=args)
        self.model_with_loss = full_model.FullModel(seg_model=self.seg_model,
                                                    model=self.model,
                                                    loss_type=self.loss_type,
                                                    criterion=self.criterion)

        # define evaluator
        self.evaluator = Evaluator(self.nclass)

        # using cuda
        # If you need to move a model to GPU via .cuda(), please do so before constructing optimizers for it.
        # Parameters of a model after .cuda() will be different objects with those before the call.
        # In general, you should make sure that optimized parameters
        # live in consistent locations when optimizers are constructed and used.
        if args.cuda:
            self.model_with_loss = torch.nn.DataParallel(
                self.model_with_loss, device_ids=self.gpu_ids)
            if self.norm_layer is SynchronizedBatchNorm2d:
                patch_replication_callback(self.model_with_loss)
                print("INFO:PyTorch: The batch norm layer is {}".format(
                    self.norm_layer))
            elif self.norm_layer is syncbn.BatchNorm2d:
                parallel.patch_replication_callback(self.model_with_loss)
                print("INFO:PyTorch: The batch norm layer is Hang Zhang's {}".
                      format(self.norm_layer))
            self.model_with_loss = self.model_with_loss.cuda(self.main_gpu)

        # optimizer parameters, construct optim after module
        self.params_list = []
        self.params_list = model_init.seg_model_get_optim_params(
            self.params_list,
            self.model_with_loss.module.model,
            norm_layer=self.norm_layer,
            seg_model=args.seg_model,
            base_lr=args.init_lr,
            lr_multiplier=self.lr_multiplier,
            weight_decay=args.weight_decay)
        self.optimizer = torch.optim.SGD(self.params_list,
                                         momentum=args.momentum,
                                         nesterov=args.nesterov)

        # define learning rate scheduler.
        # Be careful about the learning rate for different params list, check
        # the `params_list` and the `lr_scheduler` to ensure the strategy is right.
        self.scheduler = train_utils.lr_scheduler(
            init_lr=self.init_lr,
            mode=self.lr_scheduler,
            num_epochs=self.train_epochs,
            max_iter=self.max_iter,
            slow_start_steps=args.slow_start_steps,
            slow_start_lr=args.slow_start_lr,
            multiplier=self.lr_multiplier)
        # resuming checkpoint
        #self.best_pred = 0.0
        if args.resume is not None:
            if os.path.isfile(args.resume):
                #raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
                print("INFO:PyTorch: Restore checkpoint from {}".format(
                    args.resume))
                checkpoint = torch.load(args.resume)
                self.global_step = checkpoint['global_step']
                if args.cuda:
                    self.model_with_loss.module.load_state_dict(
                        checkpoint['state_dict'])
                else:
                    self.model_with_loss.load_state_dict(
                        checkpoint['state_dict'])
                self.start_epoch = (self.global_step +
                                    1) // self.steps_per_epochs
        #	if not args.ft:
        #		self.optimizer.load_state_dict(checkpoint['optimizer'])
        #	self.best_pred = checkpoint['best_pred']
        #	print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        # clear start epoch if fine-tuning
        #if args.ft:
        #	args.start_epoch = 0

    def training(self, epoch):
        """training procedure
		"""
        # set training mode
        self.model_with_loss.train()
        self.evaluator.reset()
        start_time = time.time()
        self.optimizer.zero_grad()
        # training loop
        for i, sample in enumerate(self.train_loader):
            # set grad zero
            self.optimizer.zero_grad()
            # accumulate global steps
            if (i + 1) % self.accumulation_steps == 0:
                self.global_step += 1
            image, target = sample['image'], sample['label']
            if self.cuda:
                image, target = image.cuda(self.main_gpu), target.cuda(
                    self.main_gpu)
            # adjust learning rate, pass input through the model, update
            self.scheduler(self.optimizer, self.global_step, epoch)
            output, loss = self.model_with_loss(inputs=image,
                                                target=target,
                                                global_step=self.global_step)
            #print(target.size())
            loss = loss.mean()
            loss.backward()
            #if (i + 1) % self.accumulation_steps == 0:
            # update the parameters and set the gradient to 0.
            self.optimizer.step()
            # add batch sample into evaluator
            pred = np.argmax(output.data.cpu().numpy(), axis=1)
            target = target.cpu().numpy()
            self.evaluator.add_batch(target, pred)

            # log info per 20 steps
            #if self.global_step % 20 == 0 and (i + 1) % self.accumulation_steps == 0:
            if self.global_step % 20 == 0:
                # the used time
                used_time = time.time() - start_time
                px_acc = self.evaluator.pixel_accuracy_np()
                miou = self.evaluator.mean_iou_np()
                lr_now = self.optimizer.param_groups[0]['lr']
                print(
                    "INFO:PyTorch: epoch={}/{}, steps={}, loss={:.5f}, learning_rate={:.5f}, train_miou={:.5f}, px_accuracy={:.5f}"
                    " ({:.3f} sec)".format(epoch + 1,
                                           self.train_epochs, self.global_step,
                                           loss.item(), lr_now, miou, px_acc,
                                           used_time))

                # summary per 100 steps
                if self.global_step % 100 == 0:
                    self.train_writer.add_scalar('train_miou',
                                                 miou,
                                                 global_step=self.global_step)
                    self.train_writer.add_scalar('px_accuracy',
                                                 px_acc,
                                                 global_step=self.global_step)
                    self.train_writer.add_scalar('learning_rate',
                                                 lr_now,
                                                 global_step=self.global_step)
                    self.train_writer.add_scalar('train_loss',
                                                 loss.item(),
                                                 global_step=self.global_step)

                start_time = time.time()

            # save checkpoints
            if (self.global_step % self.save_ckpt_steps
                    == 0) or (i == self.steps_per_epochs - 1):
                filename = os.path.join(
                    self.model_dir,
                    "{0}_ckpt_{1}.pth".format(self.checkname,
                                              self.global_step))
                self.saved_ckpt_filenames.append(filename)
                # remove the newest file if saved ckpts if more than max_ckpt_nums
                if len(self.saved_ckpt_filenames) > self.max_ckpt_nums:
                    del_filename = self.saved_ckpt_filenames.pop(0)
                    os.remove(del_filename)
                # save new ckpt
                state = {
                    'global_step': self.global_step,
                    'state_dict': self.model_with_loss.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                }
                torch.save(state, filename)

    def validation(self, epoch):
        """validation procedure
		"""
        # set validation mode
        self.model_with_loss.eval()
        self.evaluator.reset()
        test_loss = 0.0
        for i, sample in enumerate(self.val_loader):
            image, target = sample['image'], sample['label']
            # repeat, you can uncomment this the line
            #image, target = image.repeat(self.num_gpus, 1, 1, 1), target.repeat(self.num_gpus, 1, 1)
            if self.cuda:
                image, target = image.cuda(self.main_gpu), target.cuda(
                    self.main_gpu)
            # forward
            with torch.no_grad():
                output = self.model_with_loss(inputs=image, mode='val')

            # Add batch sample into evaluator
            pred = np.argmax(output.data.cpu().numpy(), axis=1)
            target = target.cpu().numpy()
            self.evaluator.add_batch(target, pred)

        # log and summary the validation results
        px_acc = self.evaluator.pixel_accuracy_np()
        val_miou = self.evaluator.mean_iou_np(is_show_per_class=True)
        print(
            "\nINFO:PyTorch: validation results: miou={:5f}, px_acc={:5f}, loss={:5f} \n"
            .format(val_miou, px_acc, test_loss))
        self.val_writer.add_scalar('val_loss', test_loss, self.global_step)
        self.val_writer.add_scalar('val_miou', val_miou, self.global_step)
        self.val_writer.add_scalar('val_px_acc', px_acc, self.global_step)
Esempio n. 5
0
class Trainer(object):
	def __init__(self, args):
		"""initialize the Trainer"""
		# about gpus
		self.cuda = args.cuda
		self.gpu_ids = args.gpu_ids
		self.num_gpus = len(self.gpu_ids)
		self.crf_iter_steps = args.crf_iter_steps
		self.output_dir = args.output_dir
		self.model = 'val'
		# define dataloader
		self.val_loader = factory.get_dataset(args.data_dir,
												batch_size=1,
												dataset=args.dataset,
												split=args.train_split)
		self.nclass = self.val_loader.NUM_CLASSES
		# define network
		assert args.seg_model in seg_model_obj_dict.keys()
		self.seg_model = args.seg_model
		self.seg_model = seg_model_obj_dict[self.seg_model](num_classes=self.nclass,
														backbone=args.backbone,
														output_stride=args.out_stride,
														norm_layer=torch.nn.BatchNorm2d,
														bn_mom=args.bn_mom,
														freeze_bn=True)

		# define criterion
		self.criterion = torch.nn.CrossEntropyLoss(weight=None, ignore_index=255, reduction='mean')
		self.model = full_model.FullModel(seg_model=self.seg_model,
														model=self.model,
														criterion=self.criterion)

		# define evaluator
		self.evaluator = Evaluator(self.nclass)

		# using cuda
		if args.cuda:
			self.model = torch.nn.DataParallel(self.model, device_ids=self.gpu_ids)
			#patch_replication_callback(self.model)
			self.model = self.model.cuda()
			self.criterion = self.criterion.cuda()

		# resuming checkpoint
		if args.resume is not None:
			if not os.path.isfile(args.resume):
				raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
			print('Restore parameters from the {}'.format(args.resume))
			checkpoint = torch.load(args.resume)
			self.global_step = checkpoint['global_step']

			if args.cuda:
				self.model.module.load_state_dict(checkpoint['state_dict'])
			else:
				self.model.load_state_dict(checkpoint['state_dict'])

	def validation(self):
		"""validation procedure
		"""
		# set validation mode
		self.model.eval()
		self.evaluator.reset()
		test_loss = 0.0
		start = timeit.default_timer()
		for i in range(len(self.val_loader)):
			#for i, sample in enumerate(self.val_loader):
			sample = self.val_loader[i]
			image, target = sample['image'], sample['label']
			image, target = image.repeat(self.num_gpus, 1, 1, 1), target.repeat(self.num_gpus, 1, 1)
			#print("{}-th sample, Image shape {}, label shape {}".format(i + 1, image.size(), target.size()))
			if self.cuda:
				image, target = image.cuda(), target.cuda()
			# forward
			with torch.no_grad():
				output = self.model(image)
			# the output of the pspnet is a tuple
			if self.seg_model == 'pspnet':
				output = output[0]
			loss = self.criterion(output, target.long())
			test_loss += loss.item()

			# get probs, shape [N, C, H, W] --> [N, H, W, C]
			output = output.squeeze_()
			pred = output.data.cpu().numpy()
			pred = np.argmax(pred, axis=0)
			target = target.squeeze_().cpu().numpy()

			# save output
			color_img = True
			path_to_output = os.path.join(self.output_dir, self.val_loader.image_ids[i] + '.png')
			pred = pred.astype(np.uint8)
			if color_img:
				pass
				pred_color = utils.decode_segmap(pred, dataset='pascal')
				result = Image.fromarray(pred_color.astype(np.uint8))
				result.save(path_to_output)
			else:
				result = Image.fromarray()
				result.save(path_to_output)
			# report time
			if not i % 100:
				stop = timeit.default_timer()
				print("current step = {} ({:.3f} sec)".format(i, stop - start))
				start = timeit.default_timer()

			# Add batch sample into evaluator
			self.evaluator.add_batch(target, pred)

		# log and summary the validation results
		# log and summary the validation results
		px_acc = self.evaluator.pixel_accuracy_np()
		val_miou = self.evaluator.mean_iou_np(is_show_per_class=True)
		print("\nINFO:PyTorch: validation results: miou={:5f}, px_acc={:5f}, loss={:5f} \n".
			format(val_miou, px_acc, test_loss))
Esempio n. 6
0
class Trainer(object):
    def __init__(self, args):
        """initialize the Trainer"""
        # about gpus
        self.cuda = args.cuda
        self.gpu_ids = args.gpu_ids
        self.num_gpus = len(self.gpu_ids)
        self.crf_iter_steps = args.crf_iter_steps
        self.output_dir = args.output_dir
        self.model = 'test'

        # define dataloader
        self.val_loader = factory.get_dataset(args.data_dir,
                                              batch_size=1,
                                              dataset=args.dataset,
                                              split=args.train_split)
        self.nclass = self.val_loader.NUM_CLASSES
        # define network
        assert args.seg_model in seg_model_obj_dict.keys()
        self.seg_model = args.seg_model
        self.seg_model = seg_model_obj_dict[self.seg_model](
            num_classes=self.nclass,
            backbone=args.backbone,
            output_stride=args.out_stride,
            norm_layer=torch.nn.BatchNorm2d,
            bn_mom=args.bn_mom,
            freeze_bn=True)

        # define criterion
        #self.criterion = torch.nn.CrossEntropyLoss(weight=None, ignore_index=255, reduction='mean')
        self.model = full_model.FullModel(seg_model=self.seg_model,
                                          model=self.model)
        # define evaluator
        self.evaluator = Evaluator(self.nclass)

        # using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.gpu_ids)
            #patch_replication_callback(self.model)
            self.model = self.model.cuda()
            #self.criterion = self.criterion.cuda()

        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            print('Restore parameters from the {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            self.global_step = checkpoint['global_step']

            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

    def validation(self):
        """validation procedure
		"""
        # set validation mode
        self.model.eval()
        self.evaluator.reset()
        start = timeit.default_timer()
        for i in range(len(self.val_loader)):
            sample = self.val_loader[i]
            image = sample['image']
            if self.cuda:
                image = image.cuda()
            image = image.unsqueeze(dim=0)
            # forward
            with torch.no_grad():
                output = self.model(image)
            # the output of the pspnet is a tuple
            if self.seg_model == 'pspnet':
                output = output[0]

            output = output.squeeze_()
            pred = output.data.cpu().numpy()
            # save output
            pred = np.argmax(pred, axis=0)
            path_to_output = os.path.join(
                self.output_dir, self.val_loader.image_ids[i] + '.png')
            result = Image.fromarray(pred.astype(np.uint8))
            result.save(path_to_output)
            #cv2.imwrite(path_to_output, pred)
            # report time of CRF
            if not i % 100:
                stop = timeit.default_timer()
                print("current step = {} ({:.3f} sec)".format(i, stop - start))
                start = timeit.default_timer()