示例#1
0
def train_func(config: Dict):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]

    worker_batch_size = batch_size // session.get_world_size()

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=worker_batch_size)
    test_dataloader = DataLoader(test_data, batch_size=worker_batch_size)

    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)

    # Create model.
    model = NeuralNetwork()
    model = train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    for _ in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        loss = validate_epoch(test_dataloader, model, loss_fn)
        session.report(dict(loss=loss))
示例#2
0
def train_func(config: Dict):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]

    worker_batch_size = batch_size // session.get_world_size()

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=worker_batch_size)
    test_dataloader = DataLoader(test_data, batch_size=worker_batch_size)

    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)

    # Create model.
    model = NeuralNetwork()
    model = train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    loss_results = []

    for _ in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        loss = validate_epoch(test_dataloader, model, loss_fn)
        loss_results.append(loss)
        session.report(dict(loss=loss))

    # return required for backwards compatibility with the old API
    # TODO(team-ml) clean up and remove return
    return loss_results
示例#3
0
def validate_epoch(dataloader, model, loss_fn):
    size = len(dataloader.dataset) // session.get_world_size()
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n "
          f"Accuracy: {(100 * correct):>0.1f}%, "
          f"Avg loss: {test_loss:>8f} \n")
    return test_loss
示例#4
0
def train_epoch(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) // session.get_world_size()
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
示例#5
0
def train_func(use_ray: bool, config: Dict):
    if use_ray:
        from ray.air import session
        import ray.train as train

    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]
    shuffle = config.get("shuffle", False)

    if use_ray:
        world_size = session.get_world_size()
        local_rank = distributed.get_rank()
    else:
        world_size = distributed.get_world_size()
        local_rank = distributed.get_rank()

    worker_batch_size = batch_size // world_size

    # Load datasets. Use download=False to catch errors in preparation, as the
    # data should have already been downloaded.
    training_data = datasets.FashionMNIST(
        root="/tmp/data_fashion_mnist",
        train=True,
        download=False,
        transform=ToTensor(),
    )

    test_data = datasets.FashionMNIST(
        root="/tmp/data_fashion_mnist",
        train=False,
        download=False,
        transform=ToTensor(),
    )

    if use_ray:
        # Ray adds DistributedSampler in train.torch.prepare_data_loader below
        training_sampler = None
        test_sampler = None
    else:
        # In vanilla PyTorch we create the distributed sampler here
        training_sampler = DistributedSampler(training_data, shuffle=shuffle)
        test_sampler = DistributedSampler(test_data, shuffle=shuffle)

    if not use_ray and config.get("use_gpu", False):
        assert torch.cuda.is_available(), "No GPUs available"
        gpu_id = config.get("gpu_id", 0)
        vanilla_device = torch.device(f"cuda:{gpu_id}")
        torch.cuda.set_device(vanilla_device)

        print(
            "Setting GPU ID to",
            gpu_id,
            "with visible devices",
            os.environ.get("CUDA_VISIBLE_DEVICES"),
        )

        def collate_fn(x):
            return tuple(x_.to(vanilla_device) for x_ in default_collate(x))

    else:
        vanilla_device = torch.device("cpu")
        collate_fn = None

    # Create data loaders and potentially pass distributed sampler
    train_dataloader = DataLoader(
        training_data,
        shuffle=shuffle,
        batch_size=worker_batch_size,
        sampler=training_sampler,
        collate_fn=collate_fn,
    )
    test_dataloader = DataLoader(
        test_data,
        shuffle=shuffle,
        batch_size=worker_batch_size,
        sampler=test_sampler,
        collate_fn=collate_fn,
    )

    if use_ray:
        # In Ray, we now retrofit the DistributedSampler
        train_dataloader = train.torch.prepare_data_loader(train_dataloader)
        test_dataloader = train.torch.prepare_data_loader(test_dataloader)

    # Create model.
    model = NeuralNetwork()

    # Prepare model
    if use_ray:
        model = train.torch.prepare_model(model)
    else:
        model = model.to(vanilla_device)

        if config.get("use_gpu", False):
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[gpu_id], output_device=gpu_id
            )
        else:
            model = nn.parallel.DistributedDataParallel(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    for _ in range(epochs):
        train_epoch(
            train_dataloader,
            model,
            loss_fn,
            optimizer,
            world_size=world_size,
            local_rank=local_rank,
        )
        loss = validate_epoch(
            test_dataloader,
            model,
            loss_fn,
            world_size=world_size,
            local_rank=local_rank,
        )
        if use_ray:
            session.report(dict(loss=loss))
        else:
            print(f"Reporting loss: {loss:.4f}")
            if local_rank == 0:
                with open(VANILLA_RESULT_JSON, "w") as f:
                    json.dump({"loss": loss}, f)
示例#6
0
def train_loop_per_worker(train_loop_config):
    dataset = train_loop_config["dataset_fn"]()
    batch_size = train_loop_config["batch_size"]
    num_epochs = train_loop_config["num_epochs"]

    data = dataset[0]
    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
    train_idx = train_idx.split(
        train_idx.size(0) //
        session.get_world_size())[session.get_world_rank()]

    train_loader = NeighborSampler(
        data.edge_index,
        node_idx=train_idx,
        sizes=[25, 10],
        batch_size=batch_size,
        shuffle=True,
    )

    # Disable distributed sampler since the train_loader has already been split above.
    train_loader = train.torch.prepare_data_loader(train_loader,
                                                   add_dist_sampler=False)

    # Do validation on rank 0 worker only.
    if session.get_world_rank() == 0:
        subgraph_loader = NeighborSampler(data.edge_index,
                                          node_idx=None,
                                          sizes=[-1],
                                          batch_size=2048,
                                          shuffle=False)
        subgraph_loader = train.torch.prepare_data_loader(
            subgraph_loader, add_dist_sampler=False)

    model = SAGE(dataset.num_features, 256, dataset.num_classes)
    model = train.torch.prepare_model(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    x, y = data.x.to(train.torch.get_device()), data.y.to(
        train.torch.get_device())

    for epoch in range(num_epochs):
        model.train()

        # ``batch_size`` is the number of samples in the current batch.
        # ``n_id`` are the ids of all the nodes used in the computation. This is
        # needed to pull in the necessary features just for the current batch that is
        # being trained on.
        # ``adjs`` is a list of 3 element tuple consisting of ``(edge_index, e_id,
        # size)`` for each sample in the batch, where ``edge_index``represent the
        # edges of the sampled subgraph, ``e_id`` are the ids of the edges in the
        # sample, and ``size`` holds the shape of the subgraph.
        # See ``torch_geometric.loader.neighbor_sampler.NeighborSampler`` for more info.
        for batch_size, n_id, adjs in train_loader:
            optimizer.zero_grad()
            out = model(x[n_id], adjs)
            loss = F.nll_loss(out, y[n_id[:batch_size]])
            loss.backward()
            optimizer.step()

        if session.get_world_rank() == 0:
            print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")

        train_accuracy = validation_accuracy = test_accuracy = None

        # Do validation on rank 0 worker only.
        if session.get_world_rank() == 0:
            model.eval()
            with torch.no_grad():
                out = model.module.test(x, subgraph_loader)
            res = out.argmax(dim=-1) == data.y
            train_accuracy = int(res[data.train_mask].sum()) / int(
                data.train_mask.sum())
            validation_accuracy = int(res[data.val_mask].sum()) / int(
                data.val_mask.sum())
            test_accuracy = int(res[data.test_mask].sum()) / int(
                data.test_mask.sum())

        session.report(
            dict(
                train_accuracy=train_accuracy,
                validation_accuracy=validation_accuracy,
                test_accuracy=test_accuracy,
            ))
示例#7
0
    def prepare_data_loader(
        self,
        data_loader: torch.utils.data.DataLoader,
        add_dist_sampler: bool = True,
        move_to_device: bool = True,
        auto_transfer: bool = True,
    ) -> torch.utils.data.DataLoader:
        """Prepares DataLoader for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

        Args:
            data_loader (torch.utils.data.DataLoader): The DataLoader to
                prepare.
            add_dist_sampler: Whether to add a DistributedSampler to
                the provided DataLoader.
            move_to_device: If set, automatically move the data
                returned by the data loader to the correct device.
            auto_transfer: If set and device is GPU, another CUDA stream
                is created to automatically copy data from host (CPU) memory
                to device (GPU) memory (the default CUDA stream still runs the
                training procedure). If device is CPU, it will be disabled
                regardless of the setting. This configuration will be ignored
                if ``move_to_device`` is False.
        """

        # Backwards compatibility
        try:
            world_size = session.get_world_size()
            world_rank = session.get_world_rank()
        except Exception:
            world_size = train.world_size()
            world_rank = train.world_rank()

        # Only add Distributed Sampler if the following conditions hold:
        # 1. More than one training worker is being used.
        # 2. A DistributedSampler has not already been added by the user.
        # 3. The dataset is not an IterableDataset. Samplers do not worker with
        # IterableDatasets.
        if (world_size > 1
                and not isinstance(data_loader.sampler, DistributedSampler)
                and not (hasattr(data_loader, "dataset")
                         and isinstance(data_loader.dataset, IterableDataset))
                and add_dist_sampler):

            def with_sampler(loader):
                # Automatically set the DistributedSampler

                # If you're using a sampler, the DataLoader shuffle flag must be set to
                # False. Shuffling is instead determined by the shuffle argument passed
                # to the DistributedSampler constructor.

                # If no sampler is passed to the DataLoader constructor, Torch
                # constructs a default sampler. The default sampler is a RandomSampler
                # if shuffling is enabled and a SequentialSampler otherwise. DataLoader
                # does not have a shuffle attribute, so we instead identify whether
                # shuffling is enabled by checking the default sampler type.
                shuffle = not isinstance(loader.sampler, SequentialSampler)

                def seeded_worker_init_fn(worker_init_fn):
                    def wrapper(worker_id):
                        worker_seed = torch.initial_seed() % 2**32
                        np.random.seed(worker_seed)
                        random.seed(worker_seed)
                        worker_init_fn(worker_id)

                    return wrapper

                worker_init_fn = loader.worker_init_fn
                generator = loader.generator
                if self._seed is not None:
                    worker_init_fn = seeded_worker_init_fn(
                        loader.worker_init_fn)
                    generator = torch.Generator()
                    generator.manual_seed(self._seed)

                using_default_sampler = isinstance(
                    loader.sampler, (SequentialSampler, RandomSampler))
                if not using_default_sampler and world_rank == 0:
                    logger.warn(
                        f"The {loader.sampler.__class__.__name__} will be overwritten "
                        "with a DistributedSampler. You can disable this by setting "
                        "`with_sampler` to False in `prepare_data_loader`.")

                data_loader_args = {
                    "dataset": loader.dataset,
                    "batch_size": loader.batch_size,
                    "shuffle": False,
                    "num_workers": loader.num_workers,
                    "collate_fn": loader.collate_fn,
                    "pin_memory": loader.pin_memory,
                    "drop_last": loader.drop_last,
                    "timeout": loader.timeout,
                    "worker_init_fn": worker_init_fn,
                    "generator": generator,
                    "sampler": DistributedSampler(loader.dataset,
                                                  shuffle=shuffle),
                }
                return DataLoader(**data_loader_args)

            data_loader = with_sampler(data_loader)

        if move_to_device:
            device = self.get_device()
            data_loader = _WrappedDataLoader(data_loader, device,
                                             auto_transfer)

        return data_loader
示例#8
0
    def prepare_model(
        self,
        model: torch.nn.Module,
        move_to_device: bool = True,
        wrap_ddp: bool = True,
        ddp_kwargs: Optional[Dict[str, Any]] = None,
    ) -> torch.nn.Module:
        """Prepares the model for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

        Args:
            model (torch.nn.Module): A torch model to prepare.
            move_to_device: Whether to move the model to the correct
                device. If set to False, the model needs to manually be moved
                to the correct device.
            wrap_ddp: Whether to wrap models in
                ``DistributedDataParallel``.
            ddp_kwargs (Dict[str, Any]): Args to pass into
                ``DistributedDataParallel`` initialization if ``wrap_ddp`` is
                set to True.
        """
        ddp_kwargs = ddp_kwargs or {}

        # Backwards compatibility
        try:
            rank = session.get_local_rank()
        except Exception:
            rank = train.local_rank()

        device = self.get_device()

        if torch.cuda.is_available():
            torch.cuda.set_device(device)

        if move_to_device:
            logger.info(f"Moving model to device: {device}")
            model = model.to(device)

        def model_get_state(self):
            # `__getstate__` is an special method that informs pickle which attributes
            # to serialize. This custom implementation ensures that the wrapped forward
            # method and custom `__getstate__` method aren't serialized.
            if hasattr(self, "_original_get_state"):
                state = self._original_get_state()
                state["__getstate__"] = state["_original_get_state"]
                del state["_original_get_state"]
            else:
                # If model does not have a `__getstate__` already defined, use default
                # implementation.
                state = self.__dict__.copy()
                del state["__getstate__"]
            state["forward"] = state["_unwrapped_forward"]
            del state["_unwrapped_forward"]

            return state

        if self.amp_is_enabled:
            # Pickle cannot serialize the wrapped forward method. As a workaround,
            # define a custom `__getstate__` method that unwraps the forward method.
            model._unwrapped_forward = model.forward
            model.forward = autocast()(model.forward)

            # TODO(amogkam): Replace below logic with a generic "unpack model" method.
            # Replacing the `model.forward` method makes the model no longer
            # serializable. When serializing the model, we have to override the
            # `__getstate__` method to set back the original forward method.
            if hasattr(model, "__getstate__"):
                model._original_get_state = model.__getstate__
            # `__getstate__` must be a bound method rather than an callable attribute.
            # See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance.  # noqa: E501
            model.__getstate__ = types.MethodType(model_get_state, model)

        # Backwards compatibility
        try:
            world_size = session.get_world_size()
        except Exception:
            world_size = train.world_size()

        if wrap_ddp and world_size > 1:
            logger.info("Wrapping provided model in DDP.")
            if torch.cuda.is_available():
                model = DistributedDataParallel(model,
                                                device_ids=[rank],
                                                output_device=rank,
                                                **ddp_kwargs)
            else:
                model = DistributedDataParallel(model, **ddp_kwargs)

        return model
示例#9
0
def _huggingface_train_loop_per_worker(config):
    """Per-worker training loop for HuggingFace Transformers."""
    trainer_init_per_worker = config.pop("_trainer_init_per_worker")

    # Env vars necessary for HF to setup DDP
    os.environ["RANK"] = str(session.get_world_rank())
    os.environ["WORLD_SIZE"] = str(session.get_world_size())
    os.environ["LOCAL_RANK"] = str(session.get_local_rank())

    train_dataset = session.get_dataset_shard(TRAIN_DATASET_KEY)
    eval_dataset = session.get_dataset_shard(EVALUATION_DATASET_KEY)

    train_torch_dataset, eval_torch_dataset = process_datasets(
        train_dataset,
        eval_dataset,
    )

    trainer: transformers.trainer.Trainer = trainer_init_per_worker(
        train_torch_dataset, eval_torch_dataset, **config)

    if trainer.args.push_to_hub and not trainer.args.hub_token:
        warnings.warn(
            "You have set `push_to_hub=True` but didn't specify `hub_token`. "
            "Pushing to hub will most likely fail, as the credentials will not "
            "be automatically propagated from the local enviroment to the Ray Actors. "
            "If that happens, specify `hub_token` in `TrainingArguments`.")

    if (trainer.args.evaluation_strategy == "steps"
            or trainer.args.save_strategy == "steps"
            or trainer.args.logging_strategy == "steps"):
        raise ValueError(
            "'steps' value for `evaluation_strategy`, `logging_strategy` "
            "or `save_strategy` is not yet supported.")

    trainer = wrap_transformers_trainer(trainer)

    # ensure no HF logging callbacks are added
    # aside from doubling functionality with our callbacks,
    # the Wandb callbacks causes training to freeze
    integration_callbacks = transformers.trainer.get_reporting_integration_callbacks(
        trainer.args.report_to)
    for callback in integration_callbacks:
        trainer.pop_callback(callback)

    trainer.add_callback(TrainReportCallback)

    checkpoint = session.get_checkpoint()
    checkpoint_path = None
    remove_checkpoint_path = False
    if checkpoint:
        assert isinstance(checkpoint, Checkpoint)
        checkpoint_dict = checkpoint.to_dict()
        source_ip = checkpoint_dict[NODE_IP_KEY]
        source_path = checkpoint_dict[CHECKPOINT_PATH_ON_NODE_KEY]
        target_ip = get_node_ip_address()
        if source_ip == target_ip:
            checkpoint_path = source_path
        else:
            checkpoint_path = tempfile.mkdtemp(
                suffix=Path(trainer.args.output_dir).name)
            remove_checkpoint_path = True
            sync_dir_between_nodes(
                source_ip=source_ip,
                source_path=source_path,
                target_ip=target_ip,
                target_path=checkpoint_path,
                return_futures=False,
                max_size_bytes=None,
            )
    trainer.train(resume_from_checkpoint=checkpoint_path)
    if remove_checkpoint_path:
        shutil.rmtree(checkpoint_path, ignore_errors=True)
示例#10
0
def train_func(config):
    epochs = config.pop("epochs", 3)
    model = ResNet18(config)
    model = train.torch.prepare_model(model)

    # Create optimizer.
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config.get("lr", 0.1),
        momentum=config.get("momentum", 0.9),
    )

    # Load in training and validation data.
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])  # meanstd transformation

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    with FileLock(".ray.lock"):
        train_dataset = CIFAR10(root="~/data",
                                train=True,
                                download=True,
                                transform=transform_train)
        validation_dataset = CIFAR10(root="~/data",
                                     train=False,
                                     download=False,
                                     transform=transform_test)

    if config.get("test_mode"):
        train_dataset = Subset(train_dataset, list(range(64)))
        validation_dataset = Subset(validation_dataset, list(range(64)))

    worker_batch_size = config["batch_size"] // session.get_world_size()

    train_loader = DataLoader(train_dataset, batch_size=worker_batch_size)
    validation_loader = DataLoader(validation_dataset,
                                   batch_size=worker_batch_size)

    train_loader = train.torch.prepare_data_loader(train_loader)
    validation_loader = train.torch.prepare_data_loader(validation_loader)

    # Create loss.
    criterion = nn.CrossEntropyLoss()

    results = []
    for _ in range(epochs):
        train_epoch(train_loader, model, criterion, optimizer)
        result = validate_epoch(validation_loader, model, criterion)
        session.report(result)
        results.append(result)

    # return required for backwards compatibility with the old API
    # TODO(team-ml) clean up and remove return
    return results