Exemplo n.º 1
0
 def __init__(self, context: det.TrialContext) -> None:
     self.context = context
     self.data_config = context.get_data_config()
     self.hparams = context.get_hparams()
     self.criterion = torch.nn.functional.cross_entropy
     # The last epoch is only used for logging.
     self._last_epoch = -1
Exemplo n.º 2
0
 def __init__(self, context: det.TrialContext) -> None:
     self.context = context
     self.data_config = context.get_data_config()
     self.hparams = context.get_hparams()
     self.criterion = torch.nn.functional.cross_entropy
     # The last epoch is only used for logging.
     self._last_epoch = -1
     self.results = {
         "loss": float("inf"),
         "top1_accuracy": 0,
         "top5_accuracy": 0
     }
Exemplo n.º 3
0
    def __init__(self, context: det.TrialContext) -> None:
        self.context = context
        self.data_config = context.get_data_config()
        self.hparams = AttrDict(context.get_hparams())

        # Create a unique download directory for each rank so they don't overwrite each other.
        self.download_directory = self.data_config["data_download_dir"]
        data.download_data(self.download_directory)
        corpus = data_util.Corpus(self.download_directory)
        self.corpus = corpus
        self.ntokens = len(corpus.dictionary)
        self.hidden = None

        # This is used to store eval history and will switch to ASGD
        # once validation perplexity stops improving.
        self._last_loss = None
        self._eval_history = []
        self._last_epoch = -1
Exemplo n.º 4
0
    def __init__(self, context: det.TrialContext) -> None:
        self.context = context
        self.hparams = context.get_hparams()
        self.data_config = context.get_data_config()
        self.cfg = Config.fromfile(self.hparams["config_file"])

        self.cfg.data.train.ann_file = self.data_config["train_ann_file"]
        self.cfg.data.val.ann_file = self.data_config["val_ann_file"]
        self.cfg.data.val.test_mode = True
        self.cfg.data.workers_per_gpu = self.data_config["workers_per_gpu"]

        if self.data_config["backend"] in ["gcs", "fake"]:
            sub_backend(self.data_config["backend"], self.cfg)

        print(self.cfg)

        self.model = self.context.wrap_model(
            build_detector(self.cfg.model,
                           train_cfg=self.cfg.train_cfg,
                           test_cfg=self.cfg.test_cfg))

        self.optimizer = self.context.wrap_optimizer(
            build_optimizer(self.model, self.cfg.optimizer))

        scheduler_cls = WarmupWrapper(MultiStepLR)
        scheduler = scheduler_cls(
            self.hparams["warmup"],  # warmup schedule
            self.hparams["warmup_iters"],  # warmup_iters
            self.hparams["warmup_ratio"],  # warmup_ratio
            self.optimizer,
            [self.hparams["step1"], self.hparams["step2"]],  # milestones
            self.hparams["gamma"],  # gamma
        )
        self.scheduler = self.context.wrap_lr_scheduler(
            scheduler, step_mode=LRScheduler.StepMode.MANUAL_STEP)

        self.clip_grads_fn = (lambda x: torch.nn.utils.clip_grad_norm_(
            x, self.hparams["clip_grads_norm"])
                              if self.hparams["clip_grads"] else None)