Esempio n. 1
0
    def test_sync_strategy(self):
        os.environ['CPU_NUM'] = "2"
        strategy = StrategyFactory.create_sync_strategy()
        self.assertEqual(strategy._program_config.sync_mode, False)
        self.assertEqual(strategy._program_config.runtime_split_send_recv,
                         True)
        self.assertEqual(strategy._build_strategy.async_mode, True)
        self.assertEqual(strategy._execute_strategy.num_threads, 2)

        # test set_program_config using DistributeTranspilerConfig()
        program_config_class = DistributeTranspilerConfig()
        program_config_class.min_block_size = 81920
        strategy.set_program_config(program_config_class)
        program_config = strategy.get_program_config()
        self.assertEqual(program_config.min_block_size, 81920)

        # test set_program_config using dict
        program_config_dict = dict()
        program_config_dict['min_block_size'] = 8192
        strategy.set_program_config(program_config_dict)
        program_config = strategy.get_program_config()
        self.assertEqual(program_config.min_block_size, 8192)

        # test set_program_config exception
        program_config_dict['unknown'] = None
        self.assertRaises(Exception, strategy.set_program_config,
                          program_config_dict)
        program_config_illegal = None
        self.assertRaises(Exception, strategy.set_program_config,
                          program_config_illegal)
Esempio n. 2
0
    def test_sync_strategy(self):
        os.environ['CPU_NUM'] = "2"
        strategy = StrategyFactory.create_sync_strategy()
        self.assertEqual(strategy._program_config.sync_mode, False)
        self.assertEqual(strategy._program_config.runtime_split_send_recv,
                         True)
        self.assertEqual(strategy._build_strategy.async_mode, True)
        self.assertEqual(strategy._execute_strategy.num_threads, 2)

        # test set_program_config using DistributeTranspilerConfig()
        program_config_class = DistributeTranspilerConfig()
        program_config_class.min_block_size = 81920
        strategy.set_program_config(program_config_class)
        program_config = strategy.get_program_config()
        self.assertEqual(program_config.min_block_size, 81920)

        # test set_program_config using dict
        program_config_dict = dict()
        program_config_dict['min_block_size'] = 8192
        strategy.set_program_config(program_config_dict)
        program_config = strategy.get_program_config()
        self.assertEqual(program_config.min_block_size, 8192)

        # test set_program_config exception
        program_config_dict['unknown'] = None
        self.assertRaises(Exception, strategy.set_program_config,
                          program_config_dict)
        program_config_illegal = None
        self.assertRaises(Exception, strategy.set_program_config,
                          program_config_illegal)

        trainer_runtime_config = strategy.get_trainer_runtime_config()
        trainer_runtime_config.runtime_configs[
            'communicator_send_queue_size'] = '50'
        runtime_configs = trainer_runtime_config.get_communicator_flags()
        self.assertIn('communicator_send_queue_size', runtime_configs)
        self.assertNotIn('communicator_independent_recv_thread',
                         runtime_configs)
        self.assertEqual(runtime_configs['communicator_send_queue_size'], '2')