示例#1
0
def get_data_from_offline_batch(params,
                                env,
                                normalization_scope=None,
                                model='dynamics',
                                split_ratio=0.666667):
    train_collection = DataCollection(batch_size=params[model]['batch_size'],
                                      max_size=params['max_train_data'],
                                      shuffle=True)
    val_collection = DataCollection(batch_size=params[model]['batch_size'],
                                    max_size=params['max_val_data'],
                                    shuffle=False)
    rollout_sampler = RolloutSampler(env)
    rl_paths = rollout_sampler.generate_offline_data(
        data_file=params['data_file'], n_train=params["n_train"])
    path_collection = PathCollection()
    obs_dim = env.observation_space.shape[0]
    normalization = add_path_data_to_collection_and_update_normalization(
        rl_paths,
        path_collection,
        train_collection,
        val_collection,
        normalization=None,
        split_ratio=split_ratio,
        obs_dim=obs_dim,
        normalization_scope=normalization_scope)
    return train_collection, val_collection, normalization, path_collection, rollout_sampler
示例#2
0
def replace_path_data_to_collection_and_update_normalization(
        paths,
        train_collection,
        val_collection,
        normalization=None,
        split_ratio=0.666667,
        obs_dim=None,
        normalization_scope=None):
    # data
    train_data, val_data = PathCollection.to_data_collections(
        paths, split_ratio=split_ratio)
    train_collection.replace_data(train_data)
    val_collection.replace_data(val_data)

    # normalization
    if not normalization:
        logger.log("Creating normalization for training data.")
        normalization = Normalization(train_data,
                                      obs_dim=obs_dim,
                                      scope=normalization_scope)
        logger.log("Done creating normalization for training data.")
        return normalization
    else:
        logger.log("Updating normalization.")
        normalization.update(train_collection.get_data())
        logger.log("Done updating normalization.")
        return normalization
示例#3
0
def add_path_data_to_collection_and_update_normalization(
        paths,
        path_collection,
        train_collection,
        val_collection,
        normalization=None,
        split_ratio=0.666667,
        train_discard_ratio=0.0,
        obs_dim=None,
        normalization_scope=None):
    """
    Add new data from paths to collections. Update normalization stats.
    :param path_collection: PathCollection object
    :param train_collection: a "data_collection" object
    :param val_collection: a "data_collection" object
    :param normalization: a "Normalization" object
    :param split_ratio: real number in [0, 1]. The split ratio between training and validation set
    :param train_discard_ratio: real number in [0, 1). The ratio to discard training data
    :param obs_dim: actual observation dimension that will be normalized
    :return: an updated normalization
    """
    # data
    train_data, val_data = PathCollection.to_data_collections(
        paths, split_ratio=split_ratio)
    train_collection.add_data(train_data, discard_ratio=train_discard_ratio)
    val_collection.add_data(val_data)

    # normalization
    if not normalization:
        logger.log("Creating normalization for training data.")
        normalization = Normalization(train_data,
                                      obs_dim=obs_dim,
                                      scope=normalization_scope)
        logger.log("Done creating normalization for training data.")
        return normalization
    else:
        logger.log("Updating normalization.")
        normalization.update(train_collection.get_data())
        logger.log("Done updating normalization.")
        return normalization
示例#4
0
def get_data_from_random_rollouts(params, env, normalization_scope=None):
    train_collection = DataCollection(
        batch_size=params['dynamics']['batch_size'],
        max_size=params['max_train_data'],
        shuffle=True)
    val_collection = DataCollection(
        batch_size=params['dynamics']['batch_size'],
        max_size=params['max_val_data'],
        shuffle=False)
    rollout_sampler = RolloutSampler(env)
    random_paths = rollout_sampler.generate_random_rollouts(
        num_paths=params['num_path_random'], horizon=params['env_horizon'])
    path_collection = PathCollection()
    obs_dim = env.observation_space.shape[0]
    normalization = add_path_data_to_collection_and_update_normalization(
        random_paths,
        path_collection,
        train_collection,
        val_collection,
        normalization=None,
        obs_dim=obs_dim,
        normalization_scope=normalization_scope)
    return train_collection, val_collection, normalization, path_collection, rollout_sampler
示例#5
0
def get_data_from_random_rollouts(params,
                                  env,
                                  random_paths,
                                  normalization_scope=None,
                                  model='dynamics',
                                  split_ratio=0.666667):
    train_collection = DataCollection(batch_size=params[model]['batch_size'],
                                      max_size=params['max_train_data'],
                                      shuffle=True)
    val_collection = DataCollection(batch_size=params[model]['batch_size'],
                                    max_size=params['max_val_data'],
                                    shuffle=False)
    path_collection = PathCollection()
    obs_dim = env.observation_space.shape[0]
    normalization = add_path_data_to_collection_and_update_normalization(
        random_paths,
        path_collection,
        train_collection,
        val_collection,
        normalization=None,
        split_ratio=split_ratio,
        obs_dim=obs_dim,
        normalization_scope=normalization_scope)
    return train_collection, val_collection, normalization, path_collection