예제 #1
0
    def __init__(self, config):
        self.experiment_name = config.pop('name')
        self.random_seed = config.get('random_seed', 30)

        self.start_epoch = config["trainer"]["scheduler"]["start_epoch"]
        self.niter = config["trainer"]["scheduler"]["niter"]
        self.niter_decay = config["trainer"]["scheduler"]["niter_decay"]

        model_name = get_model_name(config["arch"].pop("type"))
        self.model = getattr(models, model_name)(config)
        logger.info("model init success")

        tensorboard_log_dir = config["trainer"]["log_dir"]
        self.writer = SummaryWriter(log_dir=tensorboard_log_dir,
                                    comment=self.experiment_name)
        self.display_freq = config["trainer"]["display_freq"]
        # self.evaluate_freq = config["trainer"]["evaluate_freq"]
        self.print_freq = config["trainer"]["print_freq"]
        self.evaluate_freq = config["trainer"]["evaluate_freq"]
        self.save_epoch_freq = config["trainer"].get("save_epoch_freq", 0)
        self.save_step_freq = config["trainer"].get("save_step_freq", 0)

        dataset_args = config["datasets"]["train"]["dataset"]
        dataset_type = dataset_args["type"]
        dataset_init_args = dataset_args["args"]
        dataset = getattr(datasets, dataset_type)(**dataset_init_args)
        data_loader_args = config["datasets"]["train"]["loader"]
        train_loader = DataLoader(dataset=dataset, **data_loader_args)
        self.train_loader = train_loader
        self.model.set_mode()

        self.global_step = 0
예제 #2
0
    def __init__(self, config):
        self.experiment_name = config.pop('name')
        self.random_seed = config.get('random_seed', 30)

        self.start_epoch = config["trainer"]["scheduler"]["start_epoch"]
        self.niter = config["trainer"]["scheduler"]["niter"]
        self.niter_decay = config["trainer"]["scheduler"]["niter_decay"]

        train_dataset_args = config["datasets"]["train"]["dataset"]
        train_dataset_type = train_dataset_args["type"]
        train_dataset_init_args = train_dataset_args["args"]
        train_dataset = getattr(datasets, train_dataset_type)(**train_dataset_init_args)
        train_data_loader_args = config["datasets"]["train"]["loader"]
        collate_fn = PaddingCollate(imgH=config["arch"]["srn"].get("img_h", 48),
                                    imgW=config["arch"]["srn"].get("img_w", 800),
                                    nc=config["arch"]["srn"].get("input_channel", 1))
        logger.info("train dataset len:{}".format(len(train_dataset)))
        train_loader = DataLoader(dataset=train_dataset, collate_fn=collate_fn, **train_data_loader_args)
        self.train_loader = train_loader

        val_dataset_args = config["datasets"]["val"]["dataset"]
        val_dataset_type = val_dataset_args["type"]
        val_dataset_init_args = val_dataset_args["args"]
        val_dataset = getattr(datasets, val_dataset_type)(**val_dataset_init_args)
        val_data_loader_args = config["datasets"]["val"]["loader"]
        logger.info("val dataset len:{}".format(len(val_dataset)))
        val_collate_fn = PaddingCollate(imgH=config["arch"]["srn"].get("img_h", 48),
                                        imgW=config["arch"]["srn"].get("img_w", 800),
                                        nc=config["arch"]["srn"].get("input_channel", 1))
        val_loader = DataLoader(dataset=val_dataset, collate_fn=val_collate_fn, **val_data_loader_args)
        self.val_loader = val_loader

        model_name = get_model_name(config["arch"].pop("type"))
        self.model = getattr(models, model_name)(config)
        logger.info("model init success")

        tensorboard_log_dir = config["trainer"]["log_dir"]
        self.writer = SummaryWriter(log_dir=tensorboard_log_dir, comment=self.experiment_name)
        self.display_val_freq = config["trainer"]["display_val_freq"]
        # self.evaluate_freq = config["trainer"]["evaluate_freq"]
        self.print_freq = config["trainer"]["print_freq"]
        self.evaluate_freq = config["trainer"]["evaluate_freq"]
        self.save_epoch_freq = config["trainer"].get("save_epoch_freq", 0)
        self.save_step_freq = config["trainer"].get("save_step_freq", 0)
        self.model.set_mode()

        self.global_step = 0
예제 #3
0
 def __init__(self, config):
     self.experiment_name = config.pop('name')
     self.random_seed = config.get('random_seed', 30)
     model_name = get_model_name(config["arch"].pop("type"))
     self.model = getattr(models, model_name)(config)
     logger.info("model init success")
     self.transform = default_transform()
     self.long_side = config["predictor"]["long_side"]
     if "test_img_dir" in config["predictor"]:
         self.batch_flag = True
         test_img_dir = config["predictor"]["test_img_dir"]
         self.input_img_paths = get_file_list(test_img_dir, p_postfix=['.jpg', '.png', ".tif", ".JPG", ".jpeg"])
     else:
         self.batch_flag = False
     out_img_dir = config["predictor"]["out_img_dir"]
     os.makedirs(out_img_dir, exist_ok=True)
     self.out_img_dir = out_img_dir
     self.combine_res = config["predictor"].get("combine_res", True)
     self.model.set_mode()