def get_model(self, optimizer='momentum', effective_lr=False, negative_momentum=False):
        """Builds a model.

        Args:
            optimizer: String. Name of the optimizer.
            effective_lr: Bool. Whether to optimize using the effective learning rate.
            negative momentum: Bool. Whether to optimize using the negative momentum.
        """
        config = get_mnist_mlp_config(
            self._lr,
            self._momentum,
            decay=self._decay,
            effective_lr=effective_lr,
            negative_momentum=negative_momentum)
        x = tf.placeholder(self._dtype, [self._batch_size] + self._input_size)
        y = tf.placeholder(tf.int64, [self._batch_size])
        model = get_mnist_mlp_model(config, x, y, optimizer=optimizer, dtype=self._dtype)
        return model
def run_offline_smd(num_steps,
                    init_lr,
                    init_decay,
                    meta_lr,
                    num_meta_steps,
                    momentum=MOMENTUM,
                    effective_lr=False,
                    negative_momentum=False,
                    pretrain_ckpt=None,
                    output_fname=None,
                    seed=0):
    """Run offline SMD experiments.

    Args:
        init_lr: Initial learning rate.
        init_decay: Initial decay constant.
        data_list: List of tuples of inputs and labels.
        meta_lr: Float. Meta descent learning rate.
        num_meta_steps: Int. Number of meta descent steps.
        momentum: Float. Momentum.
        effective_lr: Bool. Whether to optimize in the effective LR space.
        negative_momentum: Bool. Whether to optimize in the negative momentum space.
    """
    bsize = BATCH_SIZE
    if output_fname is not None:
        log_folder = os.path.dirname(output_fname)
    else:
        log_folder = os.path.join('results', 'mnist', 'offline', 'smd')
        log_folder = os.path.join(log_folder, _get_run_number(log_folder))
    if not os.path.exists(log_folder):
        os.makedirs(log_folder)
    with tf.Graph().as_default(), tf.Session() as sess:
        dataset = get_dataset('mnist')
        exp_logger = _get_exp_logger(sess, log_folder)
        if effective_lr:
            init_lr_ = init_lr / float(1.0 - momentum)
        else:
            init_lr_ = init_lr

        if negative_momentum:
            init_mom_ = 1.0 - momentum
        else:
            init_mom_ = momentum

        config = get_mnist_mlp_config(
            init_lr_,
            init_mom_,
            decay=init_decay,
            effective_lr=effective_lr,
            negative_momentum=negative_momentum)
        x = tf.placeholder(tf.float32, [None, 28, 28, 1], name="x")
        y = tf.placeholder(tf.int64, [None], name="y")
        with tf.name_scope('Train'):
            with tf.variable_scope('Model'):
                model = get_mnist_mlp_model(
                    config,
                    x,
                    y,
                    optimizer='momentum_inv_decay',
                    training=True)
        all_vars = tf.global_variables()
        var_to_restore = list(
            filter(lambda x: 'momentum' not in x.name.lower(), all_vars))
        var_to_restore = list(
            filter(lambda x: 'global_step' not in x.name.lower(),
                   var_to_restore))
        var_to_restore = list(
            filter(lambda x: 'lr' not in x.name.lower(), var_to_restore))
        var_to_restore = list(
            filter(lambda x: 'mom' not in x.name.lower(), var_to_restore))
        var_to_restore = list(
            filter(lambda x: 'decay' not in x.name.lower(), var_to_restore))
        saver = tf.train.Saver(var_to_restore)
        rnd = np.random.RandomState(seed)

        hp_dict = {'lr': init_lr, 'decay': init_decay}
        hp_names = hp_dict.keys()
        hyperparams = dict([(hp_name, model.optimizer.hyperparams[hp_name])
                            for hp_name in hp_names])
        grads = model.optimizer.grads
        accumulators = model.optimizer.accumulators
        new_accumulators = model.optimizer.new_accumulators
        loss = model.cost

        # Build look ahead graph.
        look_ahead_ops, hp_grad_ops, zero_out_ops = look_ahead_grads(
            hyperparams, grads, accumulators, new_accumulators, loss)

        # Meta optimizer, use Adam on the log space.
        # meta_opt = LogOptimizer(tf.train.AdamOptimizer(meta_lr))
        meta_opt = LogOptimizer(tf.train.MomentumOptimizer(meta_lr, 0.9))
        hp = [model.optimizer.hyperparams[hp_name] for hp_name in hp_names]
        hp_grads_dict = {
            'lr': tf.placeholder(tf.float32, [], name='lr_grad'),
            'decay': tf.placeholder(tf.float32, [], name='decay_grad')
        }
        hp_grads_plh = [hp_grads_dict[hp_name] for hp_name in hp_names]
        hp_grads_and_vars = list(zip(hp_grads_plh, hp))
        cgrad = {'lr': (-1e1, 1e1), 'decay': (-1e1, 1e1)}
        cval = {'lr': (1e-4, 1e1), 'decay': (1e-4, 1e3)}
        cgrad_ = [cgrad[hp_name] for hp_name in hp_names]
        cval_ = [cval[hp_name] for hp_name in hp_names]
        meta_train_op = meta_opt.apply_gradients(
            hp_grads_and_vars, clip_gradients=cgrad_, clip_values=cval_)

        if output_fname is not None:
            msg = '{} exists, please remove previous experiment data.'.format(
                output_fname)
            assert not os.path.exists(output_fname), msg
            log.info('Writing to {}'.format(output_fname))
            with open(output_fname, 'w') as f:
                f.write('Step,LR,Mom,Decay,Loss\n')

        # Initialize all variables.
        sess.run(tf.global_variables_initializer())
        var_list = tf.global_variables()
        if pretrain_ckpt is not None:
            saver.restore(sess, pretrain_ckpt)
        ckpt = build_checkpoint(var_list)
        write_op = write_checkpoint(ckpt, var_list)
        read_op = read_checkpoint(ckpt, var_list)
        sess.run(write_op)

        # Progress bar.
        it = tqdm(
            six.moves.xrange(num_meta_steps),
            ncols=0,
            desc='look_{}_ilr_{:.0e}_decay_{:.0e}'.format(
                num_steps, init_lr, init_decay))

        for run in it:
            # Stochastic data list makes the SMD converge faster.
            data_list = [
                dataset.next_batch(bsize)
                for step in six.moves.xrange(num_steps)
            ]
            eval_data_list = [
                dataset.next_batch(bsize)
                for step in six.moves.xrange(NUM_TRAIN // bsize)
            ]
            # Run meta descent step.
            cost, hp_dict = meta_step(sess, model, data_list, look_ahead_ops,
                                      hp_grad_ops, hp_grads_plh, meta_train_op,
                                      eval_data_list)

            # Early stop if hits NaN.
            if np.isnan(cost):
                break

            # Restore parameters.
            sess.run(read_op)
            for hpname, hpval in hp_dict.items():
                model.optimizer.assign_hyperparam(sess, hpname, hpval)

            # Read out hyperparameters in normal parameterization.
            if negative_momentum:
                mom = 1 - hp_dict['mom']
            else:
                mom = hp_dict['mom']
            if effective_lr:
                lr = hp_dict['lr'] * (1 - mom)
            else:
                lr = hp_dict['lr']

            # Write to logs.
            if output_fname is not None:
                with open(output_fname, 'a') as f:
                    f.write('{:d},{:f},{:f},{:f},{:f}\n'.format(
                        run, lr, hp_dict['mom'], hp_dict['decay'], cost))
            # Log to TensorBoard.
            exp_logger.log(run, 'lr', lr)
            exp_logger.log(run, 'decay', hp_dict['decay'])
            exp_logger.log(run, 'log loss', np.log10(cost))
            exp_logger.flush()

            # Update progress bar.
            it.set_postfix(
                lr='{:.3e}'.format(lr),
                decay='{:.3e}'.format(hp_dict['decay']),
                loss='{:.3e}'.format(cost))

        exp_logger.close()
Beispiel #3
0
def online_smd(dataset_name='mnist',
               init_lr=1e-1,
               momentum=0.001,
               num_steps=20000,
               middle_decay=False,
               steps_per_update=10,
               smd=True,
               steps_look_ahead=5,
               num_meta_steps=10,
               steps_per_eval=100,
               batch_size=100,
               meta_lr=1e-2,
               print_step=False,
               effective_lr=True,
               negative_momentum=True,
               optimizer='momentum',
               stochastic=True,
               exp_folder='.'):
    """Train an MLP for MNIST.

    Args:
        dataset_name: String. Name of the dataset.
        init_lr: Float. Initial learning rate, default 0.1.
        momentum: Float. Initial momentum, default 0.9.
        num_steps: Int. Total number of steps, default 20000.
        middle_decay: Whether applying manual learning rate decay to 1e-4 from the middle, default False.
        steps_per_update: Int. Number of steps per update, default 10.
        smd: Bool. Whether run SMD.
        steps_look_ahead: Int. Number of steps to look ahead, default 5.
        num_meta_steps: Int. Number of meta steps, default 10.
        steps_per_eval: Int. Number of training steps per evaluation, default 100.
        batch_size: Int. Mini-batch size, default 100.
        meta_lr: Float. Meta learning rate, default 1e-2.
        print_step: Bool. Whether to print loss during training, default True.
        effective_lr: Bool. Whether to re-parameterize learning rate as lr / (1 - momentum), default True.
        negative_momentum: Bool. Whether to re-parameterize momentum as (1 - momentum), default True.
        optimizer: String. Name of the optimizer. Options: `momentum`, `adam, default `momentum`.
        stochastic: Bool. Whether to do stochastic or deterministic look ahead, default True.

    Returns:
        results: Results tuple object.
    """
    dataset = get_dataset(dataset_name)
    dataset_train = get_dataset(
        dataset_name)  # For evaluate training progress (full epoch).
    dataset_test = get_dataset(
        dataset_name, test=True)  # For evaluate test progress (full epoch).

    if dataset_name == 'mnist':
        input_shape = [None, 28, 28, 1]
    elif dataset_name.startswith('cifar'):
        input_shape = [None, 32, 32, 3]

    x = tf.placeholder(tf.float32, input_shape, name="x")
    y = tf.placeholder(tf.int64, [None], name="y")

    if effective_lr:
        init_lr_ = init_lr / (1.0 - momentum)
    else:
        init_lr_ = init_lr

    if negative_momentum:
        init_mom_ = 1.0 - momentum
    else:
        init_mom_ = momentum
    if dataset_name == 'mnist':
        config = get_mnist_mlp_config(
            init_lr_,
            init_mom_,
            effective_lr=effective_lr,
            negative_momentum=negative_momentum)
    elif dataset_name == 'cifar-10':
        config = get_cifar_cnn_config(
            init_lr_,
            init_mom_,
            effective_lr=effective_lr,
            negative_momentum=negative_momentum)
    else:
        raise NotImplemented
    with tf.name_scope('Train'):
        with tf.variable_scope('Model'):
            if dataset_name == 'mnist':
                m = get_mnist_mlp_model(
                    config, x, y, optimizer=optimizer, training=True)
                model = m
            elif dataset_name == 'cifar-10':
                m = get_cifar_cnn_model(
                    config, x, y, optimizer=optimizer, training=True)
                model = m
    with tf.name_scope('Test'):
        with tf.variable_scope('Model', reuse=True):
            if dataset_name == 'mnist':
                mtest = get_mnist_mlp_model(config, x, y, training=False)
            elif dataset_name == 'cifar-10':
                mtest = get_cifar_cnn_model(config, x, y, training=False)

    final_lr = 1e-4
    midpoint = num_steps // 2

    if dataset_name == 'mnist':
        num_train = 60000
        num_test = 10000
    elif dataset_name.startswith('cifar'):
        num_train = 50000
        num_test = 10000

    lr_ = init_lr_
    mom_ = init_mom_
    bsize = batch_size
    steps_per_epoch = num_train // bsize
    steps_test_per_epoch = num_test // bsize

    train_xent_list = []
    train_acc_list = []
    test_xent_list = []
    test_acc_list = []
    lr_list = []
    mom_list = []
    step_list = []
    log.info(
        'Applying decay at midpoint with final learning rate = {:.3e}'.format(
            final_lr))

    if 'momentum' in optimizer:
        mom_name = 'mom'
    elif 'adam' in optimizer:
        mom_name = 'beta1'
    else:
        raise ValueError('Unknown optimizer')
    hp_dict = {'lr': init_lr} #, mom_name: momentum}
    hp_names = hp_dict.keys()
    hyperparams = dict([(hp_name, model.optimizer.hyperparams[hp_name])
                        for hp_name in hp_names])
    grads = model.optimizer.grads
    accumulators = model.optimizer.accumulators
    new_accumulators = model.optimizer.new_accumulators
    loss = model.cost

    # Build look ahead graph.
    look_ahead_ops, hp_grad_ops, zero_out_ops = look_ahead_grads(
        hyperparams, grads, accumulators, new_accumulators, loss)

    # Meta optimizer, use Adam on the log space.
    meta_opt = LogOptimizer(tf.train.AdamOptimizer(meta_lr))
    hp = [model.optimizer.hyperparams[hp_name] for hp_name in hp_names]
    hp_grads_dict = {
        'lr': tf.placeholder(tf.float32, [], name='lr_grad'),
        # mom_name: tf.placeholder(
        #     tf.float32, [], name='{}_grad'.format(mom_name))
    }
    hp_grads_plh = [hp_grads_dict[hp_name] for hp_name in hp_names]
    hp_grads_and_vars = list(zip(hp_grads_plh, hp))
    cgrad = {'lr': (-1e1, 1e1)} #, mom_name: (-1e1, 1e1)}
    cval = {'lr': (1e-4, 1e1)} #, mom_name: (1e-4, 1e0)}
    cgrad_ = [cgrad[hp_name] for hp_name in hp_names]
    cval_ = [cval[hp_name] for hp_name in hp_names]
    meta_train_op = meta_opt.apply_gradients(
        hp_grads_and_vars, clip_gradients=cgrad_, clip_values=cval_)

    var_list = tf.global_variables()
    ckpt = build_checkpoint(tf.global_variables())
    write_op = write_checkpoint(ckpt, var_list)
    read_op = read_checkpoint(ckpt, var_list)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        exp_logger = _get_exp_logger(sess, exp_folder)

        def log_hp(hp_dict):
            lr_ = hp_dict['lr']
            mom_ = hp_dict['mom']
            # Log current learning rate and momentum.
            if negative_momentum:
                exp_logger.log(ii, 'mom', 1.0 - mom_)
                exp_logger.log(ii, 'log neg mom', np.log10(mom_))
                mom__ = 1.0 - mom_
            else:
                exp_logger.log(ii, 'mom', mom_)
                exp_logger.log(ii, 'log neg mom', np.log10(1.0 - mom_))
                mom__ = mom_

            if effective_lr:
                lr__ = lr_ * (1.0 - mom__)
                eflr_ = lr_
            else:
                lr__ = lr_
                eflr_ = lr_ / (1.0 - mom__)
            exp_logger.log(ii, 'eff lr', eflr_)
            exp_logger.log(ii, 'log eff lr', np.log10(eflr_))
            exp_logger.log(ii, 'lr', lr__)
            exp_logger.log(ii, 'log lr', np.log10(lr__))
            exp_logger.flush()
            return lr__, mom__

        # Assign initial learning rate and momentum.
        m.optimizer.assign_hyperparam(sess, 'lr', lr_)
        m.optimizer.assign_hyperparam(sess, 'mom', mom_)
        train_iter = six.moves.xrange(num_steps)
        if not print_step:
            train_iter = tqdm(train_iter, ncols=0)
        for ii in train_iter:
            # Meta-optimization loop.
            if ii == 0 or ii % steps_per_update == 0:
                if ii < midpoint and smd:
                    if stochastic:
                        data_list = [
                            dataset.next_batch(bsize)
                            for step in six.moves.xrange(steps_look_ahead)
                        ]
                        # Take next few batches for last step evaluation.
                        eval_data_list = [
                            dataset.next_batch(bsize)
                            for step in six.moves.xrange(steps_look_ahead)
                        ]
                    else:
                        data_entry = dataset.next_batch(bsize)
                        data_list = [data_entry] * steps_look_ahead
                        # Use the deterministic batch for last step evaluation.
                        eval_data_list = [data_list[0]]
                    sess.run(write_op)
                    for ms in six.moves.xrange(num_meta_steps):
                        cost, hp_dict = meta_step(sess, model, data_list,
                                                  look_ahead_ops, hp_grad_ops,
                                                  hp_grads_plh, meta_train_op,
                                                  eval_data_list)
                        sess.run(read_op)
                        for hpname, hpval in hp_dict.items():
                            model.optimizer.assign_hyperparam(
                                sess, hpname, hpval)
                    lr_ = hp_dict['lr']
                    # mom_ = hp_dict['mom']
                else:
                    hp_dict = sess.run(model.optimizer.hyperparams)
                lr_log, mom_log = log_hp(hp_dict)
                lr_list.append(lr_log)
                mom_list.append(mom_log)

            if ii == midpoint // 2:
                m.optimizer.assign_hyperparam(sess, 'mom', 1 - 0.9)

            if ii == midpoint:
                lr_before_mid = hp_dict['lr']
                tau = (num_steps - midpoint) / np.log(lr_before_mid / final_lr)

            if ii > midpoint:
                lr_ = np.exp(-(ii - midpoint) / tau) * lr_before_mid
                m.optimizer.assign_hyperparam(sess, 'lr', lr_)

            # Run regular training.
            if lr_ > 1e-6:
                # Use CBL for first half of training
                xd, yd = data_entry if (smd and not stochastic and ii < midpoint) else dataset.next_batch(bsize)
                cost_, _ = sess.run(
                    [m.cost, m.train_op], feed_dict={
                        m.x: xd,
                        m.y: yd
                    })
                if ii < midpoint:
                    sess.run(m._retrieve_ema_op)

            # Evaluate every certain number of steps.
            if ii == 0 or (ii + 1) % steps_per_eval == 0:
                test_acc = 0.0
                test_xent = 0.0
                train_acc = 0.0
                train_xent = 0.0

                # Report full epoch training loss.
                for jj in six.moves.xrange(steps_per_epoch):
                    xd, yd = dataset_train.next_batch(bsize)
                    xent_, acc_ = sess.run(
                        [m.cost, m.acc], feed_dict={
                            x: xd,
                            y: yd
                        })
                    train_xent += xent_ / float(steps_per_epoch)
                    train_acc += acc_ / float(steps_per_epoch)
                step_list.append(ii + 1)
                train_xent_list.append(train_xent)
                train_acc_list.append(train_acc)
                dataset_train.reset()

                # Report full epoch validation loss.
                for jj in six.moves.xrange(steps_test_per_epoch):
                    xd, yd = dataset_test.next_batch(bsize)
                    xent_, acc_ = sess.run(
                        [mtest.cost, mtest.acc], feed_dict={
                            x: xd,
                            y: yd
                        })
                    test_xent += xent_ / float(steps_test_per_epoch)
                    test_acc += acc_ / float(steps_test_per_epoch)
                test_xent_list.append(test_xent)
                test_acc_list.append(test_acc)
                dataset_test.reset()

                # Log training progress.
                exp_logger.log(ii, 'train loss', train_xent)
                exp_logger.log(ii, 'log train loss', np.log10(train_xent))
                exp_logger.log(ii, 'test loss', test_xent)
                exp_logger.log(ii, 'log test loss', np.log10(test_xent))
                exp_logger.log(ii, 'train acc', train_acc)
                exp_logger.log(ii, 'test acc', test_acc)
                exp_logger.flush()

                if print_step:
                    log.info((
                        'Steps {:d} T Xent {:.3e} T Acc {:.3f} V Xent {:.3e} V Acc {:.3f} '
                        'LR {:.3e}').format(ii + 1, train_xent,
                                            train_acc * 100.0, test_xent,
                                            test_acc * 100.0, lr_))

    return Results(
        step=np.array(step_list),
        train_xent=np.array(train_xent_list),
        train_acc=np.array(train_acc_list),
        test_xent=np.array(test_xent_list),
        test_acc=np.array(test_acc_list),
        lr=np.array(lr_list),
        momentum=np.array(mom_list))
def run_random_search(num_steps,
                      lr_limit,
                      decay_limit,
                      num_samples,
                      ckpt,
                      output,
                      seed=0):
    """Random search hyperparameters to plot the surface.

    Args:
        num_steps: Int. Number of look ahead steps.
        lr_limit: Tuple. Two float denoting the lower and upper search bound.
        decay_limit: Tuple. Two float denoting the lower and upper search bound.
        num_samples: Int. Number of samples to try.
        ckpt: String. Pretrain checkpoint name.
        output: String. Output CSV results file name.

    Returns:
    """
    bsize = BATCH_SIZE
    log.info('Writing output to {}'.format(output))
    log_folder = os.path.dirname(output)
    if not os.path.exists(log_folder):
        os.makedirs(log_folder)

    with tf.Graph().as_default(), tf.Session() as sess:
        dataset = get_dataset('mnist')
        config = get_mnist_mlp_config(0.0, MOMENTUM)
        x = tf.placeholder(tf.float32, [None, 28, 28, 1], name="x")
        y = tf.placeholder(tf.int64, [None], name="y")
        with tf.name_scope('Train'):
            with tf.variable_scope('Model'):
                m = get_mnist_mlp_model(config, x, y, training=True)
        var_to_restore = list(
            filter(lambda x: 'Momentum' not in x.name, tf.global_variables()))
        saver = tf.train.Saver(var_to_restore)
        # 200 points in the learning rate list, and 100 points in the decay list.
        # random sample 1000.
        rnd = np.random.RandomState(seed)
        # Get a list of stochastic batches first.
        data_list = [
            dataset.next_batch(bsize) for step in six.moves.xrange(num_steps)
        ]
        settings = []
        for run in tqdm(
                six.moves.xrange(num_samples),
                ncols=0,
                desc='{} steps'.format(num_steps)):
            sess.run(tf.global_variables_initializer())
            saver.restore(sess, ckpt)
            lr = np.random.rand() * (lr_limit[1] - lr_limit[0]) + lr_limit[0]
            lr = np.exp(lr * np.log(10))
            decay = rnd.uniform(0, 1) * (
                decay_limit[1] - decay_limit[0]) + decay_limit[0]
            decay = np.exp(decay * np.log(10))
            m.optimizer.assign_hyperparam(sess, 'lr', lr)
            loss, final_loss = train_steps(
                sess, m, data_list, init_lr=lr, decay_const=decay)
            settings.append([lr, decay, final_loss])
        settings = np.array(settings)
        np.savetxt(output, settings, delimiter=',', header='lr,decay,loss')
        loss = settings[:, 2]
        sort_idx = np.argsort(loss)
        sorted_settings = settings[sort_idx]
        print('======')
        print('Best 10 settings')
        for ii in six.moves.xrange(10):
            aa = sorted_settings[ii, 0]
            decay = sorted_settings[ii, 1]
            loss = sorted_settings[ii, 2]
            print('Alpha', aa, 'Decay', decay, 'Loss', loss)
    return sorted_settings[0, 0], sorted_settings[0, 1], sorted_settings[0, 2]
Beispiel #5
0
def train_mnist_mlp_with_test(init_lr=0.1,
                              momentum=0.9,
                              num_steps=50000,
                              middle_decay=False,
                              inverse_decay=False,
                              decay_const=0.0,
                              time_const=5000.0,
                              steps_per_eval=100,
                              batch_size=100,
                              pretrain_ckpt=None,
                              save_ckpt=None,
                              print_step=False,
                              data_list=None,
                              data_list_eval=None,
                              data_list_test=None):
    """Train an MLP for MNIST.

    Args:
        init_lr:
        momentum:
        num_steps:
        middle_decay:
        pretrain_ckpt:

    Returns:
        results: Results tuple object.
    """
    if data_list is None:
        dataset = get_dataset('mnist')
    if data_list_eval is None:
        dataset_train = get_dataset('mnist')
    if data_list_test is None:
        dataset_test = get_dataset('mnist', test=True)
    x = tf.placeholder(tf.float32, [None, 28, 28, 1], name="x")
    y = tf.placeholder(tf.int64, [None], name="y")
    config = get_mnist_mlp_config(init_lr, momentum)
    with tf.name_scope('Train'):
        with tf.variable_scope('Model'):
            m = get_mnist_mlp_model(config, x, y, training=True)
    with tf.name_scope('Test'):
        with tf.variable_scope('Model', reuse=True):
            mtest = get_mnist_mlp_model(config, x, y, training=False)

    final_lr = 1e-4
    midpoint = num_steps // 2

    if True:
        num_train = 60000
        num_test = 10000
    lr_ = init_lr
    bsize = batch_size
    steps_per_epoch = num_train // bsize
    steps_test_per_epoch = num_test // bsize
    tau = (num_steps - midpoint) / np.log(init_lr / final_lr)

    train_xent_list = []
    train_cost_list = []
    train_acc_list = []
    test_xent_list = []
    test_cost_list = []
    test_acc_list = []
    lr_list = []
    step_list = []
    var_to_restore = list(
        filter(lambda x: 'momentum' not in x.name.lower(),
               tf.global_variables()))
    var_to_restore = list(
        filter(lambda x: 'global_step' not in x.name.lower(), var_to_restore))
    var_to_restore = list(
        filter(lambda x: 'lr' not in x.name.lower(), var_to_restore))
    var_to_restore = list(
        filter(lambda x: 'mom' not in x.name.lower(), var_to_restore))
    var_to_restore = list(
        filter(lambda x: 'decay' not in x.name.lower(), var_to_restore))
    var_to_init = list(
        filter(lambda x: x not in var_to_restore, tf.global_variables()))
    restorer = tf.train.Saver(var_to_restore)
    if inverse_decay:
        log.info(
            'Applying inverse decay with time constant = {:.3e} and decay constant = {:.3e}'.
            format(time_const, decay_const))
    if middle_decay:
        log.info(
            'Applying decay at midpoint with final learning rate = {:.3e}'.
            format(final_lr))
    assert not (
        inverse_decay and middle_decay
    ), 'Inverse decay and middle decay cannot be applied at the same time.'

    with tf.Session() as sess:
        if pretrain_ckpt is None:
            sess.run(tf.global_variables_initializer())
        else:
            sess.run(tf.variables_initializer(var_to_init))
            restorer.restore(sess, pretrain_ckpt)
        # Assign initial learning rate.
        m.optimizer.assign_hyperparam(sess, 'lr', lr_)
        train_iter = six.moves.xrange(num_steps)
        if not print_step:
            train_iter = tqdm(train_iter, ncols=0)

        for ii in train_iter:
            if data_list is None:
                xd, yd = dataset.next_batch(bsize)
            else:
                xd, yd = data_list[ii]
            if lr_ > 1e-6:
                cost_, _ = sess.run(
                    [m.cost, m.train_op], feed_dict={
                        x: xd,
                        y: yd
                    })
            test_acc = 0.0
            test_xent = 0.0
            train_acc = 0.0
            train_xent = 0.0
            epoch = ii // steps_per_epoch

            if inverse_decay:
                lr_ = init_lr / ((1.0 + ii / time_const)**decay_const)

            if middle_decay and ii > midpoint:
                lr_ = np.exp(-(ii - midpoint) / tau) * init_lr

            m.optimizer.assign_hyperparam(sess, 'lr', lr_)

            # Evaluate every certain number of steps.
            if ii == 0 or (ii + 1) % steps_per_eval == 0:
                for jj in six.moves.xrange(steps_per_epoch):
                    if data_list_eval is None:
                        xd, yd = dataset_train.next_batch(bsize)
                    else:
                        xd, yd = data_list_eval[jj]
                    xent_, acc_ = sess.run(
                        [m.cost, m.acc], feed_dict={
                            x: xd,
                            y: yd
                        })
                    train_xent += xent_ / float(steps_per_epoch)
                    train_acc += acc_ / float(steps_per_epoch)
                step_list.append(ii + 1)
                train_xent_list.append(train_xent)
                train_acc_list.append(train_acc)

                if data_list_eval is None:
                    dataset_train.reset()

                for jj in six.moves.xrange(steps_test_per_epoch):
                    if data_list_test is None:
                        xd, yd = dataset_test.next_batch(bsize)
                    else:
                        xd, yd = data_list_test[jj]
                    xent_, acc_ = sess.run(
                        [mtest.cost, mtest.acc], feed_dict={
                            x: xd,
                            y: yd
                        })
                    test_xent += xent_ / float(steps_test_per_epoch)
                    test_acc += acc_ / float(steps_test_per_epoch)
                test_xent_list.append(test_xent)
                test_acc_list.append(test_acc)

                if data_list_test is None:
                    dataset_test.reset()

                lr_list.append(lr_)
                if print_step:
                    log.info((
                        'Steps {:d} T Xent {:.3e} T Acc {:.3f} V Xent {:.3e} V Acc {:.3f} '
                        'LR {:.3e}').format(ii + 1, train_xent,
                                            train_acc * 100.0, test_xent,
                                            test_acc * 100.0, lr_))
        if save_ckpt is not None:
            saver = tf.train.Saver()
            saver.save(sess, save_ckpt)

    return Results(
        step=np.array(step_list),
        train_xent=np.array(train_xent_list),
        train_acc=np.array(train_acc_list),
        test_xent=np.array(test_xent_list),
        test_acc=np.array(test_acc_list),
        lr=np.array(lr_list),
        decay=decay_const)