コード例 #1
0
        n_params = np.prod(shape, dtype=np.int32)
        scope_n_params += n_params
        print '\t', name, shape
    print


def get_session():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    return sess


if __name__ == '__main__':

    tf_flags.DEFINE_integer('int_flag', -2, 'some int')
    tf_flags.DEFINE_string('string_flag', 'abc', 'some string')

    checkpoint_dir = '../checkpoints/setup'
    data_config = 'configs/static_mnist_data.py'
    model_config = 'configs/imp_weighted_nvil.py'


    # sys.argv.append('--int_flag=100')
    # sys.argv.append('--model_flag=-1')
    # print sys.argv

    experiment_folder, loaded_flags, checkpoint_dir = init_checkpoint(checkpoint_dir, data_config, model_config, resume=False)

    print experiment_folder
    print loaded_flags
コード例 #2
0
ファイル: mnist_tools.py プロジェクト: lqiang2003cn/sqair
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
########################################################################################

import numpy as np
import tensorflow as tf
from attrdict import AttrDict

from sqair.data.data import load_data as _load_data, tensors_from_data as _tensors
from sqair import tf_flags as flags
from sqair.index import dynamic_truncate

flags.DEFINE_integer(
    'seq_len', 0,
    'Length of loaded data sequences. If 0, it defaults to the maximum length.'
)
flags.DEFINE_integer(
    'stage_itr', 0,
    'If > 0 it setups a curriculum learning where `seq_len` starts as given and '
    'increases by one every `stage_itr` until it gets to the maximum value.')

axes = {'imgs': 1, 'labels': 0, 'nums': 1, 'coords': 1}


def truncate(data_dict, n_timesteps):
    data_dict['imgs'] = data_dict['imgs'][:n_timesteps]
    data_dict['coords'] = data_dict['coords'][:n_timesteps]
    data_dict['nums'] = data_dict['nums'][:n_timesteps]
    return data_dict
コード例 #3
0
########################################################################################

"""Common flags used by moodel configurations.
"""

from attrdict import AttrDict

from sqair import tf_flags as flags


flags.DEFINE_float('transform_var_bias', -3., 'Bias added to the the variance logit of Gaussian `where` distributions.')

flags.DEFINE_float('output_scale', .25, 'It\'s used to scale the output mean of the glimpse decoder.')
flags.DEFINE_string('scale_prior', '-2', 'A single float or four comma-separated floats representing the mean of the '
                                         'Gaussian prior for scale logit.')
flags.DEFINE_integer('glimpse_size', 20, 'Glimpse size.')

flags.DEFINE_float('prop_prior_step_bias', 10., '')
flags.DEFINE_string('prop_prior_type', 'rnn', 'Choose from {rnn, rw_rnn} for a recurrent prior and a random-walk '
                                              'recurrent prior.')
flags.DEFINE_boolean('masked_glimpse', True, 'Masks glimpses based on what_tm1 in propagation if True')


flags.DEFINE_integer('k_particles', 5, 'Number of particles used for the IWAE bound computation')
flags.DEFINE_integer('n_steps_per_image', 3, 'Number of inference steps per frame.')

flags.DEFINE_string('transition', 'VanillaRNN', 'RNNCore from Sonnet to use in discovery and propagation cores.')
flags.DEFINE_string('time_transition', 'GRU', 'RNNCore used for temporal rnn in propagation core.')
flags.DEFINE_string('prior_transition', 'GRU', 'RNNCore used by the propagation prior.')

flags.DEFINE_float('output_std', .3, 'Standard deviation of Gaussian p(x|z)')
コード例 #4
0
ファイル: eval.py プロジェクト: lqiang2003cn/sqair
from os import path as osp

import numpy as np
import tensorflow as tf

import sys
sys.path.append('../')

from sqair.experiment_tools import load, get_session, parse_flags, assert_all_flags_parsed, _load_flags, FLAG_FILE, json_load, _restore_flags
from sqair import tf_flags as flags

flags.DEFINE_string('data_config', 'configs/seq_mnist_data.py', '')
flags.DEFINE_string('model_config', 'configs/apdr.py', '')
flags.DEFINE_string('checkpoint_dir', '../checkpoints', '')

flags.DEFINE_integer('batch_size', 5, '')

flags.DEFINE_integer(
    'every_nth_checkpoint', 1,
    'takes 1 in nth checkpoints to evaluate; takes only the last checkpoint if -1'
)
flags.DEFINE_integer(
    'from_itr', 0,
    'Evaluates only checkpoints with training iteration greater than `from_itr`'
)

flags.DEFINE_string('dataset', 'valid', 'test or valid')

flags.DEFINE_boolean('logp', True, '')
flags.DEFINE_boolean('vae', True, '')
flags.DEFINE_boolean('num_step_acc', True, '')
コード例 #5
0
ファイル: experiment.py プロジェクト: cvoelcker/sqair
                                    print_variables_by_scope)
from sqair import tf_flags as flags

# Define flags

flags.DEFINE_string('data_config', 'configs/orig_seq_mnist.py',
                    'Path to a data config file.')
flags.DEFINE_string('model_config', 'configs/mlp_mnist_model.py',
                    'Path to a model config file.')
flags.DEFINE_string('results_dir', '../checkpoints',
                    'Top directory for all experimental results.')
flags.DEFINE_string(
    'run_name', 'test_run',
    'Name of this job. Results will be stored in a corresponding folder.')

flags.DEFINE_integer('batch_size', 32, '')

flags.DEFINE_integer('log_itr', int(1e4),
                     'Number of iterations between storing tensorboard logs.')
flags.DEFINE_integer(
    'report_loss_every', int(1e3),
    'Number of iterations between reporting minibatch loss - hearbeat.')
flags.DEFINE_integer('save_itr', int(1e5),
                     'Number of iterations between snapshotting the model.')
flags.DEFINE_integer('fig_itr', 10000,
                     'Number of iterations between creating results figures.')
flags.DEFINE_integer('train_itr', int(2e6),
                     'Maximum number of training iterations.')
flags.DEFINE_boolean('resume', False,
                     'Tries to resume the previous run if True.')
flags.DEFINE_boolean(