예제 #1
0
    def get_dataloader_sampler(self, klass, split, dataset):

        from lib.datasets.loader.multi_dataset_loader import MultiDatasetLoader, MultiDatasetTrainingSampler

        root_dir = self.configer.get('data', 'data_dir')
        if isinstance(root_dir, list) and len(root_dir) == 1:
            root_dir = root_dir[0]

        kwargs = dict(
            dataset=dataset,
            aug_transform=(self.aug_train_transform if split == 'train' else self.aug_val_transform),
            img_transform=self.img_transform,
            torch_img_transform=(self.torch_img_transform if split == 'train' else None),
            label_transform=self.label_transform,
            configer=self.configer
        )

        if isinstance(root_dir, str):
            loader = klass(root_dir, **kwargs)
            multi_dataset = False
        elif isinstance(root_dir, list):
            loader = MultiDatasetLoader(root_dir, klass, **kwargs)
            multi_dataset = True
            Log.info('use multi-dataset for {}...'.format(dataset))
        else:
            raise RuntimeError('Unknown root dir {}'.format(root_dir))

        if split == 'train':
            if is_distributed() and multi_dataset:
                raise RuntimeError('Currently multi dataset doesn\'t support distributed.')

            if is_distributed():
                sampler = torch.utils.data.distributed.DistributedSampler(loader)
            elif multi_dataset:
                sampler = MultiDatasetTrainingSampler(loader)
            else:
                sampler = None

        elif split == 'val':

            if is_distributed():
                sampler = torch.utils.data.distributed.DistributedSampler(loader)
            else:
                sampler = None

        return loader, sampler
예제 #2
0
    def _parallel(self, loss):
        if is_distributed():
            Log.info('use distributed loss')
            return loss
            
        if self.configer.get('network', 'loss_balance') and len(self.configer.get('gpu')) > 1:
            Log.info('use DataParallelCriterion loss')
            from lib.extensions.parallel.data_parallel import DataParallelCriterion
            loss = DataParallelCriterion(loss)

        return loss
예제 #3
0
    def to_device(self, *params, force_list=False):
        if is_distributed():
            device = torch.device('cuda:{}'.format(get_rank()))
        else:
            device = torch.device('cpu' if self.configer.get('gpu') is None else 'cuda')
        return_list = list()
        for i in range(len(params)):
            return_list.append(params[i].to(device))

        if force_list:
            return return_list
        else:
            return return_list[0] if len(params) == 1 else return_list
예제 #4
0
    def save_net(self, net, save_mode='iters'):
        if is_distributed() and get_rank() != 0:
            return

        state = {
            'config_dict': self.configer.to_dict(),
            'state_dict': net.state_dict(),
        }
        if self.configer.get('checkpoints', 'checkpoints_root') is None:
            checkpoints_dir = os.path.join(self.configer.get('project_dir'),
                                           self.configer.get('checkpoints', 'checkpoints_dir'))
        else:
            checkpoints_dir = os.path.join(self.configer.get('checkpoints', 'checkpoints_root'),
                                           self.configer.get('checkpoints', 'checkpoints_dir'))

        if not os.path.exists(checkpoints_dir):
            os.makedirs(checkpoints_dir)

        latest_name = '{}_latest.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'))
        torch.save(state, os.path.join(checkpoints_dir, latest_name))
        if save_mode == 'performance':
            if self.configer.get('performance') > self.configer.get('max_performance'):
                latest_name = '{}_max_performance.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update(['max_performance'], self.configer.get('performance'))

        elif save_mode == 'val_loss':
            if self.configer.get('val_loss') < self.configer.get('min_val_loss'):
                latest_name = '{}_min_loss.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update(['min_val_loss'], self.configer.get('val_loss'))

        elif save_mode == 'iters':
            if self.configer.get('iters') - self.configer.get('last_iters') >= \
                    self.configer.get('checkpoints', 'save_iters'):
                latest_name = '{}_iters{}.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'),
                                                 self.configer.get('iters'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update(['last_iters'], self.configer.get('iters'))

        elif save_mode == 'epoch':
            if self.configer.get('epoch') - self.configer.get('last_epoch') >= \
                    self.configer.get('checkpoints', 'save_epoch'):
                latest_name = '{}_epoch{}.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'),
                                                 self.configer.get('epoch'))
                torch.save(state, os.path.join(checkpoints_dir, latest_name))
                self.configer.update(['last_epoch'], self.configer.get('epoch'))

        else:
            Log.error('Metric: {} is invalid.'.format(save_mode))
            exit(1)
예제 #5
0
    def _make_parallel(self, net):
        if is_distributed():
            local_rank = get_rank()

            return torch.nn.parallel.DistributedDataParallel(
                net,
                device_ids=[local_rank],
                output_device=local_rank,
            )

        if len(self.configer.get('gpu')) == 1:
            self.configer.update(['network', 'gathered'], True)

        return DataParallelModel(net, gather_=self.configer.get('network', 'gathered'))
예제 #6
0
    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        Log.info('Params Group Method: {}'.format(self.configer.get('optim', 'group_method')))
        if self.configer.get('optim', 'group_method') == 'decay':
            params_group = self.group_weight(self.seg_net)
        else:
            assert self.configer.get('optim', 'group_method') is None
            params_group = self._get_parameters()

        self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(params_group)

        self.train_loader = self.data_loader.get_trainloader()
        self.val_loader = self.data_loader.get_valloader()
        self.pixel_loss = self.loss_manager.get_seg_loss()
        if is_distributed():
            self.pixel_loss = self.module_runner.to_device(self.pixel_loss)        
예제 #7
0
    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        Log.info('Params Group Method: {}'.format(
            self.configer.get('optim', 'group_method')))
        if self.configer.get('optim', 'group_method') == 'decay':
            params_group = self.group_weight(self.seg_net)
        else:
            assert self.configer.get('optim', 'group_method') is None
            params_group = self._get_parameters()

        self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(
            params_group)

        self.train_loader = self.data_loader.get_trainloader()
        self.val_loader = self.data_loader.get_valloader()
        self.pixel_loss = self.loss_manager.get_seg_loss()
        if is_distributed():
            self.pixel_loss = self.module_runner.to_device(self.pixel_loss)

        self.with_contrast = True if self.configer.exists(
            "contrast") else False
        if self.configer.exists("contrast", "warmup_iters"):
            self.contrast_warmup_iters = self.configer.get(
                "contrast", "warmup_iters")
        else:
            self.contrast_warmup_iters = 0

        self.with_memory = self.configer.exists('contrast', 'with_memory')
        if self.with_memory:
            self.memory_size = self.configer.get('contrast', 'memory_size')
            self.pixel_update_freq = self.configer.get('contrast',
                                                       'pixel_update_freq')

        self.network_stride = self.configer.get('network', 'stride')

        Log.info("with_contrast: {}, warmup_iters: {}, with_memory: {}".format(
            self.with_contrast, self.contrast_warmup_iters, self.with_memory))
예제 #8
0
    def load_net(self, net):
        net = self.to_device(net)
        net = self._make_parallel(net)

        if not is_distributed():
            net = net.to(torch.device('cpu' if self.configer.get('gpu') is None else 'cuda'))

        net.float()
        if self.configer.get('network', 'resume') is not None:
            Log.info('Loading checkpoint from {}...'.format(self.configer.get('network', 'resume')))
            resume_dict = torch.load(self.configer.get('network', 'resume'))
            if 'state_dict' in resume_dict:
                checkpoint_dict = resume_dict['state_dict']

            elif 'model' in resume_dict:
                checkpoint_dict = resume_dict['model']

            elif isinstance(resume_dict, OrderedDict):
                checkpoint_dict = resume_dict

            else:
                raise RuntimeError(
                    'No state_dict found in checkpoint file {}'.format(self.configer.get('network', 'resume')))

            if list(checkpoint_dict.keys())[0].startswith('module.'):
                checkpoint_dict = {k[7:]: v for k, v in checkpoint_dict.items()}

            # load state_dict
            if hasattr(net, 'module'):
                self.load_state_dict(net.module, checkpoint_dict, self.configer.get('network', 'resume_strict'))
            else:
                self.load_state_dict(net, checkpoint_dict, self.configer.get('network', 'resume_strict'))

            if self.configer.get('network', 'resume_continue'):
                self.configer.resume(resume_dict['config_dict'])

        return net
예제 #9
0
    def __val(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        self.pixel_loss.eval()
        start_time = time.time()
        replicas = self.evaluator.prepare_validaton()

        data_loader = self.val_loader if data_loader is None else data_loader
        for j, data_dict in enumerate(data_loader):
            if j % 10 == 0:
                Log.info('{} images processed\n'.format(j))

            if self.configer.get('dataset') == 'lip':
                (inputs, targets, inputs_rev, targets_rev), batch_size = self.data_helper.prepare_data(data_dict, want_reverse=True)
            else:
                (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict)

            with torch.no_grad():
                if self.configer.get('dataset') == 'lip':
                    inputs = torch.cat([inputs[0], inputs_rev[0]], dim=0)
                    outputs = self.seg_net(inputs)        
                    outputs_ = self.module_runner.gather(outputs)
                    if isinstance(outputs_, (list, tuple)):
                        outputs_ = outputs_[-1]
                    outputs = outputs_[0:int(outputs_.size(0)/2),:,:,:].clone()
                    outputs_rev = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),:,:,:].clone()
                    if outputs_rev.shape[1] == 20:
                        outputs_rev[:,14,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),15,:,:]
                        outputs_rev[:,15,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),14,:,:]
                        outputs_rev[:,16,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),17,:,:]
                        outputs_rev[:,17,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),16,:,:]
                        outputs_rev[:,18,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),19,:,:]
                        outputs_rev[:,19,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),18,:,:]
                    outputs_rev = torch.flip(outputs_rev, [3])
                    outputs = (outputs + outputs_rev) / 2.
                    self.evaluator.update_score(outputs, data_dict['meta'])

                elif self.data_helper.conditions.diverse_size:
                    outputs = nn.parallel.parallel_apply(replicas[:len(inputs)], inputs)

                    for i in range(len(outputs)):
                        loss = self.pixel_loss(outputs[i], targets[i])
                        self.val_losses.update(loss.item(), 1)
                        outputs_i = outputs[i]
                        if isinstance(outputs_i, torch.Tensor):
                            outputs_i = [outputs_i]
                        self.evaluator.update_score(outputs_i, data_dict['meta'][i:i+1])
                            
                else:
                    outputs = self.seg_net(*inputs)

                    try:
                        loss = self.pixel_loss(
                            outputs, targets, 
                            gathered=self.configer.get('network', 'gathered')
                        )
                    except AssertionError as e:
                        print(len(outputs), len(targets))


                    if not is_distributed():
                        outputs = self.module_runner.gather(outputs)
                    self.val_losses.update(loss.item(), batch_size)
                    self.evaluator.update_score(outputs, data_dict['meta'])

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.evaluator.update_performance()
        
        self.configer.update(['val_loss'], self.val_losses.avg)
        self.module_runner.save_net(self.seg_net, save_mode='performance')
        self.module_runner.save_net(self.seg_net, save_mode='val_loss')
        cudnn.benchmark = True

        # Print the log info & reset the states.
        if not is_distributed() or get_rank() == 0:
            Log.info(
                'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                'Loss {loss.avg:.8f}\n'.format(
                    batch_time=self.batch_time, loss=self.val_losses))
            self.evaluator.print_scores()
            
        self.batch_time.reset()
        self.val_losses.reset()
        self.evaluator.reset()
        self.seg_net.train()
        self.pixel_loss.train()
예제 #10
0
    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        self.pixel_loss.train()
        start_time = time.time()

        if "swa" in self.configer.get('lr', 'lr_policy'):
            normal_max_iters = int(self.configer.get('solver', 'max_iters') * 0.75)
            swa_step_max_iters = (self.configer.get('solver', 'max_iters') - normal_max_iters) // 5 + 1

        if hasattr(self.train_loader.sampler, 'set_epoch'):
            self.train_loader.sampler.set_epoch(self.configer.get('epoch'))

        for i, data_dict in enumerate(self.train_loader):
            if self.configer.get('lr', 'metric') == 'iters':
                self.scheduler.step(self.configer.get('iters'))
            else:
                self.scheduler.step(self.configer.get('epoch'))


            if self.configer.get('lr', 'is_warm'):
                self.module_runner.warm_lr(
                    self.configer.get('iters'),
                    self.scheduler, self.optimizer, backbone_list=[0,]
                )

            (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict)
            self.data_time.update(time.time() - start_time)

            foward_start_time = time.time()
            outputs = self.seg_net(*inputs)
            self.foward_time.update(time.time() - foward_start_time)

            loss_start_time = time.time()
            if is_distributed():
                import torch.distributed as dist
                def reduce_tensor(inp):
                    """
                    Reduce the loss from all processes so that 
                    process with rank 0 has the averaged results.
                    """
                    world_size = get_world_size()
                    if world_size < 2:
                        return inp
                    with torch.no_grad():
                        reduced_inp = inp
                        dist.reduce(reduced_inp, dst=0)
                    return reduced_inp
                loss = self.pixel_loss(outputs, targets)
                backward_loss = loss
                display_loss = reduce_tensor(backward_loss) / get_world_size()
            else:
                backward_loss = display_loss = self.pixel_loss(outputs, targets, gathered=self.configer.get('network', 'gathered'))

            self.train_losses.update(display_loss.item(), batch_size)
            self.loss_time.update(time.time() - loss_start_time)

            backward_start_time = time.time()
            self.optimizer.zero_grad()
            backward_loss.backward()
            self.optimizer.step()
            self.backward_time.update(time.time() - backward_start_time)

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.configer.plus_one('iters')

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get('solver', 'display_iter') == 0 and \
                (not is_distributed() or get_rank() == 0):
                Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t'
                         'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                         'Forward Time {foward_time.sum:.3f}s / {2}iters, ({foward_time.avg:.3f})\t'
                         'Backward Time {backward_time.sum:.3f}s / {2}iters, ({backward_time.avg:.3f})\t'
                         'Loss Time {loss_time.sum:.3f}s / {2}iters, ({loss_time.avg:.3f})\t'
                         'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                         'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                         self.configer.get('epoch'), self.configer.get('iters'),
                         self.configer.get('solver', 'display_iter'),
                         self.module_runner.get_lr(self.optimizer), batch_time=self.batch_time,
                         foward_time=self.foward_time, backward_time=self.backward_time, loss_time=self.loss_time,
                         data_time=self.data_time, loss=self.train_losses))
                self.batch_time.reset()
                self.foward_time.reset()
                self.backward_time.reset()
                self.loss_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            # save checkpoints for swa
            if 'swa' in self.configer.get('lr', 'lr_policy') and \
               self.configer.get('iters') > normal_max_iters and \
               ((self.configer.get('iters') - normal_max_iters) % swa_step_max_iters == 0 or \
                self.configer.get('iters') == self.configer.get('solver', 'max_iters')):
               self.optimizer.update_swa()

            if self.configer.get('iters') == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            # if self.configer.get('epoch') % self.configer.get('solver', 'test_interval') == 0:
            if self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0:
                self.__val()

        self.configer.plus_one('epoch')
예제 #11
0
    def get_trainloader(self):
        if self.configer.exists('data', 'use_edge') and self.configer.get(
                'data', 'use_edge') == 'ce2p':
            """
            ce2p manner:
            load both the ground-truth label and edge.
            """
            Log.info('use edge (follow ce2p) for train...')
            trainloader = data.DataLoader(
                LipLoader(root_dir=self.configer.get('data', 'data_dir'),
                          dataset='train',
                          aug_transform=self.aug_train_transform,
                          img_transform=self.img_transform,
                          label_transform=self.label_transform,
                          configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                pin_memory=True,
                num_workers=self.configer.get('data', 'workers'),
                shuffle=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))
            return trainloader

        elif self.configer.exists('train', 'loader') and \
            (self.configer.get('train', 'loader') == 'ade20k'
             or self.configer.get('train', 'loader') == 'pascal_context'
             or self.configer.get('train', 'loader') == 'pascal_voc'
             or self.configer.get('train', 'loader') == 'coco_stuff'):
            """
            ADE20KLoader manner:
            support input images of different shapes.
            """
            Log.info('use ADE20KLoader (diverse input shape) for train...')
            trainloader = data.DataLoader(
                ADE20KLoader(root_dir=self.configer.get('data', 'data_dir'),
                             dataset='train',
                             aug_transform=self.aug_train_transform,
                             img_transform=self.img_transform,
                             label_transform=self.label_transform,
                             configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                pin_memory=True,
                num_workers=self.configer.get('data', 'workers'),
                shuffle=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))
            return trainloader

        else:
            """
            Default manner:
            support input images of the same shapes.
            """
            dataset = DefaultLoader(root_dir=self.configer.get(
                'data', 'data_dir'),
                                    dataset='train',
                                    aug_transform=self.aug_train_transform,
                                    img_transform=self.img_transform,
                                    label_transform=self.label_transform,
                                    configer=self.configer)
            if is_distributed():
                sampler = torch.utils.data.distributed.DistributedSampler(
                    dataset)
            else:
                sampler = None
            Log.info('use the DefaultLoader for train...')
            trainloader = data.DataLoader(
                dataset,
                batch_size=self.configer.get('train', 'batch_size') //
                get_world_size(),
                pin_memory=True,
                num_workers=self.configer.get('data', 'workers') //
                get_world_size(),
                sampler=sampler,
                shuffle=(sampler is None),
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))
            return trainloader