Пример #1
0
    def wrapper(**kwargs):
        opt = tf.sg_opt(kwargs)

        # default training options
        opt += tf.sg_opt(lr=0.001,
                         save_dir='asset/train',
                         max_ep=1000,
                         ep_size=100000,
                         save_interval=600,
                         log_interval=60,
                         early_stop=True,
                         lr_reset=False,
                         eval_metric=[],
                         max_keep=5,
                         keep_interval=1,
                         tqdm=True,
                         console_log=False)

        # make directory if not exist
        if not os.path.exists(opt.save_dir + '/log'):
            os.makedirs(opt.save_dir + '/log')
        if not os.path.exists(opt.save_dir + '/ckpt'):
            os.makedirs(opt.save_dir + '/ckpt')

        # find last checkpoint
        last_file = tf.train.latest_checkpoint(opt.save_dir + '/ckpt')
        if last_file:
            ep = start_ep = int(last_file.split('-')[1]) + 1
            start_step = int(last_file.split('-')[2])
        else:
            ep = start_ep = 1
            start_step = 0

        # checkpoint saver
        saver = tf.train.Saver(max_to_keep=opt.max_keep,
                               keep_checkpoint_every_n_hours=opt.keep_interval)

        # summary writer
        summary_writer = tf.train.SummaryWriter(opt.save_dir + '/log',
                                                graph=tf.get_default_graph())

        # add learning rate summary
        with tf.name_scope('summary'):
            tf.scalar_summary('60. learning_rate/learning_rate',
                              _learning_rate)

        # add evaluation metric summary
        for m in opt.eval_metric:
            tf.sg_summary_metric(m)

        # summary op
        summary_op = tf.merge_all_summaries()

        # create session
        if opt.sess:
            sess = opt.sess
        else:
            # session with multiple GPU support
            sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
            # initialize variables
            sg_init(sess)

        # restore last checkpoint
        if last_file:
            saver.restore(sess, last_file)

        # set learning rate
        if start_ep == 1 or opt.lr_reset:
            sess.run(_learning_rate.assign(opt.lr))

        # logging
        tf.sg_info('Training started from epoch[%03d]-step[%d].' %
                   (start_ep, start_step))

        try:
            # start data queue runner
            with tf.sg_queue_context(sess):

                # set session mode to train
                tf.sg_set_train(sess)

                # loss history for learning rate decay
                loss, loss_prev, early_stopped = None, None, False

                # time stamp for saving and logging
                last_saved = last_logged = time.time()

                # epoch loop
                for ep in range(start_ep, opt.max_ep + 1):

                    # show progressbar
                    if opt.tqdm:
                        iterator = tqdm(range(opt.ep_size),
                                        desc='train',
                                        ncols=70,
                                        unit='b',
                                        leave=False)
                    else:
                        iterator = range(opt.ep_size)

                    # batch loop
                    for _ in iterator:

                        # call train function
                        batch_loss = func(sess, opt)

                        # loss history update
                        if batch_loss is not None:
                            if loss is None:
                                loss = np.mean(batch_loss)
                            else:
                                loss = loss * 0.9 + np.mean(batch_loss) * 0.1

                        # saving
                        if time.time() - last_saved > opt.save_interval:
                            last_saved = time.time()
                            saver.save(sess,
                                       opt.save_dir + '/ckpt/model-%03d' % ep,
                                       write_meta_graph=False,
                                       global_step=sess.run(
                                           tf.sg_global_step()))

                        # logging
                        if time.time() - last_logged > opt.log_interval:
                            last_logged = time.time()

                            # set session mode to infer
                            tf.sg_set_infer(sess)

                            # run evaluation op
                            if len(opt.eval_metric) > 0:
                                sess.run(opt.eval_metric)

                            if opt.console_log:  # console logging
                                # log epoch information
                                tf.sg_info(
                                    '\tEpoch[%03d:lr=%7.5f:gs=%d] - loss = %s'
                                    % (ep, sess.run(_learning_rate),
                                       sess.run(tf.sg_global_step()),
                                       ('NA' if loss is None else '%8.6f' %
                                        loss)))
                            else:  # tensorboard logging
                                # run logging op
                                summary_writer.add_summary(
                                    sess.run(summary_op),
                                    global_step=sess.run(tf.sg_global_step()))

                            # learning rate decay
                            if opt.early_stop and loss_prev:
                                # if loss stalling
                                if loss >= 0.95 * loss_prev:
                                    # early stopping
                                    current_lr = sess.run(_learning_rate)
                                    if current_lr < 5e-6:
                                        early_stopped = True
                                        break
                                    else:
                                        # decrease learning rate by half
                                        sess.run(
                                            _learning_rate.assign(current_lr /
                                                                  2.))

                            # update loss history
                            loss_prev = loss

                            # revert session mode to train
                            tf.sg_set_train(sess)

                    # log epoch information
                    if not opt.console_log:
                        tf.sg_info(
                            '\tEpoch[%03d:lr=%7.5f:gs=%d] - loss = %s' %
                            (ep, sess.run(_learning_rate),
                             sess.run(tf.sg_global_step()),
                             ('NA' if loss is None else '%8.6f' % loss)))

                    if early_stopped:
                        tf.sg_info('\tEarly stopped ( no loss progress ).')
                        break
        finally:
            # save last epoch
            saver.save(sess,
                       opt.save_dir + '/ckpt/model-%03d' % ep,
                       write_meta_graph=False,
                       global_step=sess.run(tf.sg_global_step()))

            # set session mode to infer
            tf.sg_set_infer(sess)

            # logging
            tf.sg_info('Training finished at epoch[%d]-step[%d].' %
                       (ep, sess.run(tf.sg_global_step())))

            # close session
            if opt.sess is None:
                sess.close()
Пример #2
0
    def wrapper(**kwargs):
        r""" Manages arguments of `tf.sg_opt`.

        Args:
          **kwargs:
            lr: A Python Scalar (optional). Learning rate. Default is .001.

            save_dir: A string. The root path to which checkpoint and log files are saved.
              Default is `asset/train`.
            max_ep: A positive integer. Maximum number of epochs. Default is 1000.
            ep_size: A positive integer. Number of Total batches in an epoch.
              For proper display of log. Default is 1e5.

            save_interval: A Python scalar. The interval of saving checkpoint files.
              By default, for every 600 seconds, a checkpoint file is written.
            log_interval: A Python scalar. The interval of recoding logs.
              By default, for every 60 seconds, logging is executed.
            max_keep: A positive integer. Maximum number of recent checkpoints to keep. Default is 5.
            keep_interval: A Python scalar. How often to keep checkpoints. Default is 1 hour.

            eval_metric: A list of tensors containing the value to evaluate. Default is [].

            tqdm: Boolean. If True (Default), progress bars are shown. If False, a series of loss
                will be shown on the console.
        """
        opt = tf.sg_opt(kwargs)

        # default training options
        opt += tf.sg_opt(lr=0.001,
                         save_dir='asset/train',
                         max_ep=1000, ep_size=100000,
                         save_interval=600, log_interval=60,
                         eval_metric=[],
                         max_keep=5, keep_interval=1,
                         tqdm=True)

        # training epoch and loss
        epoch, loss = -1, None

        # checkpoint saver
        saver = tf.train.Saver(max_to_keep=opt.max_keep,
                               keep_checkpoint_every_n_hours=opt.keep_interval)

        # add evaluation summary
        for m in opt.eval_metric:
            tf.sg_summary_metric(m)

        # summary writer
        log_dir = opt.save_dir + '/run-%02d%02d-%02d%02d' % tuple(tf.time.localtime(tf.time.time()))[1:5]
        summary_writer = tf.summary.FileWriter(log_dir)

        # console logging function
        def console_log(sess_):
            if epoch >= 0:
                tf.sg_info('\tEpoch[%03d:gs=%d] - loss = %s' %
                           (epoch, sess_.run(tf.sg_global_step()),
                            ('NA' if loss is None else '%8.6f' % loss)))

        local_init_op = tf.group(tf.sg_phase().assign(True), tf.tables_initializer(), tf.local_variables_initializer())

        # create supervisor
        sv = tf.train.Supervisor(logdir=opt.save_dir,
                                 saver=saver,
                                 save_model_secs=opt.save_interval,
                                 summary_writer=summary_writer,
                                 save_summaries_secs=opt.log_interval,
                                 global_step=tf.sg_global_step(),
                                 local_init_op=local_init_op)

        # create session
        with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:

            # console logging loop
            if not opt.tqdm:
                sv.loop(opt.log_interval, console_log, args=(sess,))

            # get start epoch
            _step = sess.run(tf.sg_global_step())
            ep = _step // opt.ep_size

            best_f1 = 0
            # check if already finished
            if ep <= opt.max_ep:

                # logging
                tf.sg_info('Training started from epoch[%03d]-step[%d].' % (ep, _step))

                # epoch loop
                for ep in range(ep, opt.max_ep + 1):

                    # update epoch info
                    start_step = sess.run(tf.sg_global_step()) % opt.ep_size
                    epoch = ep

                    # create progressbar iterator
                    if opt.tqdm:
                        iterator = tf.tqdm(range(start_step, opt.ep_size), total=opt.ep_size, initial=start_step,
                                           desc='train', ncols=70, unit='b', leave=False)
                    else:
                        iterator = range(start_step, opt.ep_size)

                    # batch loop
                    for _ in iterator:

                        # exit loop
                        if sv.should_stop():
                            break

                        # call train function
                        batch_loss = func(sess, opt)

                        # loss history update
                        if batch_loss is not None and \
                                not np.isnan(batch_loss.all()) and not np.isinf(batch_loss.all()):
                            if loss is None:
                                loss = np.mean(batch_loss)
                            else:
                                loss = loss * 0.9 + np.mean(batch_loss) * 0.1

                    # log epoch information
                    console_log(sess)

                    f1_stat = show_metrics(sv, sess, opt.eval_metric[2], opt.eval_metric[3], ep, opt.val_ep_size,
                                              'val', use_tqdm=True)

                    if f1_stat > best_f1:
                        best_f1 = f1_stat

                        max_model_file = opt.save_dir + max_model_name

                        # save last version
                        saver.save(sess, max_model_file)
                        print("Improved F1 score, max model saved in file: %s" % max_model_file)

                        print('Test metrics:')
                        show_metrics(sv, sess, opt.test_metric[0], opt.test_metric[1], ep, opt.test_ep_size,
                                        'test', use_tqdm=True)

                # save last version
                saver.save(sess, opt.save_dir + '/model.ckpt', global_step=sess.run(tf.sg_global_step()))

                # logging
                tf.sg_info('Training finished at epoch[%d]-step[%d].' % (ep, sess.run(tf.sg_global_step())))
            else:
                tf.sg_info('Training already finished at epoch[%d]-step[%d].' %
                           (ep - 1, sess.run(tf.sg_global_step())))
Пример #3
0
    def wrapper(**kwargs):
        r""" Manages arguments of `tf.sg_opt`.

        Args:
          **kwargs:
            lr: A Python Scalar (optional). Learning rate. Default is .001.

            eval_metric: A list of tensors containing the value to evaluate. Default is [].
            early_stop: Boolean. If True (default), the training should stop when the following two conditions are met.
              i. Current loss is less than .95 * previous loss.
              ii. Current learning rate is less than 5e-6.
            lr_reset: Boolean. If True, learning rate is set to opt.lr. when training restarts.
              Otherwise (Default), the value of the stored `_learning_rate` is taken.
            save_dir: A string. The root path to which checkpoint and log files are saved.
              Default is `asset/train`.
            max_ep: A positive integer. Maximum number of epochs. Default is 1000.
            ep_size: A positive integer. Number of Total batches in an epoch.
              For proper display of log. Default is 1e5.

            save_interval: A Python scalar. The interval of saving checkpoint files.
              By default, for every 600 seconds, a checkpoint file is written.
            log_interval: A Python scalar. The interval of recoding logs.
              By default, for every 60 seconds, logging is executed.
            max_keep: A positive integer. Maximum number of recent checkpoints to keep. Default is 5.
            keep_interval: A Python scalar. How often to keep checkpoints. Default is 1 hour.

            tqdm: Boolean. If True (Default), progress bars are shown.
            console_log: Boolean. If True, a series of loss will be shown
              on the console instead of tensorboard. Default is False.
        """
        opt = tf.sg_opt(kwargs)

        # default training options
        opt += tf.sg_opt(lr=0.001,
                         save_dir='asset/train',
                         max_ep=1000, ep_size=100000,
                         save_interval=600, log_interval=60,
                         early_stop=True, lr_reset=False,
                         eval_metric=[],
                         max_keep=5, keep_interval=1,
                         tqdm=True, console_log=False)

        # make directory if not exist
        if not os.path.exists(opt.save_dir):
            os.makedirs(opt.save_dir)

        # find last checkpoint
        last_file = tf.train.latest_checkpoint(opt.save_dir)
        if last_file:
            ep = start_ep = int(last_file.split('-')[1]) + 1
            start_step = int(last_file.split('-')[2])
        else:
            ep = start_ep = 1
            start_step = 0

        # checkpoint saver
        saver = tf.train.Saver(max_to_keep=opt.max_keep,
                               keep_checkpoint_every_n_hours=opt.keep_interval)

        # summary writer
        summary_writer = tf.summary.FileWriter(opt.save_dir, graph=tf.get_default_graph())

        # add learning rate summary
        tf.summary.scalar('learning_r', _learning_rate)

        # add evaluation metric summary
        for m in opt.eval_metric:
            tf.sg_summary_metric(m)

        # summary op
        summary_op = tf.summary.merge_all()

        # create session
        if opt.sess:
            sess = opt.sess
        else:
            # session with multiple GPU support
            sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        # initialize variables
        sg_init(sess)

        # restore last checkpoint
        if last_file:
            saver.restore(sess, last_file)

        # set learning rate
        if start_ep == 1 or opt.lr_reset:
            sess.run(_learning_rate.assign(opt.lr))

        # logging
        tf.sg_info('Training started from epoch[%03d]-step[%d].' % (start_ep, start_step))

        try:
            # start data queue runner
            with tf.sg_queue_context(sess):

                # set session mode to train
                tf.sg_set_train(sess)

                # loss history for learning rate decay
                loss, loss_prev, early_stopped = None, None, False

                # time stamp for saving and logging
                last_saved = last_logged = time.time()

                # epoch loop
                for ep in range(start_ep, opt.max_ep + 1):

                    # show progressbar
                    if opt.tqdm:
                        iterator = tqdm(range(opt.ep_size), desc='train', ncols=70, unit='b', leave=False)
                    else:
                        iterator = range(opt.ep_size)

                    # batch loop
                    for _ in iterator:

                        # call train function
                        batch_loss = func(sess, opt)

                        # loss history update
                        if batch_loss is not None:
                            if loss is None:
                                loss = np.mean(batch_loss)
                            else:
                                loss = loss * 0.9 + np.mean(batch_loss) * 0.1

                        # saving
                        if time.time() - last_saved > opt.save_interval:
                            last_saved = time.time()
                            saver.save(sess, opt.save_dir + '/model-%03d' % ep,
                                       write_meta_graph=False,
                                       global_step=sess.run(tf.sg_global_step()))

                        # logging
                        if time.time() - last_logged > opt.log_interval:
                            last_logged = time.time()

                            # set session mode to infer
                            tf.sg_set_infer(sess)

                            # run evaluation op
                            if len(opt.eval_metric) > 0:
                                sess.run(opt.eval_metric)

                            if opt.console_log:   # console logging
                                # log epoch information
                                tf.sg_info('\tEpoch[%03d:lr=%7.5f:gs=%d] - loss = %s' %
                                           (ep, sess.run(_learning_rate), sess.run(tf.sg_global_step()),
                                            ('NA' if loss is None else '%8.6f' % loss)))
                            else:   # tensorboard logging
                                # run logging op
                                summary_writer.add_summary(sess.run(summary_op),
                                                           global_step=sess.run(tf.sg_global_step()))

                            # learning rate decay
                            if opt.early_stop and loss_prev:
                                # if loss stalling
                                if loss >= 0.95 * loss_prev:
                                    # early stopping
                                    current_lr = sess.run(_learning_rate)
                                    if current_lr < 5e-6:
                                        early_stopped = True
                                        break
                                    else:
                                        # decrease learning rate by half
                                        sess.run(_learning_rate.assign(current_lr / 2.))

                            # update loss history
                            loss_prev = loss

                            # revert session mode to train
                            tf.sg_set_train(sess)

                    # log epoch information
                    if not opt.console_log:
                        tf.sg_info('\tEpoch[%03d:lr=%7.5f:gs=%d] - loss = %s' %
                                   (ep, sess.run(_learning_rate), sess.run(tf.sg_global_step()),
                                    ('NA' if loss is None else '%8.6f' % loss)))

                    if early_stopped:
                        tf.sg_info('\tEarly stopped ( no loss progress ).')
                        break
        finally:
            # save last epoch
            saver.save(sess, opt.save_dir + '/model-%03d' % ep,
                       write_meta_graph=False,
                       global_step=sess.run(tf.sg_global_step()))

            # set session mode to infer
            tf.sg_set_infer(sess)

            # logging
            tf.sg_info('Training finished at epoch[%d]-step[%d].' % (ep, sess.run(tf.sg_global_step())))

            # close session
            if opt.sess is None:
                sess.close()