コード例 #1
0
    def __init__(self, loss_config: AttrDict):
        """
        Intializer for the sum cross-entropy loss. For a single
        tensor, this is equivalent to the cross-entropy loss. For a
        list of tensors, this computes the sum of the cross-entropy
        losses for each tensor in the list against the target.

        Config params:
            reduction: specifies reduction to apply to the output, optional
            normalize_output: Whether to L2 normalize the outputs
            world_size: total number of gpus in training. automatically inferred by vissl
        """
        super(BCELogitsMultipleOutputSingleTargetLoss, self).__init__()
        self.loss_config = loss_config
        self._losses = torch.nn.modules.ModuleList([])
        self._reduction = loss_config.get("reduction", "none")
        self._normalize_output = loss_config.get("normalize_output", False)
        self._world_size = loss_config["world_size"]
コード例 #2
0
ファイル: fastmri_dataset.py プロジェクト: sbelenki/vissl
 def __init__(self, cfg: AttrDict, path: str, split: str, dataset_name="fastmri_dataset", data_source="fastmri"):
     super(FastMRIDataSet, self).__init__()
     
     assert PathManager.isdir(path), f"Directory {path} does not exist"
     self.dataset_name = "singlecoil"
     self.data_source = "fastmri"
     self.path = path
     
     data = cfg.get("DATA", AttrDict({}))
     self.key = data.get("KEY", "reconstruction_esc")
     self.index = data.get("INDEX", 12)
     self.split = split.lower()
     self.dataset = self._load_data()
コード例 #3
0
ファイル: mean_ap_meter.py プロジェクト: zlapp/vissl
 def __init__(self, meters_config: AttrDict):
     self.num_classes = meters_config.get("num_classes")
     self._total_sample_count = None
     self._curr_sample_count = None
     self.reset()