Exemple #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-modelType',
                        default=4,
                        type=int,
                        help='Refer train_utils.py ')
    parser.add_argument('-numSpkrs',
                        default=7323,
                        type=int,
                        help='Number of output labels for model')
    parser.add_argument('modelDirectory',
                        help='Directory containing the model checkpoints')
    parser.add_argument(
        'featDir', help='Directory containing features ready for extraction')
    parser.add_argument('embeddingDir', help='Output directory')
    args = parser.parse_args()

    modelFile = max(glob.glob(args.modelDirectory + '/*'),
                    key=os.path.getctime)
    # Load model definition
    if args.modelType == 3:
        net = simpleTDNN(args.numSpkrs, p_dropout=0)
    else:
        net = xvecTDNN(args.numSpkrs, p_dropout=0)

    checkpoint = torch.load(modelFile, map_location=torch.device('cuda'))
    new_state_dict = OrderedDict()
    for k, v in checkpoint['model_state_dict'].items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v  # ugly fix to remove 'module' from key
        else:
            new_state_dict[k] = v

    # load trained weights
    net.load_state_dict(new_state_dict)
    net = net.cuda()
    net.eval()

    # Parallel Processing
    try:
        nSplits = int(
            sorted(glob.glob(args.featDir + '/split*'),
                   key=getSplitNum)[-1].split('/')[-1].lstrip('split'))
    except:
        print('Cannot find %s/splitN directory' % args.featDir)
        sys.exit(1)

    if not os.path.isdir(args.embeddingDir):
        os.makedirs(args.embeddingDir)
    nProcs = nSplits
    L = [('%s/split%d/%d/feats.scp' % (args.featDir, nSplits, i),
          '%s/xvector.%d.ark' % (args.embeddingDir, i),
          '%s/xvector.%d.scp' % (args.embeddingDir, i), net, 'fc1')
         for i in range(1, nSplits + 1)]
    pool2 = Pool(processes=nProcs)
    result = pool2.starmap(par_core_extractXvectors, L)
    pool2.terminate()

    os.system('cat %s/xvector.*.scp > %s/xvector.scp' %
              (args.embeddingDir, args.embeddingDir))
Exemple #2
0
def meta_ars(env_name,
             policy,
             meta_epochs,
             meta_seed,
             n_seeds=4,
             n_top_seeds=1,
             n_workers=4,
             mean_lookback=10,
             ars_epochs=10,
             env_config=None,
             step_size=.02,
             n_delta=32,
             n_top=16,
             exp_noise=0.03):

    n_children = n_seeds // n_top_seeds
    np.random.seed(meta_seed)

    W = torch.nn.utils.parameters_to_vector(policy.parameters())
    W = torch.zeros_like(W)
    torch.nn.utils.vector_to_parameters(W, policy.parameters())

    pool = Pool(processes=n_seeds)
    ars_partial = partial(ars, env_name, ars_epochs, env_config, step_size,
                          n_delta, n_top, exp_noise, n_workers)
    #root = Node(meta_seed)
    reward_log = []

    top_policies = []
    for _ in range(n_top_seeds):
        top_policies.append(copy.deepcopy(policy))

    for epoch in range(meta_epochs):
        pols_and_seeds = []
        for pol in top_policies:
            for _ in range(n_children):
                pols_and_seeds.append(
                    (pol, int(np.random.randint(0, 2**32 - 1, 1))))

        results = pool.starmap(ars_partial, pols_and_seeds)

        p_list = []
        r_list = []
        for result in results:
            policy, rews = result
            p_list.append(policy)
            r = torch.stack(rews[-mean_lookback:])
            r_list.append(r.mean())

        top_idx = sorted(range(len(r_list)),
                         key=lambda k: r_list[k],
                         reverse=True)[:n_top_seeds]
        for i in top_idx:
            top_policies.append(p_list[i])

        reward_log.append(max(r_list))

    return top_policies, reward_log
Exemple #3
0
class Predictor:
    def __init__(self, cfg_path: str, num_workers: int = 4) -> None:
        cfg = get_cfg()
        cfg.merge_from_file(model_zoo.get_config_file(cfg_path))
        # NOTE: you may customize cfg settings
        # cfg.MODEL.DEVICE="cuda" # use gpu by default
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
        # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
        # you can also give a path to you checkpoint
        cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(cfg_path)

        self.cfg = cfg.clone()
        self.model = build_model(cfg)
        self.model.eval()
        checkpointer = DetectionCheckpointer(self.model)
        checkpointer.load(cfg.MODEL.WEIGHTS)

        self.aug = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST],
            cfg.INPUT.MAX_SIZE_TEST)
        self.pool = Pool(num_workers)

    @staticmethod
    def url_to_img(aug: object, device: str, url: str) -> Dict:
        img_response = requests.get(url)
        img = np.array(Image.open(BytesIO(img_response.content)))
        height, width = img.shape[:2]
        img = aug.get_transform(img).apply_image(img)
        img = torch.as_tensor(img.astype("float32").transpose(2, 0, 1))
        if device == "cuda":
            img = img.pin_memory()
            img = img.cuda(non_blocking=True)
        return {"image": img, "height": height, "width": width}

    def load_inputs(self, urls: List[str]) -> List[Dict]:
        # return [self.url_to_img(url) for url in urls]
        n = len(urls)
        return list(
            self.pool.starmap(
                self.url_to_img,
                zip([self.aug] * n, [self.cfg.MODEL.DEVICE] * n, urls)))

    def __call__(self, items: Dict) -> List:
        urls = [item["url"] for item in items]
        inputs = self.load_inputs(urls)
        with torch.no_grad():
            predictions = self.model(inputs)
            return predictions
 def create_episodes(
     self,
     n_episodes: int,
     n_processes: int,
     mcts_samples: int,
     mcts_temp: float,
     mcts_cpuct: int,
     mcts_observation_weight: float,
     model: Model,
 ) -> List[Tuple[List[ObservationType], List[np.ndarray], int, Summary]]:
     pool = Pool(n_processes)
     res = pool.starmap(
         self._generator.perform_episode,
         [[mcts_samples, mcts_temp, mcts_cpuct, mcts_observation_weight, model]]
         * n_episodes,
     )
     pool.close()
     pool.terminate()
     pool.join()
     return res
Exemple #5
0
def distribute_action_among_workers(
    pool: mp.Pool,
    num_workers: int,
    config: ConfigSchema,
    action: Action,
    model: MultiRelationEmbedder,
    epoch_idx: int,
    lhs: EntityList,
    rhs: EntityList,
    rel: LongTensorType,
    edge_perm: LongTensorType,
    optimizers: Optional[List[Optimizer]] = None
) -> Union[TrainStats, EvalStats]:
    all_stats = pool.starmap(perform_action_one_thread, [
        (Rank(i), config, action, model, epoch_idx, lhs, rhs, rel, edge_perm[s], optimizers)
        for i, s in enumerate(split_almost_equally(edge_perm.size(0), num_parts=num_workers))
    ])

    if action is action.TRAIN:
        return TrainStats.sum(all_stats).average()
    elif action is action.EVAL:
        return EvalStats.sum(all_stats).average()
    else:
        raise NotImplementedError("Unknown action: %s" % action)
Exemple #6
0
    def start_multiprocessing(self, embeddings, naming_list, naming_dict,  dataset):
        n_threads = 4
        chunked = chunkIt(naming_list, n_threads)
        # multiprocess gridsearch and have a seperate thread for the progress bar.
        pool1 = Pool(processes=n_threads)
        m = Manager()
        q = m.Queue()
        p = Process(target=progressBar, args=(len(naming_list), q,))
        p.start()

        results = pool1.starmap(self.determine_triplets,
                                zip(n_threads * [q],
                                    chunked,
                                    n_threads * [naming_list],
                                    n_threads * [embeddings],
                                    n_threads * [naming_dict],
                                    n_threads * [dataset]))
        final_results = []
        for r in results:
            final_results += r

        p.join()
        pool1.close()
        return final_results
Exemple #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-modelType',
                        default='xvecTDNN',
                        help='Refer train_utils.py ')
    parser.add_argument('-numSpkrs',
                        default=7323,
                        type=int,
                        help='Number of output labels for model')
    parser.add_argument('-layerName',
                        default='fc1',
                        help="DNN layer for embeddings")
    parser.add_argument('modelDirectory',
                        help='Directory containing the model checkpoints')
    parser.add_argument(
        'featDir', help='Directory containing features ready for extraction')
    parser.add_argument('embeddingDir', help='Output directory')
    args = parser.parse_args()

    try:
        modelFile = max(glob.glob(args.modelDirectory + '/*'),
                        key=os.path.getctime)
    except ValueError:
        print("[ERROR] No trained model has been found in {}.".format(
            args.modelDirectory))
        sys.exit(1)

    # Load model definition
    net = eval('{}({}, p_dropout=0)'.format(args.modelType, args.numSpkrs))

    checkpoint = torch.load(modelFile, map_location=torch.device('cuda'))
    new_state_dict = OrderedDict()
    if 'relation' in args.modelType:
        checkpoint_dict = checkpoint['encoder_state_dict']
    else:
        checkpoint_dict = checkpoint['model_state_dict']
    for k, v in checkpoint_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v  # ugly fix to remove 'module' from key
        else:
            new_state_dict[k] = v

    # load trained weights
    net.load_state_dict(new_state_dict)
    net = net.cuda()
    net.eval()

    # Parallel Processing
    try:
        nSplits = int(
            sorted(glob.glob(args.featDir + '/split*'),
                   key=getSplitNum)[-1].split('/')[-1].lstrip('split'))
    except ValueError:
        print('[ERROR] Cannot find %s/splitN directory' % args.featDir)
        sys.exit(1)

    if not os.path.isdir(args.embeddingDir):
        os.makedirs(args.embeddingDir)

    print('Extracting xvectors by distributing jobs to pool workers... ')
    nProcs = nSplits
    L = [('%s/split%d/%d/feats.scp' % (args.featDir, nSplits, i),
          '%s/xvector.%d.ark' % (args.embeddingDir, i),
          '%s/xvector.%d.scp' % (args.embeddingDir, i), net, args.layerName)
         for i in range(1, nSplits + 1)]
    pool2 = Pool(processes=nProcs)
    result = pool2.starmap(par_core_extractXvectors, L)
    pool2.terminate()
    print('Multithread job has been finished.')

    print('Writing xvectors to {}'.format(args.embeddingDir))
    os.system('cat %s/xvector.*.scp > %s/xvector.scp' %
              (args.embeddingDir, args.embeddingDir))
    def eval_(self, X, *args):
        """
        Evaluate a number of DARTS architecture in parallel. X should be a list of Genotypes defined by DARTS API.
        """
        from math import ceil
        n_parallel = min(len(X), self.n_gpu)
        res = []
        diag_stats = []

        if n_parallel == 0:
            raise ValueError("No GPUs available!")
        elif n_parallel == 1:
            for i, genotype in enumerate(X):
                t = DARTSTrainer(self.data_path,
                                 self.save_path,
                                 genotype,
                                 self.dataset,
                                 cutout=self.cutout,
                                 auxiliary_tower=self.auxiliary,
                                 epochs=self.epochs,
                                 eval_policy=self.query_policy)
                print('Start training: ', i + 1, "/ ", len(X))
                try:
                    t.train()  # bottleneck
                    result = t.retrieve()
                    res.append(1. - result[0] / 100.)  # Turn into error
                    diag_stats.append(result[1])
                except Exception as e:
                    logging.error(
                        "An error occured in the current architecture. Assigning nan to the arch. The error is:"
                    )
                    logging.error(e)
                    res.append(np.nan)
                    diag_stats.append(None)

        else:
            gpu_ids = range(n_parallel)
            num_reps = ceil(len(X) / float(n_parallel))
            for i in range(num_reps):
                x = X[i * n_parallel:min((i + 1) * n_parallel, len(
                    X))]  # select the number of parallel archs to evaluate
                selected_gpus = gpu_ids[:len(x)]
                other_arg = [
                    self.data_path, self.save_path, self.dataset, self.cutout,
                    self.epochs, self.query_policy
                ]
                args = list(map(
                    list,
                    zip(
                        x,
                        selected_gpus,
                    ),
                ))
                args = [a + other_arg for a in args]
                pool = Pool(processes=len(x))
                current_res = pool.starmap(parallel_eval, args)
                pool.close()
                pool.join()
                res.extend([i for i in current_res if i >= 0
                            ])  # Filter out the negative results due to errors
        res = np.array(res).flatten()
        if self.log_scale:
            res = np.log(res)
        if self.negative:
            res = -res
        return res, diag_stats
                # print(probabilities)
                break
    return (features, payoffs, probabilities)


game = GameState(4)
features = []
payoffs = []
probabilities = []
policy = None
step = 128
milestone = step
mcts_pool = Pool(32)
for x in range(100):
    results: List[Tuple[Any, Any,
                        Any]] = mcts_pool.starmap(one_mcts,
                                                  [(game, policy)] * 64)
    # f, pa, pr = one_mcts(game, policy)
    for f, pa, pr in results:
        features.extend(f)
        payoffs.extend(pa)
        probabilities.extend(pr)

    while len(features) >= milestone:
        milestone += step
        if policy is None:
            policy = Policy(game.feature_dim(), game.action_dim(),
                            game.num_players)
            policy.eval()
        policy.train()
        policy.fit(torch.stack(features), torch.stack(payoffs),
                   torch.stack(probabilities))
class MetricTester:
    """Class used for efficiently run alot of parametrized tests in ddp mode.
    Makes sure that ddp is only setup once and that pool of processes are
    used for all tests.

    All tests should subclass from this and implement a new method called
        `test_metric_name`
    where the method `self.run_metric_test` is called inside.
    """

    atol = 1e-8

    def setup_class(self):
        """Setup the metric class. This will spawn the pool of workers that are
        used for metric testing and setup_ddp
        """

        self.poolSize = NUM_PROCESSES
        self.pool = Pool(processes=self.poolSize)
        self.pool.starmap(setup_ddp, [(rank, self.poolSize)
                                      for rank in range(self.poolSize)])

    def teardown_class(self):
        """ Close pool of workers """
        self.pool.close()
        self.pool.join()

    def run_functional_metric_test(
        self,
        preds: Tensor,
        target: Tensor,
        metric_functional: Callable,
        sk_metric: Callable,
        metric_args: dict = None,
        **kwargs_update,
    ):
        """Main method that should be used for testing functions. Call this inside
        testing method

        Args:
            preds: torch tensor with predictions
            target: torch tensor with targets
            metric_functional: lightning metric class that should be tested
            sk_metric: callable function that is used for comparison
            metric_args: dict with additional arguments used for class initialization
            kwargs_update: Additional keyword arguments that will be passed with preds and
                target when running update on the metric.
        """
        _functional_test(
            preds=preds,
            target=target,
            metric_functional=metric_functional,
            sk_metric=sk_metric,
            metric_args=metric_args,
            atol=self.atol,
            **kwargs_update,
        )

    def run_class_metric_test(
        self,
        ddp: bool,
        preds: Tensor,
        target: Tensor,
        metric_class: Metric,
        sk_metric: Callable,
        dist_sync_on_step: bool,
        metric_args: dict = None,
        check_dist_sync_on_step: bool = True,
        check_batch: bool = True,
        **kwargs_update,
    ):
        """Main method that should be used for testing class. Call this inside testing
        methods.

        Args:
            ddp: bool, if running in ddp mode or not
            preds: torch tensor with predictions
            target: torch tensor with targets
            metric_class: lightning metric class that should be tested
            sk_metric: callable function that is used for comparison
            dist_sync_on_step: bool, if true will synchronize metric state across
                processes at each ``forward()``
            metric_args: dict with additional arguments used for class initialization
            check_dist_sync_on_step: bool, if true will check if the metric is also correctly
                calculated per batch per device (and not just at the end)
            check_batch: bool, if true will check if the metric is also correctly
                calculated across devices for each batch (and not just at the end)
            kwargs_update: Additional keyword arguments that will be passed with preds and
                target when running update on the metric.
        """
        if not metric_args:
            metric_args = {}
        if ddp:
            if sys.platform == "win32":
                pytest.skip("DDP not supported on windows")

            self.pool.starmap(
                partial(
                    _class_test,
                    preds=preds,
                    target=target,
                    metric_class=metric_class,
                    sk_metric=sk_metric,
                    dist_sync_on_step=dist_sync_on_step,
                    metric_args=metric_args,
                    check_dist_sync_on_step=check_dist_sync_on_step,
                    check_batch=check_batch,
                    atol=self.atol,
                    **kwargs_update,
                ),
                [(rank, self.poolSize) for rank in range(self.poolSize)],
            )
        else:
            _class_test(
                0,
                1,
                preds=preds,
                target=target,
                metric_class=metric_class,
                sk_metric=sk_metric,
                dist_sync_on_step=dist_sync_on_step,
                metric_args=metric_args,
                check_dist_sync_on_step=check_dist_sync_on_step,
                check_batch=check_batch,
                atol=self.atol,
                **kwargs_update,
            )

    def run_precision_test_cpu(
        self,
        preds: torch.Tensor,
        target: torch.Tensor,
        metric_module: Metric,
        metric_functional: Callable,
        metric_args: dict = {},
    ):
        """Test if an metric can be used with half precision tensors on cpu
        Args:
            preds: torch tensor with predictions
            target: torch tensor with targets
            metric_module: the metric module to test
            metric_functional: the metric functional to test
            metric_args: dict with additional arguments used for class initialization
        """
        _assert_half_support(metric_module(**metric_args),
                             partial(metric_functional, **metric_args),
                             preds,
                             target,
                             device="cpu")

    def run_precision_test_gpu(
        self,
        preds: torch.Tensor,
        target: torch.Tensor,
        metric_module: Metric,
        metric_functional: Callable,
        metric_args: dict = {},
    ):
        """Test if an metric can be used with half precision tensors on gpu
        Args:
            preds: torch tensor with predictions
            target: torch tensor with targets
            metric_module: the metric module to test
            metric_functional: the metric functional to test
            metric_args: dict with additional arguments used for class initialization
        """
        _assert_half_support(metric_module(**metric_args),
                             partial(metric_functional, **metric_args),
                             preds,
                             target,
                             device="cuda")
class ESOptimizer:
    """
    An optimizer class that implements Evolution Strategies (ES)
    """
    def __init__(self,
                 model: Model,
                 sgd_optimizer: Optimizer,
                 objective_fn: Objective,
                 obj_weights: List[float],
                 sigma: float,
                 n_samples: int,
                 devices: List,
                 n_workers=4):
        self.model = model
        self._optimizer = sgd_optimizer
        self.sigma = sigma
        self.n_samples = n_samples
        self.objective_fn = objective_fn
        self.obj_weights = torch.Tensor(obj_weights)
        self.n_objectives = len(obj_weights)
        self.devices = devices
        self.pool = Pool(processes=n_workers)

        # evaluator
        self.evaluator = ModelEvaluator(self.model, self.objective_fn, None,
                                        None)

    def _compute_gradient(self, obj_value: List[float], delta: float):
        """
        Computes the gradient for one sample
        """
        obj_value = torch.Tensor(obj_value)
        weighted_sum = torch.dot(obj_value, self.obj_weights)
        grad = delta * weighted_sum
        return grad

    def _fit_scalers_for_objectives(
            self, objectives: List[List[float]]) -> List[RankScaler]:
        """
        Fits rank scalers for each objectives
        """
        rank_scalers = []
        n_values = len(objectives)
        for obj_ix in range(self.n_objectives):
            values = [objectives[i][obj_ix] for i in range(n_values)]
            scaler = RankScaler(values)
            rank_scalers.append(scaler)
        return rank_scalers

    def _gradients_from_objectives(
            self, current_value: torch.Tensor, obj_values: List[List[float]],
            perturbations: List[torch.Tensor]) -> torch.Tensor:
        """
        Computes average gradients using multi-objective values
        """
        total_gradients = torch.zeros(current_value.shape)
        rank_scalers = self._fit_scalers_for_objectives(obj_values)

        for obj, delta in zip(obj_values, perturbations):
            # rank scale them
            obj = [
                rank_scalers[ix].transform(value)
                for ix, value in enumerate(obj)
            ]

            # compute gradient
            gradient = self._compute_gradient(obj, delta)
            total_gradients += gradient

        # average the gradients
        grad = total_gradients / (self.n_samples * self.sigma)

        return grad

    def _generate_perturbations(self, current_value):
        # create mirrored sampled perturbations
        perturbs = [
            torch.randn_like(current_value)
            for i in range(int(self.n_samples / 2))
        ]
        mirrored = [-i for i in perturbs]
        perturbs += mirrored
        return perturbs

    def gradient_step(self, samples: Samples):
        """
        Performs a gradient ascent step

        Args:
            samples (Samples): samples
        """

        # sample some parameters here
        parameter_name, parameter_value = self.model.sample()

        # generate unit gaussian perturbations
        unit_perturbations = self._generate_perturbations(parameter_value)

        # apply user selected deviation
        perturbations = [
            parameter_value + perturb * self.sigma
            for perturb in unit_perturbations
        ]

        # sample devices
        devices = random.choices(self.devices, k=len(unit_perturbations))

        # get the objective values
        self.evaluator.current_parameter_name = parameter_name
        self.evaluator.samples = samples
        obj_values = self.pool.starmap(self.evaluator,
                                       zip(perturbations, devices))

        # compute gradients
        gradients = self._gradients_from_objectives(parameter_value,
                                                    obj_values,
                                                    unit_perturbations)

        # update the model paramters and take a gradient step
        self.model.set_gradients(parameter_name, -gradients)
        self._optimizer.step()

        return gradients
Exemple #12
0
class MetricTester:
    """Class used for efficiently run alot of parametrized tests in ddp mode.
    Makes sure that ddp is only setup once and that pool of processes are
    used for all tests.

    All tests should subclass from this and implement a new method called
        `test_metric_name`
    where the method `self.run_metric_test` is called inside.
    """

    atol = 1e-8

    def setup_class(self):
        """Setup the metric class. This will spawn the pool of workers that are
        used for metric testing and setup_ddp
        """

        self.poolSize = NUM_PROCESSES
        self.pool = Pool(processes=self.poolSize)
        self.pool.starmap(setup_ddp, [(rank, self.poolSize)
                                      for rank in range(self.poolSize)])

    def teardown_class(self):
        """ Close pool of workers """
        self.pool.close()
        self.pool.join()

    def run_functional_metric_test(
        self,
        preds: Tensor,
        target: Tensor,
        metric_functional: Callable,
        sk_metric: Callable,
        metric_args: dict = None,
        fragment_kwargs: bool = False,
        **kwargs_update,
    ):
        """Main method that should be used for testing functions. Call this inside
        testing method

        Args:
            preds: torch tensor with predictions
            target: torch tensor with targets
            metric_functional: lightning metric class that should be tested
            sk_metric: callable function that is used for comparison
            metric_args: dict with additional arguments used for class initialization
            fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes
            kwargs_update: Additional keyword arguments that will be passed with preds and
                target when running update on the metric.
        """
        device = 'cuda' if (torch.cuda.is_available()
                            and torch.cuda.device_count() > 0) else 'cpu'

        _functional_test(
            preds=preds,
            target=target,
            metric_functional=metric_functional,
            sk_metric=sk_metric,
            metric_args=metric_args,
            atol=self.atol,
            device=device,
            fragment_kwargs=fragment_kwargs,
            **kwargs_update,
        )

    def run_class_metric_test(
        self,
        ddp: bool,
        preds: Tensor,
        target: Tensor,
        metric_class: Metric,
        sk_metric: Callable,
        dist_sync_on_step: bool,
        metric_args: dict = None,
        check_dist_sync_on_step: bool = True,
        check_batch: bool = True,
        fragment_kwargs: bool = False,
        check_scriptable: bool = True,
        **kwargs_update,
    ):
        """Main method that should be used for testing class. Call this inside testing
        methods.

        Args:
            ddp: bool, if running in ddp mode or not
            preds: torch tensor with predictions
            target: torch tensor with targets
            metric_class: lightning metric class that should be tested
            sk_metric: callable function that is used for comparison
            dist_sync_on_step: bool, if true will synchronize metric state across
                processes at each ``forward()``
            metric_args: dict with additional arguments used for class initialization
            check_dist_sync_on_step: bool, if true will check if the metric is also correctly
                calculated per batch per device (and not just at the end)
            check_batch: bool, if true will check if the metric is also correctly
                calculated across devices for each batch (and not just at the end)
            fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes
            kwargs_update: Additional keyword arguments that will be passed with preds and
                target when running update on the metric.
        """
        if not metric_args:
            metric_args = {}
        if ddp:
            if sys.platform == "win32":
                pytest.skip("DDP not supported on windows")

            self.pool.starmap(
                partial(
                    _class_test,
                    preds=preds,
                    target=target,
                    metric_class=metric_class,
                    sk_metric=sk_metric,
                    dist_sync_on_step=dist_sync_on_step,
                    metric_args=metric_args,
                    check_dist_sync_on_step=check_dist_sync_on_step,
                    check_batch=check_batch,
                    atol=self.atol,
                    fragment_kwargs=fragment_kwargs,
                    check_scriptable=check_scriptable,
                    **kwargs_update,
                ),
                [(rank, self.poolSize) for rank in range(self.poolSize)],
            )
        else:
            device = 'cuda' if (torch.cuda.is_available()
                                and torch.cuda.device_count() > 0) else 'cpu'

            _class_test(
                rank=0,
                worldsize=1,
                preds=preds,
                target=target,
                metric_class=metric_class,
                sk_metric=sk_metric,
                dist_sync_on_step=dist_sync_on_step,
                metric_args=metric_args,
                check_dist_sync_on_step=check_dist_sync_on_step,
                check_batch=check_batch,
                atol=self.atol,
                device=device,
                fragment_kwargs=fragment_kwargs,
                check_scriptable=check_scriptable,
                **kwargs_update,
            )

    def run_precision_test_cpu(
        self,
        preds: Tensor,
        target: Tensor,
        metric_module: Metric,
        metric_functional: Callable,
        metric_args: dict = None,
        **kwargs_update,
    ):
        """Test if a metric can be used with half precision tensors on cpu
        Args:
            preds: torch tensor with predictions
            target: torch tensor with targets
            metric_module: the metric module to test
            metric_functional: the metric functional to test
            metric_args: dict with additional arguments used for class initialization
            kwargs_update: Additional keyword arguments that will be passed with preds and
                target when running update on the metric.
        """
        metric_args = metric_args or {}
        _assert_half_support(metric_module(**metric_args),
                             metric_functional,
                             preds,
                             target,
                             device="cpu",
                             **kwargs_update)

    def run_precision_test_gpu(
        self,
        preds: Tensor,
        target: Tensor,
        metric_module: Metric,
        metric_functional: Callable,
        metric_args: dict = None,
        **kwargs_update,
    ):
        """Test if a metric can be used with half precision tensors on gpu
        Args:
            preds: torch tensor with predictions
            target: torch tensor with targets
            metric_module: the metric module to test
            metric_functional: the metric functional to test
            metric_args: dict with additional arguments used for class initialization
            kwargs_update: Additional keyword arguments that will be passed with preds and
                target when running update on the metric.
        """
        metric_args = metric_args or {}
        _assert_half_support(metric_module(**metric_args),
                             metric_functional,
                             preds,
                             target,
                             device="cuda",
                             **kwargs_update)

    def run_differentiability_test(
        self,
        preds: Tensor,
        target: Tensor,
        metric_module: Metric,
        metric_functional: Callable,
        metric_args: dict = None,
    ):
        """Test if a metric is differentiable or not

        Args:
            preds: torch tensor with predictions
            target: torch tensor with targets
            metric_module: the metric module to test
            metric_args: dict with additional arguments used for class initialization
        """
        metric_args = metric_args or {}
        # only floating point tensors can require grad
        metric = metric_module(**metric_args)
        if preds.is_floating_point():
            preds.requires_grad = True
            out = metric(preds[0], target[0])
            # metrics can return list of values
            if isinstance(out, list):
                assert all(metric.is_differentiable == o.requires_grad
                           for o in out)
            else:
                assert metric.is_differentiable == out.requires_grad

            if metric.is_differentiable:
                # check for numerical correctness
                assert torch.autograd.gradcheck(
                    partial(metric_functional, **metric_args),
                    (preds[0].double(), target[0]))

            # reset as else it will carry over to other tests
            preds.requires_grad = False
Exemple #13
0
class MetricTester:
    """ Class used for efficiently run alot of parametrized tests in ddp mode.
        Makes sure that ddp is only setup once and that pool of processes are
        used for all tests.

        All tests should subclass from this and implement a new method called
            `test_metric_name`
        where the method `self.run_metric_test` is called inside.
    """
    def setup_class(self):
        """ Setup the metric class. This will spawn the pool of workers that are
            used for metric testing and setup_ddp
        """
        try:
            set_start_method('spawn')
        except RuntimeError:
            pass
        self.poolSize = NUM_PROCESSES
        self.pool = Pool(processes=self.poolSize)
        self.pool.starmap(setup_ddp, [(rank, self.poolSize)
                                      for rank in range(self.poolSize)])

    def teardown_class(self):
        """ Close pool of workers """
        self.pool.close()
        self.pool.join()

    def run_metric_test(
        self,
        ddp: bool,
        preds: torch.Tensor,
        target: torch.Tensor,
        metric_class: Metric,
        sk_metric: Callable,
        dist_sync_on_step: bool,
        metric_args: dict = {},
        check_dist_sync_on_step: bool = True,
        check_batch: bool = True,
    ):
        """ Main method that should be used for testing. Call this inside testing
            methods.

            Args:
                ddp: bool, if running in ddp mode or not
                preds: torch tensor with predictions
                target: torch tensor with targets
                metric_class: lightning metric class that should be tested
                sk_metric: callable function that is used for comparison
                dist_sync_on_step: bool, if true will synchronize metric state across
                    processes at each ``forward()``
                metric_args: dict with additional arguments used for class initialization
                check_dist_sync_on_step: bool, if true will check if the metric is also correctly
                    calculated per batch per device (and not just at the end)
                check_batch: bool, if true will check if the metric is also correctly
                    calculated across devices for each batch (and not just at the end)
        """
        if ddp:
            if sys.platform == "win32":
                pytest.skip("DDP not supported on windows")

            self.pool.starmap(
                partial(
                    _compute_batch,
                    preds=preds,
                    target=target,
                    metric_class=metric_class,
                    sk_metric=sk_metric,
                    dist_sync_on_step=dist_sync_on_step,
                    metric_args=metric_args,
                    check_dist_sync_on_step=check_dist_sync_on_step,
                    check_batch=check_batch,
                ),
                [(rank, self.poolSize) for rank in range(self.poolSize)],
            )
        else:
            _compute_batch(
                0,
                1,
                preds=preds,
                target=target,
                metric_class=metric_class,
                sk_metric=sk_metric,
                dist_sync_on_step=dist_sync_on_step,
                metric_args=metric_args,
                check_dist_sync_on_step=check_dist_sync_on_step,
                check_batch=check_batch,
            )
Exemple #14
0
class Boundaryloss(object):
    def __init__(self,
                 dense_crf: DenseCRF,
                 eps: float = 1e-5,
                 crf_num_workers: int = 4):
        """Compute boundary loss

        Args:
            dense_crf: DenseCRF functor
            eps: Min prob allowed when clamp probs
            crf_num_workers: num of workers when crf parallel
        """
        self.dense_crf = dense_crf
        self.eps = eps
        self.crf_num_workers = crf_num_workers

        self.crf_pool = Pool(self.crf_num_workers)

    def crf(self, imgs, probs):
        np_imgs = BGR2RGB(imgs.cpu().numpy().astype(np.uint8).transpose(
            0, 2, 3, 1))  # (N, H, W, C)
        np_probs = probs.detach().cpu().numpy()  # (N, C, H, W)

        # Scaled imgs to probs shape
        scaled_imgs = nd.zoom(np_imgs,
                              (1.0, np_probs.shape[2] / np_imgs.shape[1],
                               np_probs.shape[3] / np_imgs.shape[2], 1.0),
                              order=1)

        # CRF
        crf_probs = self.crf_pool.starmap(self.dense_crf,
                                          zip(scaled_imgs, np_probs))
        crf_prob = np.stack(crf_probs, axis=0)

        # Clamp smoothed probs
        # TODO: Can be removed?
        crf_prob[crf_prob < self.eps] = self.eps
        crf_prob = crf_prob / np.sum(crf_prob, axis=1, keepdims=True)

        # to Tensor
        return torch.from_numpy(crf_prob).float().cuda(probs.get_device())

    def clamp_softmax(self, score, dim=1):
        probs = torch.clamp(F.softmax(score, dim), self.eps, 1)
        probs = probs / torch.sum(probs, dim=dim, keepdim=True)
        return probs

    def __call__(
        self,
        images,
        score_map,
        out_prob=False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Compute the constrain-to-boundary loss

        Args:
            images: (N, 3, H, W) RGB img
            score_map: (N, C, H, W) score map
            out_prob: If true, return smoothed predict_probs. (Default: False)

        Returns:
            constrain-to-boundary loss
        """
        probs = self.clamp_softmax(score_map)
        smooth_probs = self.crf(images, probs)
        # Compute KL-Div
        # TODO: clamp is not needed?
        loss = torch.mean(
            torch.sum(smooth_probs *
                      torch.log(torch.clamp(smooth_probs / probs, 0.05, 20)),
                      dim=1))

        if out_prob:
            return loss, smooth_probs

        return loss

    def __del__(self):
        self.crf_pool.close()
        self.crf_pool.join()