def get_optimized_mols(model_dir, ckpt=80000):
  """Get optimized Molecules.

  Args:
    model_dir: String. model directory.
    ckpt: the checkpoint to load.

  Returns:
    List of 800 optimized molecules
  """
  hparams_file = os.path.join(model_dir, 'config.json')
  with gfile.Open(hparams_file, 'r') as f:
    hp_dict = json.load(f)
    hparams = deep_q_networks.get_hparams(**hp_dict)

  dqn = deep_q_networks.DeepQNetwork(
      input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
      q_fn=functools.partial(
          deep_q_networks.multi_layer_model, hparams=hparams),
      optimizer=hparams.optimizer,
      grad_clipping=hparams.grad_clipping,
      num_bootstrap_heads=hparams.num_bootstrap_heads,
      gamma=hparams.gamma,
      epsilon=0.0)

  tf.reset_default_graph()
  optimized_mol = []
  with tf.Session() as sess:
    dqn.build()
    model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints)
    model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt))
    for mol in all_mols:
      logging.info('Eval: %s', mol)
      environment = molecules_mdp.Molecule(
          atom_types=set(hparams.atom_types),
          init_mol=mol,
          allow_removal=hparams.allow_removal,
          allow_no_modification=hparams.allow_no_modification,
          allow_bonds_between_rings=hparams.allow_bonds_between_rings,
          allowed_ring_sizes=set(hparams.allowed_ring_sizes),
          max_steps=hparams.max_steps_per_episode,
          record_path=True)
      environment.initialize()
      if hparams.num_bootstrap_heads:
        head = np.random.randint(hparams.num_bootstrap_heads)
      else:
        head = 0
      for _ in range(hparams.max_steps_per_episode):
        steps_left = hparams.max_steps_per_episode - environment.num_steps_taken
        valid_actions = list(environment.get_valid_actions())
        observations = np.vstack([
            np.append(
                deep_q_networks.get_fingerprint(act, hparams), steps_left)
            for act in valid_actions
        ])
        action = valid_actions[dqn.get_action(
            observations, head=head, update_epsilon=0.0)]
        environment.step(action)
      optimized_mol.append(environment.get_path())
  return optimized_mol
Beispiel #2
0
def main():
    tf.reset_default_graph()

    saver = tf.Saver()

    with tf.Session() as sess:
        saver.restore(sess, "checkpoint/model.ckpt")
def run_training(hparams, environment, dqn):
    """Runs the training procedure.

  Briefly, the agent runs the action network to get an action to take in
  the environment. The state transition and reward are stored in the memory.
  Periodically the agent samples a batch of samples from the memory to
  update(train) its Q network. Note that the Q network and the action network
  share the same set of parameters, so the action network is also updated by
  the samples of (state, action, next_state, reward) batches.


  Args:
    hparams: tf.HParams. The hyper parameters of the model.
    environment: molecules.Molecule. The environment to run on.
    dqn: An instance of the DeepQNetwork class.

  Returns:
    None
  """
    summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
    tf.reset_default_graph()
    with tf.Session() as sess:
        dqn.build()
        model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints)
        # The schedule for the epsilon in epsilon greedy policy.
        exploration = schedules.PiecewiseSchedule(
            [(0, 1.0), (int(hparams.num_episodes / 2), 0.1),
             (hparams.num_episodes, 0.01)],
            outside_value=0.01)
        if hparams.prioritized:
            memory = replay_buffer.PrioritizedReplayBuffer(
                hparams.replay_buffer_size, hparams.prioritized_alpha)
            beta_schedule = schedules.LinearSchedule(
                hparams.num_episodes,
                initial_p=hparams.prioritized_beta,
                final_p=0)
        else:
            memory = replay_buffer.ReplayBuffer(hparams.replay_buffer_size)
            beta_schedule = None
        sess.run(tf.global_variables_initializer())
        sess.run(dqn.update_op)
        global_step = 0
        for episode in range(hparams.num_episodes):
            global_step = _episode(environment=environment,
                                   dqn=dqn,
                                   memory=memory,
                                   episode=episode,
                                   global_step=global_step,
                                   hparams=hparams,
                                   summary_writer=summary_writer,
                                   exploration=exploration,
                                   beta_schedule=beta_schedule)
            if (episode + 1) % hparams.update_frequency == 0:
                sess.run(dqn.update_op)
            if (episode + 1) % hparams.save_frequency == 0:
                model_saver.save(sess,
                                 os.path.join(FLAGS.model_dir, 'ckpt'),
                                 global_step=global_step)
Beispiel #4
0
    def train(self, session, dataset, train_dir):
        """
        Implement main training loop

        TIPS:
        You should also implement learning rate annealing (look into tf.train.exponential_decay)
        Considering the long time to train, you should save your model per epoch.

        More ambitious appoarch can include implement early stopping, or reload
        previous models if they have higher performance than the current one

        As suggested in the document, you should evaluate your training progress by
        printing out information every fixed number of iterations.

        We recommend you evaluate your model performance on F1 and EM instead of just
        looking at the cost.

        :param session: it should be passed in from train.py
        :param dataset: a representation of our data, in some implementations, you can
                        pass in multiple components (arguments) of one dataset to this function
        :param train_dir: path to the directory where you should save the model checkpoint
        :return:
        """

        # some free code2 to print out number of parameters in your model
        # it's always good to check!
        # you will also want to save your model parameters in train_dir
        # so that you can use your trained model to make predictions, or
        # even continue training

        
        tic = time.time()
        params = tf.trainable_variables()
        num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()), params))
        toc = time.time()
        logging.info("Number of params: %d (retreival took %f secs)" % (num_params, toc - tic))

        for e in range(self.FLAGS.epochs):
            for p, q, a in util.load_dataset("data/squad/train.ids.context", "data/squad/train.ids.question", "data/squad/train.span", self.FLAGS.batch_size, in_batches=True):

                a_s, a_e = self.one_hot_func(a)
                q, q_mask = self.mask_and_pad(q)
                p, p_mask = self.mask_and_pad(p)
                
                updates, loss = self.optimize(session, q, q_mask, p, p_mask, a_s, a_e)
                print(loss)
            # save the model
            saver = tf.Saver()


            val_loss = self.validate(p_val, q_val, a_val)

            self.evaluate_answer(session, p_val, q_val)
            self.evaluate_answer(session, q, p, sample = 100)
def prediction():
    pred,_=lstm(1)
    saver = tf.Saver(tf.global_variables())
    with tf.Session() as sess:
        #restore model
        model_file = tf.train.latest_checkpoint(modelpath)
        saver.restore(sess,model_file)
        #取一个样本作为测试样本
        pre_seq = train_x[5000,:,:]
        predict = []
        #之后的预测结果
        for i in range(1000):
            next_seq = sess.run(pred,feed_dict={batch_x:pre_seq})
            predict.append(next_seq[-1])
            pre_seq = np.vstack((pre_seq[1:],next_seq[-1]))
        plt.figure()
        plt.plot(list(range(len(normalize_data))),normalize_data,color='b')
        plt.plot(list(range(len(normalize_data), len(normalize_data) + len(predict))), predict, color='r')
        plt.show()
Beispiel #6
0
  def build_graph_from_config(self, model_config, checkpoint_path):
    """Builds the inference graph from a configuration object.

    Args:
      model_config: Object containing configuration for building the model.
      checkpoint_path: Checkpoint file or a directory containing a checkpoint
        file.

    Returns:
      restore_fn: A function such that restore_fn(sess) loads model variables
        from the checkpoint file.
    """
    tf.logging.info("Building model.")
    model = self.build_model(model_config)
    saver = model.saver
    if not saver:
      saver = tf.Saver()

    return self._create_restore_fn(checkpoint_path, saver)
Beispiel #7
0
def run_wall_clock_test(optimizer,
                        problem,
                        num_steps,
                        dataset=datasets.EMPTY_DATASET,
                        seed=None,
                        logdir=None,
                        batch_size=None):
    """Runs optimization with the given parameters and return average iter time.

  Args:
    optimizer: The tf.train.Optimizer instance
    problem: The problem to optimize (a problem_generator.Problem)
    num_steps: The number of steps to run optimization for
    dataset: The dataset to train the problem against
    seed: The seed used for drawing the initial parameters, or a list of
      numpy arrays used to explicitly initialize the parameters
    logdir: A directory containing model checkpoints. If given, then the
            parameters of the optimizer are loaded from the latest checkpoint
            in this folder.
    batch_size: The number of samples per batch.

  Returns:
    The average time in seconds for a single optimization iteration.
  """
    if dataset is None:
        dataset = datasets.EMPTY_DATASET
        batch_size = dataset.size
    else:
        # default batch size is the entire dataset
        batch_size = dataset.size if batch_size is None else batch_size

    # define the parameters of the optimization problem
    if isinstance(seed, (list, tuple)):
        # seed is a list of arrays
        params = problem_generator.init_fixed_variables(seed)
    else:
        # seed is an int or None
        params = problem.init_variables(seed)

    data_placeholder = tf.placeholder(tf.float32)
    labels_placeholder = tf.placeholder(tf.int32)

    obj = problem.objective(params, data_placeholder, labels_placeholder)
    gradients = problem.gradients(obj, params)
    vars_to_preinitialize = params

    with tf.Session(graph=tf.get_default_graph()) as sess:
        # initialize the parameter scope variables; necessary for apply_gradients
        sess.run(tf.variables_initializer(vars_to_preinitialize))
        train_op = optimizer.apply_gradients(zip(gradients, params))
        if isinstance(train_op, tuple) or isinstance(train_op, list):
            # LOL apply_gradients returns a tuple. Regular optimizers do not.
            train_op = train_op[0]
        vars_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                            scope=OPTIMIZER_SCOPE)
        vars_to_initialize = list(
            set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) -
            set(vars_to_restore) - set(vars_to_preinitialize))
        # load or initialize optimizer variables
        if logdir is not None:
            restorer = tf.Saver(var_list=vars_to_restore)
            ckpt = tf.train.latest_checkpoint(logdir)
            restorer.restore(sess, ckpt)
        else:
            sess.run(tf.variables_initializer(vars_to_restore))
        # initialize all the other variables
        sess.run(tf.variables_initializer(vars_to_initialize))

        problem.init_fn(sess)

        # generate the minibatch indices
        batch_inds = dataset.batch_indices(num_steps, batch_size)

        avg_iter_time = []
        for batch in batch_inds:
            # data to feed in
            feed = {
                data_placeholder: dataset.data[batch],
                labels_placeholder: dataset.labels[batch]
            }

            # run the optimization train operation
            start = time.time()
            sess.run([train_op], feed_dict=feed)
            avg_iter_time.append(time.time() - start)

    return np.median(np.array(avg_iter_time))
Beispiel #8
0
def test_optimizer(optimizer,
                   problem,
                   num_iter,
                   dataset=datasets.EMPTY_DATASET,
                   batch_size=None,
                   seed=None,
                   graph=None,
                   logdir=None,
                   record_every=None):
    """Tests an optimization algorithm on a given problem.

  Args:
    optimizer: Either a tf.train.Optimizer instance, or an Optimizer instance
               inheriting from trainable_optimizer.py
    problem: A Problem instance that defines an optimization problem to solve
    num_iter: The number of iterations of the optimizer to run
    dataset: The dataset to train the problem against
    batch_size: The number of samples per batch. If None (default), the
      batch size is set to the full batch (dataset.size)
    seed: A random seed used for drawing the initial parameters, or a list of
      numpy arrays used to explicitly initialize the parameters.
    graph: The tensorflow graph to execute (if None, uses the default graph)
    logdir: A directory containing model checkpoints. If given, then the
            parameters of the optimizer are loaded from the latest checkpoint
            in this folder.
    record_every: if an integer, stores the parameters, objective, and gradient
                  every recored_every iterations. If None, nothing is stored

  Returns:
    objective_values: A list of the objective values during optimization
    parameters: The parameters obtained after training
    records: A dictionary containing lists of the parameters and gradients
             during optimization saved every record_every iterations (empty if
             record_every is set to None)
  """

    if dataset is None:
        dataset = datasets.EMPTY_DATASET
        batch_size = dataset.size
    else:
        # default batch size is the entire dataset
        batch_size = dataset.size if batch_size is None else batch_size

    graph = tf.get_default_graph() if graph is None else graph
    with graph.as_default():

        # define the parameters of the optimization problem
        if isinstance(seed, (list, tuple)):
            # seed is a list of arrays
            params = problem_generator.init_fixed_variables(seed)
        else:
            # seed is an int or None
            params = problem.init_variables(seed)

        data_placeholder = tf.placeholder(tf.float32)
        labels_placeholder = tf.placeholder(tf.int32)

        # get the problem objective and gradient(s)
        obj = problem.objective(params, data_placeholder, labels_placeholder)
        gradients = problem.gradients(obj, params)

        vars_to_preinitialize = params

    with tf.Session(graph=graph) as sess:
        # initialize the parameter scope variables; necessary for apply_gradients
        sess.run(tf.variables_initializer(vars_to_preinitialize))
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # create the train operation and training variables
        try:
            train_op, real_params = optimizer.apply_gradients(
                zip(gradients, params))
            obj = problem.objective(real_params, data_placeholder,
                                    labels_placeholder)
        except TypeError:
            # If all goes well, this exception should only be thrown when we are using
            # a non-hrnn optimizer.
            train_op = optimizer.apply_gradients(zip(gradients, params))

        vars_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                            scope=OPTIMIZER_SCOPE)
        vars_to_initialize = list(
            set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) -
            set(vars_to_restore) - set(vars_to_preinitialize))
        # load or initialize optimizer variables
        if logdir is not None:
            restorer = tf.Saver(var_list=vars_to_restore)
            ckpt = tf.train.latest_checkpoint(logdir)
            restorer.restore(sess, ckpt)
        else:
            sess.run(tf.variables_initializer(vars_to_restore))
        # initialize all the other variables
        sess.run(tf.variables_initializer(vars_to_initialize))

        problem.init_fn(sess)

        # generate the minibatch indices
        batch_inds = dataset.batch_indices(num_iter, batch_size)

        # run the train operation for n iterations and save the objectives
        records = defaultdict(list)
        objective_values = []
        for itr, batch in enumerate(batch_inds):

            # data to feed in
            feed = {
                data_placeholder: dataset.data[batch],
                labels_placeholder: dataset.labels[batch]
            }
            full_feed = {
                data_placeholder: dataset.data,
                labels_placeholder: dataset.labels
            }

            # record stuff
            if record_every is not None and (itr % record_every) == 0:

                def grad_value(g):
                    if isinstance(g, tf.IndexedSlices):
                        return g.values
                    else:
                        return g

                records_fetch = {}
                for p in params:
                    for key in optimizer.get_slot_names():
                        v = optimizer.get_slot(p, key)
                        records_fetch[p.name + "_" + key] = v
                gav_fetch = [(grad_value(g), v)
                             for g, v in zip(gradients, params)]

                _, gav_eval, records_eval = sess.run(
                    (obj, gav_fetch, records_fetch), feed_dict=feed)
                full_obj_eval = sess.run([obj], feed_dict=full_feed)

                records["objective"].append(full_obj_eval)
                records["grad_norm"].append(
                    [np.linalg.norm(g.ravel()) for g, _ in gav_eval])
                records["param_norm"].append(
                    [np.linalg.norm(v.ravel()) for _, v in gav_eval])
                records["grad"].append([g for g, _ in gav_eval])
                records["param"].append([v for _, v in gav_eval])
                records["iter"].append(itr)

                for k, v in records_eval.iteritems():
                    records[k].append(v)

            # run the optimization train operation
            objective_values.append(
                sess.run([train_op, obj], feed_dict=feed)[1])

        # final parameters
        parameters = [sess.run(p) for p in params]
        coord.request_stop()
        coord.join(threads)

    return objective_values, parameters, records
                       './data/experiments/CXR1900_IM-0584-2001.png',
                       'The lateral image path')
tf.flags.DEFINE_string('model_path', './data/model/my-test-1000',
                       'The test model path')

img_frontal_path = FLAGS.img_frontal_path
img_lateral_path = FLAGS.img_lateral_path
model_path = FLAGS.model_path

config = Config()
mt = Model(is_training=False, batch_size=1)

img_frontal, img_lateral, sentence, mask = get_test_data(
    img_frontal_path, img_lateral_path, config)

saver = tf.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, model_path)
    feed_dict = {
        mt.images_frontal: img_frontal,
        mt.images_lateral: img_lateral,
        mt.sentences: sentence,
        mt.masks: mask
    }
    predicts_list = sess.run([mt.predicts], feed_dict=feed_dict)

sentence_list = get_sentences(predicts_list, config)

print('The generate report:')
for sentence in sentence_list:
Beispiel #10
0
        accuracy_val = accuracy.eval(feed_dict={X: x_batch, y: y_batch})
        saver.save(sess, "")
        accuracy_val = accuracy.eval(feed_dict={
            X: mnist.test.images,
            y: mnist.test.labels
        })

# pretrain
pretrain = False
if pretrain:
    with tf.Session() as sess:
        init.run()
        # 重新获得所有的参数
        pretrain_saver.restore(sess, "")
        # reuse 部分参数
        reuse_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES, scope="hidden[1,2,3]")
        reuse_vars_dict = dict([(var.name, var.name) for var in reuse_vars])
        original_saver = tf.Saver(reuse_vars_dict)
        for epoch in range(n_epochs):
            n_batched = n_labeled_instances // batch_size
            for iteration in range(n_batched):
                x_batch, y_batch = mnist.train.next_batch(100)
                sess.run(training_op, feed_dict={X: x_batch, y: y_batch})
            accuracy_val = accuracy.eval(feed_dict={X: x_batch, y: y_batch})
            saver.save(sess, "")
            accuracy_val = accuracy.eval(feed_dict={
                X: mnist.test.images,
                y: mnist.test.labels
            })
Beispiel #11
0
def run_test():
    """Estimates the homography between two input images.
  """
    image1 = cv2.imread(FLAGS.image1)
    image2 = cv2.imread(FLAGS.image2)
    image_list = [image1, image2]
    image_norm_list = []
    for i in range(2):
        if FLAGS.network_id == 'fmask_sem':
            image_scale = cv2.resize(image_list[i],
                                     (FLAGS.train_width, FLAGS.train_height),
                                     cv2.INTER_LANCZOS4)
        else:
            image_gray = cv2.cvtColor(image_list[i], cv2.COLOR_BGR2GRAY)
            image_scale = cv2.resize(image_gray,
                                     (FLAGS.train_width, FLAGS.train_height),
                                     cv2.INTER_LANCZOS4)

        image_norm = image_scale / 256.0 - 0.5
        image_norm_list.append(image_norm)
    if FLAGS.network_id == 'fmask_sem':
        norm_image_pair = np.expand_dims(np.concatenate(image_norm_list, 2),
                                         axis=0)
        num_channel = 3
    else:
        norm_image_pair = np.expand_dims(np.stack(image_norm_list, -1), axis=0)
        num_channel = 1

    batch_pairs = tf.placeholder(
        tf.float32,
        [1, FLAGS.train_height, FLAGS.train_width, 2 * num_channel])
    with slim.arg_scope(models.homography_arg_scope()):
        if FLAGS.network_id == 'fmask_sem':
            batch_hmg_prediction, _ = models.hier_homography_fmask_estimator(
                batch_pairs,
                num_param=8,
                num_layer=FLAGS.num_layer,
                num_level=FLAGS.num_level,
                is_training=False)
        else:
            batch_hmg_prediction, _ = models.hier_homography_estimator(
                batch_pairs,
                num_param=8,
                num_layer=FLAGS.num_layer,
                num_level=FLAGS.num_level,
                is_training=False)

    batch_warped_result, _ = hmg_util.homography_warp_per_batch(
        batch_pairs[Ellipsis, 0:num_channel],
        batch_hmg_prediction[FLAGS.num_level - 1])

    saver = tf.Saver()
    with tf.Session() as sess:
        saver.restore(sess, FLAGS.model_path)
        image_warp, homography_list = sess.run(
            [batch_warped_result, batch_hmg_prediction],
            feed_dict={batch_pairs: norm_image_pair})
        for i in range(8):
            logging.info('%f ', homography_list[FLAGS.num_level - 1][0][i])
        cv2.imwrite('%s/input0.jpg' % FLAGS.out_dir,
                    (image_norm_list[0] + 0.5) * 256)
        cv2.imwrite('%s/input1.jpg' % FLAGS.out_dir,
                    (image_norm_list[1] + 0.5) * 256)
        cv2.imwrite('%s/result.jpg' % FLAGS.out_dir,
                    (image_warp[0] + 0.5) * 256)
# TRANSFER LEARNING: reusing pretrained model
# [...] # <- construct the original model
with tf.Session() as sess:
    saver.restor(sess, "./my_original_model.ckpt")
    # [...] # <- train it on the new task
      
# but often want to reuse only a part of the trained model
# -> configure the Saver to restore only a subset of variables from the original model; e.g. to restor only hidden layers 1,2,3
# [...] # <- build new model with same definition as before for layers 1-3
init = tf.global_variables_initializer()
# get the list of all trainable variables just created with trainable=True (default), 
# keeps only those matching regular expression "hidden[123]"
reuse_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="hidden[123]")
# create a dict to map the name of each variable in the original model to its name in the new model (generally the same)
reuse_vars_dict = dict([(var.name, var.name) for var in reuse_vars])
original_saver = tf.Saver(reuse_vars_dict) # saver to restore the original model
new_saver = tf.Saver() # saver to save the new model
with tf.Session() as sess:
    sess.run(init)
    original_saver.restore("./my_original_model.ckpt") # rest1ore hidden layers 1,2,3
    # [...] -> train the new model
    new_saver.save("./my_new_model.ckpt")# save the new model    
# NB: if model from another framework, weights and biases can be assigned manually (create nodes and assign the arbitrary values)
    
# FREEZING LOW-LEVEL LAYERS: makes high-level layers more easy to train
# give the optimizer the list of variables to train (exclude variables from lower layers):
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="hidden[34]|outputs")    
training_op = optimizer.minimize(loss, var_list = train_vars) # -> layers 1 and 2 are now frozen

# the more data available, the more layers can be unfrozen
Beispiel #13
0
def train(hparams):
    """Run training loop."""
    data_iterator, clause_metadata = load_data(random_start=True)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
        # The following three lines prevent hangs during distributed training.
        vs = tf.get_variable_scope()
        if vs.caching_device is None:
            vs.set_caching_device(lambda op: op.device)

        # Build the graph.
        global_step = slim.variables.get_or_create_global_step()
        if FLAGS.model_type == 'tree':
            m = cnf_model.CNFTreeModel(data_iterator, hparams, clause_metadata)
        else:
            m = cnf_model.CNFSequenceModel(data_iterator, hparams,
                                           clause_metadata)

        variables = tf.trainable_variables()

        learning_rate = tf.train.exponential_decay(
            hparams.learning_rate,
            global_step,
            hparams.decay_steps,
            hparams.learning_rate_decay_factor,
            staircase=True)

        if hparams.optimizer == 'sgd':
            optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        elif hparams.optimizer == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate)
        elif hparams.optimizer == 'rmsprop':
            optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate,
                                                  decay=0.9,
                                                  momentum=0.9,
                                                  epsilon=1e-5)
        else:
            raise RuntimeError('Unknown optimizer %s' % hparams.optimizer)

        if FLAGS.master not in ('', 'local') and FLAGS.sync_replicas:
            replica_id = tf.constant(FLAGS.task, tf.int32, shape=())
            optimizer = tf.LegacySyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                replica_id=replica_id,
                total_num_replicas=FLAGS.worker_replicas)

        tf.contrib.deprecated.scalar_summary('lr', learning_rate)
        tf.contrib.deprecated.scalar_summary('loss', m.loss)
        for metric_name, metric_value in m.metrics.items():
            tf.contrib.deprecated.scalar_summary('metric/' + metric_name,
                                                 metric_value)

        grads_and_vars = optimizer.compute_gradients(m.loss, variables)
        if hparams.grad_max_norm > 0:
            g, v = zip(*grads_and_vars)
            g, global_norm = tf.clip_by_global_norm(g, hparams.grad_max_norm)
            tf.contrib.deprecated.scalar_summary('global_norm', global_norm)
            grads_and_vars = zip(g, v)
        train_op = optimizer.apply_gradients(grads_and_vars, global_step)
        summary_op = tf.get_summary_op()

        if FLAGS.master not in ('', 'local') and FLAGS.sync_replicas:
            init_token_op = optimizer.get_init_tokens_op()
            chief_queue_runner = optimizer.get_chief_queue_runner()

        saver = tf.Saver(keep_checkpoint_every_n_hours=1.0)

        supervisor = tf.Supervisor(
            is_chief=(FLAGS.task == 0),
            logdir=FLAGS.tf_log_dir,
            global_step=global_step,
            saver=saver,
            # We are going to compute summaries ourselves.
            summary_op=None,
            save_model_secs=FLAGS.save_model_secs,
            # But we set this so that this computes global_step/sec.
            save_summaries_secs=FLAGS.save_summaries_secs)
        sess = supervisor.prepare_or_wait_for_session(FLAGS.master)

        # TODO(ricshin):
        # Rewrite this to use supervisor.managed_session().
        # Look at how slim/learning.py handles SyncReplicas, in particular
        # init_token_op.  Use normal text summaries once they exist.
        # Use supervisor.should_stop().
        if FLAGS.task == 0:
            if FLAGS.master not in ('', 'local') and FLAGS.sync_replicas:
                supervisor.start_queue_runners(sess, [chief_queue_runner])
                sess.run(init_token_op)

            sampling_temps = [
                float(x) for x in FLAGS.sampling_temps.split(',')
            ]

            def summarize():
                try:
                    summary_strs, global_step_val = sess.run(
                        [summary_op, global_step])
                    summaries = tf.Summary.FromString(summary_strs)

                    for i, temp in itertools.product(
                            xrange(FLAGS.num_summary_samples), sampling_temps):
                        cnf = textwrap.wrap(
                            cnf_utils.unparse_cnf(m.sample(sess)))
                        summaries.value.add(
                            tag='formula_temp%g_%d' % (temp, i),
                            tensor=make_tensor_proto('\n'.join(cnf)))

                    supervisor.summary_writer.add_summary(
                        summaries.SerializeToString(), global_step_val)
                    status_str = ', '.join('%s=%f' %
                                           (value.tag, value.simple_value)
                                           for value in summaries.value
                                           if value.HasField('simple_value'))
                    tf.logging.info('step=%d: %s', global_step_val, status_str)
                except:
                    # The supervisor eats the backtrace, so print it here.
                    traceback.print_exc()
                    raise

            supervisor.loop(FLAGS.save_summaries_secs, summarize)

        # Run the trainer.
        for unused_i in xrange(hparams.max_steps):
            sess.run(train_op)