def compute_objectives(posterior, prior, target, graph, config): raw_features = graph.cell.features_from_state(posterior) heads = graph.heads objectives = [] for name, scale in config.loss_scales.items(): if config.loss_scales[name] == 0.0: continue if name in config.heads and name not in config.gradient_heads: features = tf.stop_gradient(raw_features) include = r'.*/head_{}/.*'.format(name) exclude = None else: features = raw_features include = r'.*' exclude = None if name == 'divergence': loss = graph.cell.divergence_from_states(posterior, prior) if config.free_nats is not None: loss = tf.maximum(0.0, loss - float(config.free_nats)) objectives.append( Objective('divergence', loss, min, include, exclude)) elif name == 'overshooting': shape = tools.shape(graph.data['action']) length = tf.tile(tf.constant(shape[1])[None], [shape[0]]) _, priors, posteriors, mask = tools.overshooting( graph.cell, {}, graph.embedded, graph.data['action'], length, config.overshooting_distance, posterior) posteriors, priors, mask = tools.nested.map( lambda x: x[:, :, 1:-1], (posteriors, priors, mask)) if config.os_stop_posterior_grad: posteriors = tools.nested.map(tf.stop_gradient, posteriors) loss = graph.cell.divergence_from_states(posteriors, priors) if config.free_nats is not None: loss = tf.maximum(0.0, loss - float(config.free_nats)) objectives.append( Objective('overshooting', loss, min, include, exclude)) else: if name == 'image': logprob = heads[name](features).log_prob( target[name][:, :, :, :, -3:]) else: logprob = heads[name](features).log_prob(target[name]) objectives.append(Objective(name, logprob, max, include, exclude)) objectives = [ o._replace(value=tf.reduce_mean(o.value)) for o in objectives ] return objectives
def get_overshoot_preds(graph, embedding, actions, length, predict_terms, posterior): _, priors, posteriors, mask = tools.overshooting(graph.cell, {}, embedding, actions, length, predict_terms, posterior) posteriors, priors, mask = tools.nested.map(lambda x: x[:, :, 1:], (posteriors, priors, mask)) context_to_use = priors[ 'sample'][:, : -predict_terms] # batch_size x effective_horizon x predict_terms x sample_size # TODO: is sample the right feature to use? return context_to_use
def test_example(self): obs = tf.constant([ [10, 20, 30, 40, 50, 60], [70, 80, 0, 0, 0, 0], ], dtype=tf.float32)[:, :, None] prev_action = tf.constant([ [0.0, 0.1, 0.2, 0.3, 0.4, 0.5], [9.0, 0.7, 0, 0, 0, 0], ], dtype=tf.float32)[:, :, None] length = tf.constant([6, 2], dtype=tf.int32) cell = _MockCell(1) _, prior, posterior, mask = overshooting(cell, obs, obs, prev_action, length, 3) prior = tf.squeeze(prior['obs'], 3) posterior = tf.squeeze(posterior['obs'], 3) mask = tf.to_int32(mask) with self.test_session(): # Each column corresponds to a different state step, and each row # corresponds to a different overshooting distance from there. self.assertAllEqual([ [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0], ], mask.eval()[0].T) self.assertAllEqual([ [1, 1, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], ], mask.eval()[1].T) self.assertAllClose([ [0.0, 10.1, 20.2, 30.3, 40.4, 50.5], [0.1, 10.3, 20.5, 30.7, 40.9, 0], [0.3, 10.6, 20.9, 31.2, 0, 0], [0.6, 11.0, 21.4, 0, 0, 0], ], prior.eval()[0].T) self.assertAllClose([ [10, 20, 30, 40, 50, 60], [20, 30, 40, 50, 60, 0], [30, 40, 50, 60, 0, 0], [40, 50, 60, 0, 0, 0], ], posterior.eval()[0].T) self.assertAllClose([ [9.0, 70.7, 0, 0, 0, 0], [9.7, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], ], prior.eval()[1].T) self.assertAllClose([ [70, 80, 0, 0, 0, 0], [80, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], ], posterior.eval()[1].T)
def test_nested(self): obs = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2)), tf.ones((3, 50, 3))) prev_action = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2))) length = tf.constant([49, 50, 3], dtype=tf.int32) cell = _MockCell(1) overshooting(cell, obs, obs, prev_action, length, 3)
def define_model(data, trainer, config): tf.logging.info('Build TensorFlow compute graph.') dependencies = [] step = trainer.step global_step = trainer.global_step phase = trainer.phase should_summarize = trainer.log # Preprocess data. with tf.device('/cpu:0'): if config.dynamic_action_noise: data['action'] += tf.random_normal( tf.shape(data['action']), 0.0, config.dynamic_action_noise) prev_action = tf.concat( [0 * data['action'][:, :1], data['action'][:, :-1]], 1) obs = data.copy() del obs['length'] # Instantiate network blocks. cell = config.cell() kwargs = dict() encoder = tf.make_template( 'encoder', config.encoder, create_scope_now_=True, **kwargs) heads = {} for key, head in config.heads.items(): name = 'head_{}'.format(key) kwargs = dict(data_shape=obs[key].shape[2:].as_list()) heads[key] = tf.make_template(name, head, create_scope_now_=True, **kwargs) # Embed observations and unroll model. embedded = encoder(obs) # Separate overshooting and zero step observations because computing # overshooting targets for images would be expensive. zero_step_obs = {} overshooting_obs = {} for key, value in obs.items(): if config.zero_step_losses.get(key): zero_step_obs[key] = value if config.overshooting_losses.get(key): overshooting_obs[key] = value assert config.overshooting <= config.batch_shape[1] target, prior, posterior, mask = tools.overshooting( cell, overshooting_obs, embedded, prev_action, data['length'], config.overshooting + 1) losses = [] # Zero step losses. _, zs_prior, zs_posterior, zs_mask = tools.nested.map( lambda tensor: tensor[:, :, :1], (target, prior, posterior, mask)) zs_target = {key: value[:, :, None] for key, value in zero_step_obs.items()} zero_step_losses = utility.compute_losses( config.zero_step_losses, cell, heads, step, zs_target, zs_prior, zs_posterior, zs_mask, config.free_nats, debug=config.debug) losses += [ loss * config.zero_step_losses[name] for name, loss in zero_step_losses.items()] if 'divergence' not in zero_step_losses: zero_step_losses['divergence'] = tf.zeros((), dtype=tf.float32) # Overshooting losses. if config.overshooting > 1: os_target, os_prior, os_posterior, os_mask = tools.nested.map( lambda tensor: tensor[:, :, 1:-1], (target, prior, posterior, mask)) if config.stop_os_posterior_gradient: os_posterior = tools.nested.map(tf.stop_gradient, os_posterior) overshooting_losses = utility.compute_losses( config.overshooting_losses, cell, heads, step, os_target, os_prior, os_posterior, os_mask, config.free_nats, debug=config.debug) losses += [ loss * config.overshooting_losses[name] for name, loss in overshooting_losses.items()] else: overshooting_losses = {} if 'divergence' not in overshooting_losses: overshooting_losses['divergence'] = tf.zeros((), dtype=tf.float32) # Workaround for TensorFlow deadlock bug. loss = sum(losses) train_loss = tf.cond( tf.equal(phase, 'train'), lambda: loss, lambda: 0 * tf.get_variable('dummy_loss', (), tf.float32)) train_summary = utility.apply_optimizers( train_loss, step, should_summarize, config.optimizers) # train_summary = tf.cond( # tf.equal(phase, 'train'), # lambda: utility.apply_optimizers( # loss, step, should_summarize, config.optimizers), # str, name='optimizers') # Active data collection. collect_summaries = [] graph = tools.AttrDict(locals()) with tf.variable_scope('collection'): should_collects = [] for name, params in config.sim_collects.items(): after, every = params.steps_after, params.steps_every should_collect = tf.logical_and( tf.equal(phase, 'train'), tools.schedule.binary(step, config.batch_shape[0], after, every)) collect_summary, _ = tf.cond( should_collect, functools.partial( utility.simulate_episodes, config, params, graph, name), lambda: (tf.constant(''), tf.constant(0.0)), name='should_collect_' + params.task.name) should_collects.append(should_collect) collect_summaries.append(collect_summary) # Compute summaries. graph = tools.AttrDict(locals()) with tf.control_dependencies(collect_summaries): summaries, score = tf.cond( should_summarize, lambda: define_summaries.define_summaries(graph, config), lambda: (tf.constant(''), tf.zeros((0,), tf.float32)), name='summaries') with tf.device('/cpu:0'): summaries = tf.summary.merge([summaries, train_summary]) # summaries = tf.summary.merge( # [summaries, train_summary] + collect_summaries) zs_entropy = (tf.reduce_sum(tools.mask( cell.dist_from_state(zs_posterior, zs_mask).entropy(), zs_mask)) / tf.reduce_sum(tf.to_float(zs_mask))) dependencies.append(utility.print_metrics(( ('score', score), ('loss', loss), ('zs_entropy', zs_entropy), ('zs_divergence', zero_step_losses['divergence']), ), step, config.mean_metrics_every)) with tf.control_dependencies(dependencies): score = tf.identity(score) return score, summaries
def define_model(data, trainer, config): tf.logging.info('Build TensorFlow compute graph.') dependencies = [] step = trainer.step global_step = trainer.global_step # tf.train.get_or_create_global_step() phase = trainer.phase should_summarize = trainer.log num_gpu = NUM_GPU # for multi-gpu if num_gpu > 1: var_for_trainop = {} grads_dict = {} # data split for multi-gpu data_dict = {} for loss_head, optimizer_cls in config.optimizers.items(): grads_dict[loss_head] = [] var_for_trainop[loss_head] = [] for gpu_i in range(num_gpu): data_dict[gpu_i] = {} for data_item in list(data.keys()): data_split = tf.split(data[data_item], num_gpu) for gpu_j in range(num_gpu): data_dict[gpu_j][data_item] = data_split[gpu_j] for gpu_k in range(num_gpu): with tf.device('/gpu:%s' % gpu_k): scope_name = r'.+shared_vars' with tf.name_scope('%s_%d' % ("GPU", gpu_k)): # 'GPU' with tf.variable_scope(name_or_scope='shared_vars', reuse=tf.AUTO_REUSE): # for multi-gpu if num_gpu > 1: data = data_dict[gpu_k] # Preprocess data. # with tf.device('/cpu:0'): if config.dynamic_action_noise: data['action'] += tf.random_normal( tf.shape(data['action']), 0.0, config.dynamic_action_noise) prev_action = tf.concat( [0 * data['action'][:, :1], data['action'][:, :-1]], 1) # i.e.: (0 * a1, a1, a2, ..., a49) obs = data.copy() del obs['length'] # Instantiate network blocks. cell = config.cell() kwargs = dict() encoder = tf.make_template('encoder', config.encoder, create_scope_now_=True, **kwargs) heads = {} for key, head in config.heads.items( ): # heads: network of 'image', 'reward', 'state' name = 'head_{}'.format(key) kwargs = dict(data_shape=obs[key].shape[2:].as_list()) heads[key] = tf.make_template(name, head, create_scope_now_=True, **kwargs) # Embed observations and unroll model. embedded = encoder(obs) # encode obs['image'] # Separate overshooting and zero step observations because computing # overshooting targets for images would be expensive. zero_step_obs = {} overshooting_obs = {} for key, value in obs.items(): if config.zero_step_losses.get(key): zero_step_obs[key] = value if config.overshooting_losses.get(key): overshooting_obs[key] = value assert config.overshooting <= config.batch_shape[1] target, prior, posterior, mask = tools.overshooting( # prior:{'mean':shape(40,50,51,30), ...}; posterior:{'mean':shape(40,50,51,30), ...} cell, overshooting_obs, embedded, prev_action, data[ 'length'], # target:{'reward':shape(40,50,51), ...}; mask:shape(40,50,51) config.overshooting + 1) losses = [] # Zero step losses. _, zs_prior, zs_posterior, zs_mask = tools.nested.map( lambda tensor: tensor[:, :, :1], (target, prior, posterior, mask)) zs_target = { key: value[:, :, None] for key, value in zero_step_obs.items() } zero_step_losses = utility.compute_losses( config.zero_step_losses, cell, heads, step, zs_target, zs_prior, zs_posterior, zs_mask, config.free_nats, debug=config.debug) losses += [ loss * config.zero_step_losses[name] for name, loss in zero_step_losses.items() ] if 'divergence' not in zero_step_losses: zero_step_losses['divergence'] = tf.zeros( (), dtype=tf.float32) # Overshooting losses. if config.overshooting > 1: os_target, os_prior, os_posterior, os_mask = tools.nested.map( lambda tensor: tensor[:, :, 1:-1], (target, prior, posterior, mask)) if config.stop_os_posterior_gradient: os_posterior = tools.nested.map( tf.stop_gradient, os_posterior) overshooting_losses = utility.compute_losses( config.overshooting_losses, cell, heads, step, os_target, os_prior, os_posterior, os_mask, config.free_nats, debug=config.debug) losses += [ loss * config.overshooting_losses[name] for name, loss in overshooting_losses.items() ] else: overshooting_losses = {} if 'divergence' not in overshooting_losses: overshooting_losses['divergence'] = tf.zeros( (), dtype=tf.float32) # Workaround for TensorFlow deadlock bug. loss = sum(losses) train_loss = tf.cond( tf.equal(phase, 'train'), lambda: loss, lambda: 0 * tf.get_variable('dummy_loss', (), tf.float32)) # for multi-gpu if num_gpu == 1: train_summary = utility.apply_optimizers( train_loss, step, should_summarize, config.optimizers) else: training_grad_dict = utility.get_grads( train_loss, step, should_summarize, config.optimizers, include_var=(scope_name, )) for a in grads_dict.keys(): grads_dict[a].append(training_grad_dict[a]["grad"]) if gpu_k == 0: var_for_trainop[a].append( training_grad_dict[a]["var"]) # train_summary = tf.cond( # tf.equal(phase, 'train'), # lambda: utility.apply_optimizers( # loss, step, should_summarize, config.optimizers), # str, name='optimizers') # for multi-gpu if num_gpu > 1: averaged_gradients = {} with tf.device('/cpu:0'): for a in grads_dict.keys(): averaged_gradients[a] = average_gradients(grads_dict[a]) train_summary = utility.apply_grads(averaged_gradients, var_for_trainop, step, should_summarize, config.optimizers) # Active data collection. collect_summaries = [] graph = tools.AttrDict(locals()) with tf.variable_scope('collection'): should_collects = [] for name, params in config.sim_collects.items(): after, every = params.steps_after, params.steps_every should_collect = tf.logical_and( tf.equal(phase, 'train'), tools.schedule.binary(step, config.batch_shape[0], after, every)) collect_summary, score_train = tf.cond( should_collect, functools.partial(utility.simulate_episodes, config, params, graph, name), lambda: (tf.constant(''), tf.constant(0.0)), name='should_collect_' + params.task.name) should_collects.append(should_collect) collect_summaries.append(collect_summary) # Compute summaries. graph = tools.AttrDict(locals()) with tf.control_dependencies(collect_summaries): summaries, score = tf.cond( should_summarize, lambda: define_summaries.define_summaries(graph, config), lambda: (tf.constant(''), tf.zeros((0, ), tf.float32)), name='summaries') with tf.device('/cpu:0'): summaries = tf.summary.merge([summaries, train_summary]) # summaries = tf.summary.merge([summaries, train_summary] + collect_summaries) zs_entropy = (tf.reduce_sum( tools.mask( cell.dist_from_state(zs_posterior, zs_mask).entropy(), zs_mask)) / tf.reduce_sum(tf.to_float(zs_mask))) dependencies.append( utility.print_metrics(( ('score', score_train), ('loss', loss), ('zs_entropy', zs_entropy), ('zs_divergence', zero_step_losses['divergence']), ), step, config.mean_metrics_every)) with tf.control_dependencies(dependencies): score = tf.identity(score) return score, summaries
def compute_objectives(posterior, prior, target, graph, config, trainer): raw_features = graph.cell.features_from_state(posterior) heads = graph.heads objectives = [] summaries = [] cstr_pct = 0.0 for name, scale in config.loss_scales.items(): if config.loss_scales[name] == 0.0: continue if name in config.heads and name not in config.gradient_heads: features = tf.stop_gradient(raw_features) include = r'.*/head_{}/.*'.format(name) exclude = None else: features = raw_features include = r'.*' exclude = None if name == 'divergence': loss = graph.cell.divergence_from_states(posterior, prior) if config.free_nats is not None: loss = tf.maximum(0.0, loss - float(config.free_nats)) objectives.append( Objective('divergence', loss, min, include, exclude)) elif name == 'overshooting': shape = tools.shape(graph.data['action']) length = tf.tile(tf.constant(shape[1])[None], [shape[0]]) _, priors, posteriors, mask = tools.overshooting( graph.cell, {}, graph.embedded, graph.data['action'], length, config.overshooting_distance, posterior) posteriors, priors, mask = tools.nested.map( lambda x: x[:, :, 1:-1], (posteriors, priors, mask)) if config.os_stop_posterior_grad: posteriors = tools.nested.map(tf.stop_gradient, posteriors) loss = graph.cell.divergence_from_states(posteriors, priors) if config.free_nats is not None: loss = tf.maximum(0.0, loss - float(config.free_nats)) objectives.append( Objective('overshooting', loss, min, include, exclude)) elif name == 'reward' and config.r_loss == 'contra': pred = heads[name](features) if config.contra_unit == 'traj': print('Using traj loss') contra_loss, cstr_pct = contra_traj_lossV6( pred, target[name], horizon=config.contra_horizon) elif config.contra_unit == 'weighted': print('Using weighted trajectory loss ', config.contra_horizon) contra_loss, cstr_pct = contra_traj_lossV7( pred, target[name], horizon=config.contra_horizon, temp=config.temp) elif config.contra_unit == 'simclr': print('Using simclr trajectory loss ', config.contra_horizon) contra_loss, cstr_pct = contra_traj_lossV8( pred, target[name], horizon=config.contra_horizon) elif config.contra_unit == 'rank': print('Using ranking trajectory loss ', config.contra_horizon) contra_loss, cstr_pct = contra_traj_lossV9( pred, target[name], horizon=config.contra_horizon, margin=config.margin) objectives.append((Objective(name, contra_loss, min, include, exclude))) elif name == 'reward' and config.r_loss == 'l2': pred = heads[name](features) l2_loss = tf.compat.v1.losses.mean_squared_error( target[name], pred) # l2_loss = tf.nn.l2_loss(pred - target[name]) objectives.append((Objective(name, l2_loss, min, include, exclude))) else: if not config.aug_same and config.aug: recon_feat = tf.concat([features, target['aug']], -1) print('Use recon feature ', name, recon_feat) logprob = heads[name](recon_feat).log_prob(target[name]) # logprob = heads[name](features).log_prob(target['ori_img']) else: logprob = heads[name](features).log_prob(target[name]) objectives.append(Objective(name, logprob, max, include, exclude)) objectives = [ o._replace(value=tf.reduce_mean(o.value)) for o in objectives ] return objectives, cstr_pct
def define_model(data, trainer, config): tf.logging.info('Build TensorFlow compute graph.') dependencies = [] step = trainer.step global_step = trainer.global_step phase = trainer.phase should_summarize = trainer.log # Preprocess data. with tf.device('/cpu:0'): if config.dynamic_action_noise: data['action'] += tf.random_normal( tf.shape(data['action']), 0.0, config.dynamic_action_noise) prev_action = tf.concat( [0 * data['action'][:, :1], data['action'][:, :-1]], 1) obs = data.copy() del obs['length'] # Instantiate network blocks. cell = config.cell() kwargs = dict() encoder = tf.make_template( 'encoder', config.encoder, create_scope_now_=True, **kwargs) heads = {} for key, head in config.heads.items(): name = 'head_{}'.format(key) kwargs = dict(data_shape=obs[key].shape[2:].as_list()) heads[key] = tf.make_template(name, head, create_scope_now_=True, **kwargs) # Embed observations and unroll model. embedded = encoder(obs) # Separate overshooting and zero step observations because computing # overshooting targets for images would be expensive. zero_step_obs = {} overshooting_obs = {} for key, value in obs.items(): if config.zero_step_losses.get(key): zero_step_obs[key] = value if config.overshooting_losses.get(key): overshooting_obs[key] = value assert config.overshooting <= config.batch_shape[1] target, prior, posterior, mask = tools.overshooting( cell, overshooting_obs, embedded, prev_action, data['length'], config.overshooting + 1) losses = [] # Zero step losses. _, zs_prior, zs_posterior, zs_mask = tools.nested.map( lambda tensor: tensor[:, :, :1], (target, prior, posterior, mask)) zs_target = {key: value[:, :, None] for key, value in zero_step_obs.items()} zero_step_losses = utility.compute_losses( config.zero_step_losses, cell, heads, step, zs_target, zs_prior, zs_posterior, zs_mask, config.free_nats, debug=config.debug) losses += [ loss * config.zero_step_losses[name] for name, loss in zero_step_losses.items()] if 'divergence' not in zero_step_losses: zero_step_losses['divergence'] = tf.zeros((), dtype=tf.float32) # Overshooting losses. if config.overshooting > 1: os_target, os_prior, os_posterior, os_mask = tools.nested.map( lambda tensor: tensor[:, :, 1:-1], (target, prior, posterior, mask)) if config.stop_os_posterior_gradient: os_posterior = tools.nested.map(tf.stop_gradient, os_posterior) overshooting_losses = utility.compute_losses( config.overshooting_losses, cell, heads, step, os_target, os_prior, os_posterior, os_mask, config.free_nats, debug=config.debug) losses += [ loss * config.overshooting_losses[name] for name, loss in overshooting_losses.items()] else: overshooting_losses = {} if 'divergence' not in overshooting_losses: overshooting_losses['divergence'] = tf.zeros((), dtype=tf.float32) # Workaround for TensorFlow deadlock bug. loss = sum(losses) train_loss = tf.cond( tf.equal(phase, 'train'), lambda: loss, lambda: 0 * tf.get_variable('dummy_loss', (), tf.float32)) train_summary = utility.apply_optimizers( train_loss, step, should_summarize, config.optimizers) # train_summary = tf.cond( # tf.equal(phase, 'train'), # lambda: utility.apply_optimizers( # loss, step, should_summarize, config.optimizers), # str, name='optimizers') # Active data collection. collect_summaries = [] graph = tools.AttrDict(locals()) with tf.variable_scope('collection'): should_collects = [] for name, params in config.sim_collects.items(): after, every = params.steps_after, params.steps_every should_collect = tf.logical_and( tf.equal(phase, 'train'), tools.schedule.binary(step, config.batch_shape[0], after, every)) collect_summary, _ = tf.cond( should_collect, functools.partial( utility.simulate_episodes, config, params, graph, name), lambda: (tf.constant(''), tf.constant(0.0)), name='should_collect_' + params.task.name) should_collects.append(should_collect) collect_summaries.append(collect_summary) # Compute summaries. graph = tools.AttrDict(locals()) with tf.control_dependencies(collect_summaries): summaries, score = tf.cond( should_summarize, lambda: define_summaries.define_summaries(graph, config), lambda: (tf.constant(''), tf.zeros((0,), tf.float32)), name='summaries') with tf.device('/cpu:0'): summaries = tf.summary.merge([summaries, train_summary]) # summaries = tf.summary.merge([summaries, train_summary] + collect_summaries) zs_entropy = (tf.reduce_sum(tools.mask( cell.dist_from_state(zs_posterior, zs_mask).entropy(), zs_mask)) / tf.reduce_sum(tf.to_float(zs_mask))) dependencies.append(utility.print_metrics(( ('score', score), ('loss', loss), ('zs_entropy', zs_entropy), ('zs_divergence', zero_step_losses['divergence']), ), step, config.mean_metrics_every)) with tf.control_dependencies(dependencies): score = tf.identity(score) return score, summaries
def compute_objectives(posterior, prior, target, graph, config): raw_features = graph.cell.features_from_state(posterior) heads = graph.heads objectives = [] cpc_logs = {} for name, scale in config.loss_scales.items(): if config.loss_scales[name] == 0.0: continue if name in config.heads and name not in config.gradient_heads: features = tf.stop_gradient(raw_features) include = r'.*/head_{}/.*'.format(name) exclude = None else: features = raw_features include = r'.*' exclude = None if name == 'divergence': loss = graph.cell.divergence_from_states(posterior, prior) if config.free_nats is not None: loss = tf.maximum(0.0, loss - float(config.free_nats)) objectives.append( Objective('divergence', loss, min, include, exclude)) elif name == 'latent_prior': num_actions = 10 prev_states_flattened = tools.nested.map( lambda x: tf.reshape(x, (-1, x.shape[-1].value)), posterior) prev_states = tools.nested.map( lambda x: tf.tile(x, multiples=(num_actions, 1)), prev_states_flattened) batch_size = prev_states['sample'].shape[0].value prev_action = tf.random.uniform( (batch_size, graph.data['action'].shape[-1].value), minval=-1, maxval=1) obs = tf.zeros(shape=[ batch_size, ] + graph.embedded.shape[2:].as_list()) use_obs = tf.zeros((batch_size, 1), tf.bool) (next_states, _), _ = graph.cell((obs, prev_action, use_obs), prev_states) if not config.latent_prior_marginal: loss = graph.cell.divergence_from_states( prev_states, next_states) else: samples_next_state = tf.reshape( next_states['sample'], shape=(batch_size // num_actions, num_actions, -1)) samples_next_state_mean = tf.reduce_mean(samples_next_state, axis=1) samples_current_state = tf.stop_gradient( prev_states_flattened['sample']) loss = tf.reduce_mean( tf.reduce_sum(tf.square(samples_next_state_mean - samples_current_state), axis=-1)) objectives.append( Objective('latent_prior', loss, min, include, exclude)) elif name == 'embedding_l2': loss = tf.reduce_mean( tf.reduce_sum(tf.square(graph.embedded), axis=-1)) objectives.append( Objective('embedding_l2', loss, min, include, exclude)) elif name == 'overshooting': shape = tools.shape(graph.data['action']) length = tf.tile(tf.constant(shape[1])[None], [shape[0]]) _, priors, posteriors, mask = tools.overshooting( graph.cell, {}, graph.embedded, graph.data['action'], length, config.overshooting_distance, posterior) posteriors, priors, mask = tools.nested.map( lambda x: x[:, :, 1:-1], (posteriors, priors, mask)) if config.os_stop_posterior_grad: posteriors = tools.nested.map(tf.stop_gradient, posteriors) loss = graph.cell.divergence_from_states(posteriors, priors) if config.free_nats is not None: loss = tf.maximum(0.0, loss - float(config.free_nats)) objectives.append( Objective('overshooting', loss, min, include, exclude)) elif name == 'cpc': loss, acc, reward_loss, reward_acc, gpenalty, kernels = networks.\ cpc(features if config.include_belief else posterior['sample'], graph, posterior, predict_terms=config.future, negative_samples=config.negatives, hard_negative_samples=config.hard_negatives, stack_actions=config.stack_actions, negative_actions=config.negative_actions, cpc_openloop=config.cpc_openloop, gradient_penalty=config.cpc_gpenalty_scale > 0, gpenalty_mode=config.gpenalty_mode) loss += reward_loss * config.cpc_reward_scale loss += gpenalty * config.cpc_gpenalty_scale objectives.append(Objective('cpc', loss, min, include, exclude)) cpc_logs['acc'] = acc cpc_logs['reward_acc'] = reward_acc cpc_logs['gpenalty'] = gpenalty if kernels: for i in range(config.future): cpc_logs['W_mag%d' % i] = tf.reduce_mean( tf.square(kernels[i])) elif name == 'inverse_model': loss, acc = networks.inverse_model( features, graph, contrastive=config.action_contrastive, negative_samples=config.negatives) objectives.append( Objective('inverse_model', loss, min, include, exclude)) if config.action_contrastive: cpc_logs['inverse_model_acc'] = acc else: logprob = heads[name](features).log_prob(target[name]) objectives.append(Objective(name, logprob, max, include, exclude)) objectives = [ o._replace(value=tf.reduce_mean(o.value)) for o in objectives ] return objectives, cpc_logs
def compute_objectives(posterior, prior, target, graph, config): heads = graph.heads objectives = [] for name, scale in config.loss_scales.items(): features = [] if config.loss_scales[name] == 0.0: continue if name in config.heads and name not in config.gradient_heads: for mdl in range(len(posterior)): raw_features = graph.cell[mdl].features_from_state( posterior[mdl]) features.append(tf.stop_gradient(raw_features)) include = r'.*/head_{}/.*'.format(name) exclude = None else: for mdl in range(len(posterior)): raw_features = graph.cell[mdl].features_from_state( posterior[mdl]) features.append(raw_features) include = r'.*' exclude = None if name == 'divergence': loss = graph.cell[0].divergence_from_states(posterior[0], prior[0]) for mdl in range(1, len(posterior)): loss = tf.math.add( loss, graph.cell[mdl].divergence_from_states( posterior[mdl], prior[mdl])) loss = tf.math.scalar_mul((1 / len(posterior)), loss) if config.free_nats is not None: loss = tf.maximum(0.0, loss - float(config.free_nats)) objectives.append( Objective('divergence', loss, min, include, exclude)) elif name == 'overshooting': assert name != 'overshooting' #Didn't change overshooting to include ensembles shape = tools.shape(graph.data['action']) length = tf.tile(tf.constant(shape[1])[None], [shape[0]]) _, priors, posteriors, mask = tools.overshooting( graph.cell[mdl], {}, graph.embedded[mdl], graph.data['action'], length, config.overshooting_distance, posterior) posteriors, priors, mask = tools.nested.map( lambda x: x[:, :, 1:-1], (posteriors, priors, mask)) if config.os_stop_posterior_grad: posteriors = tools.nested.map(tf.stop_gradient, posteriors) loss = graph.cell[mdl].divergence_from_states(posteriors, priors) if config.free_nats is not None: loss = tf.maximum(0.0, loss - float(config.free_nats)) objectives.append( Objective('overshooting', loss, min, include, exclude)) else: bootstrap_target = tf.gather(target[name], graph.sample_with_replacement[0, :], axis=0) logprob = heads[name](features[0]).log_prob(bootstrap_target) for mdl in range(1, len(posterior)): bootstrap_target = tf.gather( target[name], graph.sample_with_replacement[mdl, :], axis=0) logprob = tf.math.add( logprob, heads[name](features[mdl]).log_prob(bootstrap_target)) logprob = tf.math.scalar_mul((1 / len(posterior)), logprob) objectives.append(Objective(name, logprob, max, include, exclude)) print(objectives) objectives = [ o._replace(value=tf.reduce_mean(o.value)) for o in objectives ] #assert 1==2 return objectives