Exemple #1
0
def launch(config_file: str, option: Optional[str], tmpdir: str,
           rank: Optional[int], n_process: Optional[int]):
    logging.set_level(L.INFO)

    logger.info(f"Launch config file: {config_file}")
    configs = load_config(config_file)
    if option == "test":
        logger.info("Modify configs for testing")
        configs = modify_config_for_test(configs, tmpdir)
    elif option == "profile":
        logger.info("Modify configs for profiling")
        output_dir = configs["output_dir"]  # TODO
        configs = modify_config_for_profile(configs, tmpdir)

    distributed.initialize(tmpdir, rank, n_process)

    rank = distributed.rank()
    seed = random.randint(0, 2**31) + rank
    logger.info(f"Fix seed={seed}")
    rng = np.random.RandomState(seed)
    torch.manual_seed(rng.randint(0, 2**32 - 1))
    np.random.seed(rng.randint(0, 2**32 - 1))
    random.seed(rng.randint(0, 2**32 - 1))

    logger.info("Run main")
    if option == "profile":
        cprofile = cProfile.Profile()
        cprofile.enable()
        with profile() as torch_prof:
            parse_config(configs)["/main"]
        cprofile.disable()
        torch.save(torch_prof,
                   os.path.join(output_dir, f"torch_profiler-{rank}.pt"))
        cprofile.dump_stats(os.path.join(output_dir, f"cprofile-{rank}.pt"))
    else:
        parse_config(configs)["/main"]
Exemple #2
0
    def __call__(
        self,
        group: Optional[torch.distributed.group] = None,
    ) -> EvaluationResult[Code, GroundTruth]:
        total = {}
        generated = []
        times = []
        for n in self.top_n:
            t = {}
            for name in self.metrics.keys():
                t[name] = 0.0
            total[n] = t
        evaluate_sample: EvaluateSample[Code] = \
            EvaluateSample(self.synthesizer, self.metrics, self.top_n)

        results: List[Result[Code, GroundTruth]] = []
        rank = distributed.rank(group=group)
        size = distributed.size(group=group)
        n_sample = (len(self.dataset) + size - 1) // size
        logger.info(
            f"Evalute with {len(self.dataset)} samples w/ {size} processes " +
            f"({n_sample} per process)")

        samples = self.dataset[rank * n_sample:(rank + 1) * n_sample]
        results = [
            evaluate_sample(elem)
            for elem in tqdm(total=len(samples),
                             iterable=logger.iterable_block(
                                 "evaluate_sample", enumerate(samples)))
        ]
        gathered_results = distributed.all_gather(results)
        results = []
        for r in gathered_results:
            results.extend(r)

        logger.info("Summarize results")
        for result in results:
            generated.append(1.0 if result.generated else 0.0)
            if result.generated:
                times.append(result.time)
            for n in self.top_n:
                m = result.metrics[n]
                for name in self.metrics.keys():
                    total[n][name] += \
                        m[name] if m[name] is not None else 0

        total = {
            n: {
                name: value / len(self.dataset)
                for name, value in metric.items()
            }
            for n, metric in total.items()
        }
        r = EvaluationResult(results, total, np.mean(generated),
                             np.mean(times))
        # report
        for n, metric in total.items():
            for name, value in metric.items():
                report({f"{name}@{n}": value})
        report({"generation_rate": r.generation_rate})
        report({"generation_time": r.generation_time})
        # logging
        logger.info(f"{r.metrics}")
        logger.info(f"generation rate: {r.generation_rate}")
        logger.info(f"generation time: {r.generation_time}")
        return r
Exemple #3
0
def create_extensions_manager(n_iter: int,
                              evaluation_interval_iter: int,
                              snapshot_interval_iter: int,
                              iter_per_epoch: int,
                              model: nn.Module,
                              optimizer: torch.optim.Optimizer,
                              evaluate: Optional[Callable[[], None]],
                              metric: str,
                              maximize: bool,
                              threshold: Optional[float],
                              output_dir: str,
                              report_metrics: Optional[List[str]] = None):
    model_dir = os.path.join(output_dir, "model")

    logger.info("Prepare pytorch-pfn-extras")
    manager = ppe.training.ExtensionsManager(
        model,
        optimizer,
        n_iter / iter_per_epoch,
        out_dir=os.path.join(output_dir),
        extensions=[],
        iters_per_epoch=iter_per_epoch,
    )
    manager.extend(extensions.FailOnNonNumber(),
                   trigger=Trigger(evaluation_interval_iter, n_iter))
    if evaluate is not None:
        manager.extend(
            Call(evaluate),
            trigger=Trigger(evaluation_interval_iter, n_iter),
        )
    if distributed.is_main_process():
        manager.extend(
            extensions.LogReport(
                trigger=Trigger(100, n_iter),
                filename="log.json",
            ))
        manager.extend(extensions.ProgressBar())
        manager.extend(
            SaveTopKModel(model_dir, 1, metric, model, maximize=maximize),
            trigger=Trigger(evaluation_interval_iter, n_iter),
        )
        metrics = report_metrics or []
        manager.extend(
            extensions.PrintReport(entries=[
                "loss", *metrics, "iteration", "epoch", "time.iteration",
                "gpu.time.iteration", "elapsed_time"
            ]),
            trigger=Trigger(100, n_iter),
        )
    if threshold is not None:
        manager.extend(
            StopByThreshold(metric, threshold, maximize=maximize),
            trigger=Trigger(evaluation_interval_iter, n_iter),
        )
    if distributed.is_initialized():
        snapshot = extensions.snapshot(autoload=True,
                                       n_retains=1,
                                       saver_rank=0)
        snapshot._rank = distributed.rank()
        snapshot._size = distributed.size()
        snapshot._local_rank = distributed.rank()
    else:
        snapshot = extensions.snapshot(autoload=True, n_retains=1)
    manager.extend(snapshot, trigger=Trigger(snapshot_interval_iter, n_iter))
    return manager
Exemple #4
0
def device(type_str: str, index: Union[int, str] = 0):
    if index == "rank":
        index = distributed.rank()
    return torch.device(type_str, index)