Пример #1
0
def call_with_critical_point_scanner(f, *args):
    """Calls f(scanner, *args) in TensorFlow session-context.

  Here, `scanner` will be a function with signature
  scanner(seed:int, scale:float) -> (potential, stationarity, pos_vector).

  The function `scanner` can only perform a scan when called from within
  the TF session-context that is set up by this function.
  """
    graph = tf.Graph()
    with graph.as_default():
        t_input = tf.placeholder(tf.float64, shape=[70])
        t_v70 = tf.Variable(initial_value=numpy.zeros([70]),
                            trainable=True,
                            dtype=tf.float64)
        op_assign_input = tf.assign(t_v70, t_input)
        d = tf_so8_sugra_potential(t_v70)
        t_potential = d['potential']
        t_stationarity = tf_so8_sugra_stationarity(d['a1'], d['a2'])
        opt = contrib_opt.ScipyOptimizerInterface(tf.asinh(t_stationarity),
                                                  options=dict(maxiter=500))
        with tf.Session() as sess:
            sess.run([tf.global_variables_initializer()])

            def scanner(seed, scale):
                rng = numpy.random.RandomState(seed)
                v70 = rng.normal(scale=scale, size=[70])
                sess.run([op_assign_input], feed_dict={t_input: v70})
                opt.minimize(session=sess)
                n_ret = sess.run([t_potential, t_stationarity, t_v70])
                return n_ret

            return f(scanner, *args)
def find_transforms():
    with tf.Graph().as_default():
        # Ensure reproducibility by seeding random number generators.
        tf.set_random_seed(0)
        transforms = tf.get_variable(
            'transforms',
            shape=(2, 8, 8),
            dtype=tf.float64,
            trainable=True,
            initializer=tf.random_normal_initializer())
        id8 = tf.constant(numpy.eye(8), dtype=tf.float64)
        gamma = tf.constant(get_gamma_vsc(), dtype=tf.float64)
        otable = tf.constant(get_octonion_mult_table(), dtype=tf.float64)
        # Transform gamma matrices step-by-step, since tf.einsum() does not
        # do SQL-like query planning optimization.
        rotated_gamma = tf.einsum(
            'vAb,bB->vAB', tf.einsum('vab,aA->vAb', gamma, transforms[0]),
            transforms[1])
        delta_mult = rotated_gamma - otable
        delta_ortho_s = tf.einsum('ab,cb->ac', transforms[0],
                                  transforms[0]) - id8
        delta_ortho_c = tf.einsum('ab,cb->ac', transforms[1],
                                  transforms[1]) - id8
        # This 'loss' function punishes deviations of the rotated gamma matrices
        # from the octonionic multiplication table, and also deviations of the
        # spinor and cospinor transformation matrices from orthogonality.
        loss = (tf.nn.l2_loss(delta_mult) + tf.nn.l2_loss(delta_ortho_s) +
                tf.nn.l2_loss(delta_ortho_c))
        opt = contrib_opt.ScipyOptimizerInterface(loss,
                                                  options=dict(maxiter=1000))
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            opt.minimize(session=sess)
            return sess.run([loss, transforms])
Пример #3
0
def get_scanner(output_path, maxiter=1000, stationarity_threshold=1e-7):
    """Obtains a basic TensorFlow-based scanner for extremal points."""
    graph = tf.Graph()
    with graph.as_default():
        tf_scalar_evaluator = get_tf_scalar_evaluator()
        t_input = tf.compat.v1.placeholder(tf.float64, shape=[70])
        t_v70 = tf.Variable(initial_value=numpy.zeros([70]),
                            trainable=True,
                            dtype=tf.float64)
        op_assign_input = tf.compat.v1.assign(t_v70, t_input)
        sinfo = tf_scalar_evaluator(tf.cast(t_v70, tf.complex128))
        t_potential = sinfo.potential
        #
        t_stationarity = sinfo.stationarity
        op_opt = contrib_opt.ScipyOptimizerInterface(
            tf.asinh(t_stationarity), options={'maxiter': maxiter})

        #
        def scanner(seed, scale=0.1, num_iterations=1):
            results = collections.defaultdict(list)
            rng = numpy.random.RandomState(seed)
            with graph.as_default():
                with tf.compat.v1.Session() as sess:
                    sess.run([tf.compat.v1.global_variables_initializer()])
                    for n in range(num_iterations):
                        v70 = rng.normal(scale=scale, size=[70])
                        sess.run([op_assign_input], feed_dict={t_input: v70})
                        op_opt.minimize(sess)
                        n_pot, n_stat, n_v70 = sess.run(
                            [t_potential, t_stationarity, t_v70])
                        if n_stat <= stationarity_threshold:
                            results[S_id(n_pot)].append(
                                (n, n_pot, n_stat, list(n_v70)))
                            # Overwrite output at every iteration.
                            if output_path is not None:
                                tmp_out = output_path + '.tmp'
                                with open(tmp_out, 'w') as h:
                                    h.write('n=%4d: p=%.12g s=%.12g\n' %
                                            (n, n_pot, n_stat))
                                    h.write(pprint.pformat(dict(results)))
                                os.rename(tmp_out, output_path)
            return dict(results)

        #
        return scanner
Пример #4
0
def _reduce_second_m35(m35s, m35c, is_diagonal_35s, seed=0):
    """Reduces the 2nd 35-irrep."""
    diag = numpy.diagonal(m35s if is_diagonal_35s else m35c)
    gens = _get_generators_for_reducing_second_m35(
        diag, 'gsS,sScC->gcC' if is_diagonal_35s else 'gcC,sScC->gsS',
        algebra.spin8.gamma_sscc)
    num_gens = len(gens)
    if num_gens == 0:
        return m35s, m35c  # No residual symmetry to exploit.
    # This residual symmetry is typically rather small.
    # So, doing a direct minimization is perhaps appropriate.
    rng = numpy.random.RandomState(seed=seed)
    v_coeffs_initial = rng.normal(
        scale=1e-3, size=(num_gens, ))  # Break symmetry with noise.
    graph = tf.Graph()
    with graph.as_default():
        tc_gens = tf.constant(gens, dtype=tf.float64)
        tc_m35 = tf.constant(m35c if is_diagonal_35s else m35s,
                             dtype=tf.float64)
        t_coeffs = tf.Variable(initial_value=v_coeffs_initial,
                               trainable=True,
                               dtype=tf.float64)
        t_rot = tf_cexpm.cexpm(tf.einsum('i,iab->ab', t_coeffs, tc_gens),
                               complex_arg=False)
        t_m35_rotated = tf.einsum('Ab,Bb->AB',
                                  tf.einsum('ab,Aa->Ab', tc_m35, t_rot), t_rot)
        # Our 'loss' is the sum of magnitudes of the off-diagonal parts after
        # rotation.
        t_loss = (tf.norm(t_m35_rotated, ord=1) -
                  tf.norm(tf.linalg.diag_part(t_m35_rotated), ord=1))
        optimizer = contrib_opt.ScipyOptimizerInterface(t_loss)
        with tf.compat.v1.Session() as sess:
            sess.run([tf.global_variables_initializer()])
            optimizer.minimize(sess)
            # We are only interested in the diagonalized matrix.
            m_diag = sess.run([t_m35_rotated])[0]
            return (m35s, m_diag) if is_diagonal_35s else (m_diag, m35c)
Пример #5
0
import tensorflow as tf
import tensorflow.contrib.opt as opt

X = tf.Variable([1.0, 2.0])
X0 = tf.Variable([3.0])

Y = tf.constant([2.0, -3.0])

scatter = tf.scatter_update(X, [0], X0)

with tf.control_dependencies([scatter]):
    loss = tf.reduce_sum(tf.squared_difference(X, Y))

opt = opt.ScipyOptimizerInterface(loss, [X0])

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    opt.minimize(sess)

    print("X: {}".format(X.eval()))
    print("X0: {}".format(X0.eval()))
Пример #6
0
# create sparsity regularizers
with tf.name_scope('regs'):
    mean_act1 = tf.reduce_mean(hidden1, 0)
    sparsity1 = 3.0 * tf.reduce_sum(kl_divergence(mean_act1, 0.01))
    decay1 = 1e-3 * tf.reduce_sum(tf.square(weights['w0']))
    decay2 = 1e-3 * tf.reduce_sum(tf.square(weights['w1']))

# create loss
with tf.name_scope('loss'):
    sse = tf.reduce_sum(tf.square(output - X))
    total_loss = sse + sparsity1 + decay1 + decay2

# create train ops
with tf.name_scope('train'):
    optimizer = opt.ScipyOptimizerInterface(total_loss,
                                            method='L-BFGS-B',
                                            options={'maxiter': 10000})

# create initializer
init = tf.global_variables_initializer()

print 'Running Optimization..'
with tf.Session() as sess:
    sess.run(init)
    optimizer.minimize(sess)
    # turn tensors into numpy arrays
    for k in weights:
        weights[k] = sess.run(weights[k])
    for k in biases:
        biases[k] = sess.run(biases[k])
"""
Пример #7
0
    def __init__(self,
                 input_size,
                 output_size,
                 max_length,
                 layers=3,
                 filter_size=11,
                 filter_depth=10,
                 crf_output_layer=False,
                 regularization_factor=0.001,
                 optimize_using_lbfgs=False,
                 lbfgs_maxiter=100):

        self.optimize_using_lbfgs = optimize_using_lbfgs
        self.crf_output_layer = crf_output_layer
        self.session = tf.InteractiveSession()

        self.x = tf.placeholder(tf.float32, [None, max_length, input_size])
        self.y = tf.placeholder(tf.float32, [None, max_length, output_size])
        self.y_argmax = tf.placeholder(tf.int32, [None, max_length])
        self.sequence_lengths = tf.placeholder(tf.int64, [None])

        # Convolution Layers
        self.Ws = []
        self.bs = []
        self.convs = []
        self.activations = []

        for i in range(0, layers):

            filter_shape = [filter_size, filter_depth, filter_depth]
            if i == 0:
                filter_shape[1] = input_size
            if i == layers - 1:
                filter_shape[2] = output_size
                if self.crf_output_layer:
                    filter_shape[0] = 1

            value = None
            if i == 0:
                value = self.x
            else:
                value = self.activations[i - 1]

            self.Ws.append(
                tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1),
                            name="W%d" % i))
            self.bs.append(
                tf.Variable(tf.truncated_normal(filter_shape[-1:], stddev=0.1),
                            name="b%d" % i))
            self.convs.append(
                tf.nn.bias_add(
                    tf.nn.conv1d(value,
                                 self.Ws[-1],
                                 stride=1,
                                 padding="SAME",
                                 name="conv%d" % i), self.bs[-1]))

            if i < (layers - 1):
                self.activations.append(tf.nn.relu(self.convs[-1]))
                # self.activations.append(tf.nn.tanh(self.convs[-1]))
                # self.activations.append(tf.nn.sigmoid(self.convs[-1]))
            else:
                if crf_output_layer:
                    self.activations.append(tf.nn.tanh(self.convs[-1]))
                    # self.activations.append(tf.nn.relu(self.convs[-1]))
                    # self.activations.append(tf.nn.softmax(self.convs[-1]))
                else:
                    # self.activations.append(self.convs[-1])
                    self.activations.append(tf.nn.softmax(self.convs[-1]))

        self.loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=self.convs[-1],
                                                    labels=self.y))

        # # In case loss is changed - we can still evaluate the loss for the nn part individually
        # self.loss_nn = tf.reduce_mean(
        #     tf.nn.softmax_cross_entropy_with_logits(logits=self.convs[-1], labels=self.y))

        if crf_output_layer:
            # self.weights_crf = tf.Variable(tf.truncated_normal([output_size, output_size], stddev=0.1), name="Ws_crf")
            self.weights_crf = tf.Variable(tf.eye(output_size), name="W_crf")
            # self.weights_crf = tf.constant(np.zeros([output_size, output_size]).astype(np.float32))
            log_likelihood, self.transition_params, self.seq_scores = crf.crf_log_likelihood(
                self.activations[-1], self.y_argmax, self.sequence_lengths,
                self.weights_crf)
            self.loss = tf.reduce_mean(-log_likelihood)

        # Add regularization (should be estimated using cross validation)
        # Note, regularization should not be applied on biases (but we have none here, so it's ok)
        # self.loss_nn += tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables()
        #                            if 'crf' not in v.name]) * regularization_factor
        self.loss += tf.add_n(
            [tf.nn.l2_loss(v)
             for v in tf.trainable_variables()]) * regularization_factor

        if self.optimize_using_lbfgs:

            from tensorflow.contrib import opt

            self.optimizer = opt.ScipyOptimizerInterface(
                self.loss,
                method='L-BFGS-B',
                options={'maxiter': lbfgs_maxiter})
        else:
            self.train_step = tf.train.AdamOptimizer(0.001,
                                                     beta1=0.9,
                                                     beta2=0.999,
                                                     epsilon=1e-08).minimize(
                                                         self.loss)
        # self.train_step = tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)

        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        self.saver = tf.train.Saver(max_to_keep=1)
def train(settings, warm_start_nn=None):
    tf.reset_default_graph()
    start = time.time()

    input_df, target_df = prep_dataset(settings)
    input_df, target_df, scale_factor, scale_bias = standardize(
        input_df, target_df, settings, warm_start_nn=warm_start_nn)

    # Standardize input
    timediff(start, 'Scaling defined')

    train_dims = target_df.columns
    scan_dims = input_df.columns

    datasets = convert_panda(input_df, target_df,
                             settings['validation_fraction'],
                             settings['test_fraction'])

    # Start tensorflow session
    config = tf.ConfigProto()
    #config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1, \
    #                    allow_soft_placement=True, device_count = {'CPU': 1})
    sess = tf.Session(config=config)

    # Input placeholders
    with tf.name_scope('input'):
        x = tf.placeholder(datasets.train._target.dtypes.iloc[0],
                           [None, len(scan_dims)],
                           name='x-input')
        y_ds = tf.placeholder(x.dtype, [None, len(train_dims)], name='y-input')

    net = QLKNet(x, len(train_dims), settings, warm_start_nn=warm_start_nn)
    y = net.y
    y_descale = (net.y - scale_bias[train_dims].values
                 ) / scale_factor[train_dims].values
    y_ds_descale = (
        y_ds - scale_bias[train_dims].values) / scale_factor[train_dims].values
    is_train = net.is_train

    timediff(start, 'NN defined')

    # Define loss functions
    with tf.name_scope('Loss'):
        with tf.name_scope('mse'):
            mse = tf.losses.mean_squared_error(y_ds, y)
            mse_descale = tf.losses.mean_squared_error(y_ds_descale, y_descale)
            tf.summary.scalar('MSE', mse)
        with tf.name_scope('mabse'):
            mabse = tf.losses.absolute_difference(y_ds, y)
            tf.summary.scalar('MABSE', mabse)
        with tf.name_scope('l2'):
            l2_scale = tf.Variable(settings['cost_l2_scale'],
                                   dtype=x.dtype,
                                   trainable=False)
            #l2_norm = tf.reduce_sum(tf.square())
            #l2_norm = tf.to_double(tf.add_n([tf.nn.l2_loss(var)
            #                        for var in tf.trainable_variables()
            #                        if 'weights' in var.name]))
            l2_norm = (tf.add_n([
                tf.nn.l2_loss(var) for var in tf.trainable_variables()
                if 'weights' in var.name
            ]))
            #mse = tf.losses.mean_squared_error(y_, y)
            # TODO: Check normalization
            l2_loss = l2_scale * l2_norm
            tf.summary.scalar('l2_norm', l2_norm)
            tf.summary.scalar('l2_scale', l2_scale)
            tf.summary.scalar('l2_loss', l2_loss)
        with tf.name_scope('l1'):
            l1_scale = tf.Variable(settings['cost_l1_scale'],
                                   dtype=x.dtype,
                                   trainable=False)
            #l1_norm = tf.to_double(tf.add_n([tf.reduce_sum(tf.abs(var))
            #                        for var in tf.trainable_variables()
            #                        if 'weights' in var.name]))
            l1_norm = (tf.add_n([
                tf.reduce_sum(tf.abs(var)) for var in tf.trainable_variables()
                if 'weights' in var.name
            ]))
            # TODO: Check normalization
            l1_loss = l1_scale * l1_norm
            tf.summary.scalar('l1_norm', l1_norm)
            tf.summary.scalar('l1_scale', l1_scale)
            tf.summary.scalar('l1_loss', l1_loss)

        if settings['goodness'] == 'mse':
            loss = mse
        elif settings['goodness'] == 'mabse':
            loss = mabse
        if settings['cost_l1_scale'] != 0:
            loss += l1_loss
        if settings['cost_l2_scale'] != 0:
            loss += l2_loss
        tf.summary.scalar('loss', loss)

    optimizer = None
    train_step = None
    # Define optimizer algorithm.
    with tf.name_scope('train'):
        lr = settings['learning_rate']
        if settings['optimizer'] == 'adam':
            beta1 = settings['adam_beta1']
            beta2 = settings['adam_beta2']
            train_step = tf.train.AdamOptimizer(
                lr,
                beta1,
                beta2,
            ).minimize(loss)
        elif settings['optimizer'] == 'adadelta':
            rho = settings['adadelta_rho']
            train_step = tf.train.AdadeltaOptimizer(
                lr,
                rho,
            ).minimize(loss)
        elif settings['optimizer'] == 'rmsprop':
            decay = settings['rmsprop_decay']
            momentum = settings['rmsprop_momentum']
            train_step = tf.train.RMSPropOptimizer(lr, decay,
                                                   momentum).minimize(loss)
        elif settings['optimizer'] == 'grad':
            train_step = tf.train.GradientDescentOptimizer(lr).minimize(loss)
        elif settings['optimizer'] == 'lbfgs':
            optimizer = opt.ScipyOptimizerInterface(
                loss,
                options={
                    'maxiter': settings['lbfgs_maxiter'],
                    'maxfun': settings['lbfgs_maxfun'],
                    'maxls': settings['lbfgs_maxls']
                })
        #tf.logging.set_verbosity(tf.logging.INFO)

    # Merge all the summaries
    merged = tf.summary.merge_all()

    # Initialze writers, variables and logdir
    log_dir = 'tf_logs'
    if tf.gfile.Exists(log_dir):
        tf.gfile.DeleteRecursively(log_dir)
    tf.gfile.MakeDirs(log_dir)
    train_writer = tf.summary.FileWriter(log_dir + '/train', sess.graph)
    validation_writer = tf.summary.FileWriter(log_dir + '/validation',
                                              sess.graph)
    tf.global_variables_initializer().run(session=sess)
    timediff(start, 'Variables initialized')

    epoch = 0

    train_log = pd.DataFrame(columns=[
        'epoch', 'walltime', 'loss', 'mse', 'mabse', 'l1_norm', 'l2_norm'
    ])
    validation_log = pd.DataFrame(columns=[
        'epoch', 'walltime', 'loss', 'mse', 'mabse', 'l1_norm', 'l2_norm'
    ])

    # Split dataset in minibatches
    minibatches = settings['minibatches']
    batch_size = int(np.floor(datasets.train.num_examples / minibatches))

    timediff(start, 'Starting loss calculation')
    xs, ys = datasets.validation.next_batch(-1, shuffle=False)
    feed_dict = {x: xs, y_ds: ys, is_train: False}
    summary, lo, meanse, meanabse, l1norm, l2norm = sess.run(
        [merged, loss, mse, mabse, l1_norm, l2_norm], feed_dict=feed_dict)
    train_log.loc[0] = (epoch, 0, lo, meanse, meanabse, l1norm, l2norm)
    validation_log.loc[0] = (epoch, 0, lo, meanse, meanabse, l1norm, l2norm)

    # Save checkpoints of training to restore for early-stopping
    saver = tf.train.Saver(max_to_keep=settings['early_stop_after'] + 1)
    checkpoint_dir = 'checkpoints'
    tf.gfile.MkDir(checkpoint_dir)

    # Define variables for early stopping
    not_improved = 0
    best_early_measure = np.inf
    early_measure = np.inf

    max_epoch = settings.get('max_epoch') or sys.maxsize

    # Set debugging parameters
    setting = lambda x, default: default if x is None else x
    steps_per_report = setting(settings.get('steps_per_report'), np.inf)
    epochs_per_report = setting(settings.get('epochs_per_report'), np.inf)
    save_checkpoint_networks = setting(
        settings.get('save_checkpoint_networks'), False)
    save_best_networks = setting(settings.get('save_best_networks'), False)
    track_training_time = setting(settings.get('track_training_time'), False)

    # Set up log files
    train_log_file = open('train_log.csv', 'a', 1)
    train_log_file.truncate(0)
    train_log.to_csv(train_log_file)
    validation_log_file = open('validation_log.csv', 'a', 1)
    validation_log_file.truncate(0)
    validation_log.to_csv(validation_log_file)

    timediff(start, 'Training started')
    train_start = time.time()
    ii = 0
    try:
        for epoch in range(max_epoch):
            for step in range(minibatches):
                # Extra debugging every steps_per_report
                if not step % steps_per_report and steps_per_report != np.inf:
                    print('debug!', epoch, step)
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()
                else:
                    run_options = None
                    run_metadata = None
                xs, ys = datasets.train.next_batch(batch_size, shuffle=True)
                feed_dict = {x: xs, y_ds: ys, is_train: True}
                # If we have a scipy-style optimizer
                if optimizer:
                    #optimizer.minimize(sess, feed_dict=feed_dict)
                    optimizer.minimize(
                        sess,
                        feed_dict=feed_dict,
                        #                   options=run_options,
                        #                   run_metadata=run_metadata)
                    )
                    lo = loss.eval(feed_dict=feed_dict)
                    meanse = mse.eval(feed_dict=feed_dict)
                    meanabse = mabse.eval(feed_dict=feed_dict)
                    l1norm = l1_norm.eval(feed_dict=feed_dict)
                    l2norm = l2_norm.eval(feed_dict=feed_dict)
                    summary = merged.eval(feed_dict=feed_dict)
                else:  # If we have a TensorFlow-style optimizer
                    summary, lo, meanse, meanabse, l1norm, l2norm, _ = sess.run(
                        [
                            merged, loss, mse, mabse, l1_norm, l2_norm,
                            train_step
                        ],
                        feed_dict=feed_dict,
                        options=run_options,
                        run_metadata=run_metadata)
                train_writer.add_summary(summary, ii)

                # Extra debugging every steps_per_report
                if not step % steps_per_report and steps_per_report != np.inf:
                    tl = timeline.Timeline(run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    with open('timeline_run.json', 'w') as f:
                        f.write(ctf)

                    train_writer.add_run_metadata(
                        run_metadata, 'epoch%d step%d' % (epoch, step))
                # Add to CSV log buffer
                if track_training_time is True:
                    train_log.loc[epoch * minibatches +
                                  step] = (epoch, time.time() - train_start,
                                           lo, meanse, meanabse, l1norm,
                                           l2norm)
            ########
            # After-epoch stuff
            ########

            if track_training_time is True:
                step_start = time.time()
            epoch = datasets.train.epochs_completed
            xs, ys = datasets.validation.next_batch(-1, shuffle=False)
            feed_dict = {x: xs, y_ds: ys, is_train: False}
            # Run with full trace every epochs_per_report Gives full runtime information
            if not epoch % epochs_per_report and epochs_per_report != np.inf:
                print('epoch_debug!', epoch)
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
            else:
                run_options = None
                run_metadata = None

            # Calculate all variables with the validation set
            summary, lo, meanse, meanabse, l1norm, l2norm = sess.run(
                [merged, loss, mse, mabse, l1_norm, l2_norm],
                feed_dict=feed_dict,
                options=run_options,
                run_metadata=run_metadata)

            validation_writer.add_summary(summary, ii)
            # More debugging every epochs_per_report
            if not epoch % epochs_per_report and epochs_per_report != np.inf:
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open('timeline.json', 'w') as f:
                    f.write(ctf)

                validation_writer.add_run_metadata(run_metadata,
                                                   'epoch%d' % epoch)

            # Save checkpoint
            save_path = saver.save(sess,
                                   os.path.join(checkpoint_dir, 'model.ckpt'),
                                   global_step=ii,
                                   write_meta_graph=False)

            # Update CSV logs
            if track_training_time is True:
                validation_log.loc[epoch] = (epoch, time.time() - train_start,
                                             lo, meanse, meanabse, l1norm,
                                             l2norm)

                validation_log.loc[epoch:].to_csv(validation_log_file,
                                                  header=False)
                validation_log = validation_log[0:0]  #Flush validation log
                train_log.loc[epoch * minibatches:].to_csv(train_log_file,
                                                           header=False)
                train_log = train_log[0:0]  #Flush train_log

            # Determine early-stopping criterion
            if settings['early_stop_measure'] == 'mse':
                early_measure = meanse
            elif settings['early_stop_measure'] == 'loss':
                early_measure = lo
            elif settings['early_stop_measure'] == 'none':
                early_measure = np.nan

            # Early stopping, check if measure is better
            if early_measure < best_early_measure:
                best_early_measure = early_measure
                if save_best_networks:
                    nn_best_file = os.path.join(
                        checkpoint_dir,
                        'nn_checkpoint_' + str(epoch) + '.json')
                    trainable = {
                        x.name: tf.to_double(x).eval(session=sess).tolist()
                        for x in tf.trainable_variables()
                    }
                    model_to_json(nn_best_file, trainable,
                                  scan_dims.values.tolist(),
                                  train_dims.values.tolist(), datasets.train,
                                  scale_factor.astype('float64'),
                                  scale_bias.astype('float64'), l2_scale,
                                  settings)
                not_improved = 0
            else:  # If early measure is not better
                not_improved += 1
            # If not improved in 'early_stop' epoch, stop
            if settings[
                    'early_stop_measure'] != 'none' and not_improved >= settings[
                        'early_stop_after']:
                if save_checkpoint_networks:
                    nn_checkpoint_file = os.path.join(
                        checkpoint_dir,
                        'nn_checkpoint_' + str(epoch) + '.json')
                    trainable = {
                        x.name: tf.to_double(x).eval(session=sess).tolist()
                        for x in tf.trainable_variables()
                    }
                    model_to_json(nn_checkpoint_file, trainable,
                                  scan_dims.values.tolist(),
                                  train_dims.values.tolist(), datasets.train,
                                  scale_factor.astype('float64'),
                                  scale_bias.astype('float64'), l2_scale,
                                  settings)

                print('Not improved for %s epochs, stopping..' %
                      (not_improved))
                break

            # Stop if loss is nan or inf
            if np.isnan(lo) or np.isinf(lo):
                print('Loss is {}! Stopping..'.format(lo))
                break

    # Stop on Ctrl-C
    except KeyboardInterrupt:
        print('KeyboardInterrupt Stopping..')

    train_writer.close()
    validation_writer.close()

    # Restore checkpoint with best epoch
    try:
        best_epoch = epoch - not_improved
        saver.restore(sess, saver.last_checkpoints[best_epoch - epoch])
    except IndexError:
        print("Can't restore old checkpoint, just saving current values")
        best_epoch = epoch

    validation_log.loc[epoch] = (epoch, time.time() - train_start, lo, meanse,
                                 meanabse, l1norm, l2norm)
    train_log.loc[epoch * minibatches +
                  step] = (epoch, time.time() - train_start, lo, meanse,
                           meanabse, l1norm, l2norm)
    validation_log.loc[epoch:].to_csv(validation_log_file, header=False)
    train_log.loc[epoch * minibatches:].to_csv(train_log_file, header=False)
    train_log_file.close()
    del train_log
    validation_log_file.close()
    del validation_log

    trainable = {
        x.name: tf.to_double(x).eval(session=sess).tolist()
        for x in tf.trainable_variables()
    }
    model_to_json('nn.json', trainable, scan_dims.values.tolist(),
                  train_dims.values.tolist(), datasets.train, scale_factor,
                  scale_bias.astype('float64'), l2_scale, settings)

    print("Best epoch was {:d} with measure '{:s}' of {:f} ".format(
        best_epoch, settings['early_stop_measure'], best_early_measure))
    print("Training time was {:.0f} seconds".format(time.time() - train_start))

    # Finally, check against validation set
    xs, ys = datasets.validation.next_batch(-1, shuffle=False)
    feed_dict = {x: xs, y_ds: ys, is_train: False}
    rms_val = np.round(np.sqrt(mse.eval(feed_dict, session=sess)), 4)
    rms_val_descale = np.round(
        np.sqrt(mse_descale.eval(feed_dict, session=sess)), 4)
    loss_val = np.round(loss.eval(feed_dict, session=sess), 4)
    print('{:22} {:5.2f}'.format('Validation RMS error: ', rms_val))
    print('{:22} {:5.2f}'.format('Descaled validation RMS error: ',
                                 rms_val_descale))
    print('{:22} {:5.2f}'.format('Validation loss: ', loss_val))

    metadata = {
        'epoch': epoch,
        'best_epoch': best_epoch,
        'rms_validation': float(rms_val),
        'loss_validation': float(loss_val),
        'rms_validation_descaled': float(rms_val_descale),
    }

    # Add metadata dict to nn.json
    with open('nn.json') as nn_file:
        data = json.load(nn_file)

    data['_metadata'] = metadata

    with open('nn.json', 'w') as nn_file:
        json.dump(data,
                  nn_file,
                  sort_keys=True,
                  indent=4,
                  separators=(',', ': '))
    sess.close()