예제 #1
0
    def train(self, task):
        """
        Runs training phases, phases are generated from the config.
        """

        assert isinstance(task, ClassyTask)
        pin_memory = self.use_gpu and torch.cuda.device_count() > 1

        task.prepare(
            num_dataloader_workers=self.num_dataloader_workers,
            pin_memory=pin_memory,
            use_gpu=self.use_gpu,
            dataloader_mp_context=self.dataloader_mp_context,
        )
        state = self._ClassyElasticState(task, self.input_args)

        state.advance_to_next_phase = True

        def elastic_train_step(orig_state):
            if state.run_start_hooks:
                # need this to ensure we don't run the on_start hooks every time
                # a trainer starts
                state.task.on_start()
                state.run_start_hooks = False
                return state, self._ClassyWorkerStats(None)

            return self._run_step(orig_state, self.use_gpu)

        torchelastic.train(self.elastic_coordinator, elastic_train_step, state)

        task.on_end()
예제 #2
0
    def train(self, task):
        """
        Runs training phases, phases are generated from the config.
        """

        assert isinstance(task, ClassyTask)
        pin_memory = self.use_gpu and torch.cuda.device_count() > 1

        task.prepare(
            num_dataloader_workers=self.num_dataloader_workers,
            pin_memory=pin_memory,
            use_gpu=self.use_gpu,
            dataloader_mp_context=self.dataloader_mp_context,
        )
        state = self._ClassyElasticState(task, self.input_args)

        local_variables = {}

        state.advance_to_next_phase = True

        def elastic_train_step(orig_state):
            return self._run_step(orig_state, local_variables, self.use_gpu)

        task.run_hooks(local_variables, ClassyHookFunctions.on_start.name)

        torchelastic.train(self.elastic_coordinator, elastic_train_step, state)

        task.run_hooks(local_variables, ClassyHookFunctions.on_end.name)
예제 #3
0
def single_trainer(
    local_rank,
    max_world_size,
    c10d_backend,
    rdzv_init_url,
    model_arch,
    training_params,
    input_path,
):
    """
    Single GPU trainer that will only train on the GPU specified by local_rank

    """

    log.info(f"Loading data from: {input_path}")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        input_path,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
    )

    log.info(f"Loading model: {model_arch}")
    model = models.__dict__[model_arch]()
    # Apply ResNet training in one hour's tricks to the model itself
    # to maintain the accuracy
    for m in model.modules():
        # Trick 1: the last BatchNorm layer in each block need to
        # be initialized as zero gamma
        if isinstance(m, BasicBlock):
            num_features = m.bn2.num_features
            m.bn2.weight = Parameter(torch.zeros(num_features))
            if isinstance(m, Bottleneck):
                num_features = m.bn3.num_features
                m.bn3.weight = Parameter(torch.zeros(num_features))
            # Trick 2: linear layers are initialized by
            # drawing weights from a zero-mean Gaussian with
            # standard deviation of 0.01. In the paper it was only
            # fc layer, but in practice we found this better for
            # accuracy.
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)

    model.train()

    torch.cuda.set_device(local_rank)
    device = torch.cuda.current_device()
    model.cuda()
    log.info(f"Rank [{local_rank}] running on GPU [{device}]")

    coordinator = CoordinatorP2P(
        c10d_backend=c10d_backend,
        init_method=rdzv_init_url,
        max_num_trainers=max_world_size,
        process_group_timeout=60000,
    )

    state = ImagenetState(
        model=model,
        params=training_params,
        dataset=train_dataset,
        num_epochs=training_params.num_epochs,
    )

    log.info(f"Entering torchelastic train_loop")
    torchelastic.train(coordinator, train_step, state)