예제 #1
0
    def testGivens(self):
        class MockOperator(TrainingOperator):
            def setup(self, config):
                self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
                self.validate = MagicMock(returns=dict(mean_accuracy=10))

        def three_model_creator(config):
            return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)

        def three_optimizer_creator(models, config):
            opts = [
                torch.optim.SGD(model.parameters(), lr=0.1) for model in models
            ]
            return opts[0], opts[1], opts[2]

        runner = TorchRunner(three_model_creator,
                             single_loader,
                             three_optimizer_creator,
                             loss_creator,
                             training_operator_cls=MockOperator)
        runner.setup()

        self.assertEqual(len(runner.given_models), 3)
        self.assertEqual(len(runner.given_optimizers), 3)

        runner2 = TorchRunner(model_creator, single_loader, optimizer_creator,
                              loss_creator)
        runner2.setup()

        self.assertNotEqual(runner2.given_models, runner2.models)
        self.assertNotEqual(runner2.given_optimizers, runner2.optimizers)
예제 #2
0
 def testSingleLoader(self):
     runner = TorchRunner(model_creator, single_loader, optimizer_creator,
                          loss_creator)
     runner.setup()
     runner.train_epoch()
     with self.assertRaises(ValueError):
         runner.validate()
예제 #3
0
    def testGivens(self):
        def three_model_creator(config):
            return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)

        def three_optimizer_creator(models, config):
            opts = [
                torch.optim.SGD(model.parameters(), lr=0.1) for model in models
            ]
            return opts[0], opts[1], opts[2]

        class MockOperator(TrainingOperator):
            def setup(self, config):
                models = three_model_creator(config)
                optimizers = three_optimizer_creator(models, config)
                loader = single_loader(config)
                loss = loss_creator(config)
                self.models, self.optimizers, self.criterion = self.register(
                    models=models, optimizers=optimizers, criterion=loss)
                self.register_data(train_loader=loader, validation_loader=None)
                self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
                self.validate = MagicMock(returns=dict(mean_accuracy=10))

        runner = TorchRunner(training_operator_cls=MockOperator)
        runner.setup_operator()

        self.assertEqual(len(runner.given_models), 3)
        self.assertEqual(len(runner.given_optimizers), 3)

        runner2 = TorchRunner(training_operator_cls=self.Operator)
        runner2.setup_operator()

        self.assertNotEqual(runner2.given_models, runner2.models)
        self.assertNotEqual(runner2.given_optimizers, runner2.optimizers)
예제 #4
0
    def testMultiModel(self):
        def multi_model_creator(config):
            return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)

        def multi_optimizer_creator(models, config):
            opts = [
                torch.optim.SGD(model.parameters(), lr=0.1) for model in models
            ]
            return opts[0], opts[1], opts[2]

        runner = TorchRunner(multi_model_creator, single_loader,
                             multi_optimizer_creator, loss_creator)

        with self.assertRaises(ValueError):
            runner.setup()
예제 #5
0
파일: torch_trainer.py 프로젝트: yosagi/ray
    def _start_workers(self, num_workers):
        logger.debug(f"start_workers: Setting %d workers." % num_workers)
        worker_config = self.config.copy()
        batch_size_per_worker = self._configure_and_split_batch(num_workers)
        if batch_size_per_worker:
            worker_config[BATCH_SIZE] = batch_size_per_worker

        params = dict(model_creator=self.model_creator,
                      data_creator=self.data_creator,
                      optimizer_creator=self.optimizer_creator,
                      loss_creator=self.loss_creator,
                      scheduler_creator=self.scheduler_creator,
                      training_operator_cls=self.training_operator_cls,
                      config=worker_config,
                      use_fp16=self.use_fp16,
                      use_gpu=self.use_gpu,
                      use_tqdm=self.use_tqdm,
                      apex_args=self.apex_args,
                      scheduler_step_freq=self.scheduler_step_freq)

        if num_workers == 1:
            # Start local worker
            self.local_worker = TorchRunner(**params)
            if self.initialization_hook:
                self.apply_all_workers(self.initialization_hook)
            self.local_worker.setup()
        else:
            params.update(backend=self.backend,
                          add_dist_sampler=self.add_dist_sampler,
                          wrap_ddp=self.wrap_ddp)

            # Start local worker
            self.local_worker = LocalDistributedRunner(num_cpus=1,
                                                       num_gpus=int(
                                                           self.use_gpu),
                                                       **params)

            # Generate actor class
            RemoteRunner = ray.remote(num_cpus=1, num_gpus=int(
                self.use_gpu))(DistributedTorchRunner)
            # Start workers
            self.remote_workers = [
                RemoteRunner.remote(**params) for i in range(num_workers - 1)
            ]
            if self.initialization_hook:
                self.apply_all_workers(self.initialization_hook)

            # Compute URL for initializing distributed PyTorch
            ip = ray.services.get_node_ip_address()
            port = self.local_worker.find_free_port()

            address = "tcp://{ip}:{port}".format(ip=ip, port=port)

            remote_setups = [
                worker.setup.remote(address, i + 1, num_workers)
                for i, worker in enumerate(self.remote_workers)
            ]
            self.local_worker.setup(address, 0, num_workers)
            # Get setup tasks in order to throw errors on failure
            ray.get(remote_setups)
예제 #6
0
 def testSingleLoader(self):
     SingleOperator = TrainingOperator.from_creators(
         model_creator,
         optimizer_creator,
         single_loader,
         loss_creator=loss_creator)
     runner = TorchRunner(training_operator_cls=SingleOperator)
     runner.setup_operator()
     runner.train_epoch()
     with self.assertRaises(ValueError):
         runner.validate()
예제 #7
0
    def testValidate(self):
        class MockOperator(self.Operator):
            def setup(self, config):
                super(MockOperator, self).setup(config)
                self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
                self.validate = MagicMock(returns=dict(mean_accuracy=10))

        runner = TorchRunner(training_operator_cls=MockOperator)
        runner.setup_operator()
        runner.train_epoch()
        runner.train_epoch()
        result = runner.train_epoch()
        self.assertEqual(runner.training_operator.validate.call_count, 0)
        runner.validate()
        self.assertTrue(runner.training_operator.validate.called)
        self.assertEqual(result["epoch"], 3)
예제 #8
0
 def testNativeLoss(self):
     runner = TorchRunner(model_creator,
                          single_loader,
                          optimizer_creator,
                          loss_creator=nn.MSELoss)
     runner.setup()
     runner.train_epoch()
예제 #9
0
    def testMultiLoaders(self):
        def three_data_loader(config):
            return (LinearDataset(2, 5), LinearDataset(2, 5, size=400),
                    LinearDataset(2, 5, size=400))

        runner = TorchRunner(model_creator, three_data_loader,
                             optimizer_creator, loss_creator)
        with self.assertRaises(ValueError):
            runner.setup()

        runner2 = TorchRunner(model_creator, three_data_loader,
                              optimizer_creator, loss_creator)
        with self.assertRaises(ValueError):
            runner2.setup()
예제 #10
0
    def testValidate(self):
        class MockOperator(TrainingOperator):
            def setup(self, config):
                self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
                self.validate = MagicMock(returns=dict(mean_accuracy=10))

        runner = TorchRunner(model_creator,
                             create_dataloaders,
                             optimizer_creator,
                             loss_creator,
                             training_operator_cls=MockOperator)
        runner.setup()
        runner.train_epoch()
        runner.train_epoch()
        result = runner.train_epoch()
        self.assertEqual(runner.training_operator.validate.call_count, 0)
        runner.validate()
        self.assertTrue(runner.training_operator.validate.called)
        self.assertEqual(result["epoch"], 3)
예제 #11
0
    def start_workers(self, num_workers):
        logger.debug(f"start_workers: Setting {num_workers} workers.")

        if num_workers == 1:
            self.local_worker = TorchRunner(**self._params)
            if self._initialization_hook:
                self.apply_all_workers(self._initialization_hook)
            self.local_worker.setup_operator()
            return True
        else:
            try:
                # Start local worker
                self.local_worker = LocalDistributedRunner(
                    num_cpus=self._num_cpus_per_worker,
                    num_gpus=int(self._use_gpu),
                    **{
                        **self._params,
                        **self._dist_params
                    },
                )
                self.remote_worker_group._init_dist_workers(num_workers - 1)
                if self._initialization_hook:
                    self.apply_all_workers(self._initialization_hook)

                # Compute URL for initializing distributed PyTorch.
                address = setup_address()

                remote_pgs = self.remote_worker_group._setup_process_group(
                    address=address, world_size=num_workers, starting_rank=1)
                # Use the local worker as rank 0. Helps with debugging.
                self.local_worker.setup_process_group(
                    url=address,
                    world_rank=0,
                    world_size=num_workers,
                    timeout=timedelta(seconds=self._timeout_s),
                )
                ray.get(remote_pgs)

                local_node_ip = ray.util.get_node_ip_address()
                rank_dict = defaultdict(int)
                self.local_worker.set_local_rank(local_rank=0)
                rank_dict[local_node_ip] += 1
                self.remote_worker_group._setup_local_rank(rank_dict)

                remote_operators = self.remote_worker_group._setup_operator()
                self.local_worker.setup_operator()
                ray.get(remote_operators)
                return True
            except RayActorError:
                return False
예제 #12
0
 def testNativeLoss(self):
     NativeOperator = TrainingOperator.from_creators(
         model_creator,
         optimizer_creator,
         single_loader,
         loss_creator=nn.MSELoss)
     runner = TorchRunner(training_operator_cls=NativeOperator)
     runner.setup_operator()
     runner.train_epoch()
예제 #13
0
    def testMultiLoaders(self):
        def three_data_loader(config):
            return (LinearDataset(2, 5), LinearDataset(2, 5, size=400),
                    LinearDataset(2, 5, size=400))

        ThreeOperator = TrainingOperator.from_creators(
            model_creator,
            optimizer_creator,
            three_data_loader,
            loss_creator=loss_creator)

        runner = TorchRunner(training_operator_cls=ThreeOperator)
        with self.assertRaises(ValueError):
            runner.setup_operator()

        runner2 = TorchRunner(training_operator_cls=ThreeOperator)
        with self.assertRaises(ValueError):
            runner2.setup_operator()
예제 #14
0
    def testtrain_epoch(self):
        class MockOperator(TrainingOperator):
            def setup(self, config):
                self.count = 0

            def train_epoch(self, *args, **kwargs):
                self.count += 1
                return {"count": self.count}

        runner = TorchRunner(model_creator,
                             create_dataloaders,
                             optimizer_creator,
                             loss_creator,
                             training_operator_cls=MockOperator)
        runner.setup()
        runner.train_epoch(num_steps=1)
        runner.train_epoch(num_steps=1)
        result = runner.train_epoch()
        self.assertEqual(runner.training_operator.count, 3)
        self.assertEqual(result["count"], 3)
        self.assertEqual(runner.stats()["epoch"], 3)
예제 #15
0
    def testtrain_epoch(self):
        class MockOperator(self.Operator):
            def setup(self, config):
                super(MockOperator, self).setup(config)
                self.count = 0

            def train_epoch(self, *args, **kwargs):
                self.count += 1
                return {"count": self.count}

        runner = TorchRunner(training_operator_cls=MockOperator)
        runner.setup_operator()
        runner.train_epoch(num_steps=1)
        runner.train_epoch(num_steps=1)
        result = runner.train_epoch()
        self.assertEqual(runner.training_operator.count, 3)
        self.assertEqual(result["count"], 3)
        self.assertEqual(result["epoch"], 3)
예제 #16
0
    def start_workers(self, num_workers):
        logger.debug(f"start_workers: Setting %d workers." % num_workers)

        if num_workers == 1:
            self.local_worker = TorchRunner(**self._params)
            if self._initialization_hook:
                self.apply_all_workers(self._initialization_hook)
            self.local_worker.setup_operator()
        else:

            # Start local worker
            self.local_worker = LocalDistributedRunner(
                num_cpus=self._num_cpus_per_worker,
                num_gpus=int(self._use_gpu),
                **{
                    **self._params,
                    **self._dist_params
                })
            self.remote_worker_group._init_dist_workers(num_workers - 1)
            if self._initialization_hook:
                self.apply_all_workers(self._initialization_hook)

            # Compute URL for initializing distributed PyTorch.
            address = setup_address()

            remote_pgs = self.remote_worker_group._setup_process_group(
                address=address, world_size=num_workers, starting_rank=1)
            # Use the local worker as rank 0. This will help with debugging.
            self.local_worker.setup_process_group(url=address,
                                                  world_rank=0,
                                                  world_size=num_workers,
                                                  timeout=timedelta(
                                                      self._timeout_s))
            ray.get(remote_pgs)

            remote_operators = self.remote_worker_group._setup_operator()
            self.local_worker.setup_operator()
            ray.get(remote_operators)
예제 #17
0
    def _start_workers(self, num_workers):
        logger.debug(f"start_workers: Setting {num_workers} workers.")
        worker_config = self.config.copy()
        batch_size_per_worker = self._configure_and_split_batch(num_workers)
        if batch_size_per_worker:
            worker_config[BATCH_SIZE] = batch_size_per_worker

        params = dict(
            model_creator=self.model_creator,
            data_creator=self.data_creator,
            optimizer_creator=self.optimizer_creator,
            loss_creator=self.loss_creator,
            scheduler_creator=self.scheduler_creator,
            training_operator_cls=self.training_operator_cls,
            config=worker_config,
            use_fp16=self.use_fp16,
            use_gpu=True,
            use_tqdm=self.use_tqdm,
            apex_args=self.apex_args,
            scheduler_step_freq=self.scheduler_step_freq,
        )

        if num_workers == 1:
            # Start local worker
            self.local_worker = TorchRunner(**params)
            self.apply_all_workers(_set_device_from_fluid_res)
            if self.initialization_hook:
                self.apply_all_workers(self.initialization_hook)
            self.local_worker.setup()
        else:
            params.update(
                backend=self.backend,
                add_dist_sampler=self.add_dist_sampler,
                wrap_ddp=self.wrap_ddp,
            )

            # Start local worker
            self.local_worker = LocalDistributedRunner(**params)

            # Start remote workers
            # assert num_workers == len(self.extra_assigned_worker_res) + 1
            self.remote_workers = []
            for res_name, res_val in self.extra_assigned_worker_res:
                # Generate actor class
                RemoteRunner = ray.remote(num_cpus=1,
                                          num_gpus=res_val,
                                          resources={res_name: res_val
                                                     })(DistributedTorchRunner)
                self.remote_workers.append(RemoteRunner.remote(**params))

            self.apply_all_workers(_set_device_from_fluid_res)
            if self.initialization_hook:
                self.apply_all_workers(self.initialization_hook)

            # Compute URL for initializing distributed PyTorch
            ip = ray.services.get_node_ip_address()
            port = self.local_worker.find_free_port()

            address = "tcp://{ip}:{port}".format(ip=ip, port=port)

            # Runs the creator functions.
            remote_component_setup = [
                worker.setup_components.remote()
                for i, worker in enumerate(self.remote_workers)
            ]
            self.local_worker.setup_components()
            # Get setup tasks in order to throw errors on failure
            ray.get(remote_component_setup)

            # Setup the process group among all workers.
            remote_pgroup_setups = [
                worker.setup_process_group.remote(address, i + 1, num_workers)
                for i, worker in enumerate(self.remote_workers)
            ]
            self.local_worker.setup_process_group(address, 0, num_workers)
            # Get setup tasks in order to throw errors on failure
            ray.get(remote_pgroup_setups)

            # Runs code that requires all creator functions to have run.
            remote_operator_setups = [
                worker.setup_ddp_and_operator.remote()
                for worker in self.remote_workers
            ]
            self.local_worker.setup_ddp_and_operator()
            # Get setup tasks in order to throw errors on failure
            ray.get(remote_operator_setups)
예제 #18
0
    def _start_workers(self, num_workers):
        logger.debug(f"start_workers: Setting %d workers." % num_workers)
        worker_config = self.config.copy()
        batch_size_per_worker = self._configure_and_split_batch(num_workers)
        if batch_size_per_worker:
            worker_config[BATCH_SIZE] = batch_size_per_worker

        params = dict(
            training_operator_cls=self.training_operator_cls,
            config=worker_config,
            serialize_data_creation=self.serialize_data_creation,
            use_fp16=self.use_fp16,
            use_gpu=self.use_gpu,
            use_tqdm=self.use_tqdm,
            apex_args=self.apex_args,
            scheduler_step_freq=self.scheduler_step_freq)

        if num_workers == 1:
            # Start local worker
            self.local_worker = TorchRunner(**params)
            if self.initialization_hook:
                self.apply_all_workers(self.initialization_hook)
            self.local_worker.setup_operator()
        else:
            params.update(
                backend=self.backend,
                add_dist_sampler=self.add_dist_sampler,
                wrap_ddp=self.wrap_ddp)

            # Start local worker
            self.local_worker = LocalDistributedRunner(
                num_cpus=self.num_cpus_per_worker,
                num_gpus=int(self.use_gpu),
                **params)

            # Generate actor class
            RemoteRunner = ray.remote(
                num_cpus=self.num_cpus_per_worker,
                num_gpus=int(self.use_gpu))(DistributedTorchRunner)
            # Start workers
            self.remote_workers = [
                RemoteRunner.remote(**params) for i in range(num_workers - 1)
            ]
            if self.initialization_hook:
                self.apply_all_workers(self.initialization_hook)

            # Compute URL for initializing distributed PyTorch
            address = setup_address()

            # Setup the process group among all workers.
            remote_pgroup_setups = [
                worker.setup_process_group.remote(address, i + 1, num_workers,
                                                  timedelta(self.timeout_s))
                for i, worker in enumerate(self.remote_workers)
            ]
            self.local_worker.setup_process_group(address, 0, num_workers,
                                                  timedelta(self.timeout_s))
            # Get setup tasks in order to throw errors on failure
            ray.get(remote_pgroup_setups)

            # Runs code that requires all creator functions to have run.
            remote_operator_setups = [
                worker.setup_operator.remote()
                for worker in self.remote_workers
            ]
            self.local_worker.setup_operator()
            # Get setup tasks in order to throw errors on failure
            ray.get(remote_operator_setups)