def main(): setup_logging() config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help()) parser = argparse.ArgumentParser( epilog=config_help, # Needed to preserve line wraps in epilog. formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument('config', help="Path to config file") parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('--rank', type=int, default=0, help="For multi-machine, this machine's rank") opt = parser.parse_args() if opt.param is not None: overrides = chain.from_iterable(opt.param) # flatten else: overrides = None loader = ConfigFileLoader() config = loader.load_config(opt.config, overrides) set_logging_verbosity(config.verbose) subprocess_init = SubprocessInitializer() subprocess_init.register(setup_logging, config.verbose) subprocess_init.register(add_to_sys_path, loader.config_dir.name) train(config, rank=Rank(opt.rank), subprocess_init=subprocess_init)
def main(): config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help()) parser = argparse.ArgumentParser( epilog=config_help, # Needed to preserve line wraps in epilog. formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument('config', help="Path to config file") parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('--rank', type=int, default=0, help="For multi-machine, this machine's rank") opt = parser.parse_args() if opt.param is not None: overrides = chain.from_iterable(opt.param) # flatten else: overrides = None config = parse_config(opt.config, overrides) train(config, rank=Rank(opt.rank))
def init_embs( entity: EntityName, entity_count: int, dim: int, scale: float, ) -> Tuple[FloatTensorType, None]: """Initialize embeddings of size entity_count x dim. """ # FIXME: Use multi-threaded instead of fast_approx_rand logger.debug(f"Initializing {entity}") return fast_approx_rand(entity_count * dim).view(entity_count, dim).mul_(scale), None RANK_ZERO = Rank(0) class AbstractSynchronizer(ABC): @abstractmethod def barrier(self) -> None: pass class DummySynchronizer(AbstractSynchronizer): def barrier(self): pass class DistributedSynchronizer(AbstractSynchronizer): def __init__(self, group: 'td.ProcessGroup') -> None:
def add_group(group_size: int) -> List[Rank]: nonlocal world_size group = [Rank(world_size + r) for r in range(group_size)] world_size += group_size return group