def test_src_render(self): """Test render file in src directory.""" template_mgr = TemplateManager(os.path.join(self.template_dir, 'src')) source_files = template_mgr.render(optimizer='Momentum') assert source_files[0].content == textwrap.dedent("""\ { 'num_classes': 10, 'lr': 0.01, "momentum": 0.9, 'epoch_size': 1 } """) source_files = template_mgr.render(optimizer='SGD') assert source_files[0].content == textwrap.dedent("""\ { 'num_classes': 10, 'lr': 0.1, 'epoch_size': 1 } """) source_files = template_mgr.render() assert source_files[0].content == textwrap.dedent("""\ { 'num_classes': 10, 'lr': 0.001, 'epoch_size': 1 } """)
def set_network(self, network_maker): self._network = network_maker template_dir = os.path.join(TEMPLATES_BASE_DIR, 'network', network_maker.name.lower(), 'dataset', self.name.lower()) self.template_manager = TemplateManager(template_dir)
def test_dataset_render(self): """Test render file in dataset directory.""" template_mgr = TemplateManager( os.path.join(self.template_dir, 'dataset')) source_files = template_mgr.render() assert source_files[0].content == textwrap.dedent("""\ import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as CV """) assert source_files[0].file_relative_path == 'mnist/dataset.py' assert source_files[0].template_file_path == os.path.join( self.template_dir, 'dataset', 'mnist/dataset.py-tpl')
class Dataset(BaseDataset): """BaseDataset code generator.""" name = 'ImageNet' def __init__(self): super(Dataset, self).__init__() self._network = None self.template_manager = None def set_network(self, network_maker): self._network = network_maker template_dir = os.path.join(TEMPLATES_BASE_DIR, 'network', network_maker.name.lower(), 'dataset', self.name.lower()) self.template_manager = TemplateManager(template_dir) def configure(self): """Configure the network options.""" return self.settings def generate(self, **options): source_files = self.template_manager.render(**options) return source_files
def test_assemble_render(self): """Test render assemble files in template directory.""" template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset']) source_files = template_mgr.render( loss='SoftmaxCrossEntropyWithLogits') unmatched_files = [] for source_file in source_files: if source_file.template_file_path == os.path.join( self.template_dir, 'scripts/run_standalone_train.sh-tpl'): assert source_file.content == textwrap.dedent("""\ python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & """) assert source_file.file_relative_path == 'scripts/run_standalone_train.sh' elif source_file.template_file_path == os.path.join( self.template_dir, 'train.py-tpl'): assert source_file.content == textwrap.dedent("""\ net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") """) assert source_file.file_relative_path == 'train.py' else: unmatched_files.append(source_file) assert not unmatched_files
def test_template_files(self): """Test get_template_files method.""" src_file_num = 1 dataset_file_num = 1 template_mgr = TemplateManager(self.template_dir) all_files = template_mgr.get_template_files() assert set(all_files) == set(self.all_template_files) template_mgr = TemplateManager(os.path.join(self.template_dir, 'src')) all_files = template_mgr.get_template_files() assert len(all_files) == src_file_num template_mgr = TemplateManager( os.path.join(self.template_dir, 'dataset')) all_files = template_mgr.get_template_files() assert len(all_files) == dataset_file_num template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src']) all_files = template_mgr.get_template_files() assert len(all_files) == len(self.all_template_files) - src_file_num template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset']) all_files = template_mgr.get_template_files() assert len(all_files) == len( self.all_template_files) - src_file_num - dataset_file_num template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset'], exclude_files=['train.py-tpl']) all_files = template_mgr.get_template_files() assert len(all_files) == len( self.all_template_files) - src_file_num - dataset_file_num - 1
def __init__(self): self._dataset_maker = None template_dir = os.path.join(TEMPLATES_BASE_DIR, 'network', self.name.lower()) self.network_template_manager = TemplateManager(os.path.join(template_dir, 'src')) self.common_template_manager = TemplateManager(template_dir, ['src', 'dataset'])
class GenericNetwork(BaseNetwork): """BaseNetwork code generator.""" name = 'GenericNetwork' supported_datasets = [] supported_loss_functions = [] supported_optimizers = [] def __init__(self): self._dataset_maker = None template_dir = os.path.join(TEMPLATES_BASE_DIR, 'network', self.name.lower()) self.network_template_manager = TemplateManager(os.path.join(template_dir, 'src')) self.common_template_manager = TemplateManager(template_dir, ['src', 'dataset']) def configure(self, settings=None): """ Configure the network options. If settings is not None, then use the input settings to configure the network. Args: settings (dict): Settings to configure, format is {'options': value}. Example: { "loss": "SoftmaxCrossEntropyWithLogits", "optimizer": "Momentum", "dataset": "Cifar10" } Returns: dict, configuration value to network. """ if settings: config = dict(settings) dataset_name = settings['dataset'] self._dataset_maker = load_dataset_maker(dataset_name) else: loss = self.ask_loss_function() optimizer = self.ask_optimizer() dataset_name = self.ask_dataset() self._dataset_maker = load_dataset_maker(dataset_name) dataset_config = self._dataset_maker.configure() config = {'loss': loss, 'optimizer': optimizer, 'dataset': dataset_name} config.update(dataset_config) self._dataset_maker.set_network(self) self.settings.update(config) return config @staticmethod def ask_choice(prompt_head, content_list, default_value=None): """Ask user to get selected result.""" if default_value is None: default_choice = 1 # start from 1 in prompt message. default_value = content_list[default_choice - 1] choice_contents = content_list[:] choice_contents.sort(reverse=False) default_choice = choice_contents.index(default_value) + 1 # start from 1 in prompt message. prompt_msg = '{}:\n{}\n'.format( prompt_head, '\n'.join(f'{idx: >4}: {choice}' for idx, choice in enumerate(choice_contents, start=1)) ) prompt_type = click.IntRange(min=1, max=len(choice_contents)) choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False, confirmation_prompt=False, default=default_choice, value_proc=lambda x: process_prompt_choice(x, prompt_type)) click.secho(textwrap.dedent("Your choice is %s." % choice_contents[choice - 1]), fg='yellow') return choice_contents[choice - 1] def ask_loss_function(self): """Select loss function by user.""" return self.ask_choice('%sPlease select a loss function' % QUESTION_START, self.supported_loss_functions) def ask_optimizer(self): """Select optimizer by user.""" return self.ask_choice('%sPlease select an optimizer' % QUESTION_START, self.supported_optimizers) def ask_dataset(self): """Select dataset by user.""" return self.ask_choice('%sPlease select a dataset' % QUESTION_START, self.supported_datasets) def generate(self, **options): """Generate network definition scripts.""" context = self.get_generate_context(**options) network_source_files = self.network_template_manager.render(**context) for source_file in network_source_files: source_file.file_relative_path = os.path.join('src', source_file.file_relative_path) dataset_source_files = self._dataset_maker.generate(**options) for source_file in dataset_source_files: source_file.file_relative_path = os.path.join('src', source_file.file_relative_path) assemble_files = self._assemble(**options) source_files = network_source_files + dataset_source_files + assemble_files return source_files def get_generate_context(self, **options): """Get detailed info based on settings to network files.""" context = dict(options) context.update(self.settings) return context def get_assemble_context(self, **options): """Get detailed info based on settings to assemble files.""" context = dict(options) context.update(self.settings) return context def _assemble(self, **options): # generate train.py & eval.py & assemble scripts. assemble_files = [] context = self.get_assemble_context(**options) common_source_files = self.common_template_manager.render(**context) assemble_files.extend(common_source_files) return assemble_files