示例#1
0
    def __init__(self, cfg, device, mode, checkpoint=None):
        self.cfg = cfg
        self.device = device
        self.output_dir = cfg.DATASET.OUTPUT_PATH
        self.mode = mode

        self.model = build_model(self.cfg, self.device)
        self.optimizer = build_optimizer(self.cfg, self.model)
        self.lr_scheduler = build_lr_scheduler(self.cfg, self.optimizer)
        self.start_iter = self.cfg.MODEL.PRE_MODEL_ITER
        if checkpoint is not None:
            self.update_checkpoint(checkpoint)

        if self.mode == 'train':
            self.augmentor = build_train_augmentor(self.cfg)
        else:
            self.augmentor = None
        self.dataloader = build_dataloader(self.cfg, self.augmentor, self.mode)
        self.monitor = build_monitor(self.cfg)
        self.criterion = build_criterion(self.cfg, self.device)

        # add config details to tensorboard
        self.monitor.load_config(self.cfg)

        self.dataloader = iter(
            self.dataloader
        )  #iter() function returns an iterator for the given object. The iter() function creates an object which can be iterated one element at a time.
    def __init__(self, cfg, device, mode, checkpoint=None):
        self.cfg = cfg
        self.device = device
        self.output_dir = cfg.DATASET.OUTPUT_PATH
        self.mode = mode

        self.model = build_model(self.cfg, self.device)
        self.optimizer = build_optimizer(self.cfg, self.model)
        self.lr_scheduler = build_lr_scheduler(self.cfg, self.optimizer)
        self.start_iter = self.cfg.MODEL.PRE_MODEL_ITER
        if checkpoint is not None:
            self.update_checkpoint(checkpoint)

        if self.mode == 'train':
            self.augmentor = build_train_augmentor(self.cfg)
            self.monitor = build_monitor(self.cfg)
            self.criterion = build_criterion(self.cfg, self.device)
            # add config details to tensorboard
            self.monitor.load_config(self.cfg)
        else:
            self.augmentor = None

        if cfg.DATASET.DO_CHUNK_TITLE == 0:
            self.dataloader = build_dataloader(self.cfg, self.augmentor,
                                               self.mode)
            self.dataloader = iter(self.dataloader)
        else:
            self.dataset = None
            self.dataloader = None

        self.total_iter_nums = self.cfg.SOLVER.ITERATION_TOTAL - self.start_iter
        self.inference_output_name = self.cfg.INFERENCE.OUTPUT_NAME
        self.total_time = 0
示例#3
0
    def __init__(self,
                 cfg,
                 device,
                 mode,
                 output_dir='outputs/',
                 checkpoint=None):
        self.cfg = cfg
        self.device = device
        self.output_dir = output_dir
        self.mode = mode

        self.model = build_model(self.cfg, self.device, checkpoint)
        self.optimizer = build_optimizer(self.cfg, self.model)
        self.lr_scheduler = build_lr_scheduler(self.cfg, self.optimizer)
        if self.mode == 'train':
            self.augmentor = build_train_augmentor(self.cfg)
        else:
            self.augmentor = None
        self.dataloader = build_dataloader(self.cfg, self.augmentor, self.mode)
        self.monitor = build_monitor(self.cfg)
        self.criterion = build_criterion(self.cfg, self.device)

        self.dataloader = iter(self.dataloader)