예제 #1
0
def compute_objectives(posterior, prior, target, graph, config):
    raw_features = graph.cell.features_from_state(posterior)
    heads = graph.heads
    objectives = []
    for name, scale in config.loss_scales.items():
        if config.loss_scales[name] == 0.0:
            continue
        if name in config.heads and name not in config.gradient_heads:
            features = tf.stop_gradient(raw_features)
            include = r'.*/head_{}/.*'.format(name)
            exclude = None
        else:
            features = raw_features
            include = r'.*'
            exclude = None

        if name == 'divergence':
            loss = graph.cell.divergence_from_states(posterior, prior)
            if config.free_nats is not None:
                loss = tf.maximum(0.0, loss - float(config.free_nats))
            objectives.append(
                Objective('divergence', loss, min, include, exclude))

        elif name == 'overshooting':
            shape = tools.shape(graph.data['action'])
            length = tf.tile(tf.constant(shape[1])[None], [shape[0]])
            _, priors, posteriors, mask = tools.overshooting(
                graph.cell, {}, graph.embedded, graph.data['action'], length,
                config.overshooting_distance, posterior)
            posteriors, priors, mask = tools.nested.map(
                lambda x: x[:, :, 1:-1], (posteriors, priors, mask))
            if config.os_stop_posterior_grad:
                posteriors = tools.nested.map(tf.stop_gradient, posteriors)
            loss = graph.cell.divergence_from_states(posteriors, priors)
            if config.free_nats is not None:
                loss = tf.maximum(0.0, loss - float(config.free_nats))
            objectives.append(
                Objective('overshooting', loss, min, include, exclude))

        else:
            if name == 'image':
                logprob = heads[name](features).log_prob(
                    target[name][:, :, :, :, -3:])
            else:
                logprob = heads[name](features).log_prob(target[name])
            objectives.append(Objective(name, logprob, max, include, exclude))

    objectives = [
        o._replace(value=tf.reduce_mean(o.value)) for o in objectives
    ]
    return objectives
예제 #2
0
def get_overshoot_preds(graph, embedding, actions, length, predict_terms,
                        posterior):
    _, priors, posteriors, mask = tools.overshooting(graph.cell, {}, embedding,
                                                     actions, length,
                                                     predict_terms, posterior)
    posteriors, priors, mask = tools.nested.map(lambda x: x[:, :, 1:],
                                                (posteriors, priors, mask))

    context_to_use = priors[
        'sample'][:, :
                  -predict_terms]  # batch_size x effective_horizon x predict_terms x sample_size
    # TODO: is sample the right feature to use?

    return context_to_use
예제 #3
0
 def test_example(self):
     obs = tf.constant([
         [10, 20, 30, 40, 50, 60],
         [70, 80, 0, 0, 0, 0],
     ],
                       dtype=tf.float32)[:, :, None]
     prev_action = tf.constant([
         [0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
         [9.0, 0.7, 0, 0, 0, 0],
     ],
                               dtype=tf.float32)[:, :, None]
     length = tf.constant([6, 2], dtype=tf.int32)
     cell = _MockCell(1)
     _, prior, posterior, mask = overshooting(cell, obs, obs, prev_action,
                                              length, 3)
     prior = tf.squeeze(prior['obs'], 3)
     posterior = tf.squeeze(posterior['obs'], 3)
     mask = tf.to_int32(mask)
     with self.test_session():
         # Each column corresponds to a different state step, and each row
         # corresponds to a different overshooting distance from there.
         self.assertAllEqual([
             [1, 1, 1, 1, 1, 1],
             [1, 1, 1, 1, 1, 0],
             [1, 1, 1, 1, 0, 0],
             [1, 1, 1, 0, 0, 0],
         ],
                             mask.eval()[0].T)
         self.assertAllEqual([
             [1, 1, 0, 0, 0, 0],
             [1, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0],
         ],
                             mask.eval()[1].T)
         self.assertAllClose([
             [0.0, 10.1, 20.2, 30.3, 40.4, 50.5],
             [0.1, 10.3, 20.5, 30.7, 40.9, 0],
             [0.3, 10.6, 20.9, 31.2, 0, 0],
             [0.6, 11.0, 21.4, 0, 0, 0],
         ],
                             prior.eval()[0].T)
         self.assertAllClose([
             [10, 20, 30, 40, 50, 60],
             [20, 30, 40, 50, 60, 0],
             [30, 40, 50, 60, 0, 0],
             [40, 50, 60, 0, 0, 0],
         ],
                             posterior.eval()[0].T)
         self.assertAllClose([
             [9.0, 70.7, 0, 0, 0, 0],
             [9.7, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0],
         ],
                             prior.eval()[1].T)
         self.assertAllClose([
             [70, 80, 0, 0, 0, 0],
             [80, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0],
         ],
                             posterior.eval()[1].T)
예제 #4
0
 def test_nested(self):
     obs = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2)), tf.ones((3, 50, 3)))
     prev_action = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2)))
     length = tf.constant([49, 50, 3], dtype=tf.int32)
     cell = _MockCell(1)
     overshooting(cell, obs, obs, prev_action, length, 3)
예제 #5
0
def define_model(data, trainer, config):
  tf.logging.info('Build TensorFlow compute graph.')
  dependencies = []
  step = trainer.step
  global_step = trainer.global_step
  phase = trainer.phase
  should_summarize = trainer.log

  # Preprocess data.
  with tf.device('/cpu:0'):
    if config.dynamic_action_noise:
      data['action'] += tf.random_normal(
          tf.shape(data['action']), 0.0, config.dynamic_action_noise)
    prev_action = tf.concat(
        [0 * data['action'][:, :1], data['action'][:, :-1]], 1)
    obs = data.copy()
    del obs['length']

  # Instantiate network blocks.
  cell = config.cell()
  kwargs = dict()
  encoder = tf.make_template(
      'encoder', config.encoder, create_scope_now_=True, **kwargs)
  heads = {}
  for key, head in config.heads.items():
    name = 'head_{}'.format(key)
    kwargs = dict(data_shape=obs[key].shape[2:].as_list())
    heads[key] = tf.make_template(name, head, create_scope_now_=True, **kwargs)

  # Embed observations and unroll model.
  embedded = encoder(obs)
  # Separate overshooting and zero step observations because computing
  # overshooting targets for images would be expensive.
  zero_step_obs = {}
  overshooting_obs = {}
  for key, value in obs.items():
    if config.zero_step_losses.get(key):
      zero_step_obs[key] = value
    if config.overshooting_losses.get(key):
      overshooting_obs[key] = value
  assert config.overshooting <= config.batch_shape[1]
  target, prior, posterior, mask = tools.overshooting(
      cell, overshooting_obs, embedded, prev_action, data['length'],
      config.overshooting + 1)
  losses = []

  # Zero step losses.
  _, zs_prior, zs_posterior, zs_mask = tools.nested.map(
      lambda tensor: tensor[:, :, :1], (target, prior, posterior, mask))
  zs_target = {key: value[:, :, None] for key, value in zero_step_obs.items()}
  zero_step_losses = utility.compute_losses(
      config.zero_step_losses, cell, heads, step, zs_target, zs_prior,
      zs_posterior, zs_mask, config.free_nats, debug=config.debug)
  losses += [
      loss * config.zero_step_losses[name] for name, loss in
      zero_step_losses.items()]
  if 'divergence' not in zero_step_losses:
    zero_step_losses['divergence'] = tf.zeros((), dtype=tf.float32)

  # Overshooting losses.
  if config.overshooting > 1:
    os_target, os_prior, os_posterior, os_mask = tools.nested.map(
        lambda tensor: tensor[:, :, 1:-1], (target, prior, posterior, mask))
    if config.stop_os_posterior_gradient:
      os_posterior = tools.nested.map(tf.stop_gradient, os_posterior)
    overshooting_losses = utility.compute_losses(
        config.overshooting_losses, cell, heads, step, os_target, os_prior,
        os_posterior, os_mask, config.free_nats, debug=config.debug)
    losses += [
        loss * config.overshooting_losses[name] for name, loss in
        overshooting_losses.items()]
  else:
    overshooting_losses = {}
  if 'divergence' not in overshooting_losses:
    overshooting_losses['divergence'] = tf.zeros((), dtype=tf.float32)

  # Workaround for TensorFlow deadlock bug.
  loss = sum(losses)
  train_loss = tf.cond(
      tf.equal(phase, 'train'),
      lambda: loss,
      lambda: 0 * tf.get_variable('dummy_loss', (), tf.float32))
  train_summary = utility.apply_optimizers(
      train_loss, step, should_summarize, config.optimizers)
  # train_summary = tf.cond(
  #     tf.equal(phase, 'train'),
  #     lambda: utility.apply_optimizers(
  #         loss, step, should_summarize, config.optimizers),
  #     str, name='optimizers')

  # Active data collection.
  collect_summaries = []
  graph = tools.AttrDict(locals())
  with tf.variable_scope('collection'):
    should_collects = []
    for name, params in config.sim_collects.items():
      after, every = params.steps_after, params.steps_every
      should_collect = tf.logical_and(
          tf.equal(phase, 'train'),
          tools.schedule.binary(step, config.batch_shape[0], after, every))
      collect_summary, _ = tf.cond(
          should_collect,
          functools.partial(
              utility.simulate_episodes, config, params, graph, name),
          lambda: (tf.constant(''), tf.constant(0.0)),
          name='should_collect_' + params.task.name)
      should_collects.append(should_collect)
      collect_summaries.append(collect_summary)

  # Compute summaries.
  graph = tools.AttrDict(locals())
  with tf.control_dependencies(collect_summaries):
    summaries, score = tf.cond(
        should_summarize,
        lambda: define_summaries.define_summaries(graph, config),
        lambda: (tf.constant(''), tf.zeros((0,), tf.float32)),
        name='summaries')
  with tf.device('/cpu:0'):
    summaries = tf.summary.merge([summaries, train_summary])
    # summaries = tf.summary.merge(
    #     [summaries, train_summary] + collect_summaries)
    zs_entropy = (tf.reduce_sum(tools.mask(
        cell.dist_from_state(zs_posterior, zs_mask).entropy(), zs_mask)) /
        tf.reduce_sum(tf.to_float(zs_mask)))
    dependencies.append(utility.print_metrics((
        ('score', score),
        ('loss', loss),
        ('zs_entropy', zs_entropy),
        ('zs_divergence', zero_step_losses['divergence']),
    ), step, config.mean_metrics_every))
  with tf.control_dependencies(dependencies):
    score = tf.identity(score)
  return score, summaries
예제 #6
0
def define_model(data, trainer, config):
    tf.logging.info('Build TensorFlow compute graph.')
    dependencies = []
    step = trainer.step
    global_step = trainer.global_step  # tf.train.get_or_create_global_step()
    phase = trainer.phase
    should_summarize = trainer.log

    num_gpu = NUM_GPU

    #  for multi-gpu
    if num_gpu > 1:
        var_for_trainop = {}
        grads_dict = {}

        # data split for multi-gpu
        data_dict = {}
        for loss_head, optimizer_cls in config.optimizers.items():
            grads_dict[loss_head] = []
            var_for_trainop[loss_head] = []

        for gpu_i in range(num_gpu):
            data_dict[gpu_i] = {}

        for data_item in list(data.keys()):
            data_split = tf.split(data[data_item], num_gpu)
            for gpu_j in range(num_gpu):
                data_dict[gpu_j][data_item] = data_split[gpu_j]

    for gpu_k in range(num_gpu):
        with tf.device('/gpu:%s' % gpu_k):
            scope_name = r'.+shared_vars'
            with tf.name_scope('%s_%d' % ("GPU", gpu_k)):  # 'GPU'
                with tf.variable_scope(name_or_scope='shared_vars',
                                       reuse=tf.AUTO_REUSE):

                    #  for multi-gpu
                    if num_gpu > 1:
                        data = data_dict[gpu_k]

                    # Preprocess data.
                    # with tf.device('/cpu:0'):
                    if config.dynamic_action_noise:
                        data['action'] += tf.random_normal(
                            tf.shape(data['action']), 0.0,
                            config.dynamic_action_noise)
                    prev_action = tf.concat(
                        [0 * data['action'][:, :1], data['action'][:, :-1]],
                        1)  # i.e.: (0 * a1, a1, a2, ..., a49)
                    obs = data.copy()
                    del obs['length']

                    # Instantiate network blocks.
                    cell = config.cell()
                    kwargs = dict()
                    encoder = tf.make_template('encoder',
                                               config.encoder,
                                               create_scope_now_=True,
                                               **kwargs)
                    heads = {}
                    for key, head in config.heads.items(
                    ):  # heads: network of 'image', 'reward', 'state'
                        name = 'head_{}'.format(key)
                        kwargs = dict(data_shape=obs[key].shape[2:].as_list())
                        heads[key] = tf.make_template(name,
                                                      head,
                                                      create_scope_now_=True,
                                                      **kwargs)

                    # Embed observations and unroll model.
                    embedded = encoder(obs)  # encode obs['image']
                    # Separate overshooting and zero step observations because computing
                    # overshooting targets for images would be expensive.
                    zero_step_obs = {}
                    overshooting_obs = {}
                    for key, value in obs.items():
                        if config.zero_step_losses.get(key):
                            zero_step_obs[key] = value
                        if config.overshooting_losses.get(key):
                            overshooting_obs[key] = value
                    assert config.overshooting <= config.batch_shape[1]
                    target, prior, posterior, mask = tools.overshooting(  # prior:{'mean':shape(40,50,51,30), ...}; posterior:{'mean':shape(40,50,51,30), ...}
                        cell,
                        overshooting_obs,
                        embedded,
                        prev_action,
                        data[
                            'length'],  # target:{'reward':shape(40,50,51), ...}; mask:shape(40,50,51)
                        config.overshooting + 1)
                    losses = []

                    # Zero step losses.
                    _, zs_prior, zs_posterior, zs_mask = tools.nested.map(
                        lambda tensor: tensor[:, :, :1],
                        (target, prior, posterior, mask))
                    zs_target = {
                        key: value[:, :, None]
                        for key, value in zero_step_obs.items()
                    }
                    zero_step_losses = utility.compute_losses(
                        config.zero_step_losses,
                        cell,
                        heads,
                        step,
                        zs_target,
                        zs_prior,
                        zs_posterior,
                        zs_mask,
                        config.free_nats,
                        debug=config.debug)
                    losses += [
                        loss * config.zero_step_losses[name]
                        for name, loss in zero_step_losses.items()
                    ]
                    if 'divergence' not in zero_step_losses:
                        zero_step_losses['divergence'] = tf.zeros(
                            (), dtype=tf.float32)

                    # Overshooting losses.
                    if config.overshooting > 1:
                        os_target, os_prior, os_posterior, os_mask = tools.nested.map(
                            lambda tensor: tensor[:, :, 1:-1],
                            (target, prior, posterior, mask))
                        if config.stop_os_posterior_gradient:
                            os_posterior = tools.nested.map(
                                tf.stop_gradient, os_posterior)
                        overshooting_losses = utility.compute_losses(
                            config.overshooting_losses,
                            cell,
                            heads,
                            step,
                            os_target,
                            os_prior,
                            os_posterior,
                            os_mask,
                            config.free_nats,
                            debug=config.debug)
                        losses += [
                            loss * config.overshooting_losses[name]
                            for name, loss in overshooting_losses.items()
                        ]
                    else:
                        overshooting_losses = {}
                    if 'divergence' not in overshooting_losses:
                        overshooting_losses['divergence'] = tf.zeros(
                            (), dtype=tf.float32)

                    # Workaround for TensorFlow deadlock bug.
                    loss = sum(losses)
                    train_loss = tf.cond(
                        tf.equal(phase, 'train'), lambda: loss,
                        lambda: 0 * tf.get_variable('dummy_loss',
                                                    (), tf.float32))

                    #  for multi-gpu
                    if num_gpu == 1:
                        train_summary = utility.apply_optimizers(
                            train_loss, step, should_summarize,
                            config.optimizers)
                    else:
                        training_grad_dict = utility.get_grads(
                            train_loss,
                            step,
                            should_summarize,
                            config.optimizers,
                            include_var=(scope_name, ))
                        for a in grads_dict.keys():
                            grads_dict[a].append(training_grad_dict[a]["grad"])
                            if gpu_k == 0:
                                var_for_trainop[a].append(
                                    training_grad_dict[a]["var"])
                        # train_summary = tf.cond(
                        #     tf.equal(phase, 'train'),
                        #     lambda: utility.apply_optimizers(
                        #         loss, step, should_summarize, config.optimizers),
                        #     str, name='optimizers')

    #  for multi-gpu
    if num_gpu > 1:
        averaged_gradients = {}
        with tf.device('/cpu:0'):
            for a in grads_dict.keys():
                averaged_gradients[a] = average_gradients(grads_dict[a])
            train_summary = utility.apply_grads(averaged_gradients,
                                                var_for_trainop, step,
                                                should_summarize,
                                                config.optimizers)

    # Active data collection.
    collect_summaries = []
    graph = tools.AttrDict(locals())
    with tf.variable_scope('collection'):
        should_collects = []
        for name, params in config.sim_collects.items():
            after, every = params.steps_after, params.steps_every
            should_collect = tf.logical_and(
                tf.equal(phase, 'train'),
                tools.schedule.binary(step, config.batch_shape[0], after,
                                      every))
            collect_summary, score_train = tf.cond(
                should_collect,
                functools.partial(utility.simulate_episodes, config, params,
                                  graph, name),
                lambda: (tf.constant(''), tf.constant(0.0)),
                name='should_collect_' + params.task.name)
            should_collects.append(should_collect)
            collect_summaries.append(collect_summary)

    # Compute summaries.
    graph = tools.AttrDict(locals())
    with tf.control_dependencies(collect_summaries):
        summaries, score = tf.cond(
            should_summarize,
            lambda: define_summaries.define_summaries(graph, config),
            lambda: (tf.constant(''), tf.zeros((0, ), tf.float32)),
            name='summaries')
    with tf.device('/cpu:0'):
        summaries = tf.summary.merge([summaries, train_summary])
        # summaries = tf.summary.merge([summaries, train_summary] + collect_summaries)
        zs_entropy = (tf.reduce_sum(
            tools.mask(
                cell.dist_from_state(zs_posterior, zs_mask).entropy(),
                zs_mask)) / tf.reduce_sum(tf.to_float(zs_mask)))
        dependencies.append(
            utility.print_metrics((
                ('score', score_train),
                ('loss', loss),
                ('zs_entropy', zs_entropy),
                ('zs_divergence', zero_step_losses['divergence']),
            ), step, config.mean_metrics_every))
    with tf.control_dependencies(dependencies):
        score = tf.identity(score)
    return score, summaries
예제 #7
0
파일: utility.py 프로젝트: zxhuang97/planet
def compute_objectives(posterior, prior, target, graph, config, trainer):
    raw_features = graph.cell.features_from_state(posterior)
    heads = graph.heads
    objectives = []
    summaries = []
    cstr_pct = 0.0
    for name, scale in config.loss_scales.items():
        if config.loss_scales[name] == 0.0:
            continue
        if name in config.heads and name not in config.gradient_heads:
            features = tf.stop_gradient(raw_features)
            include = r'.*/head_{}/.*'.format(name)
            exclude = None
        else:
            features = raw_features
            include = r'.*'
            exclude = None

        if name == 'divergence':
            loss = graph.cell.divergence_from_states(posterior, prior)
            if config.free_nats is not None:
                loss = tf.maximum(0.0, loss - float(config.free_nats))
            objectives.append(
                Objective('divergence', loss, min, include, exclude))

        elif name == 'overshooting':
            shape = tools.shape(graph.data['action'])
            length = tf.tile(tf.constant(shape[1])[None], [shape[0]])
            _, priors, posteriors, mask = tools.overshooting(
                graph.cell, {}, graph.embedded, graph.data['action'], length,
                config.overshooting_distance, posterior)
            posteriors, priors, mask = tools.nested.map(
                lambda x: x[:, :, 1:-1], (posteriors, priors, mask))
            if config.os_stop_posterior_grad:
                posteriors = tools.nested.map(tf.stop_gradient, posteriors)
            loss = graph.cell.divergence_from_states(posteriors, priors)
            if config.free_nats is not None:
                loss = tf.maximum(0.0, loss - float(config.free_nats))
            objectives.append(
                Objective('overshooting', loss, min, include, exclude))

        elif name == 'reward' and config.r_loss == 'contra':
            pred = heads[name](features)
            if config.contra_unit == 'traj':
                print('Using traj loss')
                contra_loss, cstr_pct = contra_traj_lossV6(
                    pred, target[name], horizon=config.contra_horizon)
            elif config.contra_unit == 'weighted':
                print('Using weighted trajectory loss ', config.contra_horizon)
                contra_loss, cstr_pct = contra_traj_lossV7(
                    pred,
                    target[name],
                    horizon=config.contra_horizon,
                    temp=config.temp)
            elif config.contra_unit == 'simclr':
                print('Using simclr trajectory loss ', config.contra_horizon)
                contra_loss, cstr_pct = contra_traj_lossV8(
                    pred, target[name], horizon=config.contra_horizon)
            elif config.contra_unit == 'rank':
                print('Using ranking trajectory loss ', config.contra_horizon)
                contra_loss, cstr_pct = contra_traj_lossV9(
                    pred,
                    target[name],
                    horizon=config.contra_horizon,
                    margin=config.margin)

            objectives.append((Objective(name, contra_loss, min, include,
                                         exclude)))
        elif name == 'reward' and config.r_loss == 'l2':
            pred = heads[name](features)
            l2_loss = tf.compat.v1.losses.mean_squared_error(
                target[name], pred)
            # l2_loss = tf.nn.l2_loss(pred - target[name])
            objectives.append((Objective(name, l2_loss, min, include,
                                         exclude)))
        else:
            if not config.aug_same and config.aug:
                recon_feat = tf.concat([features, target['aug']], -1)
                print('Use recon feature ', name, recon_feat)
                logprob = heads[name](recon_feat).log_prob(target[name])
                # logprob = heads[name](features).log_prob(target['ori_img'])
            else:
                logprob = heads[name](features).log_prob(target[name])
            objectives.append(Objective(name, logprob, max, include, exclude))

    objectives = [
        o._replace(value=tf.reduce_mean(o.value)) for o in objectives
    ]

    return objectives, cstr_pct
예제 #8
0
파일: PlaNet.py 프로젝트: Aurametrix/Alg
def define_model(data, trainer, config):
  tf.logging.info('Build TensorFlow compute graph.')
  dependencies = []
  step = trainer.step
  global_step = trainer.global_step
  phase = trainer.phase
  should_summarize = trainer.log

  # Preprocess data.
  with tf.device('/cpu:0'):
    if config.dynamic_action_noise:
      data['action'] += tf.random_normal(
          tf.shape(data['action']), 0.0, config.dynamic_action_noise)
    prev_action = tf.concat(
        [0 * data['action'][:, :1], data['action'][:, :-1]], 1)
    obs = data.copy()
    del obs['length']

  # Instantiate network blocks.
  cell = config.cell()
  kwargs = dict()
  encoder = tf.make_template(
      'encoder', config.encoder, create_scope_now_=True, **kwargs)
  heads = {}
  for key, head in config.heads.items():
    name = 'head_{}'.format(key)
    kwargs = dict(data_shape=obs[key].shape[2:].as_list())
    heads[key] = tf.make_template(name, head, create_scope_now_=True, **kwargs)

  # Embed observations and unroll model.
  embedded = encoder(obs)
  # Separate overshooting and zero step observations because computing
  # overshooting targets for images would be expensive.
  zero_step_obs = {}
  overshooting_obs = {}
  for key, value in obs.items():
    if config.zero_step_losses.get(key):
      zero_step_obs[key] = value
    if config.overshooting_losses.get(key):
      overshooting_obs[key] = value
  assert config.overshooting <= config.batch_shape[1]
  target, prior, posterior, mask = tools.overshooting(
      cell, overshooting_obs, embedded, prev_action, data['length'],
      config.overshooting + 1)
  losses = []

  # Zero step losses.
  _, zs_prior, zs_posterior, zs_mask = tools.nested.map(
      lambda tensor: tensor[:, :, :1], (target, prior, posterior, mask))
  zs_target = {key: value[:, :, None] for key, value in zero_step_obs.items()}
  zero_step_losses = utility.compute_losses(
      config.zero_step_losses, cell, heads, step, zs_target, zs_prior,
      zs_posterior, zs_mask, config.free_nats, debug=config.debug)
  losses += [
      loss * config.zero_step_losses[name] for name, loss in
      zero_step_losses.items()]
  if 'divergence' not in zero_step_losses:
    zero_step_losses['divergence'] = tf.zeros((), dtype=tf.float32)

  # Overshooting losses.
  if config.overshooting > 1:
    os_target, os_prior, os_posterior, os_mask = tools.nested.map(
        lambda tensor: tensor[:, :, 1:-1], (target, prior, posterior, mask))
    if config.stop_os_posterior_gradient:
      os_posterior = tools.nested.map(tf.stop_gradient, os_posterior)
    overshooting_losses = utility.compute_losses(
        config.overshooting_losses, cell, heads, step, os_target, os_prior,
        os_posterior, os_mask, config.free_nats, debug=config.debug)
    losses += [
        loss * config.overshooting_losses[name] for name, loss in
        overshooting_losses.items()]
  else:
    overshooting_losses = {}
  if 'divergence' not in overshooting_losses:
    overshooting_losses['divergence'] = tf.zeros((), dtype=tf.float32)

  # Workaround for TensorFlow deadlock bug.
  loss = sum(losses)
  train_loss = tf.cond(
      tf.equal(phase, 'train'),
      lambda: loss,
      lambda: 0 * tf.get_variable('dummy_loss', (), tf.float32))
  train_summary = utility.apply_optimizers(
      train_loss, step, should_summarize, config.optimizers)
  # train_summary = tf.cond(
  #     tf.equal(phase, 'train'),
  #     lambda: utility.apply_optimizers(
  #         loss, step, should_summarize, config.optimizers),
  #     str, name='optimizers')

  # Active data collection.
  collect_summaries = []
  graph = tools.AttrDict(locals())
  with tf.variable_scope('collection'):
    should_collects = []
    for name, params in config.sim_collects.items():
      after, every = params.steps_after, params.steps_every
      should_collect = tf.logical_and(
          tf.equal(phase, 'train'),
          tools.schedule.binary(step, config.batch_shape[0], after, every))
      collect_summary, _ = tf.cond(
          should_collect,
          functools.partial(
              utility.simulate_episodes, config, params, graph, name),
          lambda: (tf.constant(''), tf.constant(0.0)),
          name='should_collect_' + params.task.name)
      should_collects.append(should_collect)
      collect_summaries.append(collect_summary)

  # Compute summaries.
  graph = tools.AttrDict(locals())
  with tf.control_dependencies(collect_summaries):
    summaries, score = tf.cond(
        should_summarize,
        lambda: define_summaries.define_summaries(graph, config),
        lambda: (tf.constant(''), tf.zeros((0,), tf.float32)),
        name='summaries')
  with tf.device('/cpu:0'):
    summaries = tf.summary.merge([summaries, train_summary])
    # summaries = tf.summary.merge([summaries, train_summary] + collect_summaries)
    zs_entropy = (tf.reduce_sum(tools.mask(
        cell.dist_from_state(zs_posterior, zs_mask).entropy(), zs_mask)) /
        tf.reduce_sum(tf.to_float(zs_mask)))
    dependencies.append(utility.print_metrics((
        ('score', score),
        ('loss', loss),
        ('zs_entropy', zs_entropy),
        ('zs_divergence', zero_step_losses['divergence']),
    ), step, config.mean_metrics_every))
  with tf.control_dependencies(dependencies):
    score = tf.identity(score)
  return score, summaries
예제 #9
0
def compute_objectives(posterior, prior, target, graph, config):
    raw_features = graph.cell.features_from_state(posterior)
    heads = graph.heads
    objectives = []
    cpc_logs = {}
    for name, scale in config.loss_scales.items():
        if config.loss_scales[name] == 0.0:
            continue
        if name in config.heads and name not in config.gradient_heads:
            features = tf.stop_gradient(raw_features)
            include = r'.*/head_{}/.*'.format(name)
            exclude = None
        else:
            features = raw_features
            include = r'.*'
            exclude = None

        if name == 'divergence':
            loss = graph.cell.divergence_from_states(posterior, prior)
            if config.free_nats is not None:
                loss = tf.maximum(0.0, loss - float(config.free_nats))
            objectives.append(
                Objective('divergence', loss, min, include, exclude))

        elif name == 'latent_prior':
            num_actions = 10
            prev_states_flattened = tools.nested.map(
                lambda x: tf.reshape(x, (-1, x.shape[-1].value)), posterior)
            prev_states = tools.nested.map(
                lambda x: tf.tile(x, multiples=(num_actions, 1)),
                prev_states_flattened)
            batch_size = prev_states['sample'].shape[0].value
            prev_action = tf.random.uniform(
                (batch_size, graph.data['action'].shape[-1].value),
                minval=-1,
                maxval=1)
            obs = tf.zeros(shape=[
                batch_size,
            ] + graph.embedded.shape[2:].as_list())
            use_obs = tf.zeros((batch_size, 1), tf.bool)
            (next_states, _), _ = graph.cell((obs, prev_action, use_obs),
                                             prev_states)
            if not config.latent_prior_marginal:
                loss = graph.cell.divergence_from_states(
                    prev_states, next_states)
            else:
                samples_next_state = tf.reshape(
                    next_states['sample'],
                    shape=(batch_size // num_actions, num_actions, -1))
                samples_next_state_mean = tf.reduce_mean(samples_next_state,
                                                         axis=1)
                samples_current_state = tf.stop_gradient(
                    prev_states_flattened['sample'])
                loss = tf.reduce_mean(
                    tf.reduce_sum(tf.square(samples_next_state_mean -
                                            samples_current_state),
                                  axis=-1))
            objectives.append(
                Objective('latent_prior', loss, min, include, exclude))

        elif name == 'embedding_l2':
            loss = tf.reduce_mean(
                tf.reduce_sum(tf.square(graph.embedded), axis=-1))
            objectives.append(
                Objective('embedding_l2', loss, min, include, exclude))

        elif name == 'overshooting':
            shape = tools.shape(graph.data['action'])
            length = tf.tile(tf.constant(shape[1])[None], [shape[0]])
            _, priors, posteriors, mask = tools.overshooting(
                graph.cell, {}, graph.embedded, graph.data['action'], length,
                config.overshooting_distance, posterior)
            posteriors, priors, mask = tools.nested.map(
                lambda x: x[:, :, 1:-1], (posteriors, priors, mask))
            if config.os_stop_posterior_grad:
                posteriors = tools.nested.map(tf.stop_gradient, posteriors)
            loss = graph.cell.divergence_from_states(posteriors, priors)
            if config.free_nats is not None:
                loss = tf.maximum(0.0, loss - float(config.free_nats))
            objectives.append(
                Objective('overshooting', loss, min, include, exclude))

        elif name == 'cpc':
            loss, acc, reward_loss, reward_acc, gpenalty, kernels = networks.\
              cpc(features if config.include_belief else posterior['sample'], graph, posterior, predict_terms=config.future,
                  negative_samples=config.negatives, hard_negative_samples=config.hard_negatives,
                  stack_actions=config.stack_actions, negative_actions=config.negative_actions,
                  cpc_openloop=config.cpc_openloop, gradient_penalty=config.cpc_gpenalty_scale > 0,
                  gpenalty_mode=config.gpenalty_mode)
            loss += reward_loss * config.cpc_reward_scale
            loss += gpenalty * config.cpc_gpenalty_scale
            objectives.append(Objective('cpc', loss, min, include, exclude))
            cpc_logs['acc'] = acc
            cpc_logs['reward_acc'] = reward_acc
            cpc_logs['gpenalty'] = gpenalty
            if kernels:
                for i in range(config.future):
                    cpc_logs['W_mag%d' % i] = tf.reduce_mean(
                        tf.square(kernels[i]))
        elif name == 'inverse_model':
            loss, acc = networks.inverse_model(
                features,
                graph,
                contrastive=config.action_contrastive,
                negative_samples=config.negatives)
            objectives.append(
                Objective('inverse_model', loss, min, include, exclude))
            if config.action_contrastive:
                cpc_logs['inverse_model_acc'] = acc
        else:
            logprob = heads[name](features).log_prob(target[name])
            objectives.append(Objective(name, logprob, max, include, exclude))

    objectives = [
        o._replace(value=tf.reduce_mean(o.value)) for o in objectives
    ]
    return objectives, cpc_logs
예제 #10
0
def compute_objectives(posterior, prior, target, graph, config):
    heads = graph.heads
    objectives = []

    for name, scale in config.loss_scales.items():
        features = []

        if config.loss_scales[name] == 0.0:
            continue
        if name in config.heads and name not in config.gradient_heads:
            for mdl in range(len(posterior)):
                raw_features = graph.cell[mdl].features_from_state(
                    posterior[mdl])
                features.append(tf.stop_gradient(raw_features))
            include = r'.*/head_{}/.*'.format(name)
            exclude = None
        else:
            for mdl in range(len(posterior)):
                raw_features = graph.cell[mdl].features_from_state(
                    posterior[mdl])
                features.append(raw_features)
            include = r'.*'
            exclude = None

        if name == 'divergence':
            loss = graph.cell[0].divergence_from_states(posterior[0], prior[0])
            for mdl in range(1, len(posterior)):
                loss = tf.math.add(
                    loss, graph.cell[mdl].divergence_from_states(
                        posterior[mdl], prior[mdl]))
            loss = tf.math.scalar_mul((1 / len(posterior)), loss)
            if config.free_nats is not None:
                loss = tf.maximum(0.0, loss - float(config.free_nats))
            objectives.append(
                Objective('divergence', loss, min, include, exclude))

        elif name == 'overshooting':
            assert name != 'overshooting'  #Didn't change overshooting to include ensembles
            shape = tools.shape(graph.data['action'])
            length = tf.tile(tf.constant(shape[1])[None], [shape[0]])
            _, priors, posteriors, mask = tools.overshooting(
                graph.cell[mdl], {}, graph.embedded[mdl], graph.data['action'],
                length, config.overshooting_distance, posterior)
            posteriors, priors, mask = tools.nested.map(
                lambda x: x[:, :, 1:-1], (posteriors, priors, mask))
            if config.os_stop_posterior_grad:
                posteriors = tools.nested.map(tf.stop_gradient, posteriors)
            loss = graph.cell[mdl].divergence_from_states(posteriors, priors)
            if config.free_nats is not None:
                loss = tf.maximum(0.0, loss - float(config.free_nats))
            objectives.append(
                Objective('overshooting', loss, min, include, exclude))

        else:
            bootstrap_target = tf.gather(target[name],
                                         graph.sample_with_replacement[0, :],
                                         axis=0)
            logprob = heads[name](features[0]).log_prob(bootstrap_target)
            for mdl in range(1, len(posterior)):
                bootstrap_target = tf.gather(
                    target[name],
                    graph.sample_with_replacement[mdl, :],
                    axis=0)
                logprob = tf.math.add(
                    logprob,
                    heads[name](features[mdl]).log_prob(bootstrap_target))
            logprob = tf.math.scalar_mul((1 / len(posterior)), logprob)
            objectives.append(Objective(name, logprob, max, include, exclude))
    print(objectives)
    objectives = [
        o._replace(value=tf.reduce_mean(o.value)) for o in objectives
    ]
    #assert 1==2
    return objectives