def run_chunk(self, mode):
     self.dataset = get_dataset(self.cfg, self.augmentor, mode)
     if mode == 'train':
         num_chunk = self.total_iter_nums // self.cfg.DATASET.DATA_CHUNK_ITER
         self.total_iter_nums = self.cfg.DATASET.DATA_CHUNK_ITER
         for chunk in range(num_chunk):
             self.dataset.updatechunk()
             self.dataloader = build_dataloader(
                 self.cfg,
                 self.augmentor,
                 mode,
                 dataset=self.dataset.dataset)
             self.dataloader = iter(self.dataloader)
             print('start train', chunk)
             self.train()
             print('finished train', chunk)
             self.start_iter += self.cfg.DATASET.DATA_CHUNK_ITER
             del self.dataloader
     else:
         num_chunk = len(self.dataset.chunk_num_ind)
         for chunk in range(num_chunk):
             self.dataset.updatechunk(do_load=False)
             self.inference_output_name = self.cfg.INFERENCE.OUTPUT_NAME + self.dataset.get_coord_name(
             ) + '.h5'
             if not os.path.exists(
                     os.path.join(self.output_dir,
                                  self.inference_output_name)):
                 self.dataset.loadchunk()
                 self.dataloader = build_dataloader(
                     self.cfg,
                     self.augmentor,
                     mode,
                     dataset=self.dataset.dataset)
                 self.dataloader = iter(self.dataloader)
                 self.test()
예제 #2
0
    def __init__(self, cfg, device, mode, checkpoint=None):
        self.cfg = cfg
        self.device = device
        self.output_dir = cfg.DATASET.OUTPUT_PATH
        self.mode = mode

        self.model = build_model(self.cfg, self.device)
        self.optimizer = build_optimizer(self.cfg, self.model)
        self.lr_scheduler = build_lr_scheduler(self.cfg, self.optimizer)
        self.start_iter = self.cfg.MODEL.PRE_MODEL_ITER
        if checkpoint is not None:
            self.update_checkpoint(checkpoint)

        if self.mode == 'train':
            self.augmentor = build_train_augmentor(self.cfg)
        else:
            self.augmentor = None
        self.dataloader = build_dataloader(self.cfg, self.augmentor, self.mode)
        self.monitor = build_monitor(self.cfg)
        self.criterion = build_criterion(self.cfg, self.device)

        # add config details to tensorboard
        self.monitor.load_config(self.cfg)

        self.dataloader = iter(
            self.dataloader
        )  #iter() function returns an iterator for the given object. The iter() function creates an object which can be iterated one element at a time.
    def __init__(self, cfg, device, mode, checkpoint=None):
        self.cfg = cfg
        self.device = device
        self.output_dir = cfg.DATASET.OUTPUT_PATH
        self.mode = mode

        self.model = build_model(self.cfg, self.device)
        self.optimizer = build_optimizer(self.cfg, self.model)
        self.lr_scheduler = build_lr_scheduler(self.cfg, self.optimizer)
        self.start_iter = self.cfg.MODEL.PRE_MODEL_ITER
        if checkpoint is not None:
            self.update_checkpoint(checkpoint)

        if self.mode == 'train':
            self.augmentor = build_train_augmentor(self.cfg)
            self.monitor = build_monitor(self.cfg)
            self.criterion = build_criterion(self.cfg, self.device)
            # add config details to tensorboard
            self.monitor.load_config(self.cfg)
        else:
            self.augmentor = None

        if cfg.DATASET.DO_CHUNK_TITLE == 0:
            self.dataloader = build_dataloader(self.cfg, self.augmentor,
                                               self.mode)
            self.dataloader = iter(self.dataloader)
        else:
            self.dataset = None
            self.dataloader = None

        self.total_iter_nums = self.cfg.SOLVER.ITERATION_TOTAL - self.start_iter
        self.inference_output_name = self.cfg.INFERENCE.OUTPUT_NAME
        self.total_time = 0
예제 #4
0
    def __init__(self,
                 cfg,
                 device,
                 mode,
                 output_dir='outputs/',
                 checkpoint=None):
        self.cfg = cfg
        self.device = device
        self.output_dir = output_dir
        self.mode = mode

        self.model = build_model(self.cfg, self.device, checkpoint)
        self.optimizer = build_optimizer(self.cfg, self.model)
        self.lr_scheduler = build_lr_scheduler(self.cfg, self.optimizer)
        if self.mode == 'train':
            self.augmentor = build_train_augmentor(self.cfg)
        else:
            self.augmentor = None
        self.dataloader = build_dataloader(self.cfg, self.augmentor, self.mode)
        self.monitor = build_monitor(self.cfg)
        self.criterion = build_criterion(self.cfg, self.device)

        self.dataloader = iter(self.dataloader)