예제 #1
0
    def get_trainloader(self):
        if self.configer.exists('data', 'use_edge') and self.configer.get('data', 'use_edge') == 'ce2p':
            """
            ce2p manner:
            load both the ground-truth label and edge.
            """
            Log.info('use edge (follow ce2p) for train...')
            klass = LipLoader

        elif self.configer.exists('data', 'use_dt_offset') or self.configer.exists('data', 'pred_dt_offset'):
            """
            dt-offset manner:
            load both the ground-truth label and offset (based on distance transform).
            """
            Log.info('use distance transform offset loader for train...')
            klass = DTOffsetLoader

        elif self.configer.exists('train', 'loader') and \
            (self.configer.get('train', 'loader') == 'ade20k' 
             or self.configer.get('train', 'loader') == 'pascal_context'
             or self.configer.get('train', 'loader') == 'pascal_voc'
             or self.configer.get('train', 'loader') == 'coco_stuff'):
            """
            ADE20KLoader manner:
            support input images of different shapes.
            """
            Log.info('use ADE20KLoader (diverse input shape) for train...')
            klass = ADE20KLoader
        else:
            """
            Default manner:
            + support input images of the same shapes.
            + support distributed training (the performance is more un-stable than non-distributed manner)
            """
            Log.info('use the DefaultLoader for train...')
            klass = DefaultLoader
        loader, sampler = self.get_dataloader_sampler(klass, 'train', 'train')
        trainloader = data.DataLoader(
            loader,
            batch_size=self.configer.get('train', 'batch_size') // get_world_size(), pin_memory=True,
            num_workers=self.configer.get('data', 'workers') // get_world_size(),
            sampler=sampler,
            shuffle=(sampler is None),
            drop_last=self.configer.get('data', 'drop_last'),
            collate_fn=lambda *args: collate(
                *args, trans_dict=self.configer.get('train', 'data_transformer')
            )
        )
        return trainloader
예제 #2
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset

        if self.configer.exists('data', 'use_dt_offset') or self.configer.exists('data', 'pred_dt_offset'):
            """
            dt-offset manner:
            load both the ground-truth label and offset (based on distance transform).
            """   
            Log.info('use distance transform based offset loader for val ...')
            klass = DTOffsetLoader

        elif self.configer.get('method') == 'fcn_segmentor':
            """
            default manner:
            load the ground-truth label.
            """   
            Log.info('use DefaultLoader for val ...')
            klass = DefaultLoader
        else:
            Log.error('Method: {} loader is invalid.'.format(self.configer.get('method')))
            return None

        loader, sampler = self.get_dataloader_sampler(klass, 'val', dataset)
        valloader = data.DataLoader(
            loader,
            sampler=sampler,
            batch_size=self.configer.get('val', 'batch_size') // get_world_size(), pin_memory=True,
            num_workers=self.configer.get('data', 'workers'), shuffle=False,
            collate_fn=lambda *args: collate(
                *args, trans_dict=self.configer.get('val', 'data_transformer')
            )
        )
        return valloader
예제 #3
0
def setup_logging(output_dir=None):
    """
    Sets up the logging for multiple processes. Only enable the logging for the
    master process, and suppress logging for the non-master processes.
    """
    # Set up logging format.
    _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s"

    if du.is_master_proc():
        # Enable logging for the master process.
        logging.root.handlers = []
    else:
        # Suppress logging for non-master processes.
        _suppress_print()

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    logger.propagate = False
    plain_formatter = logging.Formatter(
        "[%(asctime)s][%(levelname)s] %(filename)s: %(lineno)3d: %(message)s",
        datefmt="%m/%d %H:%M:%S",
    )

    if du.is_master_proc():
        ch = logging.StreamHandler(stream=sys.stdout)
        ch.setLevel(logging.DEBUG)
        ch.setFormatter(plain_formatter)
        logger.addHandler(ch)

    if output_dir is not None and du.is_master_proc(du.get_world_size()):
        filename = os.path.join(output_dir, "stdout.log")
        fh = logging.StreamHandler(_cached_log_stream(filename))
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(plain_formatter)
        logger.addHandler(fh)
예제 #4
0
 def reduce_tensor(inp):
     """
     Reduce the loss from all processes so that 
     process with rank 0 has the averaged results.
     """
     world_size = get_world_size()
     if world_size < 2:
         return inp
     with torch.no_grad():
         reduced_inp = inp
         dist.reduce(reduced_inp, dst=0)
     return reduced_inp
예제 #5
0
    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        self.pixel_loss.train()
        start_time = time.time()

        if "swa" in self.configer.get('lr', 'lr_policy'):
            normal_max_iters = int(self.configer.get('solver', 'max_iters') * 0.75)
            swa_step_max_iters = (self.configer.get('solver', 'max_iters') - normal_max_iters) // 5 + 1

        if hasattr(self.train_loader.sampler, 'set_epoch'):
            self.train_loader.sampler.set_epoch(self.configer.get('epoch'))

        for i, data_dict in enumerate(self.train_loader):
            if self.configer.get('lr', 'metric') == 'iters':
                self.scheduler.step(self.configer.get('iters'))
            else:
                self.scheduler.step(self.configer.get('epoch'))


            if self.configer.get('lr', 'is_warm'):
                self.module_runner.warm_lr(
                    self.configer.get('iters'),
                    self.scheduler, self.optimizer, backbone_list=[0,]
                )

            (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict)
            self.data_time.update(time.time() - start_time)

            foward_start_time = time.time()
            outputs = self.seg_net(*inputs)
            self.foward_time.update(time.time() - foward_start_time)

            loss_start_time = time.time()
            if is_distributed():
                import torch.distributed as dist
                def reduce_tensor(inp):
                    """
                    Reduce the loss from all processes so that 
                    process with rank 0 has the averaged results.
                    """
                    world_size = get_world_size()
                    if world_size < 2:
                        return inp
                    with torch.no_grad():
                        reduced_inp = inp
                        dist.reduce(reduced_inp, dst=0)
                    return reduced_inp
                loss = self.pixel_loss(outputs, targets)
                backward_loss = loss
                display_loss = reduce_tensor(backward_loss) / get_world_size()
            else:
                backward_loss = display_loss = self.pixel_loss(outputs, targets, gathered=self.configer.get('network', 'gathered'))

            self.train_losses.update(display_loss.item(), batch_size)
            self.loss_time.update(time.time() - loss_start_time)

            backward_start_time = time.time()
            self.optimizer.zero_grad()
            backward_loss.backward()
            self.optimizer.step()
            self.backward_time.update(time.time() - backward_start_time)

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.configer.plus_one('iters')

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get('solver', 'display_iter') == 0 and \
                (not is_distributed() or get_rank() == 0):
                Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t'
                         'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                         'Forward Time {foward_time.sum:.3f}s / {2}iters, ({foward_time.avg:.3f})\t'
                         'Backward Time {backward_time.sum:.3f}s / {2}iters, ({backward_time.avg:.3f})\t'
                         'Loss Time {loss_time.sum:.3f}s / {2}iters, ({loss_time.avg:.3f})\t'
                         'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                         'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                         self.configer.get('epoch'), self.configer.get('iters'),
                         self.configer.get('solver', 'display_iter'),
                         self.module_runner.get_lr(self.optimizer), batch_time=self.batch_time,
                         foward_time=self.foward_time, backward_time=self.backward_time, loss_time=self.loss_time,
                         data_time=self.data_time, loss=self.train_losses))
                self.batch_time.reset()
                self.foward_time.reset()
                self.backward_time.reset()
                self.loss_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            # save checkpoints for swa
            if 'swa' in self.configer.get('lr', 'lr_policy') and \
               self.configer.get('iters') > normal_max_iters and \
               ((self.configer.get('iters') - normal_max_iters) % swa_step_max_iters == 0 or \
                self.configer.get('iters') == self.configer.get('solver', 'max_iters')):
               self.optimizer.update_swa()

            if self.configer.get('iters') == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            # if self.configer.get('epoch') % self.configer.get('solver', 'test_interval') == 0:
            if self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0:
                self.__val()

        self.configer.plus_one('epoch')
예제 #6
0
    def get_trainloader(self):
        if self.configer.exists('data', 'use_edge') and self.configer.get(
                'data', 'use_edge') == 'ce2p':
            """
            ce2p manner:
            load both the ground-truth label and edge.
            """
            Log.info('use edge (follow ce2p) for train...')
            trainloader = data.DataLoader(
                LipLoader(root_dir=self.configer.get('data', 'data_dir'),
                          dataset='train',
                          aug_transform=self.aug_train_transform,
                          img_transform=self.img_transform,
                          label_transform=self.label_transform,
                          configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                pin_memory=True,
                num_workers=self.configer.get('data', 'workers'),
                shuffle=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))
            return trainloader

        elif self.configer.exists('train', 'loader') and \
            (self.configer.get('train', 'loader') == 'ade20k'
             or self.configer.get('train', 'loader') == 'pascal_context'
             or self.configer.get('train', 'loader') == 'pascal_voc'
             or self.configer.get('train', 'loader') == 'coco_stuff'):
            """
            ADE20KLoader manner:
            support input images of different shapes.
            """
            Log.info('use ADE20KLoader (diverse input shape) for train...')
            trainloader = data.DataLoader(
                ADE20KLoader(root_dir=self.configer.get('data', 'data_dir'),
                             dataset='train',
                             aug_transform=self.aug_train_transform,
                             img_transform=self.img_transform,
                             label_transform=self.label_transform,
                             configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                pin_memory=True,
                num_workers=self.configer.get('data', 'workers'),
                shuffle=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))
            return trainloader

        else:
            """
            Default manner:
            support input images of the same shapes.
            """
            dataset = DefaultLoader(root_dir=self.configer.get(
                'data', 'data_dir'),
                                    dataset='train',
                                    aug_transform=self.aug_train_transform,
                                    img_transform=self.img_transform,
                                    label_transform=self.label_transform,
                                    configer=self.configer)
            if is_distributed():
                sampler = torch.utils.data.distributed.DistributedSampler(
                    dataset)
            else:
                sampler = None
            Log.info('use the DefaultLoader for train...')
            trainloader = data.DataLoader(
                dataset,
                batch_size=self.configer.get('train', 'batch_size') //
                get_world_size(),
                pin_memory=True,
                num_workers=self.configer.get('data', 'workers') //
                get_world_size(),
                sampler=sampler,
                shuffle=(sampler is None),
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))
            return trainloader