Exemple #1
0
def load_or_create_samples(params, config):
    """Load or create the samples used as the sampler inputs."""
    # greg: En caso de no estar condicionando, solamente se sacan muestras de z.

    # Load sample_z
    LOGGER.info("Loading sample_z.")
    sample_z_path = os.path.join(config['model_dir'], 'sample_z.npy')
    if os.path.exists(sample_z_path):
        sample_z = np.load(sample_z_path)
        if sample_z.shape[1] != params['latent_dim']:
            LOGGER.info("Loaded sample_z has wrong shape")
            resample = True
        else:
            resample = False
    else:
        LOGGER.info("File for sample_z not found")
        resample = True

    # Draw new sample_z
    if resample:
        LOGGER.info("Drawing new sample_z.")
        sample_z = scipy.stats.truncnorm.rvs(
            -2, 2, size=(np.prod(config['sample_grid']), params['latent_dim']))
        make_sure_path_exists(config['model_dir'])
        np.save(sample_z_path, sample_z)

    # greg: Solo si estamos condicionando se extraen muestras de sample_x
    if params.get('is_accompaniment'):
        # Load sample_x
        LOGGER.info("Loading sample_x.")
        sample_x_path = os.path.join(config['model_dir'], 'sample_x.npy')
        if os.path.exists(sample_x_path):
            sample_x = np.load(sample_x_path)
            if sample_x.shape[1:] != params['data_shape']:
                LOGGER.info("Loaded sample_x has wrong shape")
                resample = True
            else:
                resample = False
        else:
            LOGGER.info("File for sample_x not found")
            resample = True

        # Draw new sample_x
        if resample:
            LOGGER.info("Drawing new sample_x.")
            data = load_data(config['data_source'], config['data_filename'])
            sample_x = get_samples(
                np.prod(config['sample_grid']),
                data,
                use_random_transpose=config['use_random_transpose'])
            make_sure_path_exists(config['model_dir'])
            np.save(sample_x_path, sample_x)
    else:
        sample_x = None

    return sample_x, None, sample_z
Exemple #2
0
def load_or_create_samples(params, config):
    """Load or create the samples used as the sampler inputs."""
    # Load sample_z
    LOGGER.info("Loading sample_z.")
    sample_z_path = os.path.join(config['model_dir'], 'sample_z.npy')
    if os.path.exists(sample_z_path):
        sample_z = np.load(sample_z_path)
        if sample_z.shape[1] != params['latent_dim']:
            LOGGER.info("Loaded sample_z has wrong shape")
            resample = True
        else:
            resample = False
    else:
        LOGGER.info("File for sample_z not found")
        resample = True

    # Draw new sample_z
    if resample:
        LOGGER.info("Drawing new sample_z.")
        sample_z = scipy.stats.truncnorm.rvs(
            -2, 2, size=(np.prod(config['sample_grid']), params['latent_dim']))
        make_sure_path_exists(config['model_dir'])
        np.save(sample_z_path, sample_z)

    if params['is_accompaniment']:
        # Load sample_x
        LOGGER.info("Loading sample_x.")
        sample_x_path = os.path.join(config['model_dir'], 'sample_x.npy')
        if os.path.exists(sample_x_path):
            sample_x = np.load(sample_x_path)
            if sample_x.shape[1:] != params['data_shape']:
                LOGGER.info("Loaded sample_x has wrong shape")
                resample = True
            else:
                resample = False
        else:
            LOGGER.info("File for sample_x not found")
            resample = True

        # Draw new sample_x
        if resample:
            LOGGER.info("Drawing new sample_x.")
            data = load_data(config['data_source'], config['data_filename'])
            sample_x = get_samples(
                np.prod(config['sample_grid']), data,
                use_random_transpose = config['use_random_transpose'])
            make_sure_path_exists(config['model_dir'])
            np.save(sample_x_path, sample_x)
    else:
        sample_x = None

    return sample_x, None, sample_z
Exemple #3
0
def main():
    """Main function."""
    # Setup
    logging.basicConfig(level=LOGLEVEL, format=LOG_FORMAT)
    params, config = setup()
    LOGGER.info("Using parameters:\n%s", pformat(params))
    LOGGER.info("Using configurations:\n%s", pformat(config))

    # ============================== Placeholders ==============================
    placeholder_x = tf.placeholder(tf.float32,
                                   shape=([None] + params['data_shape']))
    placeholder_z = tf.placeholder(tf.float32,
                                   shape=(None, params['latent_dim']))
    placeholder_c = tf.placeholder(tf.float32,
                                   shape=([None] + params['data_shape'][:-1] +
                                          [1]))
    placeholder_suffix = tf.placeholder(tf.string)

    # ================================= Model ==================================
    # Create sampler configurations
    sampler_config = {
        'result_dir': config['result_dir'],
        'image_grid': (config['rows'], config['columns']),
        'suffix': placeholder_suffix,
        'midi': config['midi'],
        'colormap': np.array(config['colormap']).T,
        'collect_save_arrays_op': config['save_array_samples'],
        'collect_save_images_op': config['save_image_samples'],
        'collect_save_pianorolls_op': config['save_pianoroll_samples']
    }

    # Build model
    model = Model(params)
    if params['is_accompaniment']:
        _ = model(x=placeholder_x,
                  c=placeholder_c,
                  z=placeholder_z,
                  mode='train',
                  params=params,
                  config=config)
        predict_nodes = model(c=placeholder_c,
                              z=placeholder_z,
                              mode='predict',
                              params=params,
                              config=sampler_config)
    else:
        _ = model(x=placeholder_x,
                  z=placeholder_z,
                  mode='train',
                  params=params,
                  config=config)
        predict_nodes = model(z=placeholder_z,
                              mode='predict',
                              params=params,
                              config=sampler_config)

    # Get sampler op
    sampler_op = tf.group([
        predict_nodes[key]
        for key in ('save_arrays_op', 'save_images_op', 'save_pianorolls_op')
        if key in predict_nodes
    ])

    # ================================== Data ==================================
    if params['is_accompaniment']:
        data = load_data(config['data_source'], config['data_filename'])

    # ========================== Session Preparation ===========================
    # Get tensorflow session config
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True

    # Create saver to restore variables
    saver = tf.train.Saver()

    # =========================== Tensorflow Session ===========================
    with tf.Session(config=tf_config) as sess:

        # Restore the latest checkpoint
        LOGGER.info("Restoring the latest checkpoint.")
        with open(os.path.join(config['checkpoint_dir'], 'checkpoint')) as f:
            checkpoint_name = os.path.basename(
                f.readline().split()[1].strip('"'))
        checkpoint_path = os.path.realpath(
            os.path.join(config['checkpoint_dir'], checkpoint_name))
        saver.restore(sess, checkpoint_path)

        # Run sampler op
        for i in range(config['runs']):
            feed_dict_sampler = {
                placeholder_z:
                scipy.stats.truncnorm.rvs(config['lower'],
                                          config['upper'],
                                          size=((config['rows'] *
                                                 config['columns']),
                                                params['latent_dim'])),
                placeholder_suffix:
                str(i)
            }
            if params['is_accompaniment']:
                sample_x = get_samples(
                    (config['rows'] * config['columns']),
                    data,
                    use_random_transpose=config['use_random_transpose'])
                feed_dict_sampler[placeholder_c] = np.expand_dims(
                    sample_x[..., params['condition_track_idx']], -1)
            sess.run(sampler_op, feed_dict=feed_dict_sampler)
def main():
    """Main function."""
    # Setup
    logging.basicConfig(level=LOGLEVEL, format=LOG_FORMAT)
    params, config = setup()
    LOGGER.info("Using parameters:\n%s", pformat(params))
    LOGGER.info("Using configurations:\n%s", pformat(config))

    # ============================== Placeholders ==============================
    placeholder_x = tf.placeholder(
        tf.float32, shape=([None] + params['data_shape']))
    placeholder_z = tf.placeholder(
        tf.float32, shape=(None, params['latent_dim']))
    placeholder_c = tf.placeholder(
        tf.float32, shape=([None] + params['data_shape'][:-1] + [1]))
    placeholder_suffix = tf.placeholder(tf.string)

    # ================================= Model ==================================
    # Create sampler configurations
    sampler_config = {
        'result_dir': config['result_dir'],
        'image_grid': (config['rows'], config['columns']),
        'suffix': placeholder_suffix, 'midi': config['midi'],
        'colormap': np.array(config['colormap']).T,
        'collect_save_arrays_op': config['save_array_samples'],
        'collect_save_images_op': config['save_image_samples'],
        'collect_save_pianorolls_op': config['save_pianoroll_samples']}

    # Build model
    model = Model(params)
    if params['is_accompaniment']:
        _ = model(
            x=placeholder_x, c=placeholder_c, z=placeholder_z, mode='train',
            params=params, config=config)
        predict_nodes = model(
            c=placeholder_c, z=placeholder_z, mode='predict', params=params,
            config=sampler_config)
    else:
        _ = model(
            x=placeholder_x, z=placeholder_z, mode='train', params=params,
            config=config)
        predict_nodes = model(
            z=placeholder_z, mode='predict', params=params,
            config=sampler_config)

    # Get sampler op
    sampler_op = tf.group([
        predict_nodes[key] for key in (
            'save_arrays_op', 'save_images_op', 'save_pianorolls_op')
        if key in predict_nodes])

    # ================================== Data ==================================
    if params['is_accompaniment']:
        data = load_data(config['data_source'], config['data_filename'])

    # ========================== Session Preparation ===========================
    # Get tensorflow session config
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True

    # Create saver to restore variables
    saver = tf.train.Saver()

    # =========================== Tensorflow Session ===========================
    with tf.Session(config=tf_config) as sess:

        # Restore the latest checkpoint
        LOGGER.info("Restoring the latest checkpoint.")
        with open(os.path.join(config['checkpoint_dir'], 'checkpoint')) as f:
            checkpoint_name = os.path.basename(
                f.readline().split()[1].strip('"'))
        checkpoint_path = os.path.realpath(
            os.path.join(config['checkpoint_dir'], checkpoint_name))
        saver.restore(sess, checkpoint_path)

        # Run sampler op
        for i in range(config['runs']):
            feed_dict_sampler = {
                placeholder_z: get_input_z(config, params),
                placeholder_suffix: str(i)}
            if params['is_accompaniment']:
                sample_x = get_samples(
                    1, data,
                    use_random_transpose=config['use_random_transpose'])
                sample_c = np.expand_dims(
                    sample_x[..., params['condition_track_idx']], -1)
                feed_dict_sampler[placeholder_c] = np.repeat(
                    sample_c, (config['rows'] * config['columns']), axis=0)
            sess.run(sampler_op, feed_dict=feed_dict_sampler)