Пример #1
0
def train_mem(ctx, path, vae_path):
    """Train MDN-RNN model as specified in .json config with data at `PATH`."""

    from third_party.torchtrainer import EarlyStopping, LambdaCallback, ModelCheckpoint, CSVLogger, RandomBatchSampler, evaluate
    from torch.utils.data import DataLoader
    config = obtain_config(ctx)

    env = hrl.create_gym(config.general['game_name'])

    # Create checkpoint directory, if it doesn't exist
    create_directory(os.path.dirname(config.rnn['ckpt_path']))

    # Create training DataLoader
    dataset = MDNDataset(path, config.rnn['sequence_len'],
                         config.rnn['terminal_prob'],
                         config.rnn['dataset_fraction'])
    data_loader = DataLoader(dataset,
                             batch_sampler=RandomBatchSampler(
                                 dataset, config.rnn['batch_size']),
                             pin_memory=True)

    # Build model
    rnn = build_rnn_model(config.rnn, config.vae['latent_space_dim'],
                          env.action_space)

    # Create callbacks
    callbacks = [
        EarlyStopping(metric='loss',
                      patience=config.rnn['patience'],
                      verbose=1),
        LambdaCallback(on_batch_begin=lambda _, batch_size: rnn.model.
                       init_hidden(batch_size)),
        ModelCheckpoint(config.rnn['ckpt_path'], metric='loss',
                        save_best=True),
        CSVLogger(
            filename=os.path.join(config.rnn['logs_dir'], 'train_mem.csv'))
    ]

    # Evaluate and visualize memory progress if render allowed
    if config.allow_render:
        if vae_path is None:
            raise ValueError("To render provide valid path to VAE checkpoint!")

        # Build VAE model and load checkpoint
        _, _, decoder = build_vae_model(config.vae,
                                        config.general['state_shape'],
                                        vae_path)

        callbacks += [
            MemoryVisualization(config, decoder, rnn.model, dataset,
                                'mdn_plots')
        ]

    # Fit MDN-RNN model!
    rnn.fit_loader(data_loader,
                   epochs=config.rnn['epochs'],
                   callbacks=callbacks)

    dataset.close()
Пример #2
0
def train_ctrl(ctx, vae_path, mdn_path):
    """Plays chosen game and trains Controller on preprocessed states with VAE and MDN-RNN
    (loaded from `vae_path` or `mdn_path`)."""

    # We will spawn multiple workers, we don't want them to access GPU
    config = obtain_config(ctx, use_gpu=False)

    # Book keeping variables
    best_return = float('-inf')

    # Gen number of workers to run
    processes = config.es['processes']
    processes = processes if processes > 0 else None

    # Get action space size
    env = hrl.create_gym(config.general['game_name'])
    action_space = env.action_space
    del env

    input_dim = config.vae['latent_space_dim'] + config.rnn['hidden_units']
    out_dim = action_space.num
    n_params = (input_dim + 1) * out_dim
    # Build CMA-ES solver
    solver = build_es_model(config.es, n_params=n_params)
    best_return = solver.best_score

    # Train for N epochs
    pbar = tqdm(range(config.es['epochs']), ascii=True)
    pbar.set_postfix(current=best_return)
    for _ in pbar:
        # Get new population
        population = solver.ask()

        # Evaluate population in parallel
        hists = hrl.pool(Evaluator(
            config,
            config.vae['latent_space_dim'] + config.rnn['hidden_units'],
            action_space, vae_path, mdn_path),
                         jobs=population,
                         processes=processes,
                         n_episodes=config.es['n_episodes'],
                         render_mode=config.allow_render,
                         verbose=0)
        returns = [np.mean(hist['return']) for hist in hists]

        # Print logs and update best return
        pbar.set_postfix(best=best_return, current=max(returns))
        best_return = max(best_return, max(returns))

        # Update solver
        solver.tell(returns)

        # Save solver in given path
        solver.save_es_ckpt_and_mind_weights(config.es['ckpt_path'],
                                             config.es['mind_path'],
                                             score=best_return)
Пример #3
0
def convert_data(ctx, path_in, path_out, vae_path):
    """Use transitions from record_data and preprocess states for Memory training
    using a trained VAE model. Data is loaded from `PATH_IN` and saved to `PATH_OUT`"""

    config = obtain_config(ctx)

    # Build VAE model
    _, encoder, _ = build_vae_model(config.vae, config.general['state_shape'],
                                    vae_path)

    convert_data_with_vae(encoder, path_in, path_out,
                          config.vae['latent_space_dim'])
Пример #4
0
def record_data(ctx, path, n_games, chunk_size, state_dtype):
    """Plays chosen game randomly and records transitions to hdf5 file in `PATH`."""

    config = obtain_config(ctx)

    # Create Gym environment, random agent and store to hdf5 callback
    env = hrl.create_gym(config.general['game_name'])
    mind, agent_callbacks = create_generating_agent(
        config.general['generating_agent'], env)
    store_callback = StoreTransitions(path,
                                      config.general['state_shape'],
                                      env.action_space,
                                      chunk_size=chunk_size,
                                      state_dtype=state_dtype,
                                      reward_dtype=np.float32)
    callbacks = agent_callbacks + [store_callback]

    if store_callback.game_count >= n_games:
        log.warning(
            "Data is already fully present in dataset you specified! If you wish to create"
            " a new dataset, please remove the one under this path or specify a different"
            " path. If you wish to gather more data, increase the number of games to "
            " record with --n-games parameter.")
        return
    elif 0 < store_callback.game_count < n_games:
        diff = n_games - store_callback.game_count
        log.info(
            "{}/{} games were already recorded in specified dataset. {} more game will be"
            " added!".format(store_callback.game_count, n_games, diff))
        n_games = diff

    # Resizes states to `state_shape` with cropping
    interpreter = BasicInterpreter(state_shape=config.general['state_shape'],
                                   crop_range=config.general['crop_range'],
                                   scale=255)

    # Play `N` random games and gather data as it goes
    hrl.loop(env,
             mind,
             interpreter,
             n_episodes=n_games,
             verbose=1,
             callbacks=callbacks,
             render_mode=config.allow_render)
Пример #5
0
def eval(ctx, controller_path, vae_path, mdn_path, n_games):
    """Plays chosen game testing whole pipeline: VAE -> MDN-RNN -> Controller
    (loaded from `vae_path`, `mdn_path` and `controller_path`)."""

    config = obtain_config(ctx)

    # Get action space size
    env = hrl.create_gym(config.general['game_name'])

    # Create VAE + MDN-RNN interpreter
    _, encoder, _ = build_vae_model(config.vae, config.general['state_shape'],
                                    vae_path)

    rnn = build_rnn_model(config.rnn, config.vae['latent_space_dim'],
                          env.action_space, mdn_path)

    basic_interpreter = BasicInterpreter(
        state_shape=config.general['state_shape'],
        crop_range=config.general['crop_range'])
    mdn_interpreter = MDNInterpreter(encoder, rnn.model,
                                     config.vae['latent_space_dim'])

    # Build CMA-ES solver and linear model
    mind = build_mind(
        config.es, config.vae['latent_space_dim'] + config.rnn['hidden_units'],
        env.action_space, controller_path)

    hist = hrl.loop(env,
                    mind,
                    ChainInterpreter(basic_interpreter, mdn_interpreter),
                    n_episodes=n_games,
                    render_mode=config.allow_render,
                    verbose=1,
                    callbacks=[ReturnTracker(), mdn_interpreter])

    print("Returns:", *hist['return'])
    print("Avg. return:", np.mean(hist['return']))
Пример #6
0
def train_vae(ctx, path):
    """Train VAE model as specified in .json config with data at `PATH`."""

    from keras.callbacks import EarlyStopping, LambdaCallback, ModelCheckpoint, CSVLogger
    config = obtain_config(ctx)

    # Get dataset length and eight examples to evaluate VAE on
    with h5.File(path, 'r') as hfile:
        n_transitions = hfile.attrs['N_TRANSITIONS']
        X_eval = hfile['states'][:8] / 255.

    # Get training data
    train_gen = HDF5DataGenerator(path,
                                  'states',
                                  'states',
                                  batch_size=config.vae['batch_size'],
                                  end=int(n_transitions * 0.8),
                                  preprocess_fn=lambda X, y:
                                  (X / 255., y / 255.))
    val_gen = HDF5DataGenerator(path,
                                'states',
                                'states',
                                batch_size=config.vae['batch_size'],
                                start=int(n_transitions * 0.8),
                                preprocess_fn=lambda X, y:
                                (X / 255., y / 255.))

    # Build VAE model
    vae, _, _ = build_vae_model(config.vae, config.general['state_shape'])

    # If render features enabled...
    if config.allow_render:
        # ...plot first eight training examples with VAE reconstructions
        # at the beginning of every epoch
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.gridspec as gridspec
        import matplotlib.pyplot as plt

        # Check if destination dir exists
        plots_dir = os.path.join(config.vae['logs_dir'], "plots_vae")
        if not os.path.exists(plots_dir):
            os.makedirs(plots_dir)

        # Evaluate VAE at the end of epoch
        def plot_samples(epoch, logs):
            pred = vae.predict(X_eval)

            samples = np.empty_like(np.concatenate((X_eval, pred)))
            samples[0::2] = X_eval
            samples[1::2] = pred

            _ = plt.figure(figsize=(4, 4))
            gs = gridspec.GridSpec(4, 4)
            gs.update(wspace=0.05, hspace=0.05)

            for i, sample in enumerate(samples):
                ax = plt.subplot(gs[i])
                plt.axis('off')
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                ax.set_aspect('equal')
                plt.imshow(sample.reshape(*config.general['state_shape']))

            # Save figure to logs dir
            plt.savefig(
                os.path.join(
                    plots_dir, "interpreter_sample_{}".format(
                        dt.datetime.now().strftime("%d-%mT%H:%M"))))
            plt.close()
    else:

        def plot_samples(epoch, logs):
            pass

    # Create checkpoint and logging directory, if it doesn't exist
    create_directory(os.path.dirname(config.vae['ckpt_path']))
    create_directory(os.path.dirname(config.vae['logs_dir']))

    # Initialize callbacks
    callbacks = [
        EarlyStopping(patience=config.vae['patience']),
        LambdaCallback(on_epoch_begin=plot_samples),
        ModelCheckpoint(config.vae['ckpt_path'],
                        verbose=1,
                        save_best_only=True,
                        save_weights_only=True),
        CSVLogger(filename=os.path.join(config.vae['logs_dir'],
                                        'train_vae.csv'),
                  append=True)
    ]

    # Fit VAE model!
    vae.fit_generator(
        generator=train_gen,
        validation_data=val_gen,
        epochs=config.vae['epochs'],
        use_multiprocessing=False,
        # NOTE:  There is no need for more then one workers, we are disk IO bound (I suppose ...)
        # NOTE2: h5py from conda should be threadsafe... but it apparently isn't and raises
        #        `OSError: Can't read data (wrong B-tree signature)` sporadically if `workers` = 1
        #        and always if `workers` > 1. That's why this generator needs to run in main thread
        #        (`workers` = 0).
        workers=3,
        max_queue_size=100,
        shuffle=True,  # It shuffles whole batches, not items in batches
        callbacks=callbacks)