Exemplo n.º 1
0
    def _eval_model_on_split(self, split: str, num_examples: int,
                             global_batch_size: int,
                             params: spec.ParameterContainer,
                             model_state: spec.ModelAuxiliaryState,
                             rng: spec.RandomState,
                             data_dir: str) -> Dict[str, float]:
        """Run a full evaluation of the model."""
        data_rng, model_rng = prng.split(rng, 2)
        if split not in self._eval_iters:
            eval_iter = self.build_input_queue(
                data_rng, split, data_dir, global_batch_size=global_batch_size)
            # Note that this stores the entire eval dataset in memory.
            self._eval_iters[split] = itertools.cycle(eval_iter)

        total_metrics = {
            'accuracy': 0.,
            'loss': 0.,
        }
        num_data = 0
        num_batches = int(math.ceil(num_examples / global_batch_size))
        for bi, batch in enumerate(self._eval_iters[split]):
            if bi > num_batches:
                break
            per_device_model_rngs = prng.split(model_rng,
                                               jax.local_device_count())
            batch_metrics = self._eval_model(params, batch, model_state,
                                             per_device_model_rngs)
            total_metrics = {
                k: v + batch_metrics[k]
                for k, v in total_metrics.items()
            }
            num_data += batch_metrics['num_data']
        return {k: float(v / num_data) for k, v in total_metrics.items()}
Exemplo n.º 2
0
    def eval_model(self, params: spec.ParameterContainer,
                   model_state: spec.ModelAuxiliaryState,
                   rng: spec.RandomState, data_dir: str):
        """Run a full evaluation of the model."""
        data_rng, model_rng = prng.split(rng, 2)
        eval_batch_size = 128
        if self._eval_ds is None:
            self._eval_ds = self._build_dataset(data_rng,
                                                'test',
                                                data_dir,
                                                batch_size=eval_batch_size)

        total_metrics = {
            'accuracy': 0.,
            'loss': 0.,
        }
        n_data = 0
        for (images, labels) in self._eval_ds:
            images = images.float().to(DEVICE)
            labels = labels.float().to(DEVICE)
            logits, _ = self.model_fn(params,
                                      images,
                                      model_state,
                                      spec.ForwardPassMode.EVAL,
                                      model_rng,
                                      update_batch_norm=False)
            batch_metrics = self._eval_metric(logits, labels)
            total_metrics = {
                k: v + batch_metrics[k]
                for k, v in total_metrics.items()
            }
            n_data += batch_metrics['n_data']
        return {k: float(v / n_data) for k, v in total_metrics.items()}
    def eval_model(self, params: spec.ParameterContainer,
                   model_state: spec.ModelAuxiliaryState,
                   rng: spec.RandomState, data_dir: str):
        """Run a full evaluation of the model."""
        data_rng, model_rng = prng.split(rng, 2)
        eval_batch_size = 2000
        self._eval_ds = self.build_input_queue(data_rng,
                                               'test',
                                               data_dir,
                                               batch_size=eval_batch_size)

        total_metrics = {
            'accuracy': 0.,
            'loss': 0.,
        }
        n_data = 0
        for (images, labels) in self._eval_ds:
            images, labels = self.preprocess_for_eval(images, labels, None,
                                                      None)
            logits, _ = self.model_fn(params,
                                      images,
                                      model_state,
                                      spec.ForwardPassMode.EVAL,
                                      model_rng,
                                      update_batch_norm=False)
            # TODO(znado): add additional eval metrics?
            batch_metrics = self._eval_metric(logits, labels)
            total_metrics = {
                k: v + batch_metrics[k]
                for k, v in total_metrics.items()
            }
            n_data += batch_metrics['n_data']
        return {k: float(v / n_data) for k, v in total_metrics.items()}
Exemplo n.º 4
0
    def _eval_model_on_split(self, split: str, num_examples: int,
                             global_batch_size: int,
                             params: spec.ParameterContainer,
                             model_state: spec.ModelAuxiliaryState,
                             rng: spec.RandomState,
                             data_dir: str) -> Dict[str, float]:
        """Run a full evaluation of the model."""
        data_rng, model_rng = prng.split(rng, 2)
        if split not in self._eval_iters:
            eval_iter = self.build_input_queue(
                data_rng, split, data_dir, global_batch_size=global_batch_size)
            # Note that this stores the entire val dataset in memory.
            self._eval_iters[split] = itertools.cycle(eval_iter)

        total_metrics = None
        num_eval_steps = int(math.ceil(
            float(num_examples) / global_batch_size))
        # Loop over graph batches in eval dataset.
        for _ in range(num_eval_steps):
            batch = next(self._eval_iters[split])
            batch_metrics = self._eval_batch(params, batch, model_state,
                                             model_rng)
            total_metrics = (batch_metrics if total_metrics is None else
                             total_metrics.merge(batch_metrics))
        if total_metrics is None:
            return {}
        return {
            k: float(v)
            for k, v in total_metrics.reduce().compute().items()
        }
Exemplo n.º 5
0
    def _eval_model_on_split(self, split: str, num_examples: int,
                             global_batch_size: int,
                             params: spec.ParameterContainer,
                             model_state: spec.ModelAuxiliaryState,
                             rng: spec.RandomState, data_dir: str):
        """Run a full evaluation of the model."""
        data_rng, model_rng = prng.split(rng, 2)
        if split not in self._eval_iters:
            # These iterators repeat indefinitely.
            self._eval_iters[split] = self.build_input_queue(
                data_rng, split, data_dir, global_batch_size=global_batch_size)

        total_metrics = {
            'accuracy': 0.,
            'loss': 0.,
        }
        num_data = 0
        num_batches = int(math.ceil(num_examples / global_batch_size))
        for _ in range(num_batches):
            batch = next(self._eval_iters[split])
            logits, _ = self.model_fn(params,
                                      batch,
                                      model_state,
                                      spec.ForwardPassMode.EVAL,
                                      model_rng,
                                      update_batch_norm=False)
            batch_metrics = self._eval_metric(logits, batch['targets'])
            total_metrics = {
                k: v + batch_metrics[k]
                for k, v in total_metrics.items()
            }
            num_data += batch_metrics['num_data']
        return {k: float(v / num_data) for k, v in total_metrics.items()}
    def eval_model(self, params: spec.ParameterContainer,
                   model_state: spec.ModelAuxiliaryState,
                   rng: spec.RandomState, data_dir: str):
        """Run a full evaluation of the model."""
        data_rng, model_rng = prng.split(rng, 2)
        total_eval_batch_size = 8192
        if self._eval_iterator is None:
            self._eval_iterator = self._build_iterator(
                data_rng,
                'validation',
                data_dir,
                batch_size=total_eval_batch_size)
            # Note that this effectively stores the entire val dataset in memory.
            self._eval_iterator = itertools.cycle(self._eval_iterator)

        total_metrics = None
        # Both val and test have the same (prime) number of examples.
        num_val_examples = 43793
        num_val_steps = num_val_examples // total_eval_batch_size + 1
        # Loop over graph batches in eval dataset.
        for _ in range(num_val_steps):
            graphs, labels, masks = next(self._eval_iterator)
            batch_metrics = self._eval_batch(params, graphs, labels, masks,
                                             model_state, model_rng)
            total_metrics = (batch_metrics if total_metrics is None else
                             total_metrics.merge(batch_metrics))
        if total_metrics is None:
            return {}
        return {
            k: float(v)
            for k, v in total_metrics.reduce().compute().items()
        }
Exemplo n.º 7
0
def score_submission_on_workload(workload: spec.Workload,
                                 workload_name: str,
                                 submission_path: str,
                                 data_dir: str,
                                 tuning_ruleset: str,
                                 tuning_search_space: Optional[str] = None,
                                 num_tuning_trials: Optional[int] = None):
    # Remove the trailing '.py' and convert the filepath to a Python module.
    submission_module_path = _convert_filepath_to_module(FLAGS.submission_path)
    submission_module = importlib.import_module(submission_module_path)

    init_optimizer_state = submission_module.init_optimizer_state
    update_params = submission_module.update_params
    data_selection = submission_module.data_selection
    get_batch_size = submission_module.get_batch_size
    batch_size = get_batch_size(workload_name)

    if tuning_ruleset == 'external':
        # If the submission runner is responsible for hyperparameter tuning, load in
        # the search space and generate a list of randomly selected hyperparameter
        # settings from it.
        if tuning_search_space is None:
            raise ValueError(
                'Must provide a tuning search space JSON file when using external '
                'tuning.')
        with open(tuning_search_space, 'r') as search_space_file:
            tuning_search_space = halton.generate_search(
                json.load(search_space_file), num_tuning_trials)
        all_timings = []
        all_metrics = []
        for hi, hyperparameters in enumerate(tuning_search_space):
            # Generate a new seed from hardware sources of randomness for each trial.
            rng_seed = struct.unpack('I', os.urandom(4))[0]
            rng = prng.PRNGKey(rng_seed)
            # Because we initialize the PRNGKey with only a single 32 bit int, in the
            # Jax implementation this means that rng[0] is all zeros, which means this
            # could lead to unintentionally reusing the same seed of only rng[0] were
            # ever used. By splitting the rng into 2, we mix the lower and upper 32
            # bit ints, ensuring we can safely use either rng[0] or rng[1] as a random
            # number.
            rng, _ = prng.split(rng, 2)
            logging.info(f'--- Tuning run {hi + 1}/{num_tuning_trials} ---')
            timing, metrics = train_once(workload, batch_size, data_dir,
                                         init_optimizer_state, update_params,
                                         data_selection, hyperparameters, rng)
            all_timings.append(timing)
            all_metrics.append(metrics)
        score = min(all_timings)
        for ti in range(num_tuning_trials):
            logging.info('Tuning trial %d/%d', ti + 1, num_tuning_trials)
            logging.info('Hyperparameters: %s', tuning_search_space[ti])
            logging.info('Metrics: %s', all_metrics[ti])
            logging.info('Timing: %s', all_timings[ti])
            logging.info('=' * 20)
    else:
        rng_seed = struct.unpack('q', os.urandom(8))[0]
        rng = prng.PRNGKey(rng_seed)
        # If the submission is responsible for tuning itself, we only need to run it
        # once and return the total time.
        score, _ = train_once(workload, batch_size, init_optimizer_state,
                              update_params, data_selection, None, rng)
    # TODO(znado): record and return other information (number of steps).
    return score
Exemplo n.º 8
0
def train_once(workload: spec.Workload, batch_size: int, data_dir: str,
               init_optimizer_state: spec.InitOptimizerFn,
               update_params: spec.UpdateParamsFn,
               data_selection: spec.DataSelectionFn,
               hyperparameters: Optional[spec.Hyperparamters],
               rng: spec.RandomState) -> Tuple[spec.Timing, spec.Steps]:
    data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)

    # Workload setup.
    logging.info('Initializing dataset.')
    input_queue = workload.build_input_queue(data_rng,
                                             'train',
                                             data_dir=data_dir,
                                             batch_size=batch_size)
    logging.info('Initializing model.')
    model_params, model_state = workload.init_model_fn(model_init_rng)
    logging.info('Initializing optimizer.')
    optimizer_state = init_optimizer_state(workload, model_params, model_state,
                                           hyperparameters, opt_init_rng)

    # Bookkeeping.
    goal_reached = False
    is_time_remaining = True
    last_eval_time = 0
    accumulated_submission_time = 0
    eval_results = []
    global_step = 0
    training_complete = False
    global_start_time = time.time()

    logging.info('Starting training loop.')
    while (is_time_remaining and not goal_reached and not training_complete):
        step_rng = prng.fold_in(rng, global_step)
        data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3)
        start_time = time.time()
        selected_train_input_batch, selected_train_label_batch = data_selection(
            workload, input_queue, optimizer_state, model_params,
            hyperparameters, global_step, data_select_rng)
        try:
            optimizer_state, model_params, model_state = update_params(
                workload=workload,
                current_param_container=model_params,
                current_params_types=workload.model_params_types(),
                model_state=model_state,
                hyperparameters=hyperparameters,
                input_batch=selected_train_input_batch,
                label_batch=selected_train_label_batch,
                loss_type=workload.loss_type,
                optimizer_state=optimizer_state,
                eval_results=eval_results,
                global_step=global_step,
                rng=update_rng)
        except spec.TrainingCompleteError:
            training_complete = True
        global_step += 1
        current_time = time.time()
        accumulated_submission_time += current_time - start_time
        is_time_remaining = (accumulated_submission_time <
                             workload.max_allowed_runtime_sec)
        # Check if submission is eligible for an untimed eval.
        if (current_time - last_eval_time >= workload.eval_period_time_sec
                or training_complete):
            latest_eval_result = workload.eval_model(model_params, model_state,
                                                     eval_rng, data_dir)
            logging.info(
                f'{current_time - global_start_time:.2f}s\t{global_step}'
                f'\t{latest_eval_result}')
            last_eval_time = current_time
            eval_results.append((global_step, latest_eval_result))
            goal_reached = workload.has_reached_goal(latest_eval_result)
    metrics = {'eval_results': eval_results, 'global_step': global_step}
    return accumulated_submission_time, metrics
Exemplo n.º 9
0
def _test_submission(workload_name, framework, submission_path,
                     search_space_path, data_dir, use_fake_input_queue):
    FLAGS.framework = framework
    workload_metadata = copy.deepcopy(
        submission_runner.WORKLOADS[workload_name])
    workload_metadata['workload_path'] = os.path.join(
        submission_runner.BASE_WORKLOADS_DIR,
        workload_metadata['workload_path'] + '_' + framework, 'workload.py')
    workload_class = submission_runner.import_workload(
        workload_path=workload_metadata['workload_path'],
        workload_class_name=workload_metadata['workload_class_name'],
        return_class=True)

    submission_module_path = submission_runner.convert_filepath_to_module(
        submission_path)
    submission_module = importlib.import_module(submission_module_path)

    init_optimizer_state = submission_module.init_optimizer_state
    update_params = submission_module.update_params
    data_selection = submission_module.data_selection
    get_batch_size = submission_module.get_batch_size
    global_batch_size = get_batch_size(workload_name)
    global_batch_size = 2
    workload = _make_one_batch_workload(workload_class, workload_name,
                                        framework, global_batch_size,
                                        use_fake_input_queue)

    # Get a sample hyperparameter setting.
    with open(search_space_path, 'r', encoding='UTF-8') as search_space_file:
        hyperparameters = halton.generate_search(json.load(search_space_file),
                                                 num_trials=1)[0]

    rng = prng.PRNGKey(0)
    data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)
    input_queue = workload.build_input_queue(
        data_rng,
        'train',
        data_dir=data_dir,
        global_batch_size=global_batch_size)
    model_params, model_state = workload.init_model_fn(model_init_rng)
    optimizer_state = init_optimizer_state(workload, model_params, model_state,
                                           hyperparameters, opt_init_rng)

    global_step = 0
    data_select_rng, update_rng, eval_rng = prng.split(rng, 3)
    batch = data_selection(workload, input_queue, optimizer_state,
                           model_params, hyperparameters, global_step,
                           data_select_rng)
    _, model_params, model_state = update_params(
        workload=workload,
        current_param_container=model_params,
        current_params_types=workload.model_params_types,
        model_state=model_state,
        hyperparameters=hyperparameters,
        batch=batch,
        loss_type=workload.loss_type,
        optimizer_state=optimizer_state,
        eval_results=[],
        global_step=global_step,
        rng=update_rng)
    eval_result = workload.eval_model(global_batch_size, model_params,
                                      model_state, eval_rng, data_dir)
    _ = workload.eval_model(global_batch_size, model_params, model_state,
                            eval_rng, data_dir)
    return eval_result