def add_phase(self, name, steps, score, summary, batch_size=1, report_every=None, log_every=None, checkpoint_every=None, restore_every=None, feed=None): score = tf.convert_to_tensor(score, tf.float32) summary = tf.convert_to_tensor(summary, tf.string) feed = feed or {} if not score.shape.ndims: score = score[None] writer = self._logdir and tf.summary.FileWriter(os.path.join( self._logdir, name), tf.get_default_graph(), flush_secs=30) op = self._define_step(name, batch_size, score, summary) tmp_phase = tools.AttrDict() tmp_phase.name = name tmp_phase.writer = writer tmp_phase.op = op tmp_phase.batch_size = batch_size tmp_phase.steps = int(steps) tmp_phase.feed = feed tmp_phase.report_every = report_every tmp_phase.log_every = log_every tmp_phase.checkpoint_every = checkpoint_every tmp_phase.restore_every = restore_every self._phases.append(tmp_phase)
def _initial_collection(config, params): num_seed_episodes = int(params.get('num_seed_episodes', 5)) num_seed_steps = int(params.get('num_seed_steps', 2500)) sims = tools.AttrDict() for task in config.train_tasks: sims['train-' + task.name] = tools.AttrDict( task=task, mode='train', save_episode_dir=config.train_dir, num_episodes=num_seed_episodes, num_steps=num_seed_steps, give_rewards=params.get('seed_episode_rewards', True)) for task in config.test_tasks: sims['test-' + task.name] = tools.AttrDict( task=task, mode='test', save_episode_dir=config.test_dir, num_episodes=num_seed_episodes, num_steps=num_seed_steps, give_rewards=True) return sims
def _active_collection(tasks, collects, defaults, config, params): sims = tools.AttrDict() for task in tasks: for user_collect in collects: for key in user_collect: if key not in defaults: message = 'Invalid key {} in activation collection config.' raise KeyError(message.format(key)) collect = tools.AttrDict(defaults, _unlocked=True) collect.update(user_collect) collect.planner = _define_planner( collect.planner, collect.horizon, config, params) collect.objective = tools.bind( getattr(objectives_lib, collect.objective), params=params) adaptation_condition = (collect.prefix=='train' and config.curious_run and config.adaptation) or (collect.prefix=='train' and config.random_run and config.adaptation) or (collect.prefix=='train' and config.vanilla_curious_run and config.adaptation) if adaptation_condition: collect.secondary_planner = _define_planner( collect.secondary_planner, collect.horizon, config, params) collect.secondary_objective = tools.bind( getattr(objectives_lib, collect.secondary_objective), params=params) if collect.give_rewards: collect.task = task else: env_ctor = tools.bind( lambda ctor: control.wrappers.NoRewardHint(ctor()), task.env_ctor) collect.task = tasks_lib.Task(task.name, env_ctor) collect.exploration = tools.AttrDict( scale=collect.action_noise_scale, type=collect.action_noise_type, schedule=tools.bind( tools.schedule.linear, ramp=collect.action_noise_ramp, min=collect.action_noise_min), factors=collect.action_noise_factors) name = '{}_{}_{}'.format(collect.prefix, collect.name, task.name) assert name not in sims, (set(sims.keys()), name) sims[name] = collect return sims
def make_config(params): config = tools.AttrDict() config.debug = params.get('debug', False) with params.unlocked: for name in params.get('defaults', ['dreamer']): for key, value in DEFAULTS[name].items(): if key not in params: params[key] = value config.loss_scales = tools.AttrDict() config = _data_processing(config, params) config = _model_components(config, params) config = _tasks(config, params) config = _loss_functions(config, params) config = _training_schedule(config, params) # Mark params as used which are only accessed at run-time. run_time_keys = [ 'planner_discount', 'planner_lambda', 'objective_entropy_scale', 'normalize_actions', 'max_length', 'render_size', 'atari_lifes', 'atari_noops', 'atari_sticky', 'atari_train_max_length', 'atari_grayscale'] for key in run_time_keys: params.get(key, None) return config
def __init__(self, logdir, config=None): self._logdir = logdir self._global_step = tf.train.get_or_create_global_step() self._step = tf.placeholder(tf.int32, name='step') self._phase = tf.placeholder(tf.string, name='phase') self._log = tf.placeholder(tf.bool, name='log') self._report = tf.placeholder(tf.bool, name='report') self._reset = tf.placeholder(tf.bool, name='reset') self._phases = [] self._epoch_store = 0 self._globalstepsubtract = 0 self._mod_phase_step_train = 0 self._trainstepprev = 0 self._mod_phase_step_test = 0 self._teststepprev = 0 self._epoch_delta = 0 # Checkpointing. self._loaders = [] self._savers = [] self._logdirs = [] self._checkpoints = [] self._config = config or tools.AttrDict()
def simulate(metrics, config, params, graph, cleanups, gif_summary, name): def env_ctor(): env = params.task.env_ctor() if params.save_episode_dir: if config.curious_run: env = control.wrappers.CollectDataset(env, params.save_episode_dir, adaptation=config.adaptation, exploration_episodes=config.exploration_episodes) else: env = control.wrappers.CollectDataset(env, params.save_episode_dir) return env bind_or_none = lambda x, **kw: x and functools.partial(x, **kw) cell = graph.cell agent_adaptation_condition = (config.curious_run and config.adaptation and params.prefix=='train') or (config.vanilla_curious_run and config.adaptation and params.prefix=='train') or (config.random_run and config.adaptation and params.prefix=='train') if agent_adaptation_condition: agent_config = tools.AttrDict( cell=cell, encoder=graph.encoder, prefix='train', planner=functools.partial(params.planner, graph=graph), secondary_planner=functools.partial(params.secondary_planner, graph=graph), objective=bind_or_none(params.objective, graph=graph), secondary_objective=bind_or_none(params.secondary_objective, graph=graph), adaptation_step=config.adaptation_step, exploration=params.exploration, preprocess_fn=config.preprocess_fn, postprocess_fn=config.postprocess_fn) else: agent_config = tools.AttrDict( cell=cell, encoder=graph.encoder, planner=functools.partial(params.planner, graph=graph), objective=bind_or_none(params.objective, graph=graph), exploration=params.exploration, preprocess_fn=config.preprocess_fn, postprocess_fn=config.postprocess_fn) params = params.copy() with params.unlocked: params.update(agent_config) with agent_config.unlocked: agent_config.update(params) with tf.variable_scope(name): summaries = [] env = control.create_batch_env( env_ctor, params.num_envs, config.isolate_envs) adaptation = (config.curious_run and config.adaptation) or (config.vanilla_curious_run and config.adaptation) or (config.random_run and config.adaptation) agent = control.MPCAgent(env, graph.step, False, False, agent_config, adaptation, graph.phase, graph.global_step) cleanup = lambda: env.close() scores, lengths, data = control.simulate( agent, env, params.num_episodes, params.num_steps) summaries.append(tf.summary.scalar('return', scores[0])) summaries.append(tf.summary.scalar('length', lengths[0])) if gif_summary: summaries.append(tools.gif_summary( 'gif', data['image'], max_outputs=1, fps=20)) write_metrics = [ metrics.add_scalars(name + '/return', scores), metrics.add_scalars(name + '/length', lengths), # metrics.add_tensor(name + '/frames', data['image']), ] with tf.control_dependencies(write_metrics): summary = tf.summary.merge(summaries) cleanups.append(cleanup) # Work around tf.cond() tensor return type. return summary, tf.reduce_mean(scores)
type=pathlib.Path, default='./rolloutdir/') parser.add_argument('--params', default='{}') parser.add_argument('--num_runs', type=int, default=1) parser.add_argument('--expID', type=str, required=True) parser.add_argument('--ping_every', type=int, default=0) parser.add_argument('--resume_runs', type=boolean, default=True) parser.add_argument('--dmlab_runfiles_path', default=None) args_, remaining = parser.parse_known_args() args_.params += ' ' for tmp in remaining: args_.params += tmp + ' ' params_ = args_.params.replace('#', ',').replace('\\', '') args_.params = tools.AttrDict(yaml.safe_load(params_)) if args_.dmlab_runfiles_path: with args_.params.unlocked: args_.params.dmlab_runfiles_path = args_.dmlab_runfiles_path assert args_.params.dmlab_runfiles_path # Mark as accessed. args_.logdir = args_.logdir and os.path.expanduser(args_.logdir) args_.rolloutdir = args_.rolloutdir and os.path.expanduser( args_.rolloutdir) expid = args_.expID.split('_') num, comm = int(expid[0]), expid[1:] comment = '' for com in comm: comment += '_' + com args_.logdir = os.path.join(args_.logdir, '{:05}_expID'.format(num) + comment)
def _training_schedule(config, params): config.train_steps = int(params.get('train_steps', 50000)) config.test_steps = int(params.get('test_steps', config.batch_shape[0])) config.max_steps = int(params.get('max_steps', 5e7)) config.train_log_every = params.get('train_log_every', config.train_steps) config.train_checkpoint_every = None config.test_checkpoint_every = int( params.get('checkpoint_every', 10 * config.test_steps)) config.checkpoint_to_load = None config.savers = [tools.AttrDict(exclude=(r'.*_temporary.*',))] config.print_metrics_every = config.train_steps // 10 config.slow_model_train_by = params.get('slow_model_train_by',1) if params.get('use_separate_rolloutdir', True): config.train_dir = os.path.join(params.rolloutdir, 'train_episodes') config.test_dir = os.path.join(params.rolloutdir, 'test_episodes') else: config.train_dir = os.path.join(params.logdir, 'train_episodes') config.test_dir = os.path.join(params.logdir, 'test_episodes') config.random_collects = _initial_collection(config, params) defaults = tools.AttrDict() defaults.name = 'main' defaults.give_rewards = True defaults.horizon = params.get('planner_horizon', 12) defaults.objective = params.get('planner_objective', 'reward_value') defaults.num_envs = params.get('num_envs', 1) defaults.num_episodes = params.get('collect_episodes', defaults.num_envs) defaults.num_steps = params.get('collect_steps', 500) # defaults.num_steps = params.get('collect_steps', 50) defaults.steps_after = params.get('collect_every', 5000) # defaults.steps_after = params.get('collect_every', 500) defaults.steps_every = params.get('collect_every', 5000) # defaults.steps_every = params.get('collect_every', 500) defaults.steps_until = -1 defaults.action_noise_type = params.get( 'action_noise_type', 'additive_normal') train_defaults = defaults.copy(_unlocked=True) train_defaults.prefix = 'train' train_defaults.mode = 'train' train_defaults.save_episode_dir = config.train_dir train_defaults.planner = params.get('train_planner', 'policy_sample') if config.curious_run: train_objective = 'curious_reward_value' elif config.vanilla_curious_run: train_objective = 'icm_reward_value' else: train_objective = defaults.objective train_defaults.objective = params.get( 'train_planner_objective', train_objective) if config.curious_run and config.adaptation: train_defaults.secondary_planner = 'policy_sample' train_defaults.secondary_objective = 'reward_value' config.exploration_episodes = params.get('exploration_episodes', int((config.adaptation_step/params.get('collect_every', 5000)))) config.adaptation_data_ratio = params.get('adaptation_data_ratio', 0.7) config.use_scheduler = params.get('use_scheduler', False) config.schedule_limit = params.get('schedule_limit', 500) elif config.vanilla_curious_run and config.adaptation: train_defaults.secondary_planner = 'policy_sample' train_defaults.secondary_objective = 'reward_value' config.exploration_episodes = params.get('exploration_episodes', int((config.adaptation_step/params.get('collect_every', 5000)))) config.adaptation_data_ratio = params.get('adaptation_data_ratio', 0.7) config.use_scheduler = params.get('use_scheduler', False) config.schedule_limit = params.get('schedule_limit', 500) elif config.random_run and config.adaptation: train_defaults.secondary_planner = 'policy_sample' train_defaults.secondary_objective = 'reward_value' config.exploration_episodes = params.get('exploration_episodes', int((config.adaptation_step/params.get('collect_every', 5000)))) config.adaptation_data_ratio = params.get('adaptation_data_ratio', 0.7) config.use_scheduler = params.get('use_scheduler', False) config.schedule_limit = params.get('schedule_limit', 500) train_defaults.action_noise_scale = params.get('train_action_noise', 0.3) train_defaults.action_noise_ramp = params.get('train_action_noise_ramp', 0) train_defaults.action_noise_min = params.get('train_action_noise_min', 0.0) train_defaults.action_noise_factors = params.get( 'train_action_noise_factors', []) config.train_collects = _active_collection( config.train_tasks, params.get('train_collects', [{}]), train_defaults, config, params) test_defaults = defaults.copy(_unlocked=True) test_defaults.prefix = 'test' test_defaults.mode = 'test' test_defaults.save_episode_dir = config.test_dir test_defaults.planner = params.get('test_planner', 'policy_mode') test_defaults.objective = params.get( 'test_planner_objective', defaults.objective) test_defaults.action_noise_scale = params.get('test_action_noise', 0.0) test_defaults.action_noise_ramp = 0 test_defaults.action_noise_min = 0.0 test_defaults.action_noise_factors = params.get( 'train_action_noise_factors', None) config.test_collects = _active_collection( config.test_tasks, params.get('test_collects', [{}]), test_defaults, config, params) return config
def _loss_functions(config, params): for head in config.gradient_heads: assert head in config.heads, head config.imagination_horizon = params.get('imagination_horizon', 15) if config.curious_run or config.vanilla_curious_run: config.exploration_imagination_horizon = params.get('exploration_imagination_horizon',15) config.imagination_skip_last = params.get('imagination_skip_last', None) config.imagination_include_initial = params.get( 'imagination_include_initial', True) config.action_source = params.get('action_source', 'model') config.action_model_horizon = params.get('action_model_horizon', None) config.action_bootstrap = params.get('action_bootstrap', True) if config.curious_run or config.vanilla_curious_run: config.curious_action_bootstrap = params.get('curious_action_bootstrap', True) config.action_discount = params.get('action_discount', 0.99) config.action_lambda = params.get('action_lambda', 0.95) config.action_target_update = params.get('action_target_update', 1) config.action_target_period = params.get('action_target_period', 50000) config.action_loss_pcont = params.get('action_loss_pcont', False) config.action_pcont_stop_grad = params.get('action_pcont_stop_grad', False) config.action_pcont_weight = params.get('action_pcont_weight', True) config.action_lr = params.get('action_lr', 8e-5) config.value_source = params.get('value_source', 'model') config.value_model_horizon = params.get('value_model_horizon', None) config.value_discount = params.get('value_discount', 0.99) config.value_lambda = params.get('value_lambda', 0.95) config.value_bootstrap = params.get('value_bootstrap', True) if config.curious_run or config.vanilla_curious_run: config.curious_value_bootstrap = params.get('curious_value_bootstrap', True) config.value_target_update = params.get('value_target_update', 1) config.value_target_period = params.get('value_target_period', 50000) config.value_loss_pcont = params.get('value_loss_pcont', False) config.value_pcont_weight = params.get('value_pcont_weight', True) config.value_maxent = params.get('value_maxent', False) config.value_lr = params.get('value_lr', 8e-5) config.action_beta = params.get('action_beta', 0.0) config.action_beta_dims_value = params.get('action_beta_dims_value', None) config.state_beta = params.get('state_beta', 0.0) config.stop_grad_pre_action = params.get('stop_grad_pre_action', True) config.pcont_label_weight = params.get('pcont_label_weight', None) config.loss_scales.divergence = params.get('divergence_scale', 1.0) config.loss_scales.global_divergence = params.get('global_div_scale', 0.0) config.loss_scales.overshooting = params.get('overshooting_scale', 0.0) for head in config.heads: if head in ('value_target', 'action_target'): # Untrained. continue config.loss_scales[head] = params.get(head + '_loss_scale', 1.0) if config.curious_run or config.combination_run: for mdl in range(config.num_models): config.loss_scales['one_step_model_'+str(mdl)] = params.get('one_step_model_scale',1.0) config.free_nats = params.get('free_nats', 3.0) config.overshooting_distance = params.get('overshooting_distance', 0) config.os_stop_posterior_grad = params.get('os_stop_posterior_grad', True) config.cpc_contrast = params.get('cpc_contrast', 'window') config.cpc_batch_amount = params.get('cpc_batch_amount', 10) config.cpc_time_amount = params.get('cpc_time_amount', 30) optimizer_cls = tools.bind( tf.train.AdamOptimizer, epsilon=params.get('optimizer_epsilon', 1e-4)) config.optimizers = tools.AttrDict() config.optimizers.default = tools.bind( tools.CustomOptimizer, optimizer_cls=optimizer_cls, # schedule=tools.bind(tools.schedule.linear, ramp=0), learning_rate=params.get('default_lr', 1e-3), clipping=params.get('default_gradient_clipping', 1000.0)) config.optimizers.model = config.optimizers.default.copy( learning_rate=params.get('model_lr', 6e-4), clipping=params.get('model_gradient_clipping', 100.0)) config.optimizers.value = config.optimizers.default.copy( learning_rate=params.get('value_lr', 8e-5), clipping=params.get('value_gradient_clipping', 100.0)) if params.get('curious_value_head', False): config.optimizers.curious_value = config.optimizers.default.copy( learning_rate=params.get('curious_value_lr', 8e-5), clipping=params.get('value_gradient_clipping', 100.0)) config.optimizers.action = config.optimizers.default.copy( learning_rate=params.get('action_lr', 8e-5), clipping=params.get('action_gradient_clipping', 100.0)) if params.get('curious_action_head', False): config.optimizers.curious_action = config.optimizers.default.copy( learning_rate=params.get('curious_action_lr', 8e-5), clipping=params.get('action_gradient_clipping', 100.0)) return config
def _model_components(config, params): config.gradient_heads = params.get('gradient_heads', ['image', 'reward']) if 'prediction_error' in params.get('defaults', ['dreamer']): config.gradient_heads.append('reward_int') config.activation = ACTIVATIONS[params.get('activation', 'elu')] config.num_layers = params.get('num_layers', 3) config.num_units = params.get('num_units', 400) encoder = params.get('encoder', 'conv') if encoder == 'conv': config.encoder = networks.conv.encoder elif encoder == 'proprio': config.encoder = tools.bind( networks.proprio.encoder, keys=params.get('proprio_encoder_keys'), num_layers=params.get('proprio_encoder_num_layers', 3), units=params.get('proprio_encoder_units', 300)) else: raise NotImplementedError(encoder) config.head_network = tools.bind( networks.feed_forward, num_layers=config.num_layers, units=config.num_units, activation=config.activation) config.heads = tools.AttrDict() config.vanilla_curious_run = False config.curious_run = False config.combination_run = False config.adaptation = params.get('adaptation', False) config.encoder_feature_shape = params.get('encoder_feature_shape', 1024) if 'prediction_error' in params.get('defaults', ['dreamer']): config.vanilla_curious_run = True config.freeze_extrinsic_heads = params.get('freeze_extrinsic_heads', True) config.adaptation = params.get('adaptation', False) config.use_data_ratio = params.get('use_data_ratio', False) config.encoder_feature_shape = params.get('encoder_feature_shape', 1024) config.exploration_episodes = params.get('exploration_episodes', 0) config.use_max_objective = params.get('use_max_objective', False) config.use_reinforce = params.get('use_reinforce',False) if config.adaptation: config.adaptation_step = params.get('adaptation_step', 5e6) config.secondary_horizon = params.get('secondary_horizon', 15) config.secondary_action_lr = params.get('secondary_action_lr', 8e-5) config.secondary_value_lr = params.get('secondary_value_lr', 8e-5) config.secondary_train_step = params.get('secondary_train_step', 50000) if 'disagree' in params.get('defaults', ['dreamer']): config.one_step_model = getattr(networks, 'one_step_model').one_step_model config.num_models = params.get('num_models', 5) config.curious_run = True config.bootstrap = params.get('bootstrap', True) config.intrinsic_reward_scale = params.get('intrinsic_reward_scale', 10000) config.ensemble_loss_scale = params.get('ensemble_loss_scale', 1) config.freeze_extrinsic_heads = params.get('freeze_extrinsic_heads', True) config.ensemble_model_type = params.get('ensemble_model_type', 1) config.model_width_factor = params.get('model_width_factor', 1) config.extrinsic_train_frequency = params.get('extrinsic_train_frequency', 1) config.slower_extrinsic_train = config.extrinsic_train_frequency>1 config.adaptation = params.get('adaptation', False) config.use_data_ratio = params.get('use_data_ratio', False) config.exploration_episodes = params.get('exploration_episodes', 0) config.encoder_feature_shape = params.get('encoder_feature_shape', 1024) config.normalize_reward = params.get('normalize_reward', False) config.use_max_objective = params.get('use_max_objective', False) config.use_reinforce = params.get('use_reinforce',False) if not config.use_max_objective: assert config.use_reinforce == False if config.normalize_reward: config.moving_average_stats = params.get('moving_average_stats', False) if config.adaptation: config.adaptation_step = params.get('adaptation_step', 5e6) config.secondary_horizon = params.get('secondary_horizon', 15) config.secondary_action_lr = params.get('secondary_action_lr', 8e-5) config.secondary_value_lr = params.get('secondary_value_lr', 8e-5) config.secondary_train_step = params.get('secondary_train_step', 50000) if 'combination' in params.get('defaults', ['dreamer']): config.combination_run = True config.extrinsic_coeff = params.get('extrinsic_coeff',1.0) config.intrinsic_coeff = params.get('intrinsic_coeff',0.01) config.one_step_model = getattr(networks, 'one_step_model').one_step_model config.num_models = params.get('num_models', 5) config.bootstrap = params.get('bootstrap', True) config.intrinsic_reward_scale = params.get('intrinsic_reward_scale', 10000) config.ensemble_loss_scale = params.get('ensemble_loss_scale', 1) config.freeze_extrinsic_heads = params.get('freeze_extrinsic_heads', True) config.ensemble_model_type = params.get('ensemble_model_type', 1) config.model_width_factor = params.get('model_width_factor', 1) config.adaptation = params.get('adaptation', False) config.encoder_feature_shape = params.get('encoder_feature_shape', 1024) config.normalize_reward = params.get('normalize_reward', False) config.use_max_objective = params.get('use_max_objective', False) config.use_reinforce = params.get('use_reinforce',False) if not config.use_max_objective: assert config.use_reinforce == False config.random_run = False if 'random' in params.get('defaults', ['dreamer']): config.random_run = True config.freeze_extrinsic_heads = params.get('freeze_extrinsic_heads', True) config.encoder_feature_shape = params.get('encoder_feature_shape', 1024) config.adaptation = params.get('adaptation', False) config.use_data_ratio = params.get('use_data_ratio', False) config.exploration_episodes = params.get('exploration_episodes', 0) if config.adaptation: config.adaptation_step = params.get('adaptation_step', 5e6) config.secondary_horizon = params.get('secondary_horizon', 15) config.secondary_action_lr = params.get('secondary_action_lr', 8e-5) config.secondary_value_lr = params.get('secondary_value_lr', 8e-5) config.secondary_train_step = params.get('secondary_train_step', 50000) if params.get('value_head', True): config.heads.value = tools.bind( config.head_network, num_layers=params.get('value_layers', 3), data_shape=[], dist=params.get('value_dist', 'normal')) if params.get('curious_value_head', False): config.heads.curious_value = tools.bind( config.head_network, num_layers=params.get('value_layers', 3), data_shape=[], dist=params.get('value_dist', 'normal')) if params.get('value_target_head', False): config.heads.value_target = tools.bind( config.head_network, num_layers=params.get('value_layers', 3), data_shape=[], stop_gradient=True, dist=params.get('value_dist', 'normal')) if params.get('return_head', False): config.heads['return'] = tools.bind( config.head_network, activation=config.activation) if params.get('action_head', True): config.heads.action = tools.bind( config.head_network, num_layers=params.get('action_layers', 4), mean_activation=ACTIVATIONS[ params.get('action_mean_activation', 'none')], dist=params.get('action_head_dist', 'tanh_normal_tanh'), std=params.get('action_head_std', 'learned'), min_std=params.get('action_head_min_std', 1e-4), init_std=params.get('action_head_init_std', 5.0)) if params.get('curious_action_head', False): config.heads.curious_action = tools.bind( config.head_network, num_layers=params.get('action_layers', 4), mean_activation=ACTIVATIONS[ params.get('action_mean_activation', 'none')], dist=params.get('action_head_dist', 'tanh_normal_tanh'), std=params.get('action_head_std', 'learned'), min_std=params.get('action_head_min_std', 1e-4), init_std=params.get('action_head_init_std', 5.0)) if params.get('action_target_head', False): config.heads.action_target = tools.bind( config.head_network, num_layers=params.get('action_layers', 4), stop_gradient=True, mean_activation=ACTIVATIONS[ params.get('action_mean_activation', 'none')], dist=params.get('action_head_dist', 'tanh_normal_tanh'), std=params.get('action_head_std', 'learned'), min_std=params.get('action_head_min_std', 1e-4), init_std=params.get('action_head_init_std', 5.0)) if params.get('cpc_head', False): config.heads.cpc = config.head_network.copy( dist=params.get('cpc_head_dist', 'normal'), std=params.get('cpc_head_std', 'learned'), num_layers=params.get('cpc_head_layers', 3)) image_head = params.get('image_head', 'conv') if image_head == 'conv': config.heads.image = tools.bind( networks.conv.decoder, std=params.get('image_head_std', 1.0)) else: raise NotImplementedError(image_head) hidden_size = params.get('model_size', 400) state_size = params.get('state_size', 60) model = params.get('model', 'rssm') if model == 'rssm': config.cell = tools.bind( models.RSSM, state_size, hidden_size, hidden_size, params.get('future_rnn', True), params.get('mean_only', False), params.get('min_stddev', 1e-1), config.activation, params.get('model_layers', 1)) else: raise NotImplementedError(model) return config
def define_model(logdir, metrics, data, trainer, config): print('Build TensorFlow compute graph.') dependencies = [] cleanups = [] step = trainer.step global_step = trainer.global_step phase = trainer.phase timestamp = tf.py_func( lambda: datetime.datetime.utcnow().strftime('%Y%m%dT%H%M%S'), [], tf.string) dependencies.append( metrics.set_tags(global_step=global_step, step=step, phase=phase, time=timestamp)) # Instantiate network blocks. Note, this initialization would be expensive # when using tf.function since it would run at every step. try: cell = config.cell() except TypeError: cell = config.cell(action_size=data['action'].shape[-1].value) one_step_models = [] kwargs = dict(create_scope_now_=True) kwargs['encoder_feature_shape'] = config.encoder_feature_shape encoder = tf.make_template('encoder', config.encoder, **kwargs) heads = tools.AttrDict(_unlocked=True) raw_dummy_features = cell.features_from_state( cell.zero_state(1, tf.float32))[:, None] for key, head in config.heads.items(): name = 'head_{}'.format(key) kwargs = dict(create_scope_now_=True) if key in data: kwargs['data_shape'] = data[key].shape[2:].as_list() if key == 'reward_int': kwargs['data_shape'] = data['reward'].shape[2:].as_list() if key == 'action_target': kwargs['data_shape'] = data['action'].shape[2:].as_list() if key == 'curious_action': kwargs['data_shape'] = data['action'].shape[2:].as_list() if key == 'cpc': kwargs['data_shape'] = [cell.feature_size] dummy_features = encoder(data)[:1, :1] else: dummy_features = raw_dummy_features heads[key] = tf.make_template(name, head, **kwargs) heads[key](dummy_features) # Initialize weights. if config.curious_run or config.combination_run: for mdl in range(config.num_models): with tf.variable_scope('one_step_model_' + str(mdl)): name = 'one_step_model_' + str(mdl) kwargs = dict(create_scope_now_=True) kwargs['max_objective'] = config.use_max_objective if config.ensemble_model_type == 1: kwargs['data_shape'] = [config.encoder_feature_shape] elif config.ensemble_model_type == 2: kwargs['data_shape'] = [ tools.shape(cell.zero_state(1, tf.float32)['belief'])[-1] ] elif config.ensemble_model_type == 3: kwargs['data_shape'] = [ tools.shape(cell.zero_state(1, tf.float32)['sample'])[-1] ] elif config.ensemble_model_type == 4: kwargs['data_shape'] = [tools.shape(dummy_features)[-1]] kwargs['model_width_factor'] = config.model_width_factor one_step_models.append( tf.make_template(name, config.one_step_model, **kwargs)) # Update target networks. if 'value_target' in heads: dependencies.append( tools.track_network(trainer, config.batch_shape[0], r'.*/head_value/.*', r'.*/head_value_target/.*', config.value_target_period, config.value_target_update)) if 'value_target_2' in heads: dependencies.append( tools.track_network(trainer, config.batch_shape[0], r'.*/head_value/.*', r'.*/head_value_target_2/.*', config.value_target_period, config.value_target_update)) if 'action_target' in heads: dependencies.append( tools.track_network(trainer, config.batch_shape[0], r'.*/head_action/.*', r'.*/head_action_target/.*', config.action_target_period, config.action_target_update)) # Apply and optimize model. embedded = encoder(data) with tf.control_dependencies(dependencies): embedded = tf.identity(embedded) graph = tools.AttrDict(locals()) prior, posterior = tools.unroll.closed_loop(cell, embedded, data['action'], config.debug) objectives = utility.compute_objectives(posterior, prior, data, graph, config) summaries, grad_norms = utility.apply_optimizers(objectives, trainer, config) dependencies += summaries # Active data collection. with tf.variable_scope('collection'): with tf.control_dependencies( dependencies): # Make sure to train first. for name, params in config.train_collects.items(): schedule = tools.schedule.binary(step, config.batch_shape[0], params.steps_after, params.steps_every, params.steps_until) summary, _ = tf.cond(tf.logical_and( tf.equal(trainer.phase, 'train'), schedule), functools.partial(utility.simulate, metrics, config, params, graph, cleanups, gif_summary=False, name=name), lambda: (tf.constant(''), tf.constant(0.0)), name='should_collect_' + name) summaries.append(summary) dependencies.append(summary) # Compute summaries. graph = tools.AttrDict(locals()) summary, score = tf.cond( trainer.log, lambda: define_summaries.define_summaries(graph, config, cleanups), lambda: (tf.constant(''), tf.zeros((0, ), tf.float32)), name='summaries') summaries = tf.summary.merge([summaries, summary]) dependencies.append( utility.print_metrics({ob.name: ob.value for ob in objectives}, step, config.print_metrics_every, 2, 'objectives')) dependencies.append( utility.print_metrics(grad_norms, step, config.print_metrics_every, 2, 'grad_norms')) dependencies.append(tf.cond(trainer.log, metrics.flush, tf.no_op)) with tf.control_dependencies(dependencies): score = tf.identity(score) return score, summaries, cleanups