コード例 #1
0
ファイル: main.py プロジェクト: yyf8989/pytext
def run_single(
    rank: int,
    config_json: str,
    world_size: int,
    dist_init_method: str,
    summary_writer: SummaryWriter,
):
    config = config_from_json(PyTextConfig, config_json)
    if rank != 0:
        summary_writer = None

    train_model(config, dist_init_method, rank, rank, world_size,
                summary_writer)
コード例 #2
0
def train(context):
    """Train a model and save the best snapshot."""
    config = parse_config(context.obj.load_config())
    print("\n===Starting training...")
    if config.distributed_world_size == 1:
        train_model(config)
    else:
        train_model_distributed(config)
    print("\n=== Starting testing...")
    test_model_from_snapshot_path(
        config.save_snapshot_path,
        config.use_cuda_if_available,
        config.task.data_handler.test_path,
    )
コード例 #3
0
def train(context):
    """Train a model and save the best snapshot."""
    config_json = context.obj.load_config()
    config = parse_config(Mode.TRAIN, config_json)
    print("\n===Starting training...")
    if config.distributed_world_size == 1:
        train_model(config)
    else:
        train_model_distributed(config)
    print("\n=== Starting testing...")
    test_config = TestConfig(
        load_snapshot_path=config.save_snapshot_path,
        test_path=config.task.data_handler.test_path,
        use_cuda_if_available=config.use_cuda_if_available,
    )
    test_model(test_config)
コード例 #4
0
def run_single(
    rank: int,
    config_json: str,
    world_size: int,
    dist_init_method: str,
    metadata: CommonMetadata,
):
    config = config_from_json(PyTextConfig, config_json)
    summary_writer = SummaryWriter(
    ) if rank != 0 and config.use_tensorboard else None
    try:
        train_model(config, dist_init_method, rank, rank, world_size,
                    summary_writer, metadata)
    finally:
        if summary_writer is not None:
            summary_writer.close()
コード例 #5
0
ファイル: main.py プロジェクト: parety/pytext
def run_single(
    rank: int,
    config_json: str,
    world_size: int,
    dist_init_method: Optional[str],
    metadata: Optional[CommonMetadata],
    metric_channels: Optional[List[Channel]],
):
    config = config_from_json(PyTextConfig, config_json)
    if rank != 0:
        metric_channels = []

    train_model(
        config=config,
        dist_init_url=dist_init_method,
        device_id=rank,
        rank=rank,
        world_size=world_size,
        metric_channels=metric_channels,
        metadata=metadata,
    )
コード例 #6
0
ファイル: main.py プロジェクト: yunchaosuper/pytext
def train(context):
    """Train a model and save the best snapshot."""

    config = context.obj.load_config()
    print("\n===Starting training...")
    summary_writer = SummaryWriter() if config.use_tensorboard else None
    try:
        if config.distributed_world_size == 1:
            train_model(config, summary_writer=summary_writer)
        else:
            train_model_distributed(config, summary_writer)
        print("\n=== Starting testing...")
        test_model_from_snapshot_path(
            config.save_snapshot_path,
            config.use_cuda_if_available,
            config.task.data_handler.test_path,
            summary_writer,
        )
    finally:
        if summary_writer is not None:
            summary_writer.close()
コード例 #7
0
def train(context):
    """Train a model and save the best snapshot."""
    config = context.obj.load_config()
    print("\n===Starting training...")
    metric_channels = []
    if config.use_tensorboard:
        metric_channels.append(TensorBoardChannel())
    try:
        if config.distributed_world_size == 1:
            train_model(config, metric_channels=metric_channels)
        else:
            train_model_distributed(config, metric_channels)
        print("\n=== Starting testing...")
        test_model_from_snapshot_path(
            config.save_snapshot_path,
            config.use_cuda_if_available,
            test_path=None,
            metric_channels=metric_channels,
        )
    finally:
        for mc in metric_channels:
            mc.close()
コード例 #8
0
def run_single(rank, config_json: str, world_size: int, dist_init_method: str):
    config = config_from_json(PyTextConfig, config_json)
    train_model(config, dist_init_method, rank, rank, world_size)