def _eval_model_on_split(self, split: str, num_examples: int, global_batch_size: int, params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, rng: spec.RandomState, data_dir: str) -> Dict[str, float]: """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: eval_iter = self.build_input_queue( data_rng, split, data_dir, global_batch_size=global_batch_size) # Note that this stores the entire eval dataset in memory. self._eval_iters[split] = itertools.cycle(eval_iter) total_metrics = { 'accuracy': 0., 'loss': 0., } num_data = 0 num_batches = int(math.ceil(num_examples / global_batch_size)) for bi, batch in enumerate(self._eval_iters[split]): if bi > num_batches: break per_device_model_rngs = prng.split(model_rng, jax.local_device_count()) batch_metrics = self._eval_model(params, batch, model_state, per_device_model_rngs) total_metrics = { k: v + batch_metrics[k] for k, v in total_metrics.items() } num_data += batch_metrics['num_data'] return {k: float(v / num_data) for k, v in total_metrics.items()}
def eval_model(self, params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, rng: spec.RandomState, data_dir: str): """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) eval_batch_size = 128 if self._eval_ds is None: self._eval_ds = self._build_dataset(data_rng, 'test', data_dir, batch_size=eval_batch_size) total_metrics = { 'accuracy': 0., 'loss': 0., } n_data = 0 for (images, labels) in self._eval_ds: images = images.float().to(DEVICE) labels = labels.float().to(DEVICE) logits, _ = self.model_fn(params, images, model_state, spec.ForwardPassMode.EVAL, model_rng, update_batch_norm=False) batch_metrics = self._eval_metric(logits, labels) total_metrics = { k: v + batch_metrics[k] for k, v in total_metrics.items() } n_data += batch_metrics['n_data'] return {k: float(v / n_data) for k, v in total_metrics.items()}
def eval_model(self, params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, rng: spec.RandomState, data_dir: str): """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) eval_batch_size = 2000 self._eval_ds = self.build_input_queue(data_rng, 'test', data_dir, batch_size=eval_batch_size) total_metrics = { 'accuracy': 0., 'loss': 0., } n_data = 0 for (images, labels) in self._eval_ds: images, labels = self.preprocess_for_eval(images, labels, None, None) logits, _ = self.model_fn(params, images, model_state, spec.ForwardPassMode.EVAL, model_rng, update_batch_norm=False) # TODO(znado): add additional eval metrics? batch_metrics = self._eval_metric(logits, labels) total_metrics = { k: v + batch_metrics[k] for k, v in total_metrics.items() } n_data += batch_metrics['n_data'] return {k: float(v / n_data) for k, v in total_metrics.items()}
def _eval_model_on_split(self, split: str, num_examples: int, global_batch_size: int, params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, rng: spec.RandomState, data_dir: str) -> Dict[str, float]: """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: eval_iter = self.build_input_queue( data_rng, split, data_dir, global_batch_size=global_batch_size) # Note that this stores the entire val dataset in memory. self._eval_iters[split] = itertools.cycle(eval_iter) total_metrics = None num_eval_steps = int(math.ceil( float(num_examples) / global_batch_size)) # Loop over graph batches in eval dataset. for _ in range(num_eval_steps): batch = next(self._eval_iters[split]) batch_metrics = self._eval_batch(params, batch, model_state, model_rng) total_metrics = (batch_metrics if total_metrics is None else total_metrics.merge(batch_metrics)) if total_metrics is None: return {} return { k: float(v) for k, v in total_metrics.reduce().compute().items() }
def _eval_model_on_split(self, split: str, num_examples: int, global_batch_size: int, params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, rng: spec.RandomState, data_dir: str): """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: # These iterators repeat indefinitely. self._eval_iters[split] = self.build_input_queue( data_rng, split, data_dir, global_batch_size=global_batch_size) total_metrics = { 'accuracy': 0., 'loss': 0., } num_data = 0 num_batches = int(math.ceil(num_examples / global_batch_size)) for _ in range(num_batches): batch = next(self._eval_iters[split]) logits, _ = self.model_fn(params, batch, model_state, spec.ForwardPassMode.EVAL, model_rng, update_batch_norm=False) batch_metrics = self._eval_metric(logits, batch['targets']) total_metrics = { k: v + batch_metrics[k] for k, v in total_metrics.items() } num_data += batch_metrics['num_data'] return {k: float(v / num_data) for k, v in total_metrics.items()}
def eval_model(self, params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, rng: spec.RandomState, data_dir: str): """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) total_eval_batch_size = 8192 if self._eval_iterator is None: self._eval_iterator = self._build_iterator( data_rng, 'validation', data_dir, batch_size=total_eval_batch_size) # Note that this effectively stores the entire val dataset in memory. self._eval_iterator = itertools.cycle(self._eval_iterator) total_metrics = None # Both val and test have the same (prime) number of examples. num_val_examples = 43793 num_val_steps = num_val_examples // total_eval_batch_size + 1 # Loop over graph batches in eval dataset. for _ in range(num_val_steps): graphs, labels, masks = next(self._eval_iterator) batch_metrics = self._eval_batch(params, graphs, labels, masks, model_state, model_rng) total_metrics = (batch_metrics if total_metrics is None else total_metrics.merge(batch_metrics)) if total_metrics is None: return {} return { k: float(v) for k, v in total_metrics.reduce().compute().items() }
def score_submission_on_workload(workload: spec.Workload, workload_name: str, submission_path: str, data_dir: str, tuning_ruleset: str, tuning_search_space: Optional[str] = None, num_tuning_trials: Optional[int] = None): # Remove the trailing '.py' and convert the filepath to a Python module. submission_module_path = _convert_filepath_to_module(FLAGS.submission_path) submission_module = importlib.import_module(submission_module_path) init_optimizer_state = submission_module.init_optimizer_state update_params = submission_module.update_params data_selection = submission_module.data_selection get_batch_size = submission_module.get_batch_size batch_size = get_batch_size(workload_name) if tuning_ruleset == 'external': # If the submission runner is responsible for hyperparameter tuning, load in # the search space and generate a list of randomly selected hyperparameter # settings from it. if tuning_search_space is None: raise ValueError( 'Must provide a tuning search space JSON file when using external ' 'tuning.') with open(tuning_search_space, 'r') as search_space_file: tuning_search_space = halton.generate_search( json.load(search_space_file), num_tuning_trials) all_timings = [] all_metrics = [] for hi, hyperparameters in enumerate(tuning_search_space): # Generate a new seed from hardware sources of randomness for each trial. rng_seed = struct.unpack('I', os.urandom(4))[0] rng = prng.PRNGKey(rng_seed) # Because we initialize the PRNGKey with only a single 32 bit int, in the # Jax implementation this means that rng[0] is all zeros, which means this # could lead to unintentionally reusing the same seed of only rng[0] were # ever used. By splitting the rng into 2, we mix the lower and upper 32 # bit ints, ensuring we can safely use either rng[0] or rng[1] as a random # number. rng, _ = prng.split(rng, 2) logging.info(f'--- Tuning run {hi + 1}/{num_tuning_trials} ---') timing, metrics = train_once(workload, batch_size, data_dir, init_optimizer_state, update_params, data_selection, hyperparameters, rng) all_timings.append(timing) all_metrics.append(metrics) score = min(all_timings) for ti in range(num_tuning_trials): logging.info('Tuning trial %d/%d', ti + 1, num_tuning_trials) logging.info('Hyperparameters: %s', tuning_search_space[ti]) logging.info('Metrics: %s', all_metrics[ti]) logging.info('Timing: %s', all_timings[ti]) logging.info('=' * 20) else: rng_seed = struct.unpack('q', os.urandom(8))[0] rng = prng.PRNGKey(rng_seed) # If the submission is responsible for tuning itself, we only need to run it # once and return the total time. score, _ = train_once(workload, batch_size, init_optimizer_state, update_params, data_selection, None, rng) # TODO(znado): record and return other information (number of steps). return score
def train_once(workload: spec.Workload, batch_size: int, data_dir: str, init_optimizer_state: spec.InitOptimizerFn, update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, hyperparameters: Optional[spec.Hyperparamters], rng: spec.RandomState) -> Tuple[spec.Timing, spec.Steps]: data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) # Workload setup. logging.info('Initializing dataset.') input_queue = workload.build_input_queue(data_rng, 'train', data_dir=data_dir, batch_size=batch_size) logging.info('Initializing model.') model_params, model_state = workload.init_model_fn(model_init_rng) logging.info('Initializing optimizer.') optimizer_state = init_optimizer_state(workload, model_params, model_state, hyperparameters, opt_init_rng) # Bookkeeping. goal_reached = False is_time_remaining = True last_eval_time = 0 accumulated_submission_time = 0 eval_results = [] global_step = 0 training_complete = False global_start_time = time.time() logging.info('Starting training loop.') while (is_time_remaining and not goal_reached and not training_complete): step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) start_time = time.time() selected_train_input_batch, selected_train_label_batch = data_selection( workload, input_queue, optimizer_state, model_params, hyperparameters, global_step, data_select_rng) try: optimizer_state, model_params, model_state = update_params( workload=workload, current_param_container=model_params, current_params_types=workload.model_params_types(), model_state=model_state, hyperparameters=hyperparameters, input_batch=selected_train_input_batch, label_batch=selected_train_label_batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, eval_results=eval_results, global_step=global_step, rng=update_rng) except spec.TrainingCompleteError: training_complete = True global_step += 1 current_time = time.time() accumulated_submission_time += current_time - start_time is_time_remaining = (accumulated_submission_time < workload.max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. if (current_time - last_eval_time >= workload.eval_period_time_sec or training_complete): latest_eval_result = workload.eval_model(model_params, model_state, eval_rng, data_dir) logging.info( f'{current_time - global_start_time:.2f}s\t{global_step}' f'\t{latest_eval_result}') last_eval_time = current_time eval_results.append((global_step, latest_eval_result)) goal_reached = workload.has_reached_goal(latest_eval_result) metrics = {'eval_results': eval_results, 'global_step': global_step} return accumulated_submission_time, metrics
def _test_submission(workload_name, framework, submission_path, search_space_path, data_dir, use_fake_input_queue): FLAGS.framework = framework workload_metadata = copy.deepcopy( submission_runner.WORKLOADS[workload_name]) workload_metadata['workload_path'] = os.path.join( submission_runner.BASE_WORKLOADS_DIR, workload_metadata['workload_path'] + '_' + framework, 'workload.py') workload_class = submission_runner.import_workload( workload_path=workload_metadata['workload_path'], workload_class_name=workload_metadata['workload_class_name'], return_class=True) submission_module_path = submission_runner.convert_filepath_to_module( submission_path) submission_module = importlib.import_module(submission_module_path) init_optimizer_state = submission_module.init_optimizer_state update_params = submission_module.update_params data_selection = submission_module.data_selection get_batch_size = submission_module.get_batch_size global_batch_size = get_batch_size(workload_name) global_batch_size = 2 workload = _make_one_batch_workload(workload_class, workload_name, framework, global_batch_size, use_fake_input_queue) # Get a sample hyperparameter setting. with open(search_space_path, 'r', encoding='UTF-8') as search_space_file: hyperparameters = halton.generate_search(json.load(search_space_file), num_trials=1)[0] rng = prng.PRNGKey(0) data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) input_queue = workload.build_input_queue( data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size) model_params, model_state = workload.init_model_fn(model_init_rng) optimizer_state = init_optimizer_state(workload, model_params, model_state, hyperparameters, opt_init_rng) global_step = 0 data_select_rng, update_rng, eval_rng = prng.split(rng, 3) batch = data_selection(workload, input_queue, optimizer_state, model_params, hyperparameters, global_step, data_select_rng) _, model_params, model_state = update_params( workload=workload, current_param_container=model_params, current_params_types=workload.model_params_types, model_state=model_state, hyperparameters=hyperparameters, batch=batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, eval_results=[], global_step=global_step, rng=update_rng) eval_result = workload.eval_model(global_batch_size, model_params, model_state, eval_rng, data_dir) _ = workload.eval_model(global_batch_size, model_params, model_state, eval_rng, data_dir) return eval_result