def __init__(self, resnet_size_choices=None): super(ResnetArgParser, self).__init__(parents=[ parsers.BaseParser(), parsers.PerformanceParser(), parsers.ImageModelParser(), parsers.ExportParser(), parsers.BenchmarkParser(), ]) self.add_argument( '--version', '-v', type=int, choices=[1, 2], default=resnet_model.DEFAULT_VERSION, help='Version of ResNet. (1 or 2) See README.md for details.' ) self.add_argument( '--resnet_size', '-rs', type=int, default=50, choices=resnet_size_choices, help='[default: %(default)s] The size of the ResNet model to use.', metavar='<RS>' if resnet_size_choices is None else None ) self.add_argument( '--no_lmk', '-nolmk', type=bool, default=False, help='[default: %(default)s] Do not use landmark' )
def __init__(self): super(MNISTArgParser, self).__init__(parents=[ parsers.BaseParser(), parsers.ImageModelParser(), parsers.ExportParser(), ]) self.set_defaults(data_dir='resources/mnist_data', model_dir='resources/mnist_model', batch_size=100, train_epochs=40)
def __init__(self): #初始化函数 super(WideDeepArgParser, self).__init__(parents=[parsers.BaseParser()]) #调用父类的初始化函数 self.add_argument( '--model_type', '-mt', type=str, default='BoostedTrees', #添加一个启动参数--model_type,默认值为wide_deep choices=['wide', 'deep', 'wide_deep',"BoostedTrees"], #定义该参数的可选值 help='[default %(default)s] Valid model types: wide, deep, wide_deep.', #定义启动参数的帮助命令 metavar='<MT>') self.set_defaults( #为其他参数设置默认值 data_dir='income_data', #设置数据样本路径 model_dir='income_model', #设置模型存放路径 export_dir='income_model_exp', #设置导出模型存放路径 train_epochs=5, #设置迭代次数 batch_size=40) #设置批次大小