def run(config): """Entry point to run training.""" init_data_normalizer(config) stage_ids = train_util.get_stage_ids(**config) if not config['train_progressive']: stage_ids = stage_ids[-1:] # Train one stage at a time for stage_id in stage_ids: batch_size = train_util.get_batch_size(stage_id, **config) tf.reset_default_graph() with tf.device(tf.train.replica_device_setter(config['ps_tasks'])): model = lib_model.Model(stage_id, batch_size, config) model.add_summaries() print('Variables:') for v in tf.global_variables(): print('\t', v.name, v.get_shape().as_list()) logging.info('Calling train.train') train_util.train(model, **config)