def __init__(self, resnet_size_choices=None):
        super(ResnetArgParser, self).__init__(parents=[
            parsers.BaseParser(),
            parsers.PerformanceParser(),
            parsers.ImageModelParser(),
            parsers.ExportParser(),
            parsers.BenchmarkParser(),
        ])

        self.add_argument(
            '--version', '-v', type=int, choices=[1, 2],
            default=resnet_model.DEFAULT_VERSION,
            help='Version of ResNet. (1 or 2) See README.md for details.'
        )

        self.add_argument(
            '--resnet_size', '-rs', type=int, default=50,
            choices=resnet_size_choices,
            help='[default: %(default)s] The size of the ResNet model to use.',
            metavar='<RS>' if resnet_size_choices is None else None
        )

        self.add_argument(
            '--no_lmk', '-nolmk', type=bool, default=False,
            help='[default: %(default)s] Do not use landmark'
        )
Esempio n. 2
0
    def __init__(self):
        super(MNISTArgParser, self).__init__(parents=[
            parsers.BaseParser(),
            parsers.ImageModelParser(),
            parsers.ExportParser(),
        ])

        self.set_defaults(data_dir='resources/mnist_data',
                          model_dir='resources/mnist_model',
                          batch_size=100,
                          train_epochs=40)
Esempio n. 3
0
 def __init__(self):                                           #初始化函数
   super(WideDeepArgParser, self).__init__(parents=[parsers.BaseParser()]) #调用父类的初始化函数
   self.add_argument(
       '--model_type', '-mt', type=str, default='BoostedTrees',   #添加一个启动参数--model_type,默认值为wide_deep
       choices=['wide', 'deep', 'wide_deep',"BoostedTrees"],                  #定义该参数的可选值
       help='[default %(default)s] Valid model types: wide, deep, wide_deep.', #定义启动参数的帮助命令
       metavar='<MT>')
   self.set_defaults(                                          #为其他参数设置默认值
       data_dir='income_data',                                 #设置数据样本路径
       model_dir='income_model',                               #设置模型存放路径
       export_dir='income_model_exp',                          #设置导出模型存放路径
       train_epochs=5,                                        #设置迭代次数
       batch_size=40)                                          #设置批次大小