Пример #1
0
def define_model(data, trainer, config):
    tf.logging.info('Build TensorFlow compute graph.')
    dependencies = []
    cleanups = []
    step = trainer.step
    global_step = trainer.global_step
    phase = trainer.phase

    # Instantiate network blocks.
    cell = config.cell()
    kwargs = dict(create_scope_now_=True)
    encoder = tf.make_template('encoder', config.encoder, **kwargs)
    heads = tools.AttrDict(_unlocked=True)
    dummy_features = cell.features_from_state(cell.zero_state(1, tf.float32))
    for key, head in config.heads.items():
        name = 'head_{}'.format(key)
        kwargs = dict(create_scope_now_=True)
        if key in data:
            kwargs['data_shape'] = data[key].shape[2:].as_list()
        elif key == 'action_target':
            kwargs['data_shape'] = data['action'].shape[2:].as_list()
        heads[key] = tf.make_template(name, head, **kwargs)
        heads[key](dummy_features)  # Initialize weights.

    # Apply and optimize model.
    embedded = encoder(data)
    with tf.control_dependencies(dependencies):
        embedded = tf.identity(embedded)
    graph = tools.AttrDict(locals())
    prior, posterior = tools.unroll.closed_loop(cell, embedded, data['action'],
                                                config.debug)
    objectives = utility.compute_objectives(posterior, prior, data, graph,
                                            config)
    summaries, grad_norms = utility.apply_optimizers(objectives, trainer,
                                                     config)

    # Active data collection.
    with tf.variable_scope('collection'):
        with tf.control_dependencies(summaries):  # Make sure to train first.
            for name, params in config.train_collects.items():
                schedule = tools.schedule.binary(step, config.batch_shape[0],
                                                 params.steps_after,
                                                 params.steps_every,
                                                 params.steps_until)
                summary, _ = tf.cond(
                    tf.logical_and(tf.equal(trainer.phase, 'train'), schedule),
                    functools.partial(utility.simulate_episodes,
                                      config,
                                      params,
                                      graph,
                                      cleanups,
                                      expensive_summaries=False,
                                      gif_summary=False,
                                      name=name),
                    lambda: (tf.constant(''), tf.constant(0.0)),
                    name='should_collect_' + name)
                summaries.append(summary)

    # Compute summaries.
    graph = tools.AttrDict(locals())
    summaries, score, prediction, truth = define_summaries.define_summaries(
        graph, config, cleanups)  #tf.cond(
    # trainer.log,
    # lambda: define_summaries.define_summaries(graph, config, cleanups),
    # lambda: (tf.constant(''), tf.zeros((0,), tf.float32), tf.zeros((8,), tf.float32)),
    # name='summaries')
    summaries = tf.summary.merge([summaries, summary])
    dependencies.append(
        utility.print_metrics({ob.name: ob.value
                               for ob in objectives}, step,
                              config.print_metrics_every, 'objectives'))
    dependencies.append(
        utility.print_metrics(grad_norms, step, config.print_metrics_every,
                              'grad_norms'))
    with tf.control_dependencies(dependencies):
        score = tf.identity(score)
    return score, summaries, cleanups, prediction, truth
Пример #2
0
def define_model(data, trainer, config):
    tf.logging.info('Build TensorFlow compute graph.')
    dependencies = []
    cleanups = []
    step = trainer.step
    global_step = trainer.global_step
    phase = trainer.phase

    #Disagreement additions

    cell = []
    for mdl in range(config.num_models):
        with tf.variable_scope('model_no' + str(mdl)):
            cell.append(config.cell())
            kwargs = dict(create_scope_now_=True)

    encoder = tf.make_template('encoder', config.encoder, **kwargs)
    #heads = tools.AttrDict(_unlocked=True)
    heads = tools.AttrDict(_unlocked=True)
    #dummy_features = cell.features_from_state(cell.zero_state(1, tf.float32))
    dummy_features = cell[0].features_from_state(cell[0].zero_state(
        1, tf.float32))

    for key, head in config.heads.items():
        print('KEYHEAD', key)
        name = 'head_{}'.format(key)
        kwargs = dict(create_scope_now_=True)
        if key in data:
            kwargs['data_shape'] = data[key].shape[2:].as_list()
        elif key == 'action_target':
            kwargs['data_shape'] = data['action'].shape[2:].as_list()
        #heads[key] = tf.make_template(name, head, **kwargs)
        heads[key] = tf.make_template(name, head, **kwargs)
        heads[key](dummy_features)  # Initialize weights.

    embedded = encoder(data)
    with tf.control_dependencies(dependencies):
        embedded = tf.identity(embedded)

    graph = tools.AttrDict(locals())
    posterior = []
    prior = []

    bagging_size = int(config.batch_shape[0])
    sample_with_replacement = tf.random.uniform(
        [config.num_models, bagging_size],
        minval=0,
        maxval=config.batch_shape[0],
        dtype=tf.int32)

    for mdl in range(config.num_models):
        with tf.variable_scope('model_no' + str(mdl)):
            bootstrap_action_data = tf.gather(data['action'],
                                              sample_with_replacement[mdl, :],
                                              axis=0)
            bootstrap_embedded = tf.gather(embedded,
                                           sample_with_replacement[mdl, :],
                                           axis=0)
            tmp_prior, tmp_posterior = tools.unroll.closed_loop(
                cell[mdl], bootstrap_embedded, bootstrap_action_data,
                config.debug)
            prior.append(tmp_prior)
            posterior.append(tmp_posterior)

    graph = tools.AttrDict(locals())
    objectives = utility.compute_objectives(posterior, prior, data, graph,
                                            config)

    summaries, grad_norms = utility.apply_optimizers(objectives, trainer,
                                                     config)

    graph = tools.AttrDict(locals())
    # Active data collection.
    with tf.variable_scope('collection'):
        with tf.control_dependencies(summaries):  # Make sure to train first.
            for name, params in config.train_collects.items():
                schedule = tools.schedule.binary(step, config.batch_shape[0],
                                                 params.steps_after,
                                                 params.steps_every,
                                                 params.steps_until)
                summary, _ = tf.cond(
                    tf.logical_and(tf.equal(trainer.phase, 'train'), schedule),
                    functools.partial(utility.simulate_episodes,
                                      config,
                                      params,
                                      graph,
                                      cleanups,
                                      expensive_summaries=False,
                                      gif_summary=False,
                                      name=name),
                    lambda: (tf.constant(''), tf.constant(0.0)),
                    name='should_collect_' + name)
                summaries.append(summary)
    print('AFTER ACTIVE DATA COLLECT')
    # Compute summaries.
    graph = tools.AttrDict(locals())
    # for k,v in graph.items():
    #     print('KEEY',k)
    #assert 1==2
    #TODO: Determine if summary from one model is enough
    summary, score = tf.cond(
        trainer.log,
        lambda: define_summaries.define_summaries(graph, config, cleanups),
        lambda: (tf.constant(''), tf.zeros((0, ), tf.float32)),
        name='summaries')
    summaries = tf.summary.merge([summaries, summary])
    #TODO: Determine if objective and grad norm printed from only one model is enough
    # Objectives
    dependencies.append(
        utility.print_metrics({ob.name: ob.value
                               for ob in objectives}, step,
                              config.print_metrics_every, 'objectives'))
    dependencies.append(
        utility.print_metrics(grad_norms, step, config.print_metrics_every,
                              'grad_norms'))
    with tf.control_dependencies(dependencies):
        score = tf.identity(score)
    print('Code runs?')
    #assert 1==2
    return score, summaries, cleanups
Пример #3
0
def define_model(data, trainer, config):
    tf.logging.info('Build TensorFlow compute graph.')
    dependencies = []
    step = trainer.step
    global_step = trainer.global_step  # tf.train.get_or_create_global_step()
    phase = trainer.phase
    should_summarize = trainer.log

    num_gpu = NUM_GPU

    #  for multi-gpu
    if num_gpu > 1:
        var_for_trainop = {}
        grads_dict = {}

        # data split for multi-gpu
        data_dict = {}
        for loss_head, optimizer_cls in config.optimizers.items():
            grads_dict[loss_head] = []
            var_for_trainop[loss_head] = []

        for gpu_i in range(num_gpu):
            data_dict[gpu_i] = {}

        for data_item in list(data.keys()):
            data_split = tf.split(data[data_item], num_gpu)
            for gpu_j in range(num_gpu):
                data_dict[gpu_j][data_item] = data_split[gpu_j]

    for gpu_k in range(num_gpu):
        with tf.device('/gpu:%s' % gpu_k):
            scope_name = r'.+shared_vars'
            with tf.name_scope('%s_%d' % ("GPU", gpu_k)):  # 'GPU'
                with tf.variable_scope(name_or_scope='shared_vars',
                                       reuse=tf.AUTO_REUSE):

                    #  for multi-gpu
                    if num_gpu > 1:
                        data = data_dict[gpu_k]

                    # Preprocess data.
                    # with tf.device('/cpu:0'):
                    if config.dynamic_action_noise:
                        data['action'] += tf.random_normal(
                            tf.shape(data['action']), 0.0,
                            config.dynamic_action_noise)
                    prev_action = tf.concat(
                        [0 * data['action'][:, :1], data['action'][:, :-1]],
                        1)  # i.e.: (0 * a1, a1, a2, ..., a49)
                    obs = data.copy()
                    del obs['length']

                    # Instantiate network blocks.
                    cell = config.cell()
                    kwargs = dict()
                    encoder = tf.make_template('encoder',
                                               config.encoder,
                                               create_scope_now_=True,
                                               **kwargs)
                    heads = {}
                    for key, head in config.heads.items(
                    ):  # heads: network of 'image', 'reward', 'state'
                        name = 'head_{}'.format(key)
                        kwargs = dict(data_shape=obs[key].shape[2:].as_list())
                        heads[key] = tf.make_template(name,
                                                      head,
                                                      create_scope_now_=True,
                                                      **kwargs)

                    # Embed observations and unroll model.
                    embedded = encoder(obs)  # encode obs['image']
                    # Separate overshooting and zero step observations because computing
                    # overshooting targets for images would be expensive.
                    zero_step_obs = {}
                    overshooting_obs = {}
                    for key, value in obs.items():
                        if config.zero_step_losses.get(key):
                            zero_step_obs[key] = value
                        if config.overshooting_losses.get(key):
                            overshooting_obs[key] = value
                    assert config.overshooting <= config.batch_shape[1]
                    target, prior, posterior, mask = tools.overshooting(  # prior:{'mean':shape(40,50,51,30), ...}; posterior:{'mean':shape(40,50,51,30), ...}
                        cell,
                        overshooting_obs,
                        embedded,
                        prev_action,
                        data[
                            'length'],  # target:{'reward':shape(40,50,51), ...}; mask:shape(40,50,51)
                        config.overshooting + 1)
                    losses = []

                    # Zero step losses.
                    _, zs_prior, zs_posterior, zs_mask = tools.nested.map(
                        lambda tensor: tensor[:, :, :1],
                        (target, prior, posterior, mask))
                    zs_target = {
                        key: value[:, :, None]
                        for key, value in zero_step_obs.items()
                    }
                    zero_step_losses = utility.compute_losses(
                        config.zero_step_losses,
                        cell,
                        heads,
                        step,
                        zs_target,
                        zs_prior,
                        zs_posterior,
                        zs_mask,
                        config.free_nats,
                        debug=config.debug)
                    losses += [
                        loss * config.zero_step_losses[name]
                        for name, loss in zero_step_losses.items()
                    ]
                    if 'divergence' not in zero_step_losses:
                        zero_step_losses['divergence'] = tf.zeros(
                            (), dtype=tf.float32)

                    # Overshooting losses.
                    if config.overshooting > 1:
                        os_target, os_prior, os_posterior, os_mask = tools.nested.map(
                            lambda tensor: tensor[:, :, 1:-1],
                            (target, prior, posterior, mask))
                        if config.stop_os_posterior_gradient:
                            os_posterior = tools.nested.map(
                                tf.stop_gradient, os_posterior)
                        overshooting_losses = utility.compute_losses(
                            config.overshooting_losses,
                            cell,
                            heads,
                            step,
                            os_target,
                            os_prior,
                            os_posterior,
                            os_mask,
                            config.free_nats,
                            debug=config.debug)
                        losses += [
                            loss * config.overshooting_losses[name]
                            for name, loss in overshooting_losses.items()
                        ]
                    else:
                        overshooting_losses = {}
                    if 'divergence' not in overshooting_losses:
                        overshooting_losses['divergence'] = tf.zeros(
                            (), dtype=tf.float32)

                    # Workaround for TensorFlow deadlock bug.
                    loss = sum(losses)
                    train_loss = tf.cond(
                        tf.equal(phase, 'train'), lambda: loss,
                        lambda: 0 * tf.get_variable('dummy_loss',
                                                    (), tf.float32))

                    #  for multi-gpu
                    if num_gpu == 1:
                        train_summary = utility.apply_optimizers(
                            train_loss, step, should_summarize,
                            config.optimizers)
                    else:
                        training_grad_dict = utility.get_grads(
                            train_loss,
                            step,
                            should_summarize,
                            config.optimizers,
                            include_var=(scope_name, ))
                        for a in grads_dict.keys():
                            grads_dict[a].append(training_grad_dict[a]["grad"])
                            if gpu_k == 0:
                                var_for_trainop[a].append(
                                    training_grad_dict[a]["var"])
                        # train_summary = tf.cond(
                        #     tf.equal(phase, 'train'),
                        #     lambda: utility.apply_optimizers(
                        #         loss, step, should_summarize, config.optimizers),
                        #     str, name='optimizers')

    #  for multi-gpu
    if num_gpu > 1:
        averaged_gradients = {}
        with tf.device('/cpu:0'):
            for a in grads_dict.keys():
                averaged_gradients[a] = average_gradients(grads_dict[a])
            train_summary = utility.apply_grads(averaged_gradients,
                                                var_for_trainop, step,
                                                should_summarize,
                                                config.optimizers)

    # Active data collection.
    collect_summaries = []
    graph = tools.AttrDict(locals())
    with tf.variable_scope('collection'):
        should_collects = []
        for name, params in config.sim_collects.items():
            after, every = params.steps_after, params.steps_every
            should_collect = tf.logical_and(
                tf.equal(phase, 'train'),
                tools.schedule.binary(step, config.batch_shape[0], after,
                                      every))
            collect_summary, score_train = tf.cond(
                should_collect,
                functools.partial(utility.simulate_episodes, config, params,
                                  graph, name),
                lambda: (tf.constant(''), tf.constant(0.0)),
                name='should_collect_' + params.task.name)
            should_collects.append(should_collect)
            collect_summaries.append(collect_summary)

    # Compute summaries.
    graph = tools.AttrDict(locals())
    with tf.control_dependencies(collect_summaries):
        summaries, score = tf.cond(
            should_summarize,
            lambda: define_summaries.define_summaries(graph, config),
            lambda: (tf.constant(''), tf.zeros((0, ), tf.float32)),
            name='summaries')
    with tf.device('/cpu:0'):
        summaries = tf.summary.merge([summaries, train_summary])
        # summaries = tf.summary.merge([summaries, train_summary] + collect_summaries)
        zs_entropy = (tf.reduce_sum(
            tools.mask(
                cell.dist_from_state(zs_posterior, zs_mask).entropy(),
                zs_mask)) / tf.reduce_sum(tf.to_float(zs_mask)))
        dependencies.append(
            utility.print_metrics((
                ('score', score_train),
                ('loss', loss),
                ('zs_entropy', zs_entropy),
                ('zs_divergence', zero_step_losses['divergence']),
            ), step, config.mean_metrics_every))
    with tf.control_dependencies(dependencies):
        score = tf.identity(score)
    return score, summaries
Пример #4
0
def define_model(data, trainer, config):
  tf.logging.info('Build TensorFlow compute graph.')
  dependencies = []
  step = trainer.step
  global_step = trainer.global_step
  phase = trainer.phase
  should_summarize = trainer.log

  # Preprocess data.
  with tf.device('/cpu:0'):
    if config.dynamic_action_noise:
      data['action'] += tf.random_normal(
          tf.shape(data['action']), 0.0, config.dynamic_action_noise)
    prev_action = tf.concat(
        [0 * data['action'][:, :1], data['action'][:, :-1]], 1)
    obs = data.copy()
    del obs['length']

  # Instantiate network blocks.
  cell = config.cell()
  kwargs = dict()
  encoder = tf.make_template(
      'encoder', config.encoder, create_scope_now_=True, **kwargs)
  heads = {}
  for key, head in config.heads.items():
    name = 'head_{}'.format(key)
    kwargs = dict(data_shape=obs[key].shape[2:].as_list())
    heads[key] = tf.make_template(name, head, create_scope_now_=True, **kwargs)

  # Embed observations and unroll model.
  embedded = encoder(obs)
  # Separate overshooting and zero step observations because computing
  # overshooting targets for images would be expensive.
  zero_step_obs = {}
  overshooting_obs = {}
  for key, value in obs.items():
    if config.zero_step_losses.get(key):
      zero_step_obs[key] = value
    if config.overshooting_losses.get(key):
      overshooting_obs[key] = value
  assert config.overshooting <= config.batch_shape[1]
  target, prior, posterior, mask = tools.overshooting(
      cell, overshooting_obs, embedded, prev_action, data['length'],
      config.overshooting + 1)
  losses = []

  # Zero step losses.
  _, zs_prior, zs_posterior, zs_mask = tools.nested.map(
      lambda tensor: tensor[:, :, :1], (target, prior, posterior, mask))
  zs_target = {key: value[:, :, None] for key, value in zero_step_obs.items()}
  zero_step_losses = utility.compute_losses(
      config.zero_step_losses, cell, heads, step, zs_target, zs_prior,
      zs_posterior, zs_mask, config.free_nats, debug=config.debug)
  losses += [
      loss * config.zero_step_losses[name] for name, loss in
      zero_step_losses.items()]
  if 'divergence' not in zero_step_losses:
    zero_step_losses['divergence'] = tf.zeros((), dtype=tf.float32)

  # Overshooting losses.
  if config.overshooting > 1:
    os_target, os_prior, os_posterior, os_mask = tools.nested.map(
        lambda tensor: tensor[:, :, 1:-1], (target, prior, posterior, mask))
    if config.stop_os_posterior_gradient:
      os_posterior = tools.nested.map(tf.stop_gradient, os_posterior)
    overshooting_losses = utility.compute_losses(
        config.overshooting_losses, cell, heads, step, os_target, os_prior,
        os_posterior, os_mask, config.free_nats, debug=config.debug)
    losses += [
        loss * config.overshooting_losses[name] for name, loss in
        overshooting_losses.items()]
  else:
    overshooting_losses = {}
  if 'divergence' not in overshooting_losses:
    overshooting_losses['divergence'] = tf.zeros((), dtype=tf.float32)

  # Workaround for TensorFlow deadlock bug.
  loss = sum(losses)
  train_loss = tf.cond(
      tf.equal(phase, 'train'),
      lambda: loss,
      lambda: 0 * tf.get_variable('dummy_loss', (), tf.float32))
  train_summary = utility.apply_optimizers(
      train_loss, step, should_summarize, config.optimizers)
  # train_summary = tf.cond(
  #     tf.equal(phase, 'train'),
  #     lambda: utility.apply_optimizers(
  #         loss, step, should_summarize, config.optimizers),
  #     str, name='optimizers')

  # Active data collection.
  collect_summaries = []
  graph = tools.AttrDict(locals())
  with tf.variable_scope('collection'):
    should_collects = []
    for name, params in config.sim_collects.items():
      after, every = params.steps_after, params.steps_every
      should_collect = tf.logical_and(
          tf.equal(phase, 'train'),
          tools.schedule.binary(step, config.batch_shape[0], after, every))
      collect_summary, _ = tf.cond(
          should_collect,
          functools.partial(
              utility.simulate_episodes, config, params, graph, name),
          lambda: (tf.constant(''), tf.constant(0.0)),
          name='should_collect_' + params.task.name)
      should_collects.append(should_collect)
      collect_summaries.append(collect_summary)

  # Compute summaries.
  graph = tools.AttrDict(locals())
  with tf.control_dependencies(collect_summaries):
    summaries, score = tf.cond(
        should_summarize,
        lambda: define_summaries.define_summaries(graph, config),
        lambda: (tf.constant(''), tf.zeros((0,), tf.float32)),
        name='summaries')
  with tf.device('/cpu:0'):
    summaries = tf.summary.merge([summaries, train_summary])
    # summaries = tf.summary.merge(
    #     [summaries, train_summary] + collect_summaries)
    zs_entropy = (tf.reduce_sum(tools.mask(
        cell.dist_from_state(zs_posterior, zs_mask).entropy(), zs_mask)) /
        tf.reduce_sum(tf.to_float(zs_mask)))
    dependencies.append(utility.print_metrics((
        ('score', score),
        ('loss', loss),
        ('zs_entropy', zs_entropy),
        ('zs_divergence', zero_step_losses['divergence']),
    ), step, config.mean_metrics_every))
  with tf.control_dependencies(dependencies):
    score = tf.identity(score)
  return score, summaries
Пример #5
0
def define_model(data, trainer, config):
  tf.logging.info('Build TensorFlow compute graph.')
  dependencies = []
  step = trainer.step
  global_step = trainer.global_step
  phase = trainer.phase
  should_summarize = trainer.log

  # Preprocess data.
  with tf.device('/cpu:0'):
    if config.dynamic_action_noise:
      data['action'] += tf.random_normal(
          tf.shape(data['action']), 0.0, config.dynamic_action_noise)
    prev_action = tf.concat(
        [0 * data['action'][:, :1], data['action'][:, :-1]], 1)
    obs = data.copy()
    del obs['length']

  # Instantiate network blocks.
  cell = config.cell()
  kwargs = dict()
  encoder = tf.make_template(
      'encoder', config.encoder, create_scope_now_=True, **kwargs)
  heads = {}
  for key, head in config.heads.items():
    name = 'head_{}'.format(key)
    kwargs = dict(data_shape=obs[key].shape[2:].as_list())
    heads[key] = tf.make_template(name, head, create_scope_now_=True, **kwargs)

  # Embed observations and unroll model.
  embedded = encoder(obs)
  # Separate overshooting and zero step observations because computing
  # overshooting targets for images would be expensive.
  zero_step_obs = {}
  overshooting_obs = {}
  for key, value in obs.items():
    if config.zero_step_losses.get(key):
      zero_step_obs[key] = value
    if config.overshooting_losses.get(key):
      overshooting_obs[key] = value
  assert config.overshooting <= config.batch_shape[1]
  target, prior, posterior, mask = tools.overshooting(
      cell, overshooting_obs, embedded, prev_action, data['length'],
      config.overshooting + 1)
  losses = []

  # Zero step losses.
  _, zs_prior, zs_posterior, zs_mask = tools.nested.map(
      lambda tensor: tensor[:, :, :1], (target, prior, posterior, mask))
  zs_target = {key: value[:, :, None] for key, value in zero_step_obs.items()}
  zero_step_losses = utility.compute_losses(
      config.zero_step_losses, cell, heads, step, zs_target, zs_prior,
      zs_posterior, zs_mask, config.free_nats, debug=config.debug)
  losses += [
      loss * config.zero_step_losses[name] for name, loss in
      zero_step_losses.items()]
  if 'divergence' not in zero_step_losses:
    zero_step_losses['divergence'] = tf.zeros((), dtype=tf.float32)

  # Overshooting losses.
  if config.overshooting > 1:
    os_target, os_prior, os_posterior, os_mask = tools.nested.map(
        lambda tensor: tensor[:, :, 1:-1], (target, prior, posterior, mask))
    if config.stop_os_posterior_gradient:
      os_posterior = tools.nested.map(tf.stop_gradient, os_posterior)
    overshooting_losses = utility.compute_losses(
        config.overshooting_losses, cell, heads, step, os_target, os_prior,
        os_posterior, os_mask, config.free_nats, debug=config.debug)
    losses += [
        loss * config.overshooting_losses[name] for name, loss in
        overshooting_losses.items()]
  else:
    overshooting_losses = {}
  if 'divergence' not in overshooting_losses:
    overshooting_losses['divergence'] = tf.zeros((), dtype=tf.float32)

  # Workaround for TensorFlow deadlock bug.
  loss = sum(losses)
  train_loss = tf.cond(
      tf.equal(phase, 'train'),
      lambda: loss,
      lambda: 0 * tf.get_variable('dummy_loss', (), tf.float32))
  train_summary = utility.apply_optimizers(
      train_loss, step, should_summarize, config.optimizers)
  # train_summary = tf.cond(
  #     tf.equal(phase, 'train'),
  #     lambda: utility.apply_optimizers(
  #         loss, step, should_summarize, config.optimizers),
  #     str, name='optimizers')

  # Active data collection.
  collect_summaries = []
  graph = tools.AttrDict(locals())
  with tf.variable_scope('collection'):
    should_collects = []
    for name, params in config.sim_collects.items():
      after, every = params.steps_after, params.steps_every
      should_collect = tf.logical_and(
          tf.equal(phase, 'train'),
          tools.schedule.binary(step, config.batch_shape[0], after, every))
      collect_summary, _ = tf.cond(
          should_collect,
          functools.partial(
              utility.simulate_episodes, config, params, graph, name),
          lambda: (tf.constant(''), tf.constant(0.0)),
          name='should_collect_' + params.task.name)
      should_collects.append(should_collect)
      collect_summaries.append(collect_summary)

  # Compute summaries.
  graph = tools.AttrDict(locals())
  with tf.control_dependencies(collect_summaries):
    summaries, score = tf.cond(
        should_summarize,
        lambda: define_summaries.define_summaries(graph, config),
        lambda: (tf.constant(''), tf.zeros((0,), tf.float32)),
        name='summaries')
  with tf.device('/cpu:0'):
    summaries = tf.summary.merge([summaries, train_summary])
    # summaries = tf.summary.merge([summaries, train_summary] + collect_summaries)
    zs_entropy = (tf.reduce_sum(tools.mask(
        cell.dist_from_state(zs_posterior, zs_mask).entropy(), zs_mask)) /
        tf.reduce_sum(tf.to_float(zs_mask)))
    dependencies.append(utility.print_metrics((
        ('score', score),
        ('loss', loss),
        ('zs_entropy', zs_entropy),
        ('zs_divergence', zero_step_losses['divergence']),
    ), step, config.mean_metrics_every))
  with tf.control_dependencies(dependencies):
    score = tf.identity(score)
  return score, summaries