def load_optimizer_instance_from_meta_data(self): # Load the meta data dictionary with open(self.optimizer_meta_data_path) as f: meta = yaml.safe_load(f) # Return an instance of the optimizer optimizer = Optimizer.initialize_from_meta_data(meta) return optimizer
def create_experiment_from_meta_data(self, key, model_meta_data, optimizer_meta_data): """ Create an experiment from meta data """ self.current_iteration = 0 # Create the keys self.model_init_key, self.opt_key, data_dependent_init_key = random.split( key, 3) # Create the model model_name = model_meta_data['model'] ModelClass = MODEL_LIST[model_name] model = ModelClass.initialize_from_meta_data(model_meta_data) # Get the data loader x_shape = self.get_data_loader(model.dataset_name) assert x_shape == model.x_shape, 'The dataset has the wrong dimensions! Has %s, expected %s' % ( str(x_shape), str(model.x_shape)) # Initalize the model. Use a key to ensure things are initialized correctly init_key = random.PRNGKey(0) model.build_model(self.quantize_level_bits, init_key=init_key) model.initialize_model(self.model_init_key) # Do data dependent initialization model.data_dependent_init(data_dependent_init_key, self.data_loader, batch_size=64) # Initialize the optimizer optimizer = Optimizer.initialize_from_meta_data(optimizer_meta_data) optimizer.initialize(model) self.model = model self.optimizer = optimizer # Save a dummy first checkpoint self.checkpoint_experiment(0, self.opt_key, np.array([]))