Beispiel #1
0
    def train(self):
        # self.writer.write("===== Model =====")
        # self.writer.write(self.model)
        if self.run_type == "all_in_one":
            self._all_in_one()
        if self.run_type == "train_viz":
            self._inference_run("train")
            return
        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_iterations = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)

        self.writer.write("Starting training...")
        while self.current_iteration < self.max_iterations and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.task_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break

            for batch in self.train_loader:
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.current_iteration, "debug")

                registry.register("current_iteration", self.current_iteration)

                if self.current_iteration > self.max_iterations:
                    break

                self._run_scheduler()
                report, _ = self._forward_pass(batch)
                self._update_meter(report, self.meter)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if should_break:
                    break

        self.finalize()
Beispiel #2
0
    def __init__(self, multi_task_instance):
        self.test_task = multi_task_instance
        self.task_type = multi_task_instance.dataset_type
        self.config = registry.get("config")
        self.writer = registry.get("writer")
        self.report = []
        self.timer = Timer()
        self.training_parameters = self.config["training_parameters"]
        self.num_workers = self.training_parameters["num_workers"]
        self.batch_size = self.training_parameters["batch_size"]
        self.report_folder_arg = self.config.get("report_folder", None)
        self.experiment_name = self.training_parameters.get("experiment_name", "")

        self.datasets = []

        for task in self.test_task.get_tasks():
            for dataset in task.get_datasets():
                self.datasets.append(dataset)

        self.current_dataset_idx = -1
        self.current_dataset = self.datasets[self.current_dataset_idx]

        self.save_dir = self.config.get("save_dir", "./save")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg is not None:
            self.report_folder = self.report_folder_arg

        if not os.path.exists(self.report_folder):
            os.makedirs(self.report_folder)
Beispiel #3
0
    def test_reset(self):
        timer = Timer()
        time.sleep(2)
        timer.reset()
        expected = "00:00:00"

        self.assertEqual(timer.get_current(), expected)
Beispiel #4
0
    def test_get_current(self):
        timer = Timer()
        expected = "00:00:00"

        self.assertEqual(timer.get_current(), expected)
Beispiel #5
0
    def test_get_time_since_start(self):
        timer = Timer()
        time.sleep(2)
        expected = "00:00:02"

        self.assertEqual(timer.get_time_since_start(), expected)
Beispiel #6
0
 def __init__(self, config):
     self.config = config
     self.profiler = Timer()
Beispiel #7
0
    def test_get_time_since_start(self):
        timer = Timer()
        time.sleep(2)
        expected = "02s "

        self.assertTrue(expected in timer.get_time_since_start())
Beispiel #8
0
 def __init__(self, args, *rest, **kwargs):
     self.args = args
     self.profiler = Timer()
Beispiel #9
0
    def __init__(self, config):
        self.logger = None
        self.summary_writer = None

        if not is_main_process():
            return

        self.timer = Timer()
        self.config = config
        self.save_dir = config.training_parameters.save_dir
        self.log_folder = ckpt_name_from_core_args(config)
        self.log_folder += foldername_from_config_override(config)
        time_format = "%Y-%m-%dT%H:%M:%S"
        self.log_filename = ckpt_name_from_core_args(config) + "_"
        self.log_filename += self.timer.get_time_hhmmss(None,
                                                        format=time_format)
        self.log_filename += ".log"

        self.log_folder = os.path.join(self.save_dir, self.log_folder, "logs")

        arg_log_dir = self.config.get("log_dir", None)
        if arg_log_dir:
            self.log_folder = arg_log_dir

        if not os.path.exists(self.log_folder):
            os.makedirs(self.log_folder)

        tensorboard_folder = os.path.join(self.log_folder, "tensorboard")
        self.summary_writer = SummaryWriter(tensorboard_folder)

        self.log_filename = os.path.join(self.log_folder, self.log_filename)

        print("Logging to:", self.log_filename)

        logging.captureWarnings(True)

        self.logger = logging.getLogger(__name__)
        self._file_only_logger = logging.getLogger(__name__)
        warnings_logger = logging.getLogger("py.warnings")

        # Set level
        level = config["training_parameters"].get("logger_level", "info")
        self.logger.setLevel(getattr(logging, level.upper()))
        self._file_only_logger.setLevel(getattr(logging, level.upper()))

        formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s",
                                      datefmt="%Y-%m-%dT%H:%M:%S")

        # Add handler to file
        channel = logging.FileHandler(filename=self.log_filename, mode="a")
        channel.setFormatter(formatter)

        self.logger.addHandler(channel)
        self._file_only_logger.addHandler(channel)
        warnings_logger.addHandler(channel)

        # Add handler to stdout
        channel = logging.StreamHandler(sys.stdout)
        channel.setFormatter(formatter)

        self.logger.addHandler(channel)
        warnings_logger.addHandler(channel)

        should_not_log = self.config["training_parameters"]["should_not_log"]
        self.should_log = not should_not_log

        # Single log wrapper map
        self._single_log_map = set()
Beispiel #10
0
    def train(self):
        self.writer.write("===== Model =====")
        self.writer.write(self.model)

        # TODO: integrate this code more effectively later
        # self.gradcam(self.test_loader)
        # return

        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_iterations = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)

        self.writer.write("Starting training...")
        while self.current_iteration < self.max_iterations and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.task_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break
            
            print("Train loader length")
            print(len(self.train_loader))
            for batch in self.train_loader:
                print("hi")
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.current_iteration, "debug")

                registry.register("current_iteration", self.current_iteration)

                if self.current_iteration > self.max_iterations:
                    break

                self._run_scheduler()
                report = self._forward_pass(batch)
                self._update_meter(report, self.meter)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if should_break:
                    break

        self.finalize()