コード例 #1
0
ファイル: train.py プロジェクト: shaojiaxue/PyTorch-BigGraph
def main():
    setup_logging()
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--rank',
                        type=int,
                        default=0,
                        help="For multi-machine, this machine's rank")
    opt = parser.parse_args()

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
    else:
        overrides = None
    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, overrides)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)

    train(config, rank=Rank(opt.rank), subprocess_init=subprocess_init)
コード例 #2
0
ファイル: train.py プロジェクト: yueyedeai/PyTorch-BigGraph
def main():
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--rank', type=int, default=0,
                        help="For multi-machine, this machine's rank")
    opt = parser.parse_args()

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
    else:
        overrides = None
    config = parse_config(opt.config, overrides)

    train(config, rank=Rank(opt.rank))
コード例 #3
0
ファイル: train.py プロジェクト: shaojiaxue/PyTorch-BigGraph
def init_embs(
    entity: EntityName,
    entity_count: int,
    dim: int,
    scale: float,
) -> Tuple[FloatTensorType, None]:
    """Initialize embeddings of size entity_count x dim.
    """
    # FIXME: Use multi-threaded instead of fast_approx_rand
    logger.debug(f"Initializing {entity}")
    return fast_approx_rand(entity_count * dim).view(entity_count,
                                                     dim).mul_(scale), None


RANK_ZERO = Rank(0)


class AbstractSynchronizer(ABC):
    @abstractmethod
    def barrier(self) -> None:
        pass


class DummySynchronizer(AbstractSynchronizer):
    def barrier(self):
        pass


class DistributedSynchronizer(AbstractSynchronizer):
    def __init__(self, group: 'td.ProcessGroup') -> None:
コード例 #4
0
 def add_group(group_size: int) -> List[Rank]:
     nonlocal world_size
     group = [Rank(world_size + r) for r in range(group_size)]
     world_size += group_size
     return group