示例#1
0
    def run_epoch(self, dataloader, phase: str, epoch: int):
        """
        Run an epoch in training or validation.

        (This function is called in fit and it is NOT RECOMMENDED to use this function from outside.)

        Optimizer is "optional" if it is set to None, it is a validation run otherwise it is a training run.

        :param dataloader: The dataloader created from a dataset.
        :param phase: The phase (train/dev/test) which is used for running.
        :param epoch: The epoch number.
        :return: Returns the average loss.
        """
        if self.model is None:
            raise RuntimeError("You must compile the trainer first!")

        if torch.has_cuda:
            torch.cuda.empty_cache()
        for callback in self.callbacks:
            callback.on_epoch_begin(dataloader, phase, epoch)

        # Loop over the dataset_class and update weights.
        for step, (network_inputs, targets) in enumerate(dataloader):
            for callback in self.callbacks:
                callback.on_iter_begin(step, network_inputs, targets)

            # Forward pass, computing gradients and applying them
            self.optimizer.zero_grad()
            network_output = self.model(*network_inputs)
            if isinstance(network_output, Tensor):
                if network_output.isnan().any():
                    print()
                    error("NaN NetworkOutput: {}".format(network_output))
                    raise ValueError("NetworkOutput got nan.")
            else:
                for name, p in network_output._asdict().items():
                    if p.isnan().any():
                        print()
                        error("NaN NetworkOutput {}: {}".format(name, p))
                        raise ValueError("NetworkOutput {} got nan.".format(name))
            loss_result = self.loss(y_true=targets, y_pred=network_output)
            tensorboard.log_scalar("loss/total", loss_result)

            if loss_result == 0:
                print()
                warn("Loss is exactly 0, is this a bug?")
            else:
                if loss_result.isnan().any():
                    error("NaN Loss")
                    raise ValueError("Loss got nan.")

                if phase == "train":
                    loss_result.backward()
                    self.optimizer.step()

            for callback in self.callbacks:
                callback.on_iter_end(network_output, loss_result)

        for callback in self.callbacks:
            callback.on_epoch_end()
    def on_fit_start(self, model, train_dataloader, dev_dataloader, loss,
                     optimizer, start_epoch: int, epochs: int) -> int:
        start_epoch = super().on_fit_start(model, train_dataloader,
                                           dev_dataloader, loss, optimizer,
                                           start_epoch, epochs)
        log_path = get_log_path()
        if log_path is None:
            raise RuntimeError(
                "You must setup logger before calling the fit method. See babilim.core.logging.set_logger"
            )
        create_checkpoint_structure()

        self.train_summary_writer = SummaryWriter(
            os.path.join(log_path, "train"))
        self.train_summary_txt = os.path.join(log_path, "train", "log.txt")
        self.dev_summary_writer = SummaryWriter(os.path.join(log_path, "val"))
        self.dev_summary_txt = os.path.join(log_path, "val", "log.txt")

        try:
            self.train_summary_writer.add_graph(model,
                                                input_to_model=next(
                                                    iter(train_dataloader))[0])
            self.dev_summary_writer.add_graph(model,
                                              input_to_model=next(
                                                  iter(dev_dataloader))[0])
        except:
            warn(
                "Cannot log model. Does it use **kwargs instead of *args somewhere?"
            )
        return start_epoch
示例#3
0
def set_main_config(config):
    """
    Set the config that is used for inject_kwargs.

    :param config: A configuration object. It must have the parameters as instance attributes.
    """
    global _config
    if _config is not None:
        warn("You are overwriting the main config. This might cause bugs!")
    _config = config
示例#4
0
def load_weights(checkpoint_path: str, model):
    """
    Load the weights from a checkpoint into a model.
    
    :param checkpoint_path: The path to the file in which the checkpoint is stored.
    :param model: The model for which to set the state from the checkpoint.
    """
    checkpoint = load_state(checkpoint_path)
    if "model" in checkpoint:
        if logging.DEBUG_VERBOSITY:
            logging.info("Load Model...")
        model.load_state_dict(checkpoint["model"])
    else:
        logging.warn("Could not find model_state in checkpoint.")
示例#5
0
    def restore(self, state_dict_path):
        # Load Checkpoint
        logging.info("Loading checkpoint: {}".format(state_dict_path))
        checkpoint = load_state(state_dict_path)
        self.epoch = checkpoint["epoch"] + 1
        if "model" in checkpoint:
            if logging.DEBUG_VERBOSITY:
                logging.info("Load Model...")
            self.model.load_state_dict(checkpoint["model"])
        else:
            logging.warn("Could not find model_state in checkpoint.")
        if "optimizer" in checkpoint:
            if logging.DEBUG_VERBOSITY:
                logging.info("Load Optimizer...")
            self.optimizer.load_state_dict(checkpoint["optimizer"])
        else:
            logging.warn("Could not find optimizer_state in checkpoint.")
        if "loss" in checkpoint:
            if logging.DEBUG_VERBOSITY:
                logging.info("Load Loss...")
            self.loss.load_state_dict(checkpoint["loss"])
        else:
            logging.warn("Could not find loss_state in checkpoint.")

        if logging.DEBUG_VERBOSITY:
            logging.info("Trainable Variables:")
            # TODO
            logging.info("Untrainable Variables:")
示例#6
0
 def __init__(self, args, returns, layers, spec, _local_variables):
     super().__init__()
     self._args = args
     self._returns = returns
     can_namedtuple_output = True
     for ret_name in returns:
         if not isinstance(ret_name, str):
             can_namedtuple_output = False
             break
     self._return_type = None
     if can_namedtuple_output and len(returns) > 1:
         self._return_type = namedtuple("ReturnType", self._returns)
     self.submodules = []
     self._inputs = []
     self._outputs = []
     self._spec = spec
     for idx, layer in enumerate(layers):
         layer = layer.copy()
         for key, val in layer.items():
             if isinstance(val, str) and val.startswith("spec:"):
                 layer[key] = self._spec[val.replace("spec:", "")]
         if "disabled" in layer and layer["disabled"]:
             if WARN_DISABLED_LAYERS:
                 warn("Disabled Layer: {}".format(layer))
             continue
         layer_type = layer["type"]
         del layer["type"]
         name = layer["name"] if "name" in layer else "layer_{}".format(idx)
         if "name" in layer:
             del layer["name"]
         module = Module.create(layer_type,
                                _local_variables=_local_variables,
                                **layer)
         self.submodules.append(module)
         self.add_module(name, module)
     self._local_variables = _local_variables
示例#7
0
 def on_fit_interruted(self, exception) -> None:
     super().on_fit_interruted(exception)
     warn("Fit interrupted by user!")
示例#8
0
    def __init__(self,
                 split,
                 DatasetInput=InputType,
                 DatasetOutput=OutputType,
                 data_version=None,
                 data_path="None",
                 model_categories=[],
                 data_image_size=None) -> None:
        super().__init__(split, DatasetInput, DatasetOutput)
        version = data_version  # 2014, 2017
        self.data_image_size = data_image_size
        self.image_folder = os.path.join(data_path, "images",
                                         f"{split}{version}")
        all_sample_tokens = os.listdir(self.image_folder)

        with open(
                os.path.join(data_path, "annotations",
                             f"instances_{split}{version}.json"), "r") as f:
            instances = json.loads(f.read())

        self.class_id_to_category = {0: 0}  # Background
        category_id_to_class_id = {0: 0}  # Background
        if len(model_categories) == 0:
            model_categories = ["background"]
            for category in instances["categories"]:
                model_categories.append(category["name"])
        for category in instances["categories"]:
            idx = model_categories.index(category["name"])
            if idx > 0:
                self.class_id_to_category[
                    idx] = category  # List of {'supercategory': 'indoor', 'id': 88, 'name': 'teddy bear'}
                category_id_to_class_id[category["id"]] = idx
        get_main_config().model_categories = model_categories

        images = {}
        for image in instances["images"]:
            images[image["id"]] = image["file_name"]

        self.annotations = {}
        for anno in instances["annotations"]:
            filename = images[anno["image_id"]]
            if not filename in self.annotations:
                self.annotations[filename] = []
            if anno["category_id"] in category_id_to_class_id:
                anno["category_id"] = category_id_to_class_id[
                    anno["category_id"]]
                self.annotations[filename].append(anno)
                # {'segmentation': [[312.29, 562.89, 402.25, 511.49, 400.96, 425.38, 398.39, 372.69, 388.11, 332.85, 318.71, 325.14,
                # 295.58, 305.86, 269.88, 314.86, 258.31, 337.99, 217.19, 321.29, 182.49, 343.13, 141.37, 348.27, 132.37, 358.55,
                # 159.36, 377.83, 116.95, 421.53, 167.07, 499.92, 232.61, 560.32, 300.72, 571.89]],
                # 'area': 54652.9556, 'iscrowd': 0, 'image_id': 480023, 'bbox': [116.95, 305.86, 285.3, 266.03],
                # 'category_id': 58, 'id': 86}
        #print(instances["annotations"][1])
        self.images_with_no_anno = []
        for sample_token in all_sample_tokens:
            if sample_token in self.annotations:
                self.all_sample_tokens.append(sample_token)
            else:
                self.images_with_no_anno.append(sample_token)
        if len(self.images_with_no_anno) > 0:
            warn("Images with no anno: {} (split={})".format(
                len(self.images_with_no_anno), split))
        self.cheap_cache_sample_token = ""
        self.cheap_cache_image = None