Ejemplo n.º 1
0
    def train_setup(self):
        main_loss = mu.get_loss(self._hp['loss'], self.data.gt, self.output)
        regu_loss = tf.losses.get_regularization_loss()

        loss = main_loss + regu_loss

        train_answer = tf.argmax(self.output, axis=1)
        train_accuracy, train_accuracy_update, train_accuracy_initializer \
            = mu.get_acc(self.data.gt, train_answer, name='train_accuracy')

        global_step = tf.train.get_or_create_global_step()

        decay_step = int(self._hp['decay_epoch'] * len(self.data))
        learning_rate = mu.get_lr(self._hp['decay_type'],
                                  self._hp['learning_rate'], global_step,
                                  decay_step, self._hp['decay_rate'])

        optimizer = mu.get_opt(self._hp['opt'], learning_rate, decay_step)

        grads_and_vars = optimizer.compute_gradients(loss)
        # grads_and_vars = [(tf.clip_by_norm(grad, 0.01, axes=[0]), var) if grad is not None else (grad, var)
        #                   for grad, var in grads_and_vars ]
        gradients, variables = list(zip(*grads_and_vars))
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.group(
                optimizer.apply_gradients(grads_and_vars, global_step),
                train_accuracy_update)

        self.saver = tf.train.Saver(tf.global_variables())
        self.best_saver = tf.train.Saver(tf.global_variables())

        # Summary
        train_gv_summaries = []
        for idx, grad in enumerate(gradients):
            if grad is not None:
                train_gv_summaries.append(
                    tf.summary.histogram('gradients/' + variables[idx].name,
                                         grad))
                train_gv_summaries.append(
                    tf.summary.histogram(variables[idx].name, variables[idx]))

        train_summaries = [
            tf.summary.scalar('train_loss', loss),
            tf.summary.scalar('train_accuracy', train_accuracy),
            tf.summary.scalar('learning_rate', learning_rate)
        ]
        self.train_summaries_op = tf.summary.merge(train_summaries)
        self.train_gv_summaries_op = tf.summary.merge(train_gv_summaries +
                                                      train_summaries)

        self.train_init_op_list = [
            self.data.initializer, train_accuracy_initializer
        ]

        self.train_op_list = [train_op, loss, train_accuracy, global_step]
Ejemplo n.º 2
0
def train_one_epoch(args, loader, optimizer, logger, epoch):
    '''
    Note: only complete point clouds are loaded during training, so data.x is
    both the input and label for the point cloud completion task. While partial
    point clouds (data.y) are loaded at testing.
    '''

    model.train()
    loss_summary = {}
    global i

    for j, data in enumerate(loader, 0):
        data = data.to(device)
        pos, batch = data.pos, data.batch
        label = pos if args.task == 'completion' else data.y
        category = data.category if args.task == 'segmentation' else None

        # training
        model.zero_grad()
        pred, loss = model(None, pos, batch, category, label)
        loss = loss.mean()

        if args.task == 'completion':
            loss_summary['loss_chamfer'] = loss
        elif args.task == 'classification':
            loss_summary['loss_cls'] = loss
        elif args.task == 'segmentation':
            loss_summary['loss_seg'] = loss

        loss.backward()
        optimizer.step()

        # write summary
        if i % 100 == 0:
            for item in loss_summary:
                logger.add_scalar(item, loss_summary[item], i)
            logger.add_scalar('lr', get_lr(optimizer), i)
            print(''.join([
                '{}: {:.4f}, '.format(k, v) for k, v in loss_summary.items()
            ]))
        i = i + 1
Ejemplo n.º 3
0
def train_one_epoch(args, loader, optimizer, logger, epoch):

    model.train()
    loss_summary = {}
    global i

    for j, data in enumerate(loader, 0):
        data = data.to(device)
        pos, batch, label = data.pos, data.batch, data.y
        category = data.category if args.task == 'segmentation' else None

        # training
        model.zero_grad()
        pred = model(None, pos, batch, category)

        if args.task == 'completion':
            loss_summary['loss_chamfer'] = chamfer_loss(
                pred, pos.view(-1, args.num_pts, 3)).mean()
            loss = loss_summary['loss_chamfer']
        elif args.task == 'classification':
            loss_summary['loss_cls'] = F.nll_loss(pred, label)
            loss = loss_summary['loss_cls']
        elif args.task == 'segmentation':
            loss_summary['loss_seg'] = F.nll_loss(pred, label)
            loss = loss_summary['loss_seg']
        else:
            assert False

        loss.backward()
        optimizer.step()

        # write summary
        if i % 100 == 0:
            for item in loss_summary:
                logger.add_scalar(item, loss_summary[item], i)
            logger.add_scalar('lr', get_lr(optimizer), i)
            print(''.join([
                '{}: {:.4f}, '.format(k, v) for k, v in loss_summary.items()
            ]))
        i = i + 1
Ejemplo n.º 4
0
    def __init__(self):

        if reset:
            if os.path.exists(self._checkpoint_dir):
                os.system('rm -rf %s' % self._checkpoint_dir)
            if os.path.exists(self._log_dir):
                os.system('rm -rf %s' % self._log_dir)
            if os.path.exists(self._attn_dir):
                os.system('rm -rf %s' % self._attn_dir)

        fu.make_dirs(os.path.join(self._checkpoint_dir, 'best'))
        fu.make_dirs(self._log_dir)
        fu.make_dirs(self._attn_dir)

        self.train_data = Input(split='train', mode=args.mode)
        self.val_data = Input(split='val', mode=args.mode)
        # self.test_data = TestInput()

        self.train_model = mod.Model(self.train_data,
                                     scale=hp['reg'],
                                     training=True)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            self.val_model = mod.Model(self.val_data)

        # with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        #     self.test_model = mod.Model(self.test_data, training=False)

        for v in tf.trainable_variables():
            print(v)

        self.main_loss = mu.get_loss(hp['loss'], self.train_data.gt,
                                     self.train_model.output)
        self.regu_loss = tf.losses.get_regularization_loss()

        self.loss = 0
        self.loss += self.main_loss
        self.loss += self.regu_loss

        self.train_answer = tf.argmax(self.train_model.output, axis=1)
        self.train_accuracy, self.train_accuracy_update, self.train_accuracy_initializer \
            = mu.get_acc(self.train_data.gt, self.train_answer, name='train_accuracy')

        self.val_answer = tf.argmax(self.val_model.output, axis=1)
        self.val_accuracy, self.val_accuracy_update, self.val_accuracy_initializer \
            = mu.get_acc(self.val_data.gt, tf.argmax(self.val_model.output, axis=1), name='val_accuracy')

        self.global_step = tf.train.get_or_create_global_step()

        decay_step = int(hp['decay_epoch'] * len(self.train_data))
        self.learning_rate = mu.get_lr(hp['decay_type'], hp['learning_rate'],
                                       self.global_step, decay_step,
                                       hp['decay_rate'])

        self.optimizer = mu.get_opt(hp['opt'], self.learning_rate, decay_step)

        grads_and_vars = self.optimizer.compute_gradients(self.loss)

        lang_grads_and_vars = [(grad, var) for grad, var in grads_and_vars
                               if 'Visual' not in var.name]
        vis_grad_and_vars = [(grad, var) for grad, var in grads_and_vars
                             if 'Visual' in var.name]

        self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(self.update_ops):
            self.lang_train_op = tf.group(
                self.optimizer.apply_gradients(lang_grads_and_vars,
                                               self.global_step),
                self.train_accuracy_update)
            self.vis_train_op = tf.group(
                self.optimizer.apply_gradients(vis_grad_and_vars,
                                               self.global_step),
                self.train_accuracy_update)

        self.saver = tf.train.Saver(tf.global_variables())
        self.best_saver = tf.train.Saver(tf.global_variables())

        if args.checkpoint:
            self.checkpoint_file = args.checkpoint
        else:
            self.checkpoint_file = tf.train.latest_checkpoint(
                self._checkpoint_dir)

        self.train_init_op_list = [
            self.train_data.initializer, self.train_accuracy_initializer
        ]

        self.val_init_op_list = [
            self.val_data.initializer, self.val_accuracy_initializer
        ]

        self.lang_train_op_list = [
            self.lang_train_op, self.loss, self.train_accuracy,
            self.global_step
        ]
        self.vis_train_op_list = [
            self.vis_train_op, self.loss, self.train_accuracy, self.global_step
        ]

        self.val_op_list = [self.val_accuracy, self.val_accuracy_update]

        if attn:
            self.lang_train_op_list += [
                self.train_model.attn, self.train_model.belief,
                self.train_data.gt, self.train_answer
            ]
            self.vis_train_op_list += [
                self.train_model.attn, self.train_model.belief,
                self.train_data.gt, self.train_answer
            ]
            self.val_op_list += [
                self.val_model.attn, self.val_model.belief, self.val_data.gt,
                self.val_answer
            ]
Ejemplo n.º 5
0
def main(_):
    """
    Builds the model and runs.
    """
    if FLAGS.distributed:
        import horovod.tensorflow as hvd
        hvd.init()

    tf.logging.set_verbosity(tf.logging.INFO)

    tx.utils.maybe_create_dir(FLAGS.output_dir)

    # Loads data
    num_train_data = config_data.num_train_data

    # Configures distribued mode
    if FLAGS.distributed:
        config_data.train_hparam["dataset"]["num_shards"] = hvd.size()
        config_data.train_hparam["dataset"]["shard_id"] = hvd.rank()
        config_data.train_hparam["batch_size"] //= hvd.size()

    train_dataset = tx.data.TFRecordData(hparams=config_data.train_hparam)
    eval_dataset = tx.data.TFRecordData(hparams=config_data.eval_hparam)
    test_dataset = tx.data.TFRecordData(hparams=config_data.test_hparam)

    iterator = tx.data.FeedableDataIterator({
        'train': train_dataset,
        'eval': eval_dataset,
        'test': test_dataset
    })
    batch = iterator.get_next()
    input_ids = batch["input_ids"]
    segment_ids = batch["segment_ids"]
    batch_size = tf.shape(input_ids)[0]
    input_length = tf.reduce_sum(1 - tf.cast(tf.equal(input_ids, 0), tf.int32),
                                 axis=1)
    # Builds BERT
    hparams = {'clas_strategy': 'cls_time'}
    model = tx.modules.BERTClassifier(
        pretrained_model_name=FLAGS.pretrained_model_name, hparams=hparams)
    logits, preds = model(input_ids, input_length, segment_ids)

    accu = tx.evals.accuracy(batch['label_ids'], preds)

    # Optimization
    loss = tf.losses.sparse_softmax_cross_entropy(labels=batch["label_ids"],
                                                  logits=logits)
    global_step = tf.Variable(0, trainable=False)

    # Builds learning rate decay scheduler
    static_lr = config_downstream.lr['static_lr']
    num_train_steps = int(num_train_data / config_data.train_batch_size *
                          config_data.max_train_epoch)
    num_warmup_steps = int(num_train_steps * config_data.warmup_proportion)
    lr = model_utils.get_lr(
        global_step,
        num_train_steps,  # lr is a Tensor
        num_warmup_steps,
        static_lr)

    opt = tx.core.get_optimizer(global_step=global_step,
                                learning_rate=lr,
                                hparams=config_downstream.opt)

    if FLAGS.distributed:
        opt = hvd.DistributedOptimizer(opt)

    train_op = tf.contrib.layers.optimize_loss(loss=loss,
                                               global_step=global_step,
                                               learning_rate=None,
                                               optimizer=opt)

    # Train/eval/test routine

    def _is_head():
        if not FLAGS.distributed:
            return True
        return hvd.rank() == 0

    def _train_epoch(sess):
        """Trains on the training set, and evaluates on the dev set
        periodically.
        """
        iterator.restart_dataset(sess, 'train')

        fetches = {
            'train_op': train_op,
            'loss': loss,
            'batch_size': batch_size,
            'step': global_step
        }

        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train'),
                    tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
                }
                rets = sess.run(fetches, feed_dict)
                step = rets['step']

                dis_steps = config_data.display_steps
                if _is_head() and dis_steps > 0 and step % dis_steps == 0:
                    tf.logging.info('step:%d; loss:%f;' % (step, rets['loss']))

                eval_steps = config_data.eval_steps
                if _is_head() and eval_steps > 0 and step % eval_steps == 0:
                    _eval_epoch(sess)

            except tf.errors.OutOfRangeError:
                break

    def _eval_epoch(sess):
        """Evaluates on the dev set.
        """
        iterator.restart_dataset(sess, 'eval')

        cum_acc = 0.0
        cum_loss = 0.0
        nsamples = 0
        fetches = {
            'accu': accu,
            'loss': loss,
            'batch_size': batch_size,
        }
        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'eval'),
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL,
                }
                rets = sess.run(fetches, feed_dict)

                cum_acc += rets['accu'] * rets['batch_size']
                cum_loss += rets['loss'] * rets['batch_size']
                nsamples += rets['batch_size']
            except tf.errors.OutOfRangeError:
                break

        tf.logging.info('eval accu: {}; loss: {}; nsamples: {}'.format(
            cum_acc / nsamples, cum_loss / nsamples, nsamples))

    def _test_epoch(sess):
        """Does predictions on the test set.
        """
        iterator.restart_dataset(sess, 'test')

        _all_preds = []
        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'test'),
                    tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT,
                }
                _preds = sess.run(preds, feed_dict=feed_dict)
                _all_preds.extend(_preds.tolist())
            except tf.errors.OutOfRangeError:
                break

        output_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
        with tf.gfile.GFile(output_file, "w") as writer:
            writer.write('\n'.join(str(p) for p in _all_preds))

    # Broadcasts global variables from rank-0 process
    if FLAGS.distributed:
        bcast = hvd.broadcast_global_variables(0)

    session_config = tf.ConfigProto()
    if FLAGS.distributed:
        session_config.gpu_options.visible_device_list = str(hvd.local_rank())

    with tf.Session(config=session_config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        if FLAGS.distributed:
            bcast.run()

        # Restores trained model if specified
        saver = tf.train.Saver()
        if FLAGS.checkpoint:
            saver.restore(sess, FLAGS.checkpoint)

        iterator.initialize_dataset(sess)

        if FLAGS.do_train:
            for i in range(config_data.max_train_epoch):
                _train_epoch(sess)
            saver.save(sess, FLAGS.output_dir + '/model.ckpt')

        if FLAGS.do_eval:
            _eval_epoch(sess)

        if FLAGS.do_test:
            _test_epoch(sess)
Ejemplo n.º 6
0
    def __init__(self):

        if reset:
            if os.path.exists(self._checkpoint_dir):
                os.system('rm -rf %s' % self._checkpoint_dir)
            if os.path.exists(self._log_dir):
                os.system('rm -rf %s' % self._log_dir)

        fu.make_dirs(os.path.join(self._checkpoint_dir, 'best'))
        fu.make_dirs(self._log_dir)

        self.train_data = Input(split='train', mode=args.mode)
        self.val_data = Input(split='val', mode=args.mode)
        # self.test_data = TestInput()

        self.train_model = mod.Model(self.train_data, beta=hp['reg'], training=True)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            self.val_model = mod.Model(self.val_data)

        # with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        #     self.test_model = mod.Model(self.test_data, training=False)

        for v in tf.trainable_variables():
            print(v)

        self.main_loss = mu.get_loss(hp['loss'], self.train_data.gt, self.train_model.output)
        self.real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=self.train_model.real_logit, labels=tf.ones_like(self.train_model.real_logit)))
        self.fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=self.train_model.fake_logit, labels=tf.zeros_like(self.train_model.fake_logit)))
        self.d_loss = self.real_loss + self.fake_loss
        self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=self.train_model.fake_logit, labels=tf.ones_like(self.train_model.fake_logit)))
        self.regu_loss = tf.losses.get_regularization_loss()

        self.loss = self.g_loss
        self.loss += self.main_loss
        self.loss += self.regu_loss

        self.d_loss += self.regu_loss

        self.train_accuracy, self.train_accuracy_update, self.train_accuracy_initializer \
            = mu.get_acc(self.train_data.gt, tf.argmax(self.train_model.output, axis=1), name='train_accuracy')

        self.val_accuracy, self.val_accuracy_update, self.val_accuracy_initializer \
            = mu.get_acc(self.val_data.gt, tf.argmax(self.val_model.output, axis=1), name='val_accuracy')

        self.global_step = tf.train.get_or_create_global_step()

        decay_step = int(hp['decay_epoch'] * len(self.train_data))
        self.learning_rate = mu.get_lr(hp['decay_type'], hp['learning_rate'], self.global_step,
                                       decay_step, hp['decay_rate'])

        self.optimizer = mu.get_opt(hp['opt'], self.learning_rate, decay_step)
        self.d_optimizer = mu.get_opt(hp['opt'], self.learning_rate, decay_step)

        grads_and_vars = self.optimizer.compute_gradients(self.loss, var_list=tf.trainable_variables('QA'))
        gradients, variables = list(zip(*grads_and_vars))
        self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(self.update_ops):
            self.train_op = tf.group(self.optimizer.apply_gradients(grads_and_vars),
                                     self.d_optimizer.minimize(self.d_loss, self.global_step,
                                                               tf.trainable_variables('Discriminator')),
                                     self.train_accuracy_update)

        self.saver = tf.train.Saver(tf.global_variables())
        self.best_saver = tf.train.Saver(tf.global_variables())

        # Summary
        train_gv_summaries = []
        for idx, grad in enumerate(gradients):
            if grad is not None:
                train_gv_summaries.append(tf.summary.histogram('gradients/' + variables[idx].name, grad))
                train_gv_summaries.append(tf.summary.histogram(variables[idx].name, variables[idx]))

        train_summaries = [
            tf.summary.scalar('train_loss', self.loss),
            tf.summary.scalar('train_accuracy', self.train_accuracy),
            tf.summary.scalar('learning_rate', self.learning_rate)
        ]
        self.train_summaries_op = tf.summary.merge(train_summaries)
        self.train_gv_summaries_op = tf.summary.merge(train_gv_summaries + train_summaries)

        self.val_summaries_op = tf.summary.scalar('val_accuracy', self.val_accuracy)

        if args.checkpoint:
            self.checkpoint_file = args.checkpoint
        else:
            self.checkpoint_file = tf.train.latest_checkpoint(self._checkpoint_dir)

        self.train_init_op_list = [self.train_data.initializer, self.train_accuracy_initializer]

        self.val_init_op_list = [self.val_data.initializer, self.val_accuracy_initializer]

        self.train_op_list = [self.train_op, self.loss, self.train_accuracy, self.global_step]
        # self.run_metadata = tf.RunMetadata()
        self.val_op_list = [self.val_accuracy, self.val_accuracy_update, self.val_summaries_op]
Ejemplo n.º 7
0
    def __init__(self):

        if reset:
            if os.path.exists(self._checkpoint_dir):
                os.system('rm -rf %s' % self._checkpoint_dir)
            if os.path.exists(self._log_dir):
                os.system('rm -rf %s' % self._log_dir)
            if os.path.exists(self._attn_dir):
                os.system('rm -rf %s' % self._attn_dir)

        fu.make_dirs(os.path.join(self._checkpoint_dir, 'best'))
        fu.make_dirs(self._log_dir)
        fu.make_dirs(self._attn_dir)

        self.train_data = Input(split='train', mode=args.mode)
        self.val_data = Input(split='val', mode=args.mode)
        # self.test_data = TestInput()

        self.train_model = mod.Model(self.train_data,
                                     scale=hp['reg'],
                                     training=True)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            self.val_model = mod.Model(self.val_data)

        # with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        #     self.test_model = mod.Model(self.test_data, training=False)

        for v in tf.trainable_variables():
            print(v)

        self.main_loss = mu.get_loss(hp['loss'], self.train_data.gt,
                                     self.train_model.output)
        self.regu_loss = tf.losses.get_regularization_loss()

        self.loss = 0
        self.loss += self.main_loss
        self.loss += self.regu_loss

        self.train_answer = tf.argmax(self.train_model.output, axis=1)
        self.train_accuracy, self.train_accuracy_update, self.train_accuracy_initializer \
            = mu.get_acc(self.train_data.gt, self.train_answer, name='train_accuracy')

        self.val_answer = tf.argmax(self.val_model.output, axis=1)
        self.val_accuracy, self.val_accuracy_update, self.val_accuracy_initializer \
            = mu.get_acc(self.val_data.gt, tf.argmax(self.val_model.output, axis=1), name='val_accuracy')

        self.global_step = tf.train.get_or_create_global_step()

        decay_step = int(hp['decay_epoch'] * len(self.train_data))
        self.learning_rate = mu.get_lr(hp['decay_type'], hp['learning_rate'],
                                       self.global_step, decay_step,
                                       hp['decay_rate'])

        self.optimizer = mu.get_opt(hp['opt'], self.learning_rate, decay_step)

        grads_and_vars = self.optimizer.compute_gradients(self.loss)
        # grads_and_vars = [(tf.clip_by_norm(grad, 0.01, axes=[0]), var) if grad is not None else (grad, var)
        #                   for grad, var in grads_and_vars ]
        gradients, variables = list(zip(*grads_and_vars))
        self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(self.update_ops):
            self.train_op = tf.group(
                self.optimizer.apply_gradients(grads_and_vars,
                                               self.global_step),
                self.train_accuracy_update)

        self.saver = tf.train.Saver(tf.global_variables())
        self.best_saver = tf.train.Saver(tf.global_variables())

        # Summary
        train_gv_summaries = []
        for idx, grad in enumerate(gradients):
            if grad is not None:
                train_gv_summaries.append(
                    tf.summary.histogram('gradients/' + variables[idx].name,
                                         grad))
                train_gv_summaries.append(
                    tf.summary.histogram(variables[idx].name, variables[idx]))

        train_summaries = [
            tf.summary.scalar('train_loss', self.loss),
            tf.summary.scalar('train_accuracy', self.train_accuracy),
            tf.summary.scalar('learning_rate', self.learning_rate)
        ]
        self.train_summaries_op = tf.summary.merge(train_summaries)
        self.train_gv_summaries_op = tf.summary.merge(train_gv_summaries +
                                                      train_summaries)

        self.val_summaries_op = tf.summary.scalar('val_accuracy',
                                                  self.val_accuracy)

        if args.checkpoint:
            self.checkpoint_file = args.checkpoint
        else:
            self.checkpoint_file = tf.train.latest_checkpoint(
                self._checkpoint_dir)

        self.train_init_op_list = [
            self.train_data.initializer, self.train_accuracy_initializer
        ]

        self.val_init_op_list = [
            self.val_data.initializer, self.val_accuracy_initializer
        ]

        self.train_op_list = [
            self.train_op, self.loss, self.train_accuracy, self.global_step
        ]

        self.val_op_list = [
            self.val_accuracy, self.val_accuracy_update, self.val_summaries_op
        ]

        if attn:
            self.train_op_list += [
                self.train_model.attn, self.train_model.belief,
                self.train_data.gt, self.train_answer
            ]
            self.val_op_list += [
                self.val_model.attn, self.val_model.belief, self.val_data.gt,
                self.val_answer
            ]
Ejemplo n.º 8
0
    def __init__(self):

        if reset:
            if os.path.exists(self._checkpoint_dir):
                os.system('rm -rf %s' % self._checkpoint_dir)
            if os.path.exists(self._log_dir):
                os.system('rm -rf %s' % self._log_dir)

        fu.make_dirs(os.path.join(self._checkpoint_dir, 'best'))
        fu.make_dirs(self._log_dir)

        self.train_data = Input(split='train', mode=args.mode)
        self.step_data = Input(split='train', mode=args.mode)
        self.val_data = Input(split='val', mode=args.mode)
        # self.test_data = TestInput()

        self.train_model = mod.Model(self.train_data, training=True)

        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            self.val_model = mod.Model(self.val_data)
        # with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        #     self.test_model = mod.Model(self.test_data, training=False)

        for v in tf.trainable_variables():
            print(v)

        self.loss = tf.losses.sparse_softmax_cross_entropy(
            self.train_data.gt, self.train_model.output)

        if args.reg:
            self.loss += tf.losses.get_regularization_loss()

        self.train_accuracy, self.train_accuracy_update = tf.metrics.accuracy(
            self.train_data.gt,
            tf.argmax(self.train_model.output, axis=1),
            name='train_accuracy')
        self.train_accuracy_initializer = tf.variables_initializer(
            tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                              scope='train_accuracy'))

        self.val_accuracy, self.val_accuracy_update = tf.metrics.accuracy(
            self.val_data.gt,
            tf.argmax(self.val_model.output, axis=1),
            name='val_accuracy')
        self.val_accuracy_initializer = tf.variables_initializer(
            tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                              scope='val_accuracy'))

        self.global_step = tf.train.get_or_create_global_step()

        self.learning_rate = mu.get_lr(
            hp['decay_type'], hp['learning_rate'], self.global_step,
            hp['decay_epoch'] * len(self.train_data), hp['decay_rate'])

        self.optimizer = mu.get_opt(hp['opt'], self.learning_rate)

        grads_and_vars = self.optimizer.compute_gradients(self.loss)
        gradients, variables = list(zip(*grads_and_vars))
        self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(self.update_ops):
            self.train_op = tf.group(
                self.optimizer.apply_gradients(grads_and_vars,
                                               self.global_step),
                self.train_accuracy_update)
        with tf.control_dependencies([self.train_op]):
            with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                self.step_model = mod.Model(self.step_data)
            gamma = tf.equal(self.step_data.gt,
                             tf.argmax(self.step_model.output, axis=1))
            neg_grads_and_vars = [(tf.where(
                tf.logical_and(gamma, tf.ones_like(g, dtype=tf.bool)),
                self.train_accuracy * g,
                -(1 + self.train_accuracy / 2) * g), v)
                                  for g, v in grads_and_vars]
            self.roll_back = self.optimizer.apply_gradients(neg_grads_and_vars)

        self.saver = tf.train.Saver(tf.global_variables())
        self.best_saver = tf.train.Saver(tf.global_variables())

        # Summary
        train_gv_summaries = []
        for idx, grad in enumerate(gradients):
            if grad is not None:
                train_gv_summaries.append(tf.summary.histogram(
                    grad.name, grad))
                train_gv_summaries.append(
                    tf.summary.histogram(variables[idx].name, variables[idx]))

        train_summaries = [
            tf.summary.scalar('train_loss', self.loss),
            tf.summary.scalar('train_accuracy', self.train_accuracy),
            tf.summary.scalar('learning_rate', self.learning_rate)
        ]
        self.train_summaries_op = tf.summary.merge(train_summaries)
        self.train_gv_summaries_op = tf.summary.merge(train_gv_summaries +
                                                      train_summaries)

        self.val_summaries_op = tf.summary.scalar('val_accuracy',
                                                  self.val_accuracy)

        if args.checkpoint:
            self.checkpoint_file = args.checkpoint
        else:
            self.checkpoint_file = tf.train.latest_checkpoint(
                self._checkpoint_dir)
Ejemplo n.º 9
0
    def __init__(self):

        if reset:
            if os.path.exists(self._checkpoint_dir):
                os.system('rm -rf %s' % self._checkpoint_dir)
            if os.path.exists(self._log_dir):
                os.system('rm -rf %s' % self._log_dir)
            if os.path.exists(self._attn_dir):
                os.system('rm -rf %s' % self._attn_dir)

        fu.make_dirs(os.path.join(self._checkpoint_dir, 'best'))
        fu.make_dirs(self._log_dir)
        fu.make_dirs(self._attn_dir)

        self.train_data = Input(split='train', mode=args.mode)
        self.val_data = Input(split='val', mode=args.mode)
        # self.test_data = TestInput()

        self.train_model = mod.Model(self.train_data, beta=hp['reg'], training=True)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            self.val_model = mod.Model(self.val_data)

        # with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        #     self.test_model = mod.Model(self.test_data, training=False)

        for v in tf.trainable_variables():
            print(v)

        self.train_attn = tf.squeeze(self.train_model.attn[self.train_data.gt[0]])
        self.val_attn = tf.squeeze(self.val_model.attn[self.val_data.gt[0]])
        self.main_loss = mu.get_loss(hp['loss'], self.train_data.gt, self.train_model.output)
        self.attn_loss = mu.get_loss('hinge', tf.to_float(self.train_data.spec), self.train_attn)
        self.regu_loss = tf.losses.get_regularization_loss()

        self.loss = 0
        if 'main' in target:
            self.loss += self.main_loss
        elif 'attn' in target:
            self.loss += self.attn_loss
        self.loss += self.regu_loss

        self.train_acc, self.train_acc_update, self.train_acc_init = \
            mu.get_acc(self.train_data.gt, tf.argmax(self.train_model.output, axis=1), name='train_accuracy')

        self.train_attn_acc, self.train_attn_acc_update, self.train_attn_acc_init = \
            mu.get_acc(self.train_data.spec, tf.to_int32(self.train_attn > 0.5), name='train_attention_accuracy')

        # self.train_q_attn_acc, self.train_q_attn_acc_update, self.train_q_attn_acc_init = \
        #     tf.metrics.accuracy(self.train_data.spec, self.train_model.output, name='train_q_attention_accuracy')
        #
        # self.train_a_attn_acc, self.train_a_attn_acc_update, self.train_a_attn_acc_init = \
        #     tf.metrics.accuracy(self.train_data.spec, self.train_model.output, name='train_a_attention_accuracy')

        self.val_acc, self.val_acc_update, self.val_acc_init = \
            mu.get_acc(self.val_data.gt, tf.argmax(self.val_model.output, axis=1), name='val_accuracy')

        self.val_attn_acc, self.val_attn_acc_update, self.val_attn_acc_init = \
            mu.get_acc(self.val_data.spec, tf.to_int32(self.val_attn > 0.5), name='val_attention_accuracy')

        # self.val_q_attn_acc, self.val_q_attn_acc_update, self.val_q_attn_acc_init = \
        #     tf.metrics.accuracy(self.train_data.spec, self.train_model.output, name='val_q_attention_accuracy')
        #
        # self.val_a_attn_acc, self.val_a_attn_acc_update, self.val_a_attn_acc_init = \
        #     tf.metrics.accuracy(self.train_data.spec, self.train_model.output, name='val_a_attention_accuracy')

        self.global_step = tf.train.get_or_create_global_step()

        decay_step = int(hp['decay_epoch'] * len(self.train_data))
        self.learning_rate = mu.get_lr(hp['decay_type'], hp['learning_rate'], self.global_step,
                                       decay_step, hp['decay_rate'])

        self.optimizer = mu.get_opt(hp['opt'], self.learning_rate, decay_step)

        grads_and_vars = self.optimizer.compute_gradients(self.loss)
        gradients, variables = list(zip(*grads_and_vars))
        self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(self.update_ops):
            self.train_op = tf.group(self.optimizer.apply_gradients(grads_and_vars, self.global_step),
                                     self.train_acc_update,
                                     self.train_attn_acc_update)  # self.train_a_attn_acc_update, self.train_q_attn_acc_update)

        self.saver = tf.train.Saver(tf.global_variables())
        self.best_saver = tf.train.Saver(tf.global_variables())

        # Summary
        train_gv_summaries = []
        for idx, grad in enumerate(gradients):
            if grad is not None:
                train_gv_summaries.append(tf.summary.histogram('gradients/' + variables[idx].name, grad))
                train_gv_summaries.append(tf.summary.histogram(variables[idx].name, variables[idx]))

        train_summaries = [
            tf.summary.scalar('train_loss', self.loss),
            tf.summary.scalar('train_accuracy', self.train_acc),
            # tf.summary.scalar('train_a_attn_accuracy', self.train_a_attn_acc),
            # tf.summary.scalar('train_q_attn_accuracy', self.train_q_attn_acc),
            tf.summary.scalar('train_attn_accuracy', self.train_attn_acc),
            tf.summary.scalar('learning_rate', self.learning_rate)
        ]
        self.train_summaries_op = tf.summary.merge(train_summaries)
        self.train_gv_summaries_op = tf.summary.merge(train_gv_summaries + train_summaries)

        val_summaries = [
            tf.summary.scalar('val_accuracy', self.val_acc),
            tf.summary.scalar('val_attn_accuracy', self.val_attn_acc),
            # tf.summary.scalar('val_a_attn_accuracy', self.val_a_attn_acc),
            # tf.summary.scalar('val_q_attn_accuracy', self.val_q_attn_acc),
        ]
        self.val_summaries_op = tf.summary.merge(val_summaries)

        if args.checkpoint:
            self.checkpoint_file = args.checkpoint
        else:
            self.checkpoint_file = tf.train.latest_checkpoint(self._checkpoint_dir)

        self.train_init_op_list = [self.train_data.initializer, self.train_acc_init,
                                   # self.train_q_attn_acc_init, self.train_a_attn_acc_init,
                                   self.train_attn_acc_init]

        self.val_init_op_list = [self.val_data.initializer, self.val_acc_init,
                                 # self.val_q_attn_acc_init, self.val_a_attn_acc_init,
                                 self.val_attn_acc_init]

        self.train_op_list = [self.train_op, self.loss, self.attn_loss, self.train_acc,
                              self.train_attn_acc,  # self.val_q_attn_acc, self.val_a_attn_acc,
                              self.global_step, self.train_data.spec, self.train_attn]
        self.val_op_list = [self.val_acc, self.val_attn_acc,  # self.val_q_attn_acc, self.val_a_attn_acc,
                            tf.group(self.val_acc_update, self.val_attn_acc_update
                                     # self.val_q_attn_acc_update, self.val_a_attn_acc_update
                                     ),
                            self.val_summaries_op, self.val_data.spec, self.val_attn]
Ejemplo n.º 10
0
def main(_):
    """
    Builds the model and runs.
    """
    tf.logging.set_verbosity(tf.logging.INFO)
    tx.utils.maybe_create_dir(FLAGS.output_dir)
    bert_pretrain_dir = 'bert_pretrained_models/%s' % FLAGS.config_bert_pretrain

    # Loads BERT model configuration
    if FLAGS.config_format_bert == "json":
        bert_config = model_utils.transform_bert_to_texar_config(
            os.path.join(bert_pretrain_dir, 'bert_config.json'))
    elif FLAGS.config_format_bert == 'texar':
        bert_config = importlib.import_module(
            'bert_config_lib.config_model_%s' % FLAGS.config_bert_pretrain)
    else:
        raise ValueError('Unknown config_format_bert.')

    # Loads data
    processors = {
        "cola": data_utils.ColaProcessor,
        "mnli": data_utils.MnliProcessor,
        "mrpc": data_utils.MrpcProcessor,
        "xnli": data_utils.XnliProcessor,
        'sst': data_utils.SSTProcessor
    }

    processor = processors[FLAGS.task.lower()]()

    num_classes = len(processor.get_labels())
    num_train_data = len(processor.get_train_examples(config_data.data_dir))

    tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(
        bert_pretrain_dir, 'vocab.txt'),
                                           do_lower_case=FLAGS.do_lower_case)

    train_dataset = data_utils.get_dataset(processor,
                                           tokenizer,
                                           config_data.data_dir,
                                           config_data.max_seq_length,
                                           config_data.train_batch_size,
                                           mode='train',
                                           output_dir=FLAGS.output_dir)
    eval_dataset = data_utils.get_dataset(processor,
                                          tokenizer,
                                          config_data.data_dir,
                                          config_data.max_seq_length,
                                          config_data.eval_batch_size,
                                          mode='eval',
                                          output_dir=FLAGS.output_dir)
    test_dataset = data_utils.get_dataset(processor,
                                          tokenizer,
                                          config_data.data_dir,
                                          config_data.max_seq_length,
                                          config_data.test_batch_size,
                                          mode='test',
                                          output_dir=FLAGS.output_dir)

    iterator = tx.data.FeedableDataIterator({
        'train': train_dataset,
        'eval': eval_dataset,
        'test': test_dataset
    })
    batch = iterator.get_next()
    input_ids = batch["input_ids"]
    segment_ids = batch["segment_ids"]
    batch_size = tf.shape(input_ids)[0]
    input_length = tf.reduce_sum(1 - tf.to_int32(tf.equal(input_ids, 0)),
                                 axis=1)

    # Builds BERT
    with tf.variable_scope('bert'):
        embedder = tx.modules.WordEmbedder(vocab_size=bert_config.vocab_size,
                                           hparams=bert_config.embed)
        word_embeds = embedder(input_ids)

        # Creates segment embeddings for each type of tokens.
        segment_embedder = tx.modules.WordEmbedder(
            vocab_size=bert_config.type_vocab_size,
            hparams=bert_config.segment_embed)
        segment_embeds = segment_embedder(segment_ids)

        input_embeds = word_embeds + segment_embeds

        # The BERT model (a TransformerEncoder)
        encoder = tx.modules.TransformerEncoder(hparams=bert_config.encoder)
        output = encoder(input_embeds, input_length)

        # Builds layers for downstream classification, which is also initialized
        # with BERT pre-trained checkpoint.
        with tf.variable_scope("pooler"):
            # Uses the projection of the 1st-step hidden vector of BERT output
            # as the representation of the sentence
            bert_sent_hidden = tf.squeeze(output[:, 0:1, :], axis=1)
            bert_sent_output = tf.layers.dense(bert_sent_hidden,
                                               config_downstream.hidden_dim,
                                               activation=tf.tanh)
            output = tf.layers.dropout(bert_sent_output,
                                       rate=0.1,
                                       training=tx.global_mode_train())

    # Adds the final classification layer
    logits = tf.layers.dense(
        output,
        num_classes,
        kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
    preds = tf.argmax(logits, axis=-1, output_type=tf.int32)
    accu = tx.evals.accuracy(batch['label_ids'], preds)

    # Optimization

    loss = tf.losses.sparse_softmax_cross_entropy(labels=batch["label_ids"],
                                                  logits=logits)
    global_step = tf.Variable(0, trainable=False)

    # Builds learning rate decay scheduler
    static_lr = config_downstream.lr['static_lr']
    num_train_steps = int(num_train_data / config_data.train_batch_size *
                          config_data.max_train_epoch)
    num_warmup_steps = int(num_train_steps * config_data.warmup_proportion)
    lr = model_utils.get_lr(
        global_step,
        num_train_steps,  # lr is a Tensor
        num_warmup_steps,
        static_lr)

    train_op = tx.core.get_train_op(loss,
                                    global_step=global_step,
                                    learning_rate=lr,
                                    hparams=config_downstream.opt)

    # Train/eval/test routine

    def _run(sess, mode):
        fetches = {
            'accu': accu,
            'batch_size': batch_size,
            'step': global_step,
            'loss': loss,
        }

        if mode == 'train':
            fetches['train_op'] = train_op
            while True:
                try:
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, 'train'),
                        tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
                    }
                    rets = sess.run(fetches, feed_dict)
                    if rets['step'] % 50 == 0:
                        tf.logging.info('step:%d loss:%f' %
                                        (rets['step'], rets['loss']))
                    if rets['step'] == num_train_steps:
                        break
                except tf.errors.OutOfRangeError:
                    break

        if mode == 'eval':
            cum_acc = 0.0
            nsamples = 0
            while True:
                try:
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, 'eval'),
                        tx.context.global_mode(): tf.estimator.ModeKeys.EVAL,
                    }
                    rets = sess.run(fetches, feed_dict)

                    cum_acc += rets['accu'] * rets['batch_size']
                    nsamples += rets['batch_size']
                except tf.errors.OutOfRangeError:
                    break

            tf.logging.info('dev accu: {}'.format(cum_acc / nsamples))

        if mode == 'test':
            _all_preds = []
            while True:
                try:
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, 'test'),
                        tx.context.global_mode():
                        tf.estimator.ModeKeys.PREDICT,
                    }
                    _preds = sess.run(preds, feed_dict=feed_dict)
                    _all_preds.extend(_preds.tolist())
                except tf.errors.OutOfRangeError:
                    break

            output_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
            with tf.gfile.GFile(output_file, "w") as writer:
                writer.write('\n'.join(str(p) for p in _all_preds))

    with tf.Session() as sess:
        # Loads pretrained BERT model parameters
        init_checkpoint = os.path.join(bert_pretrain_dir, 'bert_model.ckpt')
        model_utils.init_bert_checkpoint(init_checkpoint)

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        # Restores trained model if specified
        saver = tf.train.Saver()
        if FLAGS.checkpoint:
            saver.restore(sess, FLAGS.checkpoint)

        iterator.initialize_dataset(sess)

        if FLAGS.do_train:
            iterator.restart_dataset(sess, 'train')
            _run(sess, mode='train')
            saver.save(sess, FLAGS.output_dir + '/model.ckpt')

        if FLAGS.do_eval:
            iterator.restart_dataset(sess, 'eval')
            _run(sess, mode='eval')

        if FLAGS.do_test:
            iterator.restart_dataset(sess, 'test')
            _run(sess, mode='test')
Ejemplo n.º 11
0
def main(_):
    """
    Builds the model and runs.
    """

    if FLAGS.distributed:
        import horovod.tensorflow as hvd
        hvd.init()

    tf.logging.set_verbosity(tf.logging.INFO)

    tx.utils.maybe_create_dir(FLAGS.output_dir)

    bert_pretrain_dir = ('bert_pretrained_models'
                         '/%s') % FLAGS.config_bert_pretrain
    # Loads BERT model configuration
    if FLAGS.config_format_bert == "json":
        bert_config = model_utils.transform_bert_to_texar_config(
            os.path.join(bert_pretrain_dir, 'bert_config.json'))
    elif FLAGS.config_format_bert == 'texar':
        bert_config = importlib.import_module(
            ('bert_config_lib.'
             'config_model_%s') % FLAGS.config_bert_pretrain)
    else:
        raise ValueError('Unknown config_format_bert.')

    # Loads data

    num_classes = config_data.num_classes
    num_train_data = config_data.num_train_data

    # Configures distribued mode
    if FLAGS.distributed:
        config_data.train_hparam["dataset"]["num_shards"] = hvd.size()
        config_data.train_hparam["dataset"]["shard_id"] = hvd.rank()
        config_data.train_hparam["batch_size"] //= hvd.size()

    train_dataset = tx.data.TFRecordData(hparams=config_data.train_hparam)
    eval_dataset = tx.data.TFRecordData(hparams=config_data.eval_hparam)
    test_dataset = tx.data.TFRecordData(hparams=config_data.test_hparam)

    iterator = tx.data.FeedableDataIterator({
        'train': train_dataset,
        'eval': eval_dataset,
        'test': test_dataset
    })
    batch = iterator.get_next()
    input_ids = batch["input_ids"]
    segment_ids = batch["segment_ids"]
    batch_size = tf.shape(input_ids)[0]
    input_length = tf.reduce_sum(1 - tf.cast(tf.equal(input_ids, 0), tf.int32),
                                 axis=1)
    # Builds BERT
    with tf.variable_scope('bert'):
        # Word embedding
        embedder = tx.modules.WordEmbedder(vocab_size=bert_config.vocab_size,
                                           hparams=bert_config.embed)
        word_embeds = embedder(input_ids)

        # Segment embedding for each type of tokens
        segment_embedder = tx.modules.WordEmbedder(
            vocab_size=bert_config.type_vocab_size,
            hparams=bert_config.segment_embed)
        segment_embeds = segment_embedder(segment_ids)

        # Position embedding
        position_embedder = tx.modules.PositionEmbedder(
            position_size=bert_config.position_size,
            hparams=bert_config.position_embed)
        seq_length = tf.ones([batch_size], tf.int32) * tf.shape(input_ids)[1]
        pos_embeds = position_embedder(sequence_length=seq_length)

        # Aggregates embeddings
        input_embeds = word_embeds + segment_embeds + pos_embeds

        # The BERT model (a TransformerEncoder)
        encoder = tx.modules.TransformerEncoder(hparams=bert_config.encoder)
        output = encoder(input_embeds, input_length)

        # Builds layers for downstream classification, which is also
        # initialized with BERT pre-trained checkpoint.
        with tf.variable_scope("pooler"):
            # Uses the projection of the 1st-step hidden vector of BERT output
            # as the representation of the sentence
            bert_sent_hidden = tf.squeeze(output[:, 0:1, :], axis=1)
            bert_sent_output = tf.layers.dense(bert_sent_hidden,
                                               config_downstream.hidden_dim,
                                               activation=tf.tanh)
            output = tf.layers.dropout(bert_sent_output,
                                       rate=0.1,
                                       training=tx.global_mode_train())

    # Adds the final classification layer
    logits = tf.layers.dense(
        output,
        num_classes,
        kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
    preds = tf.argmax(logits, axis=-1, output_type=tf.int32)
    accu = tx.evals.accuracy(batch['label_ids'], preds)

    # Optimization

    loss = tf.losses.sparse_softmax_cross_entropy(labels=batch["label_ids"],
                                                  logits=logits)
    global_step = tf.Variable(0, trainable=False)

    # Builds learning rate decay scheduler
    static_lr = config_downstream.lr['static_lr']
    num_train_steps = int(num_train_data / config_data.train_batch_size *
                          config_data.max_train_epoch)
    num_warmup_steps = int(num_train_steps * config_data.warmup_proportion)
    lr = model_utils.get_lr(
        global_step,
        num_train_steps,  # lr is a Tensor
        num_warmup_steps,
        static_lr)

    opt = tx.core.get_optimizer(global_step=global_step,
                                learning_rate=lr,
                                hparams=config_downstream.opt)

    if FLAGS.distributed:
        opt = hvd.DistributedOptimizer(opt)

    train_op = tf.contrib.layers.optimize_loss(loss=loss,
                                               global_step=global_step,
                                               learning_rate=None,
                                               optimizer=opt)

    # Train/eval/test routine

    def _is_head():
        if not FLAGS.distributed:
            return True
        else:
            return hvd.rank() == 0

    def _train_epoch(sess):
        """Trains on the training set, and evaluates on the dev set
        periodically.
        """
        iterator.restart_dataset(sess, 'train')

        fetches = {
            'train_op': train_op,
            'loss': loss,
            'batch_size': batch_size,
            'step': global_step
        }

        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train'),
                    tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
                }
                rets = sess.run(fetches, feed_dict)
                step = rets['step']

                dis_steps = config_data.display_steps
                if _is_head() and dis_steps > 0 and step % dis_steps == 0:
                    tf.logging.info('step:%d; loss:%f' % (step, rets['loss']))

                eval_steps = config_data.eval_steps
                if _is_head() and eval_steps > 0 and step % eval_steps == 0:
                    _eval_epoch(sess)

            except tf.errors.OutOfRangeError:
                break

    def _eval_epoch(sess):
        """Evaluates on the dev set.
        """
        iterator.restart_dataset(sess, 'eval')

        cum_acc = 0.0
        cum_loss = 0.0
        nsamples = 0
        fetches = {
            'accu': accu,
            'loss': loss,
            'batch_size': batch_size,
        }
        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'eval'),
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL,
                }
                rets = sess.run(fetches, feed_dict)

                cum_acc += rets['accu'] * rets['batch_size']
                cum_loss += rets['loss'] * rets['batch_size']
                nsamples += rets['batch_size']
            except tf.errors.OutOfRangeError:
                break

        tf.logging.info('eval accu: {}; loss: {}; nsamples: {}'.format(
            cum_acc / nsamples, cum_loss / nsamples, nsamples))

    def _test_epoch(sess):
        """Does predictions on the test set.
        """
        iterator.restart_dataset(sess, 'test')

        _all_preds = []
        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'test'),
                    tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT,
                }
                _preds = sess.run(preds, feed_dict=feed_dict)
                _all_preds.extend(_preds.tolist())
            except tf.errors.OutOfRangeError:
                break

        output_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
        with tf.gfile.GFile(output_file, "w") as writer:
            writer.write('\n'.join(str(p) for p in _all_preds))

    # Loads pretrained BERT model parameters
    init_checkpoint = os.path.join(bert_pretrain_dir, 'bert_model.ckpt')
    model_utils.init_bert_checkpoint(init_checkpoint)

    # Broadcasts global variables from rank-0 process
    if FLAGS.distributed:
        bcast = hvd.broadcast_global_variables(0)

    session_config = tf.ConfigProto()
    if FLAGS.distributed:
        session_config.gpu_options.visible_device_list = str(hvd.local_rank())

    with tf.Session(config=session_config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        if FLAGS.distributed:
            bcast.run()

        # Restores trained model if specified
        saver = tf.train.Saver()
        if FLAGS.checkpoint:
            saver.restore(sess, FLAGS.checkpoint)

        iterator.initialize_dataset(sess)

        if FLAGS.do_train:
            for i in range(config_data.max_train_epoch):
                _train_epoch(sess)
            saver.save(sess, FLAGS.output_dir + '/model.ckpt')

        if FLAGS.do_eval:
            _eval_epoch(sess)

        if FLAGS.do_test:
            _test_epoch(sess)