예제 #1
0
def monitor_and_sample(config, work_dir):
    """Monitors `work_dir` for new checkpoints and run sampling on them.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    work_dir: Directory where the tensorboard summaries are written to.
  """
    # Init rng key.
    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)
    is_first_host = jax.process_index() == 0

    # TODO(agritsenko): We are loading the datasets just to get the metadata.
    #  Can we be smarter about this?
    if config.dataset.name.endswith('speech_commands09'):
        _, ds_metadata = input_pipeline_sc09.get_dataset(data_rng, config)
    else:
        raise ValueError(f'Unknown dataset {config.dataset.name}.')

    # TODO(agritsenko): Can we fix the ugly nested dicts?
    config.data_shape = ds_metadata['train']['shape']['inputs'][2:]
    config.num_classes = ds_metadata['train']['num_classes']
    config.sample_rate = ds_metadata['train']['sample_rate']

    writer = metric_writers.create_default_writer(
        work_dir, just_logging=jax.process_index() > 0)
    rng, init_rng = jax.random.split(rng)

    model, variables = model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())
    rng, rng_sample = jax.random.split(rng)

    def tx_fn(lr):
        return optax.adamw(lr,
                           b1=0.9,
                           b2=config.beta2,
                           eps=1e-08,
                           eps_root=0.0,
                           weight_decay=config.weight_decay)

    state = language_train_state.TrainState.create(params=variables['params'],
                                                   tx_fn=tx_fn)

    # Wait for checkpoints in an loop.
    ckpt_path_iterator = checkpoint.checkpoints_iterator(work_dir, target=None)

    with metric_writers.ensure_flushes(writer):
        for _ in ckpt_path_iterator:
            state, step = checkpoint.restore_from_path(work_dir, state)
            is_last_step = step == config.num_train_steps - 1
            logging.info('Loaded checkpoint for step: %d', step)

            # Replicate the state
            state = flax.jax_utils.replicate(state)

            ######################### Run sampling ###############################
            chain = model.sample(jax.random.fold_in(rng_sample, step),
                                 state.ema_params,
                                 config.sample_batch_size,
                                 chain_out_size=config.get(
                                     'chain_out_size', model.num_stages))

            if is_first_host:
                chain = jax.device_get(chain)
                long_sample = np.reshape(chain[-1],
                                         (1, -1, 1)).astype(np.float32)
                long_sample = (2. * long_sample) / config.num_classes - 1.
                long_sample = long_sample.astype(np.float32)
                writer.write_audios(step, {'samples': long_sample},
                                    sample_rate=config.sample_rate)

            if is_last_step:
                break
예제 #2
0
def train_and_evaluate(config, work_dir, try_checkpoint=True):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    work_dir: Directory where the tensorboard summaries are written to.
    try_checkpoint: Should try to load checkpoint (usually enabled, practical
        for debugging purposes to disable).

  Returns:
    The train state (which includes the `.params`).
  """
    # Init rng key.
    msg = f'Running with seed {config.seed}.'
    logging.info(msg)
    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)
    is_first_host = jax.process_index() == 0

    train_ds, test_ds, shape, num_classes = datasets.get_dataset(
        config, data_rng)

    # config.mask_shape = mask_shape
    config.data_shape = shape
    config.num_classes = num_classes

    writer = metric_writers.create_default_writer(
        work_dir, just_logging=jax.process_index() > 0)
    rng, init_rng = jax.random.split(rng)

    # Create output directory for saving samples.
    output_path = work_dir
    tf.io.gfile.makedirs(output_path)

    model, variables = model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())

    tx = optax.adam(config.learning_rate,
                    b1=0.9,
                    b2=config.beta2,
                    eps=1e-08,
                    eps_root=0.0)
    state = custom_train_state.TrainState.create(params=variables['params'],
                                                 tx=tx)

    if try_checkpoint:
        state, start_epoch = checkpoint.restore_from_path(work_dir, state)
        if start_epoch is None:
            start_epoch = 1
    else:
        # For debugging we start at zero, so we immediately do detailed eval.
        start_epoch = 0

    if is_first_host and start_epoch == 1:
        config_dict = dict(config)
        writer.write_hparams(config_dict)

    if is_first_host and start_epoch in (0, 1):
        # Dump config file to work dir for easy model loading.
        config_path = os.path.join(work_dir, 'config')
        with tf.io.gfile.GFile(config_path, 'wb') as fp:
            pickle.dump(config, fp)

    test_rng, train_rng = jax.random.split(rng)

    kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps)
    kl_history = []

    p_train_step = jax.pmap(functools.partial(train_step,
                                              model=model,
                                              config=config),
                            axis_name='batch',
                            in_axes=(None, 0, 0),
                            out_axes=(0, 0, None),
                            donate_argnums=(2, ))

    # The only axes that are broadcasted are the in- and output rng key ones. The
    # rng is the first arg, and the last return value.
    p_eval_step = jax.pmap(functools.partial(eval_step, model=model),
                           axis_name='batch',
                           in_axes=(None, 0, 0),
                           out_axes=(0, None))

    # Replicate state.
    state = flax.jax_utils.replicate(state)

    with metric_writers.ensure_flushes(writer):
        for epoch in range(start_epoch, config.num_epochs + 1):
            # Train part.
            state, train_metrics, train_rng = train_epoch(
                p_train_step, state, train_ds, config.batch_size, epoch,
                train_rng, kl_tracker_train)

            # Val part.
            eval_metrics, test_rng = eval_model(p_eval_step, test_rng, state,
                                                test_ds, epoch)

            # Metric logging.
            if is_first_host:
                log_standard_metrics(writer, train_metrics, eval_metrics,
                                     epoch)

            kl_values = kl_tracker_train.get_kl_per_t()
            kl_history.append(np.array(kl_values))

            # Prune to avoid too much memory consumption.
            kl_history = kl_history[-50:]

            if epoch == 15 or epoch % config.detailed_eval_every == 0:
                if is_first_host:
                    loss_components_path = os.path.join(
                        work_dir, 'loss_components')
                    with tf.io.gfile.GFile(loss_components_path, 'wb') as fp:
                        pickle.dump(kl_history[-1], fp)

                test_rng = extensive_eval(config, test_rng, writer,
                                          output_path, model, state,
                                          kl_history, test_ds, epoch)

            # Save to checkpoint.
            if is_first_host and epoch % config.save_every == 0:
                # Save to epoch + 1 since current epoch has just been completed.
                logging.info('saving checkpoint')
                checkpoint.save_checkpoint(
                    work_dir,
                    state=flax.jax_utils.unreplicate(state),
                    step=epoch + 1,
                    keep=2)
                logging.info('finished saving checkpoint')

        return state
예제 #3
0
def train_and_evaluate(config, work_dir, try_checkpoint=True):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    work_dir: Directory where the tensorboard summaries are written to.
    try_checkpoint: Should try to load checkpoint (usually enabled, practical
        for debugging purposes to disable).

  Returns:
    The train state (which includes the `.params`).
  """
    # Init rng key.
    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)
    is_first_host = jax.process_index() == 0

    if config.dataset.name.endswith('speech_commands09'):
        ds, ds_metadata = input_pipeline_sc09.get_dataset(data_rng, config)
    else:
        raise ValueError(f'Unknown dataset {config.dataset.name}.')

    # Immediately create infinite iterators.
    it = jax.tree_map(util_fns.get_iterator, ds)

    # TODO(agritsenko): Can we fix the ugly nested dicts?
    config.data_shape = ds_metadata['train']['shape']['inputs'][2:]
    config.num_classes = ds_metadata['train']['num_classes']
    config.sample_rate = ds_metadata['train']['sample_rate']

    writer = metric_writers.create_default_writer(
        work_dir, just_logging=jax.process_index() > 0)
    rng, init_rng = jax.random.split(rng)

    model, variables = model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())

    def tx_fn(lr):
        return optax.adamw(lr,
                           b1=0.9,
                           b2=config.beta2,
                           eps=1e-08,
                           eps_root=0.0,
                           weight_decay=config.weight_decay)

    state = language_train_state.TrainState.create(params=variables['params'],
                                                   tx_fn=tx_fn)

    start_step = None
    if try_checkpoint:
        state, start_step = checkpoint.restore_from_path(work_dir, state)
    start_step = start_step or 0

    # Use different rngs for train & eval.
    rng_train, rng_eval, rng_sample = jax.random.split(rng, 3)

    kl_tracker = util_fns.KLTracker(num_steps=model.num_steps)
    kl_history = []

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        **config.learning_rate)
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=config,
        learning_rate_fn=learning_rate_fn,
        model=model),
                            axis_name='batch',
                            in_axes=(None, 0, 0),
                            out_axes=(0, 0, None),
                            donate_argnums=(2, ))

    # The only axes that are broadcasted are the in- and output rng key ones. The
    # rng is the first arg, and the last return value.
    p_eval_step = jax.pmap(functools.partial(eval_step, model=model),
                           axis_name='batch',
                           in_axes=(None, 0, 0),
                           out_axes=(0, 0, None))

    # Training length.
    logging.info('Training will start from step %d', start_step)

    # Replicate state.
    state = flax.jax_utils.replicate(state)

    # Setup hooks.
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if is_first_host:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=work_dir, num_profile_steps=5)
        ]

    with metric_writers.ensure_flushes(writer):
        batch_metrics = []
        for step in range(start_step, config.num_train_steps):
            logging.log_first_n(logging.INFO, f'Train step: {step}', 5)
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                state, metrics, rng_train = p_train_step(
                    rng_train, next(it['train']), state)
            batch_metrics.append(metrics)

            # Cycle though hooks.
            for h in hooks:
                h(step)

            is_last_step = step == config.num_train_steps - 1

            if (step % config.log_every_steps == 0) or is_last_step:
                with report_progress.timed('training_metrics'):
                    ################### Process batch metrics ############################
                    batch_metrics = jax.device_get(
                        flax.jax_utils.unreplicate(batch_metrics))

                    if 't_batch' in metrics:
                        # TODO(agritsenko): Factor out into a separate function.
                        # This processes the loss per t, although two nested for-loops
                        # (counting the one inside kl_tracker), it actually does not hurt
                        # timing performance meaningfully.
                        batch_t = [
                            metrics['t_batch'].reshape(-1)
                            for metrics in batch_metrics
                        ]
                        batch_nelbo_per_t = [
                            metrics['nelbo_per_t_batch'].reshape(-1)
                            for metrics in batch_metrics
                        ]
                        for t, nelbo_per_t in zip(batch_t, batch_nelbo_per_t):
                            kl_tracker.update(t, nelbo_per_t)

                    ################### Process batch metrics ############################
                    metrics = {
                        key:
                        np.mean([metrics[key] for metrics in batch_metrics])
                        for key in batch_metrics[0] if 'batch' not in key
                    }

                    # Metric logging.
                    if is_first_host:
                        log_standard_metrics(writer,
                                             step,
                                             train_metrics=metrics)
                    batch_metrics = []

            if config.eval_every_steps and (
                (step % config.eval_every_steps == 0) or is_last_step):
                with report_progress.timed('eval'):
                    ####################### Run evaluation ###############################
                    metrics, rng_eval = eval_model(
                        p_eval_step, rng_eval, state, it['eval'],
                        (ds_metadata['eval']['num_batches'] *
                         config.get('num_eval_passes', 1)))

                    # Metric logging.
                    if is_first_host:
                        log_standard_metrics(writer,
                                             step,
                                             eval_metrics=metrics)

                # Track KL (unrelated to the eval, but nice to not do every step).
                kl_values = kl_tracker.get_kl_per_t()
                kl_history.append(np.array(kl_values))
                kl_history = kl_history[-50:]

            if config.sample_every_steps and (
                (step % config.sample_every_steps == 0) or is_last_step):
                with report_progress.timed('sample'):
                    ######################### Run sampling ###############################
                    chain = model.sample(jax.random.fold_in(rng_sample, step),
                                         state.ema_params,
                                         config.sample_batch_size,
                                         chain_out_size=config.get(
                                             'chain_out_size',
                                             model.num_stages))

                    if is_first_host:
                        chain = jax.device_get(chain)
                        long_sample = np.reshape(chain[-1],
                                                 (1, -1, 1)).astype(np.float32)
                        long_sample = (2. *
                                       long_sample) / config.num_classes - 1.
                        writer.write_audios(step, {'samples': long_sample},
                                            sample_rate=config.sample_rate)

            ######################### Checkpointing #################################
            if is_first_host and config.checkpoint_every_steps and (
                (step % config.checkpoint_every_steps == 0) or is_last_step):
                logging.info('Saving checkpoint: step %d', step)
                with report_progress.timed('checkpoint'):
                    checkpoint.save_checkpoint(
                        work_dir,
                        state=flax.jax_utils.unreplicate(state),
                        step=step)
                logging.info('Finished saving checkpoint: step %d', step)

        return state
def evaluate_compression(work_dir, budget=50):
    """Execute model training and evaluation loop.

  Args:
    work_dir: Directory where the saved files are located.
    budget: Budget for the policy.
  """
    # Loading config file.
    config_path = os.path.join(work_dir, 'config')
    with tf.io.gfile.GFile(config_path, 'rb') as fp:
        config = pickle.load(fp)
        logging.info('Loaded config')

    # Loading loss components
    with tf.io.gfile.GFile(os.path.join(work_dir, 'loss_components'),
                           'rb') as fp:
        loss_components = pickle.load(fp)
        logging.info('Loaded loss components')

    config.test_batch_size = 80
    config.num_eval_passes = 1

    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)

    train_ds, test_ds, shape, num_classes = datasets.get_dataset(
        config, data_rng)

    config.data_shape = shape
    config.num_classes = num_classes

    rng, init_rng = jax.random.split(rng)

    model, variables = train.model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())

    tx = optax.adam(config.learning_rate,
                    b1=0.9,
                    b2=config.beta2,
                    eps=1e-08,
                    eps_root=0.0)
    state = custom_train_state.TrainState.create(params=variables['params'],
                                                 tx=tx)

    state, start_epoch = checkpoint.restore_from_path(work_dir, state)

    logging.info('Loaded checkpoint at epoch %d', start_epoch)

    test_rng, train_rng = jax.random.split(rng)
    del test_rng

    # Replicate state.
    state = flax.jax_utils.replicate(state)

    # Find optimal policy.
    policies, costs = model.compute_policies_and_costs(loss_components,
                                                       budgets=[budget])
    policy, expected_cost = policies[0], costs[0]
    logging.info('Using policy %s\n with expected cost %.2f', str(policy),
                 expected_cost)

    # Find optimal sigma given policy using train data.
    sigma, _ = search_sigma_given_policy(policy, train_rng, state, model,
                                         train_ds)

    # Compress the test data.
    compress_dataset(state, model, test_ds, sigma=sigma, policy=policy)