def load_run_config(self, print_info=False, dataset='C10+'): if os.path.isfile(self.run_config_path): run_config = json.load(open(self.run_config_path, 'r')) else: print('Use Default Run Config for %s' % dataset) run_config = RunConfig.get_default_run_config(dataset) if print_info: print('Run config:') for k, v in run_config.items(): print('\t%s: %s' % (k, v)) return RunConfig(**run_config)
type=str, default='C10+', choices=['C10', 'C10+', 'C100', 'C100+'], ) parser.add_argument('--path', type=str, default='') parser.add_argument('--save_config', action='store_true', help='Whether to save config in the path') parser.add_argument('--save_init', action='store_true') parser.add_argument('--load_model', action='store_true') args = parser.parse_args() if args.dataset in ['C10', 'C100', 'C10+', 'C100+']: run_config_cifar['dataset'] = args.dataset run_config = RunConfig(**run_config_cifar) net_config = standard_net_config_cifar else: raise ValueError if len(args.path) == 0: args.path = '../trained_nets/DenseNet/vs=%s_%s_%s_L=%d_K=%d_%s' % \ (run_config.validation_size, os.uname()[1], net_config['model_type'], net_config['depth'], net_config['growth_rate'], run_config.dataset) if run_config.dataset in ['C10+', 'C100+']: net_config['keep_prob'] = 1.0 if standard_net_config_cifar['model_type'] == 'DenseNet': net_config['reduction'] = 1.0 if args.test: args.load_model = True # print configurations
'--dataset', type=str, default='C10+', choices=['C10', 'C10+', 'C100', 'C100+', 'SVHN', 'MNIST'], ) parser.add_argument('--path', type=str, default='') parser.add_argument('--save_config', action='store_true', help='Whether to save config in the path') parser.add_argument('--save_init', action='store_true') parser.add_argument('--load_model', action='store_true') args = parser.parse_args() if args.dataset in ['C10', 'C100', 'C10+', 'C100+']: run_config_cifar['dataset'] = args.dataset run_config = RunConfig(**run_config_cifar) elif args.dataset in ['SVHN']: run_config = RunConfig(**run_config_svhn) elif args.dataset in ['MNIST']: run_config = RunConfig(**run_config_mnist) else: raise ValueError if len(args.path) == 0: args.path = '../trained_nets/Convnet/vs=%s_Convnet_%s_%s_%s' % \ (run_config.validation_size, os.uname()[1], run_str, run_config.dataset) if args.test: args.load_model = True # print configurations print('Run config:') for k, v in run_config.get_config().items(): print('\t%s: %s' % (k, v))