コード例 #1
0
    def test_src_render(self):
        """Test render file in src directory."""
        template_mgr = TemplateManager(os.path.join(self.template_dir, 'src'))
        source_files = template_mgr.render(optimizer='Momentum')
        assert source_files[0].content == textwrap.dedent("""\
                {
                    'num_classes': 10,
                    'lr': 0.01,
                    "momentum": 0.9,
                    'epoch_size': 1
                }
                """)

        source_files = template_mgr.render(optimizer='SGD')
        assert source_files[0].content == textwrap.dedent("""\
                {
                    'num_classes': 10,
                    'lr': 0.1,
                    'epoch_size': 1
                }
                """)
        source_files = template_mgr.render()
        assert source_files[0].content == textwrap.dedent("""\
                        {
                            'num_classes': 10,
                            'lr': 0.001,
                            'epoch_size': 1
                        }
                        """)
コード例 #2
0
ファイル: imagenet.py プロジェクト: rock4you/mindinsight
 def set_network(self, network_maker):
     self._network = network_maker
     template_dir = os.path.join(TEMPLATES_BASE_DIR,
                                 'network',
                                 network_maker.name.lower(),
                                 'dataset',
                                 self.name.lower())
     self.template_manager = TemplateManager(template_dir)
コード例 #3
0
 def test_dataset_render(self):
     """Test render file in dataset directory."""
     template_mgr = TemplateManager(
         os.path.join(self.template_dir, 'dataset'))
     source_files = template_mgr.render()
     assert source_files[0].content == textwrap.dedent("""\
             import mindspore.dataset as ds
             import mindspore.dataset.transforms.vision.c_transforms as CV
             """)
     assert source_files[0].file_relative_path == 'mnist/dataset.py'
     assert source_files[0].template_file_path == os.path.join(
         self.template_dir, 'dataset', 'mnist/dataset.py-tpl')
コード例 #4
0
ファイル: imagenet.py プロジェクト: rock4you/mindinsight
class Dataset(BaseDataset):
    """BaseDataset code generator."""
    name = 'ImageNet'

    def __init__(self):
        super(Dataset, self).__init__()
        self._network = None
        self.template_manager = None

    def set_network(self, network_maker):
        self._network = network_maker
        template_dir = os.path.join(TEMPLATES_BASE_DIR,
                                    'network',
                                    network_maker.name.lower(),
                                    'dataset',
                                    self.name.lower())
        self.template_manager = TemplateManager(template_dir)

    def configure(self):
        """Configure the network options."""
        return self.settings

    def generate(self, **options):
        source_files = self.template_manager.render(**options)
        return source_files
コード例 #5
0
    def test_assemble_render(self):
        """Test render assemble files in template directory."""
        template_mgr = TemplateManager(self.template_dir,
                                       exclude_dirs=['src', 'dataset'])
        source_files = template_mgr.render(
            loss='SoftmaxCrossEntropyWithLogits')
        unmatched_files = []
        for source_file in source_files:
            if source_file.template_file_path == os.path.join(
                    self.template_dir, 'scripts/run_standalone_train.sh-tpl'):
                assert source_file.content == textwrap.dedent("""\
                        python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
                        """)
                assert source_file.file_relative_path == 'scripts/run_standalone_train.sh'
            elif source_file.template_file_path == os.path.join(
                    self.template_dir, 'train.py-tpl'):
                assert source_file.content == textwrap.dedent("""\
                        net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
                        """)
                assert source_file.file_relative_path == 'train.py'
            else:
                unmatched_files.append(source_file)

        assert not unmatched_files
コード例 #6
0
    def test_template_files(self):
        """Test get_template_files method."""
        src_file_num = 1
        dataset_file_num = 1
        template_mgr = TemplateManager(self.template_dir)
        all_files = template_mgr.get_template_files()
        assert set(all_files) == set(self.all_template_files)

        template_mgr = TemplateManager(os.path.join(self.template_dir, 'src'))
        all_files = template_mgr.get_template_files()
        assert len(all_files) == src_file_num

        template_mgr = TemplateManager(
            os.path.join(self.template_dir, 'dataset'))
        all_files = template_mgr.get_template_files()
        assert len(all_files) == dataset_file_num

        template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src'])
        all_files = template_mgr.get_template_files()
        assert len(all_files) == len(self.all_template_files) - src_file_num

        template_mgr = TemplateManager(self.template_dir,
                                       exclude_dirs=['src', 'dataset'])
        all_files = template_mgr.get_template_files()
        assert len(all_files) == len(
            self.all_template_files) - src_file_num - dataset_file_num

        template_mgr = TemplateManager(self.template_dir,
                                       exclude_dirs=['src', 'dataset'],
                                       exclude_files=['train.py-tpl'])
        all_files = template_mgr.get_template_files()
        assert len(all_files) == len(
            self.all_template_files) - src_file_num - dataset_file_num - 1
コード例 #7
0
 def __init__(self):
     self._dataset_maker = None
     template_dir = os.path.join(TEMPLATES_BASE_DIR, 'network', self.name.lower())
     self.network_template_manager = TemplateManager(os.path.join(template_dir, 'src'))
     self.common_template_manager = TemplateManager(template_dir, ['src', 'dataset'])
コード例 #8
0
class GenericNetwork(BaseNetwork):
    """BaseNetwork code generator."""
    name = 'GenericNetwork'
    supported_datasets = []
    supported_loss_functions = []
    supported_optimizers = []

    def __init__(self):
        self._dataset_maker = None
        template_dir = os.path.join(TEMPLATES_BASE_DIR, 'network', self.name.lower())
        self.network_template_manager = TemplateManager(os.path.join(template_dir, 'src'))
        self.common_template_manager = TemplateManager(template_dir, ['src', 'dataset'])

    def configure(self, settings=None):
        """
        Configure the network options.

        If settings is not None, then use the input settings to configure the network.

        Args:
            settings (dict): Settings to configure, format is {'options': value}.
                Example:
                    {
                        "loss": "SoftmaxCrossEntropyWithLogits",
                        "optimizer": "Momentum",
                        "dataset": "Cifar10"
                    }

        Returns:
            dict, configuration value to network.
        """
        if settings:
            config = dict(settings)
            dataset_name = settings['dataset']
            self._dataset_maker = load_dataset_maker(dataset_name)
        else:
            loss = self.ask_loss_function()
            optimizer = self.ask_optimizer()
            dataset_name = self.ask_dataset()
            self._dataset_maker = load_dataset_maker(dataset_name)
            dataset_config = self._dataset_maker.configure()

            config = {'loss': loss,
                      'optimizer': optimizer,
                      'dataset': dataset_name}
            config.update(dataset_config)
        self._dataset_maker.set_network(self)
        self.settings.update(config)
        return config

    @staticmethod
    def ask_choice(prompt_head, content_list, default_value=None):
        """Ask user to get selected result."""
        if default_value is None:
            default_choice = 1  # start from 1 in prompt message.
            default_value = content_list[default_choice - 1]

        choice_contents = content_list[:]
        choice_contents.sort(reverse=False)
        default_choice = choice_contents.index(default_value) + 1  # start from 1 in prompt message.

        prompt_msg = '{}:\n{}\n'.format(
            prompt_head,
            '\n'.join(f'{idx: >4}: {choice}' for idx, choice in enumerate(choice_contents, start=1))
        )
        prompt_type = click.IntRange(min=1, max=len(choice_contents))
        choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False,
                              confirmation_prompt=False, default=default_choice,
                              value_proc=lambda x: process_prompt_choice(x, prompt_type))
        click.secho(textwrap.dedent("Your choice is %s." % choice_contents[choice - 1]), fg='yellow')
        return choice_contents[choice - 1]

    def ask_loss_function(self):
        """Select loss function by user."""
        return self.ask_choice('%sPlease select a loss function' % QUESTION_START, self.supported_loss_functions)

    def ask_optimizer(self):
        """Select optimizer by user."""
        return self.ask_choice('%sPlease select an optimizer' % QUESTION_START, self.supported_optimizers)

    def ask_dataset(self):
        """Select dataset by user."""
        return self.ask_choice('%sPlease select a dataset' % QUESTION_START, self.supported_datasets)

    def generate(self, **options):
        """Generate network definition scripts."""
        context = self.get_generate_context(**options)
        network_source_files = self.network_template_manager.render(**context)
        for source_file in network_source_files:
            source_file.file_relative_path = os.path.join('src', source_file.file_relative_path)
        dataset_source_files = self._dataset_maker.generate(**options)
        for source_file in dataset_source_files:
            source_file.file_relative_path = os.path.join('src', source_file.file_relative_path)

        assemble_files = self._assemble(**options)
        source_files = network_source_files + dataset_source_files + assemble_files
        return source_files

    def get_generate_context(self, **options):
        """Get detailed info based on settings to network files."""
        context = dict(options)
        context.update(self.settings)
        return context

    def get_assemble_context(self, **options):
        """Get detailed info based on settings to assemble files."""
        context = dict(options)
        context.update(self.settings)
        return context

    def _assemble(self, **options):
        # generate train.py & eval.py & assemble scripts.
        assemble_files = []
        context = self.get_assemble_context(**options)
        common_source_files = self.common_template_manager.render(**context)
        assemble_files.extend(common_source_files)
        return assemble_files