Пример #1
0
    def __init__(self, multi_task_instance):
        self.test_task = multi_task_instance
        self.task_type = multi_task_instance.dataset_type
        self.config = registry.get("config")
        self.writer = registry.get("writer")
        self.report = []
        self.timer = Timer()
        self.training_parameters = self.config["training_parameters"]
        self.num_workers = self.training_parameters["num_workers"]
        self.batch_size = self.training_parameters["batch_size"]
        self.report_folder_arg = self.config.get("report_folder", None)
        self.experiment_name = self.training_parameters.get("experiment_name", "")

        self.datasets = []

        for task in self.test_task.get_tasks():
            for dataset in task.get_datasets():
                self.datasets.append(dataset)

        self.current_dataset_idx = -1
        self.current_dataset = self.datasets[self.current_dataset_idx]

        self.save_dir = self.config.get("save_dir", "./save")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg is not None:
            self.report_folder = self.report_folder_arg

        if not os.path.exists(self.report_folder):
            os.makedirs(self.report_folder)
Пример #2
0
    def test_reset(self):
        timer = Timer()
        time.sleep(2)
        timer.reset()
        expected = "00:00:00"

        self.assertEqual(timer.get_current(), expected)
Пример #3
0
    def train(self):
        # self.writer.write("===== Model =====")
        # self.writer.write(self.model)
        if self.run_type == "all_in_one":
            self._all_in_one()
        if self.run_type == "train_viz":
            self._inference_run("train")
            return
        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_iterations = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)

        self.writer.write("Starting training...")
        while self.current_iteration < self.max_iterations and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.task_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break

            for batch in self.train_loader:
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.current_iteration, "debug")

                registry.register("current_iteration", self.current_iteration)

                if self.current_iteration > self.max_iterations:
                    break

                self._run_scheduler()
                report, _ = self._forward_pass(batch)
                self._update_meter(report, self.meter)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if should_break:
                    break

        self.finalize()
Пример #4
0
class TestReporter(Dataset):
    def __init__(self, multi_task_instance):
        self.test_task = multi_task_instance
        self.task_type = multi_task_instance.dataset_type
        self.config = registry.get("config")
        self.writer = registry.get("writer")
        self.report = []
        self.timer = Timer()
        self.training_parameters = self.config["training_parameters"]
        self.num_workers = self.training_parameters["num_workers"]
        self.batch_size = self.training_parameters["batch_size"]
        self.report_folder_arg = self.config.get("report_folder", None)
        self.experiment_name = self.training_parameters.get("experiment_name", "")

        self.datasets = []

        for task in self.test_task.get_tasks():
            for dataset in task.get_datasets():
                self.datasets.append(dataset)

        self.current_dataset_idx = -1
        self.current_dataset = self.datasets[self.current_dataset_idx]

        self.save_dir = self.config.get("save_dir", "./save")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg is not None:
            self.report_folder = self.report_folder_arg

        if not os.path.exists(self.report_folder):
            os.makedirs(self.report_folder)

    def next_dataset(self):
        if self.current_dataset_idx >= 0:
            self.flush_report()

        self.current_dataset_idx += 1

        if self.current_dataset_idx == len(self.datasets):
            return False
        else:
            self.current_dataset = self.datasets[self.current_dataset_idx]
            self.writer.write("Predicting for " + self.current_dataset._name)
            return True

    def flush_report(self):
        if not is_main_process():
            return

        name = self.current_dataset._name
        time_format = "%Y-%m-%dT%H:%M:%S"
        time = self.timer.get_time_hhmmss(None, format=time_format)

        filename = name + "_"

        if len(self.experiment_name) > 0:
            filename += self.experiment_name + "_"

        filename += self.task_type + "_"

        filename += time + ".json"
        filepath = os.path.join(self.report_folder, filename)

        with open(filepath, "w") as f:
            json.dump(self.report, f)

        self.writer.write(
            "Wrote evalai predictions for %s to %s" % (name, os.path.abspath(filepath))
        )
        self.report = []

    def get_dataloader(self):
        other_args = self._add_extra_args_for_dataloader()
        return DataLoader(
            dataset=self.current_dataset,
            collate_fn=BatchCollator(),
            num_workers=self.num_workers,
            pin_memory=self.config.training_parameters.pin_memory,
            **other_args
        )

    def _add_extra_args_for_dataloader(self, other_args={}):
        training_parameters = self.config.training_parameters

        if (
            training_parameters.local_rank is not None
            and training_parameters.distributed
        ):
            other_args["sampler"] = DistributedSampler(self.current_dataset)
        else:
            other_args["shuffle"] = True

        batch_size = training_parameters.batch_size

        world_size = get_world_size()

        if batch_size % world_size != 0:
            raise RuntimeError(
                "Batch size {} must be divisible by number "
                "of GPUs {} used.".format(batch_size, world_size)
            )

        other_args["batch_size"] = batch_size // world_size

        return other_args

    def prepare_batch(self, batch):
        return self.current_dataset.prepare_batch(batch)

    def __len__(self):
        return len(self.current_dataset)

    def __getitem__(self, idx):
        return self.current_dataset[idx]

    def add_to_report(self, report):
        # TODO: Later gather whole report for no opinions
        report.scores = gather_tensor(report.scores).view(-1, report.scores.size(-1))
        report.question_id = gather_tensor(report.question_id).view(-1)

        if not is_main_process():
            return

        results = self.current_dataset.format_for_evalai(report)

        self.report = self.report + results
Пример #5
0
    def test_get_current(self):
        timer = Timer()
        expected = "00:00:00"

        self.assertEqual(timer.get_current(), expected)
Пример #6
0
    def test_get_time_since_start(self):
        timer = Timer()
        time.sleep(2)
        expected = "00:00:02"

        self.assertEqual(timer.get_time_since_start(), expected)
Пример #7
0
 def __init__(self, config):
     self.config = config
     self.profiler = Timer()
Пример #8
0
class BaseTrainer:
    def __init__(self, config):
        self.config = config
        self.profiler = Timer()

    def load(self):
        self._init_process_group()

        self.run_type = self.config.training_parameters.get("run_type", "train")
        self.task_loader = TaskLoader(self.config)

        self.writer = Logger(self.config)
        registry.register("writer", self.writer)

        self.configuration = registry.get("configuration")
        self.configuration.pretty_print()

        self.config_based_setup()

        self.load_task()
        self.load_model()
        self.load_optimizer()
        self.load_extras()

        # a survey on model size
        self.writer.write("----------MODEL SIZE----------")
        total = 0
        for p in self.model.named_parameters():
            self.writer.write(p[0] + str(p[1].shape))
            total += torch.numel(p[1])
        self.writer.write("total parameters to train: {}".format(total))

        # init a TensorBoard writer
        self.tb_writer = SummaryWriter(
            os.path.join("save/tb", getattr(self.config.model_attributes, self.config.model).code_name))

    def _init_process_group(self):
        training_parameters = self.config.training_parameters
        self.local_rank = training_parameters.local_rank
        self.device = training_parameters.device

        if self.local_rank is not None and training_parameters.distributed:
            if not torch.distributed.is_nccl_available():
                raise RuntimeError(
                    "Unable to initialize process group: NCCL is not available"
                )
            torch.distributed.init_process_group(backend="nccl")
            synchronize()

        if (
            "cuda" in self.device
            and training_parameters.distributed
            and self.local_rank is not None
        ):
            self.device = torch.device("cuda", self.local_rank)

        registry.register("current_device", self.device)

    def load_task(self):
        self.writer.write("Loading tasks and data", "info")
        self.task_loader.load_task()

        self.task_loader.make_dataloaders()

        self.train_loader = self.task_loader.train_loader
        self.val_loader = self.task_loader.val_loader
        self.test_loader = self.task_loader.test_loader
        self.train_task = self.task_loader.train_task
        self.val_task = self.task_loader.val_task

        # Total iterations for snapshot
        self.snapshot_iterations = len(self.val_task)
        self.snapshot_iterations //= self.config.training_parameters.batch_size

        self.test_task = self.task_loader.test_task

    def load_model(self):
        attributes = self.config.model_attributes[self.config.model]
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_attributes[attributes]

        attributes["model"] = self.config.model

        self.task_loader.update_registry_for_model(attributes)
        self.model = build_model(attributes)
        self.task_loader.clean_config(attributes)
        training_parameters = self.config.training_parameters

        data_parallel = training_parameters.data_parallel
        distributed = training_parameters.distributed

        registry.register("data_parallel", data_parallel)
        registry.register("distributed", distributed)

        if "cuda" in str(self.config.training_parameters.device):
            rank = self.local_rank if self.local_rank is not None else 0
            device_info = "CUDA Device {} is: {}".format(
                rank, torch.cuda.get_device_name(self.local_rank)
            )

            self.writer.write(device_info, log_all=True)

        self.model = self.model.to(self.device)

        self.writer.write("Torch version is: " + torch.__version__)

        if (
            "cuda" in str(self.device)
            and torch.cuda.device_count() > 1
            and data_parallel is True
        ):
            self.model = torch.nn.DataParallel(self.model)

        if (
            "cuda" in str(self.device)
            and self.local_rank is not None
            and distributed is True
        ):
            torch.cuda.set_device(self.local_rank)
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model, device_ids=[self.local_rank]
            )

    def load_optimizer(self):
        self.optimizer = build_optimizer(self.model, self.config)

    def load_extras(self):
        self.checkpoint = Checkpoint(self)
        self.meter = Meter()

        self.training_parameters = self.config.training_parameters

        monitored_metric = self.training_parameters.monitored_metric
        metric_minimize = self.training_parameters.metric_minimize
        should_early_stop = self.training_parameters.should_early_stop
        patience = self.training_parameters.patience

        self.log_interval = self.training_parameters.log_interval
        self.snapshot_interval = self.training_parameters.snapshot_interval
        self.max_iterations = self.training_parameters.max_iterations
        self.should_clip_gradients = self.training_parameters.clip_gradients
        self.max_epochs = self.training_parameters.max_epochs

        self.early_stopping = EarlyStopping(
            self.model,
            self.checkpoint,
            monitored_metric,
            patience=patience,
            minimize=metric_minimize,
            should_stop=should_early_stop,
        )
        self.current_epoch = 0
        self.current_iteration = 0

        self.checkpoint.load_state_dict()

        self.not_debug = self.training_parameters.logger_level != "debug"

        self.lr_scheduler = None

        # TODO: Allow custom scheduler
        if self.training_parameters.lr_scheduler is True:
            scheduler_class = optim.lr_scheduler.LambdaLR
            scheduler_func = lambda x: lr_lambda_update(x, self.config)
            self.lr_scheduler = scheduler_class(
                self.optimizer, lr_lambda=scheduler_func
            )

    def config_based_setup(self):
        seed = self.config.training_parameters.seed
        if seed is None:
            return

        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def train(self):
        # self.writer.write("===== Model =====")
        # self.writer.write(self.model)
        if self.run_type == "all_in_one":
            self._all_in_one()
        if self.run_type == "train_viz":
            self._inference_run("train")
            return
        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_iterations = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)

        self.writer.write("Starting training...")
        while self.current_iteration < self.max_iterations and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.task_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break

            for batch in self.train_loader:
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.current_iteration, "debug")

                registry.register("current_iteration", self.current_iteration)

                if self.current_iteration > self.max_iterations:
                    break

                self._run_scheduler()
                report, _ = self._forward_pass(batch)
                self._update_meter(report, self.meter)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if should_break:
                    break

        self.finalize()

    def _run_scheduler(self):
        if self.lr_scheduler is not None:
            self.lr_scheduler.step(self.current_iteration)

    def _forward_pass(self, batch):
        prepared_batch = self.task_loader.prepare_batch(batch)
        self.profile("Batch prepare time")

        # Arguments should be a dict at this point
        model_output = self.model(prepared_batch)  # a dict of losses, metrics, scores, and att
        report = Report(prepared_batch, model_output)
        self.profile("Forward time")

        return report, (model_output["att"] if "att" in model_output.keys() else None)

    def _backward(self, loss):
        self.optimizer.zero_grad()
        loss.backward()

        if self.should_clip_gradients:
            clip_gradients(self.model, self.current_iteration, self.writer, self.config)

            # visualization of parameters and their grads' distribution
            if self.current_iteration % 100 == 0 and hasattr(self, "tb_writer"):
                data = {key: [] for key in secs}
                grad = {key: [] for key in secs}
                for p in self.model.named_parameters():
                    if p[1].data is not None and p[1].grad is not None and p[1].data.shape != torch.Size([]):
                        for sec in secs:
                            if sec in p[0]:
                                data[sec].append(p[1].data.flatten())
                                grad[sec].append(p[1].grad.flatten())
                for sec in secs:
                    if len(data[sec]) != 0 and len(grad[sec]) != 0:
                        self.tb_writer.add_histogram(sec + "_data_dis", torch.cat(data[sec], dim=0),
                                                     global_step=self.current_iteration)
                        self.tb_writer.add_histogram(sec + "_grad_dis", torch.cat(grad[sec], dim=0),
                                                     global_step=self.current_iteration)

        self.optimizer.step()
        self.profile("Backward time")

    def _extract_loss(self, report):
        loss_dict = report.losses
        loss = sum([loss.mean() for loss in loss_dict.values()])
        return loss

    def finalize(self):
        self.writer.write("Stepping into final validation check")
        self._try_full_validation(force=True)
        self.checkpoint.restore()
        self.checkpoint.finalize()
        self.inference()

    def _update_meter(self, report, meter=None, eval_mode=False):
        if meter is None:
            meter = self.meter

        loss_dict = report.losses
        metrics_dict = report.metrics

        reduced_loss_dict = reduce_dict(loss_dict)
        reduced_metrics_dict = reduce_dict(metrics_dict)

        loss_key = report.dataset_type + "/total_loss"

        with torch.no_grad():
            reduced_loss = sum([loss.mean() for loss in reduced_loss_dict.values()])
            if hasattr(reduced_loss, "item"):
                reduced_loss = reduced_loss.item()

            registry.register(loss_key, reduced_loss)

            meter_update_dict = {loss_key: reduced_loss}
            meter_update_dict.update(reduced_loss_dict)
            meter_update_dict.update(reduced_metrics_dict)
            meter.update(meter_update_dict)

    def _logistics(self, report):
        should_print = self.current_iteration % self.log_interval == 0
        should_break = False
        extra = {}

        if should_print is True:
            if "cuda" in str(self.device):
                extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
                extra["max mem"] //= 1024

            extra.update(
                {
                    "lr": "{:.5f}".format(self.optimizer.param_groups[0]["lr"]).rstrip(
                        "0"
                    ),
                    "time": self.train_timer.get_time_since_start(),
                    "eta": self._calculate_time_left(),
                }
            )

            self.train_timer.reset()

            _, meter = self.evaluate(self.val_loader, single_batch=False)
            self.meter.update_from_meter(meter)

            # meter.get_scalar_dict() or meter.get_useful_dict() is a dict containing:
            # ['train/total_loss', 'train/vqa_accuracy',
            # 'val/total_loss', 'val/vqa_accuracy']
            if hasattr(self, "tb_writer"):
                self.tb_writer.add_scalar("lr", self.optimizer.param_groups[0]["lr"],
                                          global_step=self.current_iteration)
                useful = self.meter.get_useful_dict()
                self.tb_writer.add_scalar("train_loss", useful["loss"]["train"],
                                          global_step=self.current_iteration)
                self.tb_writer.add_scalar("val_loss", useful["loss"]["val"],
                                          global_step=self.current_iteration)
                self.tb_writer.add_scalar("train_acc", useful["accuracy"]["train"],
                                          global_step=self.current_iteration)
                self.tb_writer.add_scalar("val_acc", useful["accuracy"]["val"],
                                          global_step=self.current_iteration)

        # Don't print train metrics if it is not log interval
        # so as to escape clutter
        self._summarize_report(
            self.meter,
            should_print=should_print,
            extra=extra,
            prefix=report.dataset_name,
        )

        should_break = self._try_full_validation()

        return should_break

    def _try_full_validation(self, force=False):
        should_break = False

        if self.current_iteration % self.snapshot_interval == 0 or force:
            self.writer.write("Evaluation time. Running on full validation set...")
            # Validation and Early stopping
            # Create a new meter for this case
            report, meter = self.evaluate(self.val_loader)

            extra = {"validation time": self.snapshot_timer.get_time_since_start()}

            stop = self.early_stopping(self.current_iteration, meter)
            stop = bool(broadcast_scalar(stop, src=0, device=self.device))

            extra.update(self.early_stopping.get_info())

            prefix = "{}: full val".format(report.dataset_name)

            self._summarize_report(meter, prefix=prefix, extra=extra)
            self.snapshot_timer.reset()
            gc.collect()

            if "cuda" in str(self.device):
                torch.cuda.empty_cache()

            if stop is True:
                self.writer.write("Early stopping activated")
                should_break = True

        return should_break

    def evaluate(self, loader, use_tqdm=False, single_batch=False):
        meter = Meter()

        with torch.no_grad():
            self.model.eval()
            for batch in tqdm(loader, disable=not use_tqdm):
                report, _ = self._forward_pass(batch)
                self._update_meter(report, meter, eval_mode=True)

                if single_batch is True:
                    break
            self.model.train()

        return report, meter

    def evaluate_full_report(self, loader, use_tqdm=False):
        report = {"question_id": [], "scores": [], "si_att": [], "s_att": [], "combine_att": [], "b2s": []}

        with torch.no_grad():
            self.model.eval()
            for batch in tqdm(loader, disable=not use_tqdm):
                rep, att = self._forward_pass(batch)
                report["question_id"] += [rep["question_id"]]
                report["scores"] += [rep["scores"]]
                report["si_att"] += [att["si_att"]]
                report["s_att"] += [att["s_att"]]
                report["combine_att"] += [att["combine_att"]]
                report["b2s"] += [att["b2s"]]
            report["question_id"] = torch.cat(report["question_id"], dim=0).detach().cpu()
            report["scores"] = torch.cat(report["scores"], dim=0).detach().cpu()
            report["si_att"] = torch.cat(report["si_att"], dim=0).detach().cpu()
            report["s_att"] = torch.cat(report["s_att"], dim=0).detach().cpu()
            report["b2s"] = torch.cat(report["b2s"], dim=0).detach().cpu()
            self.model.train()

        return report

    def _summarize_report(self, meter, prefix="", should_print=True, extra={}):
        if not is_main_process():
            return

        scalar_dict = meter.get_scalar_dict()
        self.writer.add_scalars(scalar_dict, registry.get("current_iteration"))

        if not should_print:
            return

        print_str = []

        if len(prefix):
            print_str += [prefix + ":"]

        print_str += ["{}/{}".format(self.current_iteration, self.max_iterations)]
        print_str += [str(meter)]
        print_str += ["{}: {}".format(key, value) for key, value in extra.items()]

        self.writer.write(meter.delimiter.join(print_str))

    def _all_in_one(self):
        dataset_type = "val"
        self.writer.write("Starting inference on {} set".format(dataset_type))

        report, meter = self.evaluate(getattr(self, "{}_loader".format(dataset_type)), use_tqdm=True)
        prefix = "{}: full {}".format(report.dataset_name, dataset_type)
        self._summarize_report(meter, prefix)

        # store information to process in jupyter
        report = self.evaluate_full_report(getattr(self, "{}_loader".format(dataset_type)), use_tqdm=True)
        code_name = getattr(self.config.model_attributes, self.config.model).code_name
        with open(self.args.resume_file[:-4] + "_" + dataset_type + ".p",
                  'wb') as f:
            pickle.dump(report, f, protocol=-1)

        self.predict_for_evalai(dataset_type)
        self.predict_for_evalai("test")

    def inference(self):
        if "val" in self.run_type:
            self._inference_run("val")

        if "inference" in self.run_type or "predict" in self.run_type:
            self._inference_run("test")

    def _inference_run(self, dataset_type):
        if self.config.training_parameters.evalai_inference is True:
            self.predict_for_evalai(dataset_type)
            return

        self.writer.write("Starting inference on {} set".format(dataset_type))

        report, meter = self.evaluate(getattr(self, "{}_loader".format(dataset_type)), use_tqdm=True)
        prefix = "{}: full {}".format(report.dataset_name, dataset_type)
        self._summarize_report(meter, prefix)

        # store information to process in jupyter
        report = self.evaluate_full_report(getattr(self, "{}_loader".format(dataset_type)), use_tqdm=True)
        code_name = getattr(self.config.model_attributes, self.config.model).code_name
        with open(self.args.resume_file[:-4] + "_" + dataset_type + ".p",
                  'wb') as f:
            pickle.dump(report, f, protocol=-1)

    def _calculate_time_left(self):
        time_taken_for_log = time.time() * 1000 - self.train_timer.start
        iterations_left = self.max_iterations - self.current_iteration
        num_logs_left = iterations_left / self.log_interval
        time_left = num_logs_left * time_taken_for_log

        snapshot_iteration = self.snapshot_iterations / self.log_interval
        snapshot_iteration *= iterations_left / self.snapshot_interval
        time_left += snapshot_iteration * time_taken_for_log

        return self.train_timer.get_time_hhmmss(gap=time_left)

    def profile(self, text):
        if self.not_debug:
            return
        self.writer.write(text + ": " + self.profiler.get_time_since_start(), "debug")
        self.profiler.reset()

    def predict_for_evalai(self, dataset_type):
        reporter = self.task_loader.get_test_reporter(dataset_type)
        with torch.no_grad():
            self.model.eval()
            message = "Starting {} inference for evalai".format(dataset_type)
            self.writer.write(message)

            while reporter.next_dataset():
                dataloader = reporter.get_dataloader()

                for batch in tqdm(dataloader):
                    prepared_batch = reporter.prepare_batch(batch)
                    model_output = self.model(prepared_batch)
                    report = Report(prepared_batch, model_output)
                    reporter.add_to_report(report)

            self.writer.write("Finished predicting")
            self.model.train()
Пример #9
0
class BaseTrainer:
    def __init__(self, config):
        self.config = config
        self.profiler = Timer()

    def load(self):
        self._init_process_group()

        self.run_type = self.config.training_parameters.get(
            "run_type", "train")
        self.dataset_loader = DatasetLoader(self.config)
        self._datasets = self.config.datasets

        self.writer = Logger(self.config)
        registry.register("writer", self.writer)

        self.configuration = registry.get("configuration")
        self.configuration.pretty_print()

        self.config_based_setup()

        self.load_task()
        self.load_model()
        self.load_optimizer()
        self.load_extras()

    def _init_process_group(self):
        training_parameters = self.config.training_parameters
        self.local_rank = training_parameters.local_rank
        self.device = training_parameters.device

        if self.local_rank is not None and training_parameters.distributed:
            if not torch.distributed.is_nccl_available():
                raise RuntimeError(
                    "Unable to initialize process group: NCCL is not available"
                )
            torch.distributed.init_process_group(backend="nccl")
            synchronize()

        if ("cuda" in self.device and training_parameters.distributed
                and self.local_rank is not None):
            self.device = torch.device("cuda", self.local_rank)

        registry.register("current_device", self.device)

    def load_task(self):
        self.writer.write("Loading datasets", "info")
        self.dataset_loader.load_datasets()

        self.train_dataset = self.dataset_loader.train_dataset
        self.val_dataset = self.dataset_loader.val_dataset

        # Total iterations for snapshot
        self.snapshot_iterations = len(self.val_dataset)
        self.snapshot_iterations //= self.config.training_parameters.batch_size

        self.test_dataset = self.dataset_loader.test_dataset

        self.train_loader = self.dataset_loader.train_loader
        self.val_loader = self.dataset_loader.val_loader
        self.test_loader = self.dataset_loader.test_loader

    def load_model(self):
        attributes = self.config.model_attributes[self.config.model]
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_attributes[attributes]

        attributes["model"] = self.config.model

        self.dataset_loader.update_registry_for_model(attributes)
        self.model = build_model(attributes)
        self.dataset_loader.clean_config(attributes)
        training_parameters = self.config.training_parameters

        data_parallel = training_parameters.data_parallel
        distributed = training_parameters.distributed

        registry.register("data_parallel", data_parallel)
        registry.register("distributed", distributed)

        if "cuda" in str(self.config.training_parameters.device):
            rank = self.local_rank if self.local_rank is not None else 0
            device_info = "CUDA Device {} is: {}".format(
                rank, torch.cuda.get_device_name(self.local_rank))

            self.writer.write(device_info, log_all=True)

        self.model = self.model.to(self.device)

        self.writer.write("Torch version is: " + torch.__version__)

        if ("cuda" in str(self.device) and torch.cuda.device_count() > 1
                and data_parallel is True):
            self.model = torch.nn.DataParallel(self.model)

        if ("cuda" in str(self.device) and self.local_rank is not None
                and distributed is True):
            torch.cuda.set_device(self.local_rank)
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
                check_reduction=True,
                find_unused_parameters=True)

    def load_optimizer(self):
        self.optimizer = build_optimizer(self.model, self.config)

    def load_extras(self):
        self.checkpoint = Checkpoint(self)
        self.meter = Meter()

        self.training_parameters = self.config.training_parameters

        monitored_metric = self.training_parameters.monitored_metric
        metric_minimize = self.training_parameters.metric_minimize
        should_early_stop = self.training_parameters.should_early_stop
        patience = self.training_parameters.patience

        self.log_interval = self.training_parameters.log_interval
        self.snapshot_interval = self.training_parameters.snapshot_interval
        self.max_iterations = self.training_parameters.max_iterations
        self.should_clip_gradients = self.training_parameters.clip_gradients
        self.max_epochs = self.training_parameters.max_epochs

        self.early_stopping = EarlyStopping(
            self.model,
            self.checkpoint,
            monitored_metric,
            patience=patience,
            minimize=metric_minimize,
            should_stop=should_early_stop,
        )
        self.current_epoch = 0
        self.current_iteration = 0
        self.checkpoint.load_state_dict()

        self.not_debug = self.training_parameters.logger_level != "debug"

        self.lr_scheduler = None

        # TODO: Allow custom scheduler
        if self.training_parameters.lr_scheduler is True:
            scheduler_class = optim.lr_scheduler.LambdaLR
            scheduler_func = lambda x: lr_lambda_update(x, self.config)
            self.lr_scheduler = scheduler_class(self.optimizer,
                                                lr_lambda=scheduler_func)

    def config_based_setup(self):
        seed = self.config.training_parameters.seed
        if seed is None:
            return

        random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def train(self):
        self.writer.write("===== Model =====")
        self.writer.write(self.model)

        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_iterations = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)
        self.writer.write("Starting training...")
        while self.current_iteration < self.max_iterations and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.dataset_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break

            for batch in self.train_loader:
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.current_iteration, "debug")

                registry.register("current_iteration", self.current_iteration)

                if self.current_iteration > self.max_iterations:
                    break

                report = self._forward_pass(batch)
                self._update_meter(report, self.meter)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if should_break:
                    break

        self.finalize()

    def _run_scheduler(self):
        if self.lr_scheduler is not None:
            self.lr_scheduler.step(self.current_iteration)

    def _forward_pass(self, batch):
        prepared_batch = self.dataset_loader.prepare_batch(batch)
        self.profile("Batch prepare time")
        # Arguments should be a dict at this point
        model_output = self.model(prepared_batch)
        report = Report(prepared_batch, model_output)
        self.profile("Forward time")

        return report

    def _backward(self, loss):
        self.optimizer.zero_grad()
        loss.backward()

        if self.should_clip_gradients:
            clip_gradients(self.model, self.current_iteration, self.writer,
                           self.config)

        self.optimizer.step()
        self._run_scheduler()

        self.profile("Backward time")

    def _extract_loss(self, report):
        loss_dict = report.losses
        loss = sum([loss.mean() for loss in loss_dict.values()])
        return loss

    def finalize(self):
        self.writer.write("Stepping into final validation check")

        # Only do when run_type has train as it shouldn't happen on validation and inference runs
        # Inference will take care of this anyways. Also, don't run if current iteration
        # is divisble by snapshot interval as it will just be a repeat
        if "train" in self.run_type and self.current_iteration % self.snapshot_interval != 0:
            self._try_full_validation(force=True)

        self.checkpoint.restore()
        self.checkpoint.finalize()
        self.inference()

    def _update_meter(self, report, meter=None, eval_mode=False):
        if meter is None:
            meter = self.meter

        loss_dict = report.losses
        metrics_dict = report.metrics

        reduced_loss_dict = reduce_dict(loss_dict)
        reduced_metrics_dict = reduce_dict(metrics_dict)

        loss_key = report.dataset_type + "/total_loss"

        with torch.no_grad():
            reduced_loss = sum(
                [loss.mean() for loss in reduced_loss_dict.values()])
            if hasattr(reduced_loss, "item"):
                reduced_loss = reduced_loss.item()

            registry.register(loss_key, reduced_loss)

            meter_update_dict = {loss_key: reduced_loss}
            meter_update_dict.update(reduced_loss_dict)
            meter_update_dict.update(reduced_metrics_dict)
            meter.update(meter_update_dict)

    def _logistics(self, report):
        should_print = self.current_iteration % self.log_interval == 0
        should_break = False
        extra = {}

        if should_print is True:
            if "cuda" in str(self.device):
                extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
                extra["max mem"] //= 1024

            extra.update({
                "lr":
                "{:.5f}".format(
                    self.optimizer.param_groups[0]["lr"]).rstrip("0"),
                "time":
                self.train_timer.get_time_since_start(),
                "eta":
                self._calculate_time_left(),
            })

            self.train_timer.reset()

            _, meter = self.evaluate(self.val_loader, single_batch=True)
            self.meter.update_from_meter(meter)

        # Don't print train metrics if it is not log interval
        # so as to escape clutter
        self._summarize_report(
            self.meter,
            should_print=should_print,
            extra=extra,
            prefix=report.dataset_name,
        )

        should_break = self._try_full_validation()

        return should_break

    def _try_full_validation(self, force=False):
        should_break = False

        if self.current_iteration % self.snapshot_interval == 0 or force:
            self.writer.write(
                "Evaluation time. Running on full validation set...")
            # Validation and Early stopping
            # Create a new meter for this case
            report, meter = self.evaluate(self.val_loader)

            extra = {
                "validation time": self.snapshot_timer.get_time_since_start()
            }

            stop = self.early_stopping(self.current_iteration, meter)
            stop = bool(broadcast_scalar(stop, src=0, device=self.device))

            extra.update(self.early_stopping.get_info())

            prefix = "{}: full val".format(report.dataset_name)

            self._summarize_report(meter, prefix=prefix, extra=extra)
            self.snapshot_timer.reset()
            gc.collect()

            if "cuda" in str(self.device):
                torch.cuda.empty_cache()

            if stop is True:
                self.writer.write("Early stopping activated")
                should_break = True

        return should_break

    def evaluate(self, loader, use_tqdm=False, single_batch=False):
        meter = Meter()

        with torch.no_grad():
            self.model.eval()
            # disable_tqdm = not use_tqdm or not is_main_process()
            disable_tqdm = False
            for batch in tqdm(loader, disable=disable_tqdm):
                report = self._forward_pass(batch)
                self._update_meter(report, meter, eval_mode=True)

                if single_batch is True:
                    break
            self.model.train()

        return report, meter

    def _summarize_report(self, meter, prefix="", should_print=True, extra={}):
        if not is_main_process():
            return

        scalar_dict = meter.get_scalar_dict()
        self.writer.add_scalars(scalar_dict, registry.get("current_iteration"))

        if not should_print:
            return

        print_str = []

        if len(prefix):
            print_str += [prefix + ":"]

        print_str += [
            "{}/{}".format(self.current_iteration, self.max_iterations)
        ]
        print_str += [str(meter)]
        print_str += [
            "{}: {}".format(key, value) for key, value in extra.items()
        ]

        self.writer.write(meter.delimiter.join(print_str))

    def inference(self):
        if "val" in self.run_type:
            self._inference_run("val")

        if "inference" in self.run_type or "predict" in self.run_type:
            self._inference_run("test")

    def _inference_run(self, dataset_type):
        if self.config.training_parameters.evalai_inference is True:
            self.predict_for_evalai(dataset_type)
            return

        self.writer.write("Starting inference on {} set".format(dataset_type))

        report, meter = self.evaluate(getattr(
            self, "{}_loader".format(dataset_type)),
                                      use_tqdm=True)
        prefix = "{}: full {}".format(report.dataset_name, dataset_type)
        self._summarize_report(meter, prefix)

    def _calculate_time_left(self):
        time_taken_for_log = time.time() * 1000 - self.train_timer.start
        iterations_left = self.max_iterations - self.current_iteration
        num_logs_left = iterations_left / self.log_interval
        time_left = num_logs_left * time_taken_for_log

        snapshot_iteration = self.snapshot_iterations / self.log_interval
        snapshot_iteration *= iterations_left / self.snapshot_interval
        time_left += snapshot_iteration * time_taken_for_log

        return self.train_timer.get_time_hhmmss(gap=time_left)

    def profile(self, text):
        if self.not_debug:
            return
        self.writer.write(text + ": " + self.profiler.get_time_since_start(),
                          "debug")
        self.profiler.reset()

    def predict_for_evalai(self, dataset_type):
        reporter = self.dataset_loader.get_test_reporter(dataset_type)
        with torch.no_grad():
            self.model.eval()
            message = "Starting {} inference for evalai".format(dataset_type)
            self.writer.write(message)

            while reporter.next_dataset():
                dataloader = reporter.get_dataloader()

                for batch in tqdm(dataloader):
                    prepared_batch = reporter.prepare_batch(batch)
                    model_output = self.model(prepared_batch)
                    report = Report(prepared_batch, model_output)
                    reporter.add_to_report(report)

            self.writer.write("Finished predicting")
            self.model.train()
Пример #10
0
 def __init__(self, args, *rest, **kwargs):
     self.args = args
     self.profiler = Timer()
Пример #11
0
    def test_get_time_since_start(self):
        timer = Timer()
        time.sleep(2)
        expected = "02s "

        self.assertTrue(expected in timer.get_time_since_start())
Пример #12
0
    def __init__(self, config):
        self.logger = None
        self.summary_writer = None

        if not is_main_process():
            return

        self.timer = Timer()
        self.config = config
        self.save_dir = config.training_parameters.save_dir
        self.log_folder = ckpt_name_from_core_args(config)
        self.log_folder += foldername_from_config_override(config)
        time_format = "%Y-%m-%dT%H:%M:%S"
        self.log_filename = ckpt_name_from_core_args(config) + "_"
        self.log_filename += self.timer.get_time_hhmmss(None,
                                                        format=time_format)
        self.log_filename += ".log"

        self.log_folder = os.path.join(self.save_dir, self.log_folder, "logs")

        arg_log_dir = self.config.get("log_dir", None)
        if arg_log_dir:
            self.log_folder = arg_log_dir

        if not os.path.exists(self.log_folder):
            os.makedirs(self.log_folder)

        tensorboard_folder = os.path.join(self.log_folder, "tensorboard")
        self.summary_writer = SummaryWriter(tensorboard_folder)

        self.log_filename = os.path.join(self.log_folder, self.log_filename)

        print("Logging to:", self.log_filename)

        logging.captureWarnings(True)

        self.logger = logging.getLogger(__name__)
        self._file_only_logger = logging.getLogger(__name__)
        warnings_logger = logging.getLogger("py.warnings")

        # Set level
        level = config["training_parameters"].get("logger_level", "info")
        self.logger.setLevel(getattr(logging, level.upper()))
        self._file_only_logger.setLevel(getattr(logging, level.upper()))

        formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s",
                                      datefmt="%Y-%m-%dT%H:%M:%S")

        # Add handler to file
        channel = logging.FileHandler(filename=self.log_filename, mode="a")
        channel.setFormatter(formatter)

        self.logger.addHandler(channel)
        self._file_only_logger.addHandler(channel)
        warnings_logger.addHandler(channel)

        # Add handler to stdout
        channel = logging.StreamHandler(sys.stdout)
        channel.setFormatter(formatter)

        self.logger.addHandler(channel)
        warnings_logger.addHandler(channel)

        should_not_log = self.config["training_parameters"]["should_not_log"]
        self.should_log = not should_not_log

        # Single log wrapper map
        self._single_log_map = set()
Пример #13
0
class Logger:
    def __init__(self, config):
        self.logger = None
        self.summary_writer = None

        if not is_main_process():
            return

        self.timer = Timer()
        self.config = config
        self.save_dir = config.training_parameters.save_dir
        self.log_folder = ckpt_name_from_core_args(config)
        self.log_folder += foldername_from_config_override(config)
        time_format = "%Y-%m-%dT%H:%M:%S"
        self.log_filename = ckpt_name_from_core_args(config) + "_"
        self.log_filename += self.timer.get_time_hhmmss(None,
                                                        format=time_format)
        self.log_filename += ".log"

        self.log_folder = os.path.join(self.save_dir, self.log_folder, "logs")

        arg_log_dir = self.config.get("log_dir", None)
        if arg_log_dir:
            self.log_folder = arg_log_dir

        if not os.path.exists(self.log_folder):
            os.makedirs(self.log_folder)

        tensorboard_folder = os.path.join(self.log_folder, "tensorboard")
        self.summary_writer = SummaryWriter(tensorboard_folder)

        self.log_filename = os.path.join(self.log_folder, self.log_filename)

        print("Logging to:", self.log_filename)

        logging.captureWarnings(True)

        self.logger = logging.getLogger(__name__)
        self._file_only_logger = logging.getLogger(__name__)
        warnings_logger = logging.getLogger("py.warnings")

        # Set level
        level = config["training_parameters"].get("logger_level", "info")
        self.logger.setLevel(getattr(logging, level.upper()))
        self._file_only_logger.setLevel(getattr(logging, level.upper()))

        formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s",
                                      datefmt="%Y-%m-%dT%H:%M:%S")

        # Add handler to file
        channel = logging.FileHandler(filename=self.log_filename, mode="a")
        channel.setFormatter(formatter)

        self.logger.addHandler(channel)
        self._file_only_logger.addHandler(channel)
        warnings_logger.addHandler(channel)

        # Add handler to stdout
        channel = logging.StreamHandler(sys.stdout)
        channel.setFormatter(formatter)

        self.logger.addHandler(channel)
        warnings_logger.addHandler(channel)

        should_not_log = self.config["training_parameters"]["should_not_log"]
        self.should_log = not should_not_log

        # Single log wrapper map
        self._single_log_map = set()

    def __del__(self):
        if getattr(self, "summary_writer", None) is not None:
            self.summary_writer.close()

    def write(self, x, level="info", donot_print=False):
        if self.logger is None:
            return
        # if it should not log then just print it
        if self.should_log:
            if hasattr(self.logger, level):
                if donot_print:
                    getattr(self._file_only_logger, level)(str(x))
                else:
                    getattr(self.logger, level)(str(x))
            else:
                self.logger.error("Unknown log level type: %s" % level)
        else:
            print(str(x) + "\n")

    def single_write(self, x, level="info"):
        if x + "_" + level in self._single_log_map:
            return
        else:
            self.write(x, level)

    def add_scalar(self, key, value, iteration):
        if self.summary_writer is None:
            return
        self.summary_writer.add_scalar(key, value, iteration)

    def add_scalars(self, scalar_dict, iteration):
        if self.summary_writer is None:
            return
        for key, val in scalar_dict.items():
            self.summary_writer.add_scalar(key, val, iteration)

    def add_histogram_for_model(self, model, iteration):
        if self.summary_writer is None:
            return
        for name, param in model.parameters():
            np_param = param.clone().cpu().data.numpy()
            self.summary_writer.add_histogram(name, np_param, iteration)
Пример #14
0
class BaseTrainer:
    def __init__(self, config):
        self.config = config
        self.profiler = Timer()

    def load(self):
        self._init_process_group()

        self.run_type = self.config.training_parameters.get("run_type", "train")
        self.task_loader = TaskLoader(self.config)

        self.writer = Logger(self.config)
        registry.register("writer", self.writer)

        self.configuration = registry.get("configuration")
        self.configuration.pretty_print()

        self.config_based_setup()

        self.load_task()
        self.load_model()
        self.load_optimizer()
        self.load_extras()

    def _init_process_group(self):
        training_parameters = self.config.training_parameters
        self.local_rank = training_parameters.local_rank
        self.device = training_parameters.device

        if self.local_rank is not None and training_parameters.distributed:
            if not torch.distributed.is_nccl_available():
                raise RuntimeError(
                    "Unable to initialize process group: NCCL is not available"
                )
            torch.distributed.init_process_group(backend="nccl")
            synchronize()

        if (
            "cuda" in self.device
            and training_parameters.distributed
            and self.local_rank is not None
        ):
            self.device = torch.device("cuda", self.local_rank)

        registry.register("current_device", self.device)

    def load_task(self):
        self.writer.write("Loading tasks and data", "info")
        self.task_loader.load_task()

        self.task_loader.make_dataloaders()

        self.train_loader = self.task_loader.train_loader
        self.val_loader = self.task_loader.val_loader
        self.test_loader = self.task_loader.test_loader
        self.train_task = self.task_loader.train_task
        self.val_task = self.task_loader.val_task

        # Total iterations for snapshot
        self.snapshot_iterations = len(self.val_task)
        self.snapshot_iterations //= self.config.training_parameters.batch_size

        self.test_task = self.task_loader.test_task

    def load_model(self):
        attributes = self.config.model_attributes[self.config.model]
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_attributes[attributes]

        attributes["model"] = self.config.model

        self.task_loader.update_registry_for_model(attributes)
        self.model = build_model(attributes)
        self.task_loader.clean_config(attributes)
        training_parameters = self.config.training_parameters

        data_parallel = training_parameters.data_parallel
        distributed = training_parameters.distributed

        registry.register("data_parallel", data_parallel)
        registry.register("distributed", distributed)

        if "cuda" in str(self.config.training_parameters.device):
            rank = self.local_rank if self.local_rank is not None else 0
            device_info = "CUDA Device {} is: {}".format(
            rank, torch.cuda.get_device_name(self.local_rank)
            )

            self.writer.write(device_info, log_all=True)

        self.model = self.model.to(self.device)

        self.writer.write("Torch version is: " + torch.__version__)

        if (
            "cuda" in str(self.device)
            and torch.cuda.device_count() > 1
            and data_parallel is True
        ):
            print("parallel!")
            self.model = torch.nn.DataParallel(self.model)

        if (
            "cuda" in str(self.device)
            and self.local_rank is not None
            and distributed is True
        ):
            torch.cuda.set_device(self.local_rank)
            self.model = torch.nn.parallel.DistributedDataParallel(
            self.model, device_ids=[self.local_rank]
            )

    def load_optimizer(self):
        self.optimizer = build_optimizer(self.model, self.config)

    def load_extras(self):
        self.checkpoint = Checkpoint(self)
        self.meter = Meter()

        self.training_parameters = self.config.training_parameters

        monitored_metric = self.training_parameters.monitored_metric
        metric_minimize = self.training_parameters.metric_minimize
        should_early_stop = self.training_parameters.should_early_stop
        patience = self.training_parameters.patience

        self.log_interval = self.training_parameters.log_interval
        self.snapshot_interval = self.training_parameters.snapshot_interval
        self.max_iterations = self.training_parameters.max_iterations
        self.should_clip_gradients = self.training_parameters.clip_gradients
        self.max_epochs = self.training_parameters.max_epochs

        self.early_stopping = EarlyStopping(
            self.model,
            self.checkpoint,
            monitored_metric,
            patience=patience,
            minimize=metric_minimize,
            should_stop=should_early_stop,
        )
        self.current_epoch = 0
        self.current_iteration = 0

        self.checkpoint.load_state_dict()

        self.not_debug = self.training_parameters.logger_level != "debug"

        self.lr_scheduler = None

        # TODO: Allow custom scheduler
        if self.training_parameters.lr_scheduler is True:
            scheduler_class = optim.lr_scheduler.LambdaLR
            scheduler_func = lambda x: lr_lambda_update(x, self.config)
            self.lr_scheduler = scheduler_class(
            self.optimizer, lr_lambda=scheduler_func
            )

    def config_based_setup(self):
        seed = self.config.training_parameters.seed
        if seed is None:
            return

        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def train(self):
        self.writer.write("===== Model =====")
        self.writer.write(self.model)

        # TODO: integrate this code more effectively later
        # self.gradcam(self.test_loader)
        # return

        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_iterations = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)

        self.writer.write("Starting training...")
        while self.current_iteration < self.max_iterations and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.task_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break
            
            print("Train loader length")
            print(len(self.train_loader))
            for batch in self.train_loader:
                print("hi")
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.current_iteration, "debug")

                registry.register("current_iteration", self.current_iteration)

                if self.current_iteration > self.max_iterations:
                    break

                self._run_scheduler()
                report = self._forward_pass(batch)
                self._update_meter(report, self.meter)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if should_break:
                    break

        self.finalize()

    def _run_scheduler(self):
        if self.lr_scheduler is not None:
            self.lr_scheduler.step(self.current_iteration)

    def _forward_pass(self, batch):
        prepared_batch = self.task_loader.prepare_batch(batch)
        self.profile("Batch prepare time")

        # Arguments should be a dict at this point
        model_output = self.model(prepared_batch)
        report = Report(prepared_batch, model_output)
        self.profile("Forward time")

        return report

    def _backward(self, loss):
        self.optimizer.zero_grad()

        # self.model.context_feature_embeddings_list[0][0].image_attention_model.module.transform.module.lc.weight.retain_grad()
        # self.model.image_feature_embeddings_list[0][0].image_attention_model.module.transform.module.lc.weight.retain_grad()
        loss.backward()

        # print(self.model.context_feature_embeddings_list[0][0].image_attention_model.module.transform.module.lc.weight.sum())
        # print(self.model.image_feature_embeddings_list[0][0].image_attention_model.module.transform.module.lc.weight.sum())
        # print(self.model.context_feature_embeddings_list[0][0].image_attention_model.module.transform.module.lc.weight.grad.sum())
        # print(self.model.image_feature_embeddings_list[0][0].image_attention_model.module.transform.module.lc.weight.grad.sum())

        if self.should_clip_gradients:
            clip_gradients(self.model, self.current_iteration, self.writer, self.config)

        self.optimizer.step()
        self.profile("Backward time")

    def _extract_loss(self, report):
        loss_dict = report.losses
        loss = sum([loss.mean() for loss in loss_dict.values()])
        return loss

    def finalize(self):
        # self.writer.write("Stepping into final validation check")
        # self._try_full_validation(force=True)
        self.checkpoint.restore()
        self.checkpoint.finalize()
        self.inference()

    def _update_meter(self, report, meter=None, eval_mode=False):
        if meter is None:
            meter = self.meter

        loss_dict = report.losses
        metrics_dict = report.metrics

        reduced_loss_dict = reduce_dict(loss_dict)
        reduced_metrics_dict = reduce_dict(metrics_dict)

        loss_key = report.dataset_type + "/total_loss"

        with torch.no_grad():
            reduced_loss = sum([loss.mean() for loss in reduced_loss_dict.values()])
            if hasattr(reduced_loss, "item"):
                reduced_loss = reduced_loss.item()

            registry.register(loss_key, reduced_loss)

            meter_update_dict = {loss_key: reduced_loss}
            meter_update_dict.update(reduced_loss_dict)
            meter_update_dict.update(reduced_metrics_dict)
            meter.update(meter_update_dict)

    def _logistics(self, report):
        should_print = self.current_iteration % self.log_interval == 0
        should_break = False
        extra = {}

        if should_print is True:
            if "cuda" in str(self.device):
                extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
                extra["max mem"] //= 1024

            extra.update(
            {
                "lr": "{:.5f}".format(self.optimizer.param_groups[0]["lr"]).rstrip(
                "0"
                ),
                "time": self.train_timer.get_time_since_start(),
                "eta": self._calculate_time_left(),
            }
            )

            self.train_timer.reset()
        '''
        # Evaluate on the whole validation set
        _, meter = self.evaluate(self.val_loader) # , single_batch=True)
        self.meter.update_from_meter(meter)
        
    # Don't print train metrics if it is not log interval
    # so as to escape clutter
    
    self._summarize_report(
        self.meter,
        should_print=should_print,
        extra=extra,
        prefix=report.dataset_name,
    )
    '''

    # should_break = self._try_full_validation()
    # Evaluate on full validation set
        should_val = self.current_iteration % self.snapshot_interval == 0
        should_break = False
        if should_val:
            print("Full validation")
            should_break = self.evaluate_full(self.val_loader) # hardcode the metric for now
        
        return should_break

    def _try_full_validation(self, force=False):
        should_break = False

        if self.current_iteration % self.snapshot_interval == 0 or force:
            self.writer.write("Evaluation time. Running on full validation set...")
            # Validation and Early stopping
            # Create a new meter for this case
            report, meter = self.evaluate(self.val_loader)

            extra = {"validation time": self.snapshot_timer.get_time_since_start()}

            stop = self.early_stopping(self.current_iteration, meter)
            stop = bool(broadcast_scalar(stop, src=0, device=self.device))

            extra.update(self.early_stopping.get_info())

            prefix = "{}: full val".format(report.dataset_name)

            self._summarize_report(meter, prefix=prefix, extra=extra)
            self.snapshot_timer.reset()
            gc.collect()

            if "cuda" in str(self.device):
                torch.cuda.empty_cache()

            if stop is True:
                self.writer.write("Early stopping activated")
                should_break = True

        return should_break

    def evaluate(self, loader, use_tqdm=False, single_batch=False):
        meter = Meter()

        with torch.no_grad():
            self.model.eval()
            for batch in tqdm(loader, disable=not use_tqdm):
                report = self._forward_pass(batch)
                self._update_meter(report, meter, eval_mode=True)

                if single_batch is True:
                    break
            self.model.train()

        return report, meter

    def evaluate_full(self, loader, use_tqdm=False):
        
        meter = Meter()

        # metrics = ['vqamb_map', 'vqamb_f1'] # hardcode metrics for now
        metrics = ['accuracy']
        # metrics = ['vqamb_f1pt']

        print(len(loader))
        
        with torch.no_grad():
            self.model.eval()
            tot_preds = []
            tot_targets = []
            tot_ids = []
            tot_att_pt = []
            tot_att_img = []
            tot_bbox_gt = []
            tot_bbox_pt = []
            tot_bbox_img = []
            tot_part = []
            # tot_qa_ids = []
            for batch in tqdm(loader, disable=not use_tqdm):
                report = self._forward_pass(batch)
                tot_preds.append(report.scores)
                tot_targets.append(report.targets)
                # tot_ids.extend(report.qa_id)
                # tot_att_pt.append(report.att)
                # tot_att_img.append(report.att_img)
                # tot_bbox_gt.append(report.gt_bbox)
                # tot_bbox_img.append(report.img_bbox)
                # tot_bbox_pt.append(report.pt_bbox)
                # tot_part.append(report.part)
                # tot_bbox_gt.append(report.gt_bbox)
                # tot_ptpath.append(report.ptpath)
                # tot_bbox_pt.append(report.bboxes)
                # tot_bbox_gt.append(report.gt_bbox)
                # tot_qa_ids.extend(report.qa_id)
                
            tot_preds = torch.cat(tot_preds, dim=0)
            tot_targets = torch.cat(tot_targets, dim=0)
            # tot_att_pt = torch.cat(tot_att_pt, dim=0)
            # tot_att_img = torch.cat(tot_att_img, dim=0)
            # tot_att_pt = torch.cat(tot_att_pt, dim=0)
            # tot_bbox_pt = torch.cat(tot_bbox_pt, dim=0)
            # tot_bbox_gt = torch.cat(tot_bbox_gt, dim=0)
            # tot_bbox_img = torch.cat(tot_bbox_img, dim=0)
            # Find bounding box with max attention
            
            # max_att_pt = tot_att_pt.argmax(dim=1)
            # max_bbox_pt = tot_bbox_pt[torch.arange(tot_bbox_pt.size(0)), max_att_pt]
            '''
            torch.save(tot_att_pt, 'tot_pt_att_objpartdev.pt')
            torch.save(tot_bbox_pt, 'tot_ptbboxes_objpartdev.pt')
            tot_part = sum(tot_part, [])
            torch.save(torch.Tensor(tot_part), 'tot_part_objpartdev.pt')
            '''
            # torch.save(tot_att_pt, 'tot_att_pt_localqafinal.pt')
            # torch.save(tot_att_img, 'tot_att_img_pythiaptfinal.pt')
            # torch.save(tot_bbox_pt, 'tot_bbox_pt_localqafinal.pt')
            # torch.save(tot_bbox_img, 'tot_bbox_img_pythia_ptfinal.pt')
            # torch.save(tot_bbox_gt, 'tot_bboxgt_localqafinal.pt')
            # torch.save(tot_preds, 'tot_preds_localqafinal.pt')
            # torch.save(tot_targets, 'tot_targets_localqafinal.pt')
            
            # torch.save(max_bbox_pt, 'max_pt_bbox_pythiaptfinal.pt')
            # torch.save(tot_bbox_gt, 'gt_bbox_pythiaptfinal.pt')
            
            # torch.save(tot_preds, 'tot_preds_localqa.pt')
            # torch.save(tot_targets, 'tot_targets_localqa.pt')
            # torch.save(tot_ptpath, 'tot_ptpath_vqambnew.pt')
            # torch.save(tot_att, 'tot_att_vqambnew.pt')
            # tot_qa_ids = torch.Tensor(tot_qa_ids)
            # torch.save(tot_qa_ids, 'tot_qa_ids.pt')

            model_output = {"scores": tot_preds}
            sample = Sample({"targets": tot_targets}) # "qa_index": tot_qa_index}) # "dataset_type": report.dataset_type, "dataset_name": report.dataset_name})
            sample_list = SampleList([sample])
            sample_list.add_field('dataset_type', report.dataset_type)
            sample_list.add_field('dataset_name', report.dataset_name)

            metric_fn = Metrics(metrics)
            full_met = metric_fn(sample_list, model_output)
            self.writer.write(full_met)

            if report.dataset_type == 'test':
                return
            
            meter.update(full_met)
            stop = self.early_stopping(self.current_iteration, meter)

            should_break = False
            if stop is True:
                self.writer.write("Early stopping activated")
                should_break = True
            
            self.model.train()

        return should_break

    def _summarize_report(self, meter, prefix="", should_print=True, extra={}):
        if not is_main_process():
            return

        scalar_dict = meter.get_scalar_dict()
        self.writer.add_scalars(scalar_dict, registry.get("current_iteration"))

        if not should_print:
            return

        print_str = []

        if len(prefix):
            print_str += [prefix + ":"]

        print_str += ["{}/{}".format(self.current_iteration, self.max_iterations)]
        print_str += [str(meter)]
        print_str += ["{}: {}".format(key, value) for key, value in extra.items()]

        self.writer.write(meter.delimiter.join(print_str))

    def inference(self):
    
        if "val" in self.run_type:
            self._inference_run("val")

        if "inference" in self.run_type or "predict" in self.run_type:
            self._inference_run("test")

    def _inference_run(self, dataset_type):
        if self.config.training_parameters.evalai_inference is True:
            self.predict_for_evalai(dataset_type)
            return

        self.writer.write("Starting inference on {} set".format(dataset_type))

        # Evaluate on full validation/test set
        self.evaluate_full(
            getattr(self, "{}_loader".format(dataset_type)), use_tqdm=True
        )

    '''
    report, meter = self.evaluate(
        getattr(self, "{}_loader".format(dataset_type)), use_tqdm=True
    )
    prefix = "{}: full {}".format(report.dataset_name, dataset_type)
    self._summarize_report(meter, prefix)
    '''

    def _calculate_time_left(self):
        time_taken_for_log = time.time() * 1000 - self.train_timer.start
        if self.max_iterations == math.inf:
            iterations_left = self.max_epochs*len(self.train_loader) - self.current_iteration
        else:
            iterations_left = self.max_iterations - self.current_iteration
        num_logs_left = iterations_left / self.log_interval
        time_left = num_logs_left * time_taken_for_log

        snapshot_iteration = self.snapshot_iterations / self.log_interval
        snapshot_iteration *= iterations_left / self.snapshot_interval
        time_left += snapshot_iteration * time_taken_for_log
        return self.train_timer.get_time_hhmmss(gap=time_left)

    def profile(self, text):
        if self.not_debug:
            return
        self.writer.write(text + ": " + self.profiler.get_time_since_start(), "debug")
        self.profiler.reset()

    def predict_for_evalai(self, dataset_type):
        reporter = self.task_loader.get_test_reporter(dataset_type)
        with torch.no_grad():
            self.model.eval()
            message = "Starting {} inference for evalai".format(dataset_type)
            self.writer.write(message)

            while reporter.next_dataset():
                dataloader = reporter.get_dataloader()

                for batch in tqdm(dataloader):
                    prepared_batch = reporter.prepare_batch(batch)
                    model_output = self.model(prepared_batch)
                    report = Report(prepared_batch, model_output)
                    reporter.add_to_report(report)

            self.writer.write("Finished predicting")
            self.model.train()

    def gradcam(self, loader, use_tqdm=False):

        # self.model.eval()
        vqa_gradcam = GradCAM(self.model, target_layer='resnet152_model.7')
        cnt = 0
        tot_pts, tot_corr = 0, 0
        for batch in tqdm(loader, disable=not use_tqdm):
            cnt = cnt + 1
            prepared_batch = self.task_loader.prepare_batch(batch)
            
            probs = vqa_gradcam.forward(prepared_batch, (prepared_batch.img_h[0], prepared_batch.img_w[0])) # hardcode image size for now

            heatmaps = []
            for ans in range(probs.shape[1]):
                answer_inds = torch.LongTensor([[ans] * probs.shape[0]]).to(self.device)
                vqa_gradcam.backward(answer_inds)
                heatmap = vqa_gradcam.generate()
                heatmaps.append(heatmap)

            heatmaps = torch.stack(heatmaps, dim=0)
            num_corr, num_pts = self.max_inference(heatmaps, prepared_batch.points[0])
            tot_corr += num_corr
            tot_pts += num_pts
            print(tot_corr / tot_pts)
            

    def max_inference(self, heatmaps, points):
        max_preds = heatmaps.argmax(dim=0)

        answer_list = ['blue', 'black', 'white', 'red', 'yellow', 'pink', 'gray', 'orange', 'brown',
                       'square', 'round', 'green', 'purple', 'walking', 'standing', 'sitting', 'playing', 'rectangular', 'looking', 'hanging']
        num_pts = 0
        num_corr = 0
        for pt in points:
            x, y = pt['x'], pt['y']
            ans = answer_list.index(pt['ans'])

            num_corr += (max_preds[y][x] == ans).item()
            num_pts += 1

        return num_corr, num_pts
            
                


            
            
Пример #15
0
    def train(self):
        self.writer.write("===== Model =====")
        self.writer.write(self.model)

        # TODO: integrate this code more effectively later
        # self.gradcam(self.test_loader)
        # return

        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_iterations = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)

        self.writer.write("Starting training...")
        while self.current_iteration < self.max_iterations and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.task_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break
            
            print("Train loader length")
            print(len(self.train_loader))
            for batch in self.train_loader:
                print("hi")
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.current_iteration, "debug")

                registry.register("current_iteration", self.current_iteration)

                if self.current_iteration > self.max_iterations:
                    break

                self._run_scheduler()
                report = self._forward_pass(batch)
                self._update_meter(report, self.meter)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if should_break:
                    break

        self.finalize()
Пример #16
0
class BaseTrainer:
    def __init__(self, config):
        self.config = config
        self.profiler = Timer()
        #self.importance_scores = defaultdict(dict)

    def load(self):
        self._init_process_group()

        self.run_type = self.config.training_parameters.get("run_type", "train")
        self.task_loader = TaskLoader(self.config)

        self.writer = Logger(self.config)
        registry.register("writer", self.writer)

        self.configuration = registry.get("configuration")
        self.configuration.pretty_print()

        self.config_based_setup()

        self.load_task()
        self.load_model()
        self.load_optimizer()
        self.load_extras()

    def _init_process_group(self):
        training_parameters = self.config.training_parameters
        self.local_rank = training_parameters.local_rank
        self.device = training_parameters.device

        if self.local_rank is not None and training_parameters.distributed:
            if not torch.distributed.is_nccl_available():
                raise RuntimeError(
                    "Unable to initialize process group: NCCL is not available"
                )
            torch.distributed.init_process_group(backend="nccl")
            synchronize()

        if (
            "cuda" in self.device
            and training_parameters.distributed
            and self.local_rank is not None
        ):
            self.device = torch.device("cuda", self.local_rank)

        registry.register("current_device", self.device)

    def load_task(self):
        self.writer.write("Loading tasks and data", "info")
        self.task_loader.load_task()

        self.task_loader.make_dataloaders()

        self.train_loader = self.task_loader.train_loader
        self.val_loader = self.task_loader.val_loader
        self.test_loader = self.task_loader.test_loader
        self.train_task = self.task_loader.train_task
        self.val_task = self.task_loader.val_task

        # Total iterations for snapshot
        self.snapshot_iterations = len(self.val_task)
        self.snapshot_iterations //= self.config.training_parameters.batch_size

        self.test_task = self.task_loader.test_task

    def load_model(self):
        attributes = self.config.model_attributes[self.config.model]
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_attributes[attributes]

        attributes["model"] = self.config.model

        self.task_loader.update_registry_for_model(attributes)
        self.model = build_model(attributes)
        self.task_loader.clean_config(attributes)
        training_parameters = self.config.training_parameters

        data_parallel = training_parameters.data_parallel
        distributed = training_parameters.distributed

        registry.register("data_parallel", data_parallel)
        registry.register("distributed", distributed)

        if "cuda" in str(self.config.training_parameters.device):
            rank = self.local_rank if self.local_rank is not None else 0
            device_info = "CUDA Device {} is: {}".format(
                rank, torch.cuda.get_device_name(self.local_rank)
            )

            self.writer.write(device_info, log_all=True)

        self.model = self.model.to(self.device)

        self.writer.write("Torch version is: " + torch.__version__)

        if (
            "cuda" in str(self.device)
            and torch.cuda.device_count() > 1
            and data_parallel is True
        ):
            self.model = torch.nn.DataParallel(self.model)

        if (
            "cuda" in str(self.device)
            and self.local_rank is not None
            and distributed is True
        ):
            torch.cuda.set_device(self.local_rank)
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model, device_ids=[self.local_rank]
            )

    def load_optimizer(self):
        self.optimizer = build_optimizer(self.model, self.config)

    def load_extras(self):
        self.checkpoint = Checkpoint(self)
        self.meter = Meter()

        self.training_parameters = self.config.training_parameters

        monitored_metric = self.training_parameters.monitored_metric
        metric_minimize = self.training_parameters.metric_minimize
        should_early_stop = self.training_parameters.should_early_stop
        patience = self.training_parameters.patience

        self.log_interval = self.training_parameters.log_interval
        self.snapshot_interval = self.training_parameters.snapshot_interval
        self.test_interval = self.training_parameters.test_interval
        self.max_iterations = self.training_parameters.max_iterations
        self.should_clip_gradients = self.training_parameters.clip_gradients
        self.max_epochs = self.training_parameters.max_epochs

        self.early_stopping = EarlyStopping(
            self.model,
            self.checkpoint,
            monitored_metric,
            patience=patience,
            minimize=metric_minimize,
            should_stop=should_early_stop,
        )
        self.current_epoch = 0
        self.current_iteration = 0

        self.checkpoint.load_state_dict()

        self.not_debug = self.training_parameters.logger_level != "debug"

        self.lr_scheduler = None

        # TODO: Allow custom scheduler
        if self.training_parameters.lr_scheduler is True:
            scheduler_class = optim.lr_scheduler.LambdaLR
            scheduler_func = lambda x: lr_lambda_update(x, self.config)
            self.lr_scheduler = scheduler_class(
                self.optimizer, lr_lambda=scheduler_func
            )

    def config_based_setup(self):
        seed = self.config.training_parameters.seed
        if seed is None:
            return

        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def train(self):
        self.writer.write("===== Model =====")
        self.writer.write(self.model)

        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_iterations = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)

        self.writer.write("Starting training...")
        while self.current_iteration < self.max_iterations and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.task_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break

            for batch in self.train_loader:
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.current_iteration, "debug")

                registry.register("current_iteration", self.current_iteration)

                if self.current_iteration > self.max_iterations:
                    break

                self._run_scheduler()
                report = self._forward_pass(batch)
                #pdb.set_trace()
                self._update_meter(report, self.meter)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if should_break:
                    break

        self.finalize()

    def _run_scheduler(self):
        if self.lr_scheduler is not None:
            self.lr_scheduler.step(self.current_iteration)
    
    def compute_grad_cam(self, report, model_output):
        importance_vectors = []
        scores = model_output['scores']
        classes = report['gt_answer_index']
        classes_one_hot = torch.zeros_like(scores)
        classes_one_hot[range(classes_one_hot.shape[0]), classes] = 1
        grads = torch.autograd.grad(outputs = scores, inputs = self.model.joint_embedding, grad_outputs = classes_one_hot, create_graph=True)[0].to(self.device)
        importance_vectors_cam = grads * self.model.joint_embedding
        importance_vectors.append(self.model.question_embedding)
        importance_vectors.append(importance_vectors_cam)
        importance_vectors.append(torch.cat((importance_vectors_cam, self.model.question_embedding), 1))

        return importance_vectors
    
    def store_importance_vectors(self, report, importance_vectors):
        with open("importance_scores_other_questions_2.json", "r+") as file:
            data = json.load(file)
            data_df = defaultdict(list, data)
            for idx in range(len(report['image_id'])):
                data_df[str(report['image_id'][idx].item())].append({str(report['question_id'][idx].item()):[report['question_text'][idx],importance_vectors[idx].tolist()]})
            file.seek(0)
            json.dump(dict(data_df), file)
        
    def store_importance_vectors_csv(self, report, importance_vectors):
        predicted_answer_ids = report['scores'].argmax(dim=1)
        with open("/srv/share/sameer/pythia_results/clean_val_reas.csv", "a+", newline='') as file:
            answer_processor = registry.get("vqa_introspect_answer_processor")
            for idx in range(len(report['image_id'])):
                predicted_answer = answer_processor.idx2word(predicted_answer_ids[idx])
                row_to_append = [str(report['image_id'][idx].item()), report['image_url'][idx], report['question_id'][idx].item(), report['reasoning_question'][idx], report['reasoning_answer'][idx], report['question_text'][idx], predicted_answer, report['answers'][idx][0], importance_vectors[0][idx].tolist(), importance_vectors[1][idx].tolist(), importance_vectors[2][idx].tolist()]
                csv_writer = writer(file)
                csv_writer.writerow(row_to_append)
    
    def _forward_pass(self, batch):
        
        prepared_batch = self.task_loader.prepare_batch(batch)
        self.profile("Batch prepare time")
        model_output = self.model(prepared_batch)
        report = Report(prepared_batch, model_output)
        self.profile("Forward time")

        return report

    def _backward(self, loss):
        self.optimizer.zero_grad()
        loss.backward()

        if self.should_clip_gradients:
            clip_gradients(self.model, self.current_iteration, self.writer, self.config)

        self.optimizer.step()
        self.profile("Backward time")

    def _extract_loss(self, report):
        loss_dict = report.losses
        loss = sum([loss.mean() for loss in loss_dict.values()])
        return loss

    def finalize(self):
        self.writer.write("Stepping into final validation check")
        self._try_full_validation(force=True)
        self.checkpoint.restore()
        self.checkpoint.finalize()
        self.inference()

    def _update_meter(self, report, meter=None, eval_mode=False):
        if meter is None:
            meter = self.meter
        #pdb.set_trace()

        loss_dict = report.losses
        metrics_dict = report.metrics

        reduced_loss_dict = reduce_dict(loss_dict)
        reduced_metrics_dict = reduce_dict(metrics_dict)

        loss_key = report.dataset_type + "/total_loss"

        with torch.no_grad():
            reduced_loss = sum([loss.mean() for loss in reduced_loss_dict.values()])
            if hasattr(reduced_loss, "item"):
                reduced_loss = reduced_loss.item()

            registry.register(loss_key, reduced_loss)

            meter_update_dict = {loss_key: reduced_loss}
            meter_update_dict.update(reduced_loss_dict)
            meter_update_dict.update(reduced_metrics_dict)
            meter.update(meter_update_dict)

    def _logistics(self, report):
        should_print = self.current_iteration % self.log_interval == 0
        should_break = False
        extra = {}

        if should_print is True:
            if "cuda" in str(self.device):
                extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
                extra["max mem"] //= 1024

            extra.update(
                {
                    "lr": "{:.5f}".format(self.optimizer.param_groups[0]["lr"]).rstrip(
                        "0"
                    ),
                    "time": self.train_timer.get_time_since_start(),
                    "eta": self._calculate_time_left(),
                }
            )

            self.train_timer.reset()

            _, meter = self.evaluate(self.val_loader, single_batch=True)
            self.meter.update_from_meter(meter)

        # Don't print train metrics if it is not log interval
        # so as to escape clutter
        self._summarize_report(
            self.meter,
            should_print=should_print,
            extra=extra,
            prefix=report.dataset_name,
        )

        should_break = self._try_full_validation()

        return should_break

    def _try_full_validation(self, force=False):
        should_break = False
        

        if self.current_iteration % self.snapshot_interval == 0 or force:
            self.writer.write("Evaluation time. Running on full validation set...")
            # Validation and Early stopping
            # Create a new meter for this case
            report, meter = self.evaluate(self.val_loader)

            extra = {"validation time": self.snapshot_timer.get_time_since_start()}

            stop = self.early_stopping(self.current_iteration, meter)
            stop = bool(broadcast_scalar(stop, src=0, device=self.device))

            extra.update(self.early_stopping.get_info())

            prefix = "{}: full val".format(report.dataset_name)

            self._summarize_report(meter, prefix=prefix, extra=extra)
            self.snapshot_timer.reset()
            gc.collect()

            if "cuda" in str(self.device):
                torch.cuda.empty_cache()

            if stop is True:
                self.writer.write("Early stopping activated")
                should_break = True
        
        if self.current_iteration == 22001: 
            self.writer.write("Testing time. Running on full test set...") 
            # test evaluation:
            report_test, meter_test = self.evaluate(self.test_loader)
            extra = {"test time": self.snapshot_timer.get_time_since_start()}
            prefix = "{}: full test".format(report_test.dataset_name)
            self._summarize_report(meter_test, prefix=prefix, extra=extra)

        return should_break

    def evaluate(self, loader, use_tqdm=False, single_batch=False):
        meter = Meter()

        for batch in tqdm(loader, disable=not use_tqdm):
            report = self._forward_pass(batch)
            self._update_meter(report, meter, eval_mode=True)

            if single_batch is True:
                break
        self.model.train()

        return report, meter

    def _summarize_report(self, meter, prefix="", should_print=True, extra={}):
        if not is_main_process():
            return

        scalar_dict = meter.get_scalar_dict()
        self.writer.add_scalars(scalar_dict, registry.get("current_iteration"))

        if not should_print:
            return

        print_str = []

        if len(prefix):
            print_str += [prefix + ":"]

        print_str += ["{}/{}".format(self.current_iteration, self.max_iterations)]
        print_str += [str(meter)]
        print_str += ["{}: {}".format(key, value) for key, value in extra.items()]

        self.writer.write(meter.delimiter.join(print_str))

    def inference(self):
        if "val" in self.run_type:
            self._inference_run("val")

        if "inference" in self.run_type or "predict" in self.run_type:
            self._inference_run("test")

    def _inference_run(self, dataset_type):
        if self.config.training_parameters.evalai_inference is True:
            self.predict_for_evalai(dataset_type)
            return

        self.writer.write("Starting inference on {} set".format(dataset_type))

        report, meter = self.evaluate(
            getattr(self, "{}_loader".format(dataset_type)), use_tqdm=True
        )
        prefix = "{}: full {}".format(report.dataset_name, dataset_type)
        self._summarize_report(meter, prefix)

    def _calculate_time_left(self):
        time_taken_for_log = time.time() * 1000 - self.train_timer.start
        iterations_left = self.max_iterations - self.current_iteration
        num_logs_left = iterations_left / self.log_interval
        time_left = num_logs_left * time_taken_for_log

        snapshot_iteration = self.snapshot_iterations / self.log_interval
        snapshot_iteration *= iterations_left / self.snapshot_interval
        time_left += snapshot_iteration * time_taken_for_log

        return self.train_timer.get_time_hhmmss(gap=time_left)

    def profile(self, text):
        if self.not_debug:
            return
        self.writer.write(text + ": " + self.profiler.get_time_since_start(), "debug")
        self.profiler.reset()

    def predict_for_evalai(self, dataset_type):
        reporter = self.task_loader.get_test_reporter(dataset_type)
        with torch.no_grad():
            self.model.eval()
            message = "Starting {} inference for evalai".format(dataset_type)
            self.writer.write(message)

            while reporter.next_dataset():
                dataloader = reporter.get_dataloader()

                for batch in tqdm(dataloader):
                    prepared_batch = reporter.prepare_batch(batch)
                    model_output = self.model(prepared_batch)
                    report = Report(prepared_batch, model_output)
                    reporter.add_to_report(report)

            self.writer.write("Finished predicting")
            self.model.train()