Ejemplo n.º 1
0
Archivo: train.py Proyecto: eiroW/FDM
def train_rule_only(
    model_dir,
    rule_trains,
    max_steps,
    hp=None,
    ruleset='all',
    seed=0,
):
    '''Customized training function.

    The network sequentially but only train rule for the second set.
    First train the network to perform tasks in group 1, then train on group 2.
    When training group 2, only rule connections are being trained.

    Args:
        model_dir: str, training directory
        rule_trains: a list of list of tasks to train sequentially
        hp: dictionary of hyperparameters
        max_steps: int, maximum number of training steps for each list of tasks
        display_step: int, display steps
        ruleset: the set of rules to train
        seed: int, random seed to be used

    Returns:
        model is stored at model_dir/model.ckpt
        training configuration is stored at model_dir/hp.json
    '''

    tools.mkdir_p(model_dir)

    # Network parameters
    default_hp = get_default_hp(ruleset)
    if hp is not None:
        default_hp.update(hp)
    hp = default_hp
    hp['seed'] = seed
    hp['rng'] = np.random.RandomState(seed)
    hp['rule_trains'] = rule_trains
    # Get all rules by flattening the list of lists
    hp['rules'] = [r for rs in rule_trains for r in rs]

    # Number of training iterations for each rule
    if hasattr(max_steps, '__iter__'):
        rule_train_iters = max_steps
    else:
        rule_train_iters = [len(r) * max_steps for r in rule_trains]

    tools.save_hp(hp, model_dir)
    # Display hp
    for key, val in hp.items():
        print('{:20s} = '.format(key) + str(val))

    # Build the model
    model = Model(model_dir, hp=hp)

    # Store results
    log = defaultdict(list)
    log['model_dir'] = model_dir

    # Record time
    t_start = time.time()

    # Use customized session that launches the graph as well
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # penalty on deviation from initial weight
        if hp['l2_weight_init'] > 0:
            raise NotImplementedError()

        # Looping
        step_total = 0
        for i_rule_train, rule_train in enumerate(hp['rule_trains']):
            step = 0

            if i_rule_train == 0:
                display_step = 200
            else:
                display_step = 50

            if i_rule_train > 0:
                # var_list = [v for v in model.var_list
                #             if ('input' in v.name) and ('rnn' not in v.name)]
                var_list = [
                    v for v in model.var_list if 'rule_input' in v.name
                ]
                model.set_optimizer(var_list=var_list)

            # Keep training until reach max iterations
            while (step * hp['batch_size_train'] <=
                   rule_train_iters[i_rule_train]):
                # Validation
                if step % display_step == 0:
                    trial = step_total * hp['batch_size_train']
                    log['trials'].append(trial)
                    log['times'].append(time.time() - t_start)
                    log['rule_now'].append(rule_train)
                    log = do_eval(sess, model, log, rule_train)
                    if log['perf_avg'][-1] > model.hp['target_perf']:
                        print('Perf reached the target: {:0.2f}'.format(
                            hp['target_perf']))
                        break

                # Training
                rule_train_now = hp['rng'].choice(rule_train)
                # Generate a random batch of trials.
                # Each batch has the same trial length
                trial = generate_trials(rule_train_now,
                                        hp,
                                        'random',
                                        batch_size=hp['batch_size_train'])

                # Generating feed_dict.
                feed_dict = tools.gen_feed_dict(model, trial, hp)

                # This will compute the gradient BEFORE train step
                _ = sess.run(model.train_step, feed_dict=feed_dict)

                step += 1
                step_total += 1

        print("Optimization Finished!")
Ejemplo n.º 2
0
Archivo: train.py Proyecto: eiroW/FDM
def train_sequential(
    model_dir,
    rule_trains,
    hp=None,
    max_steps=1e7,
    display_step=500,
    ruleset='mante',
    seed=0,
):
    '''Train the network sequentially.

    Args:
        model_dir: str, training directory
        rule_trains: a list of list of tasks to train sequentially
        hp: dictionary of hyperparameters
        max_steps: int, maximum number of training steps for each list of tasks
        display_step: int, display steps
        ruleset: the set of rules to train
        seed: int, random seed to be used

    Returns:
        model is stored at model_dir/model.ckpt
        training configuration is stored at model_dir/hp.json
    '''

    tools.mkdir_p(model_dir)

    # Network parameters
    default_hp = get_default_hp(ruleset)
    if hp is not None:
        default_hp.update(hp)
    hp = default_hp
    hp['seed'] = seed
    hp['rng'] = np.random.RandomState(seed)
    hp['rule_trains'] = rule_trains
    # Get all rules by flattening the list of lists
    hp['rules'] = [r for rs in rule_trains for r in rs]

    # Number of training iterations for each rule
    rule_train_iters = [len(r) * max_steps for r in rule_trains]

    tools.save_hp(hp, model_dir)
    # Display hp
    for key, val in hp.items():
        print('{:20s} = '.format(key) + str(val))

    # Using continual learning or not
    c, ksi = hp['c_intsyn'], hp['ksi_intsyn']

    # Build the model
    model = Model(model_dir, hp=hp)

    grad_unreg = tf.gradients(model.cost_lsq, model.var_list)

    # Store results
    log = defaultdict(list)
    log['model_dir'] = model_dir

    # Record time
    t_start = time.time()

    # tensorboard summaries
    placeholders = list()
    for v_name in ['Omega0', 'omega0', 'vdelta']:
        for v in model.var_list:
            placeholder = tf.placeholder(tf.float32, shape=v.shape)
            tf.summary.histogram(v_name + '/' + v.name, placeholder)
            placeholders.append(placeholder)
    merged = tf.summary.merge_all()
    test_writer = tf.summary.FileWriter(model_dir + '/tb')

    def relu(x):
        return x * (x > 0.)

    # Use customized session that launches the graph as well
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # penalty on deviation from initial weight
        if hp['l2_weight_init'] > 0:
            raise NotImplementedError()

        # Looping
        step_total = 0
        for i_rule_train, rule_train in enumerate(hp['rule_trains']):
            step = 0

            # At the beginning of new tasks
            # Only if using intelligent synapses
            v_current = sess.run(model.var_list)

            if i_rule_train == 0:
                v_anc0 = v_current
                Omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0]
                omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0]
                v_delta = [np.zeros(v.shape, dtype='float32') for v in v_anc0]
            elif c > 0:
                v_anc0_prev = v_anc0
                v_anc0 = v_current
                v_delta = [
                    v - v_prev for v, v_prev in zip(v_anc0, v_anc0_prev)
                ]

                # Make sure all elements in omega0 are non-negative
                # Penalty
                Omega0 = [
                    relu(O + o / (v_d**2 + ksi))
                    for O, o, v_d in zip(Omega0, omega0, v_delta)
                ]

                # Update cost
                model.cost_reg = tf.constant(0.)
                for v, w, v_val in zip(model.var_list, Omega0, v_current):
                    model.cost_reg += c * tf.reduce_sum(
                        tf.multiply(tf.constant(w),
                                    tf.square(v - tf.constant(v_val))))
                model.set_optimizer()

            # Store Omega0 to tf summary
            feed_dict = dict(zip(placeholders, Omega0 + omega0 + v_delta))
            summary = sess.run(merged, feed_dict=feed_dict)
            test_writer.add_summary(summary, i_rule_train)

            # Reset
            omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0]

            # Keep training until reach max iterations
            while (step * hp['batch_size_train'] <=
                   rule_train_iters[i_rule_train]):
                # Validation
                if step % display_step == 0:
                    trial = step_total * hp['batch_size_train']
                    log['trials'].append(trial)
                    log['times'].append(time.time() - t_start)
                    log['rule_now'].append(rule_train)
                    log = do_eval(sess, model, log, rule_train)
                    if log['perf_avg'][-1] > model.hp['target_perf']:
                        print('Perf reached the target: {:0.2f}'.format(
                            hp['target_perf']))
                        break

                # Training
                rule_train_now = hp['rng'].choice(rule_train)
                # Generate a random batch of trials.
                # Each batch has the same trial length
                trial = generate_trials(rule_train_now,
                                        hp,
                                        'random',
                                        batch_size=hp['batch_size_train'])

                # Generating feed_dict.
                feed_dict = tools.gen_feed_dict(model, trial, hp)

                # Continual learning with intelligent synapses
                v_prev = v_current

                # This will compute the gradient BEFORE train step
                _, v_grad = sess.run([model.train_step, grad_unreg],
                                     feed_dict=feed_dict)
                # Get the weight after train step
                v_current = sess.run(model.var_list)

                # Update synaptic importance
                omega0 = [
                    o - (v_c - v_p) * v_g for o, v_c, v_p, v_g in zip(
                        omega0, v_current, v_prev, v_grad)
                ]

                step += 1
                step_total += 1

        print("Optimization Finished!")
Ejemplo n.º 3
0
def train_sequential_orthogonalized(model_dir,
                                    rule_trains,
                                    hp=None,
                                    max_steps=1e7,
                                    display_step=500,
                                    rich_output=False,
                                    ruleset='mante',
                                    applyProj='both',
                                    seed=0,
                                    nEpisodeBatches=100,
                                    projGrad=True,
                                    alpha=0.001,
                                    fixReadout=False):
    '''Train the network sequentially.

    Args:
        model_dir: str, training directory
        rule_trains: a list of list of tasks to train sequentially
        hp: dictionary of hyperparameters
        max_steps: int, maximum number of training steps for each list of tasks
        display_step: int, display steps
        ruleset: the set of rules to train
        seed: int, random seed to be used

    Returns:
        model is stored at model_dir/model.ckpt
        training configuration is stored at model_dir/hp.json
    '''

    tools.mkdir_p(model_dir)

    # Network parameters
    default_hp = get_default_hp(ruleset)
    if hp is not None:
        default_hp.update(hp)
    hp = default_hp
    hp['seed'] = seed
    hp['rng'] = np.random.RandomState(seed)
    hp['rule_trains'] = rule_trains
    # Get all rules by flattening the list of lists
    # hp['rules'] = [r for rs in rule_trains for r in rs]
    hp['rules'] = rule_trains

    # save some other parameters
    hp['alpha_projection'] = alpha
    hp['max_steps'] = max_steps

    # Number of training iterations for each rule
    rule_train_iters = [max_steps for _ in rule_trains]

    tools.save_hp(hp, model_dir)
    # Display hp
    for key, val in hp.items():
        print('{:20s} = '.format(key) + str(val))

    # Build the model
    model = Sequential_Model(model_dir,
                             projGrad=projGrad,
                             applyProj=applyProj,
                             hp=hp)

    # Store results
    log = defaultdict(list)
    log['model_dir'] = model_dir

    # Record time
    t_start = time.time()

    def relu(x):
        return x * (x > 0.)

    # -------------------------------------------------------

    # Use customized session that launches the graph as well
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # penalty on deviation from initial weight
        if hp['l2_weight_init'] > 0:
            raise NotImplementedError()

        # Looping
        step_total = 0
        taskNumber = 0

        if fixReadout is True:
            my_var_list = [
                var for var in model.var_list
                if 'rnn/leaky_rnn_cell/kernel:0' in var.name
            ]
        else:
            my_var_list = [
                var for var in model.var_list
                if 'rnn/leaky_rnn_cell/kernel:0' in var.name
                or 'output/weights:0' in var.name
            ]

        # initialise projection matrices
        input_proj = tf.zeros(
            (hp['n_rnn'] + hp['n_input'], hp['n_rnn'] + hp['n_input']))
        activity_proj = tf.zeros((hp['n_rnn'], hp['n_rnn']))
        output_proj = tf.zeros((hp['n_output'], hp['n_output']))
        recurrent_proj = tf.zeros((hp['n_rnn'], hp['n_rnn']))

        for i_rule_train, rule_train in enumerate(hp['rule_trains']):

            step = 0

            model.set_optimizer(activity_proj=activity_proj,
                                input_proj=input_proj,
                                output_proj=output_proj,
                                recurrent_proj=recurrent_proj,
                                taskNumber=taskNumber,
                                var_list=my_var_list,
                                alpha=alpha)

            # Keep training until reach max iterations
            while (step * hp['batch_size_train'] <=
                   rule_train_iters[i_rule_train]):
                # Validation
                if step % display_step == 0:
                    trial = step_total * hp['batch_size_train']
                    log['trials'].append(trial)
                    log['times'].append(time.time() - t_start)
                    log['rule_now'].append(rule_train)
                    log = do_eval(sess, model, log, rule_train)
                    if log['perf_avg'][-1] > model.hp['target_perf']:
                        print('Perf reached the target: {:0.2f}'.format(
                            hp['target_perf']))
                        break

                # Training
                # rule_train_now = hp['rng'].choice(rule_train)

                # Generate a random batch of trials.
                # Each batch has the same trial length
                trial = generate_trials(rule_train,
                                        hp,
                                        'random',
                                        batch_size=hp['batch_size_train'],
                                        delay_fac=hp['delay_fac'])

                # Generating feed_dict.
                feed_dict = tools.gen_feed_dict(model, trial, hp)

                # update model
                sess.run(model.train_step, feed_dict=feed_dict)

                # # Get the weight after train step
                # v_current = sess.run(model.var_list)

                step += 1
                step_total += 1

                if step % display_step == 0:
                    model.save_ckpt(step_total)

            # ---------- save model after its completed training the current task ----------
            model.save_after_task(taskNumber)

            # ---------- generate task activity for continual learning -------
            trial = generate_trials(rule_train,
                                    hp,
                                    'random',
                                    batch_size=hp['batch_size_test'],
                                    delay_fac=hp['delay_fac'])

            # Generating feed_dict.
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            eval_h, eval_x, eval_y, Wrec, Win = sess.run(
                [model.h, model.x, model.y, model.w_rec, model.w_in],
                feed_dict=feed_dict)
            full_state = np.concatenate([eval_x, eval_h], -1)

            # get weight matrix after current task
            Wfull = np.concatenate([Win, Wrec], 0)

            # joint covariance matrix of input and activity
            Shx_task = compute_covariance(
                np.reshape(full_state, (-1, hp['n_rnn'] + hp['n_input'])).T)

            # covariance matrix of output
            Sy_task = compute_covariance(
                np.reshape(eval_y, (-1, hp['n_output'])).T)

            # get block matrices from Shx_task
            # Sh_task = Shx_task[-hp['n_rnn']:, -hp['n_rnn']:]
            Sh_task = np.matmul(np.matmul(Wfull.T, Shx_task), Wfull)

            # ---------- update stored covariance matrices for continual learning -------
            if taskNumber == 0:
                input_cov = Shx_task
                activity_cov = Sh_task
                output_cov = Sy_task
            else:
                input_cov = taskNumber / (
                    taskNumber + 1) * input_cov + Shx_task / (taskNumber + 1)
                activity_cov = taskNumber / (
                    taskNumber + 1) * activity_cov + Sh_task / (taskNumber + 1)
                output_cov = taskNumber / (
                    taskNumber + 1) * output_cov + Sy_task / (taskNumber + 1)

            # ---------- update projection matrices for continual learning ----------
            activity_proj, input_proj, output_proj, recurrent_proj = compute_projection_matrices(
                activity_cov, input_cov, output_cov,
                input_cov[-hp['n_rnn']:, -hp['n_rnn']:], alpha)

            # update task number
            taskNumber += 1

        print("Optimization Finished!")
Ejemplo n.º 4
0
Archivo: train.py Proyecto: eiroW/FDM
def train(
    model_dir,
    hp=None,
    max_steps=1e7,
    display_step=500,
    ruleset='mante',
    rule_trains=None,
    rule_prob_map=None,
    seed=0,
    rich_output=False,
    load_dir=None,
    trainables=None,
):
    """Train the network.

    Args:
        model_dir: str, training directory
        hp: dictionary of hyperparameters
        max_steps: int, maximum number of training steps
        display_step: int, display steps
        ruleset: the set of rules to train
        rule_trains: list of rules to train, if None then all rules possible
        rule_prob_map: None or dictionary of relative rule probability
        seed: int, random seed to be used

    Returns:
        model is stored at model_dir/model.ckpt
        training configuration is stored at model_dir/hp.json
    """

    tools.mkdir_p(model_dir)

    # Network parameters
    default_hp = get_default_hp(ruleset)
    if hp is not None:
        default_hp.update(hp)
    hp = default_hp
    hp['seed'] = seed
    hp['rng'] = np.random.RandomState(seed)

    # Rules to train and test. Rules in a set are trained together
    if rule_trains is None:
        # By default, training all rules available to this ruleset
        hp['rule_trains'] = task.rules_dict[ruleset]
    else:
        hp['rule_trains'] = rule_trains
    hp['rules'] = hp['rule_trains']

    # Assign probabilities for rule_trains.
    if rule_prob_map is None:
        rule_prob_map = dict()

    # Turn into rule_trains format
    hp['rule_probs'] = None
    if hasattr(hp['rule_trains'], '__iter__'):
        # Set default as 1.
        rule_prob = np.array(
            [rule_prob_map.get(r, 1.) for r in hp['rule_trains']])
        hp['rule_probs'] = list(rule_prob / np.sum(rule_prob))
    tools.save_hp(hp, model_dir)

    # Build the model
    model = Model(model_dir, hp=hp)

    # Display hp
    for key, val in hp.items():
        print('{:20s} = '.format(key) + str(val))

    # Store results
    log = defaultdict(list)
    log['model_dir'] = model_dir

    # Record time
    t_start = time.time()

    with tf.Session() as sess:
        if load_dir is not None:
            model.restore(load_dir)  # complete restore
        else:
            # Assume everything is restored
            sess.run(tf.global_variables_initializer())

        # Set trainable parameters
        if trainables is None or trainables == 'all':
            var_list = model.var_list  # train everything
        elif trainables == 'input':
            # train all nputs
            var_list = [
                v for v in model.var_list
                if ('input' in v.name) and ('rnn' not in v.name)
            ]
        elif trainables == 'rule':
            # train rule inputs only
            var_list = [v for v in model.var_list if 'rule_input' in v.name]
        else:
            raise ValueError('Unknown trainables')
        model.set_optimizer(var_list=var_list)

        # penalty on deviation from initial weight
        if hp['l2_weight_init'] > 0:
            anchor_ws = sess.run(model.weight_list)
            for w, w_val in zip(model.weight_list, anchor_ws):
                model.cost_reg += (hp['l2_weight_init'] *
                                   tf.nn.l2_loss(w - w_val))

            model.set_optimizer(var_list=var_list)

        # partial weight training
        if ('p_weight_train' in hp and (hp['p_weight_train'] is not None)
                and hp['p_weight_train'] < 1.0):
            for w in model.weight_list:
                w_val = sess.run(w)
                w_size = sess.run(tf.size(w))
                w_mask_tmp = np.linspace(0, 1, w_size)
                hp['rng'].shuffle(w_mask_tmp)
                ind_fix = w_mask_tmp > hp['p_weight_train']
                w_mask = np.zeros(w_size, dtype=np.float32)
                w_mask[ind_fix] = 1e-1  # will be squared in l2_loss
                w_mask = tf.constant(w_mask)
                w_mask = tf.reshape(w_mask, w.shape)
                model.cost_reg += tf.nn.l2_loss((w - w_val) * w_mask)
            model.set_optimizer(var_list=var_list)

        step = 0
        while step * hp['batch_size_train'] <= max_steps:
            try:
                # Validation
                if step % display_step == 0:
                    log['trials'].append(step * hp['batch_size_train'])
                    log['times'].append(time.time() - t_start)
                    log = do_eval(sess, model, log, hp['rule_trains'])
                    #if log['perf_avg'][-1] > model.hp['target_perf']:
                    #check if minimum performance is above target
                    if log['perf_min'][-1] > model.hp['target_perf']:
                        print('Perf reached the target: {:0.2f}'.format(
                            hp['target_perf']))
                        break

                    if rich_output:
                        display_rich_output(model, sess, step, log, model_dir)

                # Training
                rule_train_now = hp['rng'].choice(hp['rule_trains'],
                                                  p=hp['rule_probs'])
                # Generate a random batch of trials.
                # Each batch has the same trial length
                trial = generate_trials(rule_train_now,
                                        hp,
                                        'random',
                                        batch_size=hp['batch_size_train'])

                # Generating feed_dict.
                feed_dict = tools.gen_feed_dict(model, trial, hp)
                sess.run(model.train_step, feed_dict=feed_dict)

                step += 1

            except KeyboardInterrupt:
                print("Optimization interrupted by user")
                break

        print("Optimization finished!")
Ejemplo n.º 5
0
def train(
    model_dir,
    hp=None,
    max_steps=1e7,
    display_step=500,
    ruleset='mante',
    rule_trains=None,
    rule_prob_map=None,
    seed=0,
    rich_output=True,
    load_dir=None,
    trainables=None,
    fixReadoutandBias=False,
    fixBias=False,
):
    """Train the network.

    Args:
        model_dir: str, training directory
        hp: dictionary of hyperparameters
        max_steps: int, maximum number of training steps
        display_step: int, display steps
        ruleset: the set of rules to train
        rule_trains: list of rules to train, if None then all rules possible
        rule_prob_map: None or dictionary of relative rule probability
        seed: int, random seed to be used

    Returns:
        model is stored at model_dir/model.ckpt
        training configuration is stored at model_dir/hp.json
    """

    tools.mkdir_p(model_dir)

    # Network parameters
    default_hp = get_default_hp(ruleset)
    if hp is not None:
        default_hp.update(hp)
    hp = default_hp
    hp['seed'] = seed
    hp['rng'] = np.random.RandomState(seed)

    # Rules to train and test. Rules in a set are trained together
    if rule_trains is None:
        # By default, training all rules available to this ruleset
        hp['rule_trains'] = task.rules_dict[ruleset]
    else:
        hp['rule_trains'] = rule_trains
    hp['rules'] = hp['rule_trains']

    # Assign probabilities for rule_trains.
    if rule_prob_map is None:
        rule_prob_map = dict()

    # Turn into rule_trains format
    hp['rule_probs'] = None
    if hasattr(hp['rule_trains'], '__iter__'):
        # Set default as 1.

        rule_prob = np.array(
            [rule_prob_map.get(r, 1.) for r in hp['rule_trains']])
        hp['rule_probs'] = list(rule_prob / np.sum(rule_prob))

    tools.save_hp(hp, model_dir)

    # Build the model
    with tf.device('gpu:0'):
        model = Model(model_dir, hp=hp)

        # Display hp
        for key, val in hp.items():
            print('{:20s} = '.format(key) + str(val))

        if fixReadoutandBias is True:
            my_var_list = [
                var for var in model.var_list
                if 'rnn/leaky_rnn_cell/kernel:0' in var.name
            ]
            print(my_var_list)
        elif fixBias is True:
            my_var_list = [
                var for var in model.var_list
                if 'rnn/leaky_rnn_cell/kernel:0' in var.name
                or 'output/weights:0' in var.name
            ]
        else:
            my_var_list = model.var_list

        model.set_optimizer(var_list=my_var_list)

        # Store results
        log = defaultdict(list)
        log['model_dir'] = model_dir

        # Record time
        t_start = time.time()

        # Use customized session that launches the graph as well
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            # penalty on deviation from initial weight
            if hp['l2_weight_init'] > 0:
                anchor_ws = sess.run(model.weight_list)
                for w, w_val in zip(model.weight_list, anchor_ws):
                    model.cost_reg += (hp['l2_weight_init'] *
                                       tf.nn.l2_loss(w - w_val))

                model.set_optimizer(var_list=my_var_list)

            # partial weight training
            if ('p_weight_train' in hp and (hp['p_weight_train'] is not None)
                    and hp['p_weight_train'] < 1.0):
                for w in model.weight_list:
                    w_val = sess.run(w)
                    w_size = sess.run(tf.size(w))
                    w_mask_tmp = np.linspace(0, 1, w_size)
                    hp['rng'].shuffle(w_mask_tmp)
                    ind_fix = w_mask_tmp > hp['p_weight_train']
                    w_mask = np.zeros(w_size, dtype=np.float32)
                    w_mask[ind_fix] = 1e-1  # will be squared in l2_loss
                    w_mask = tf.constant(w_mask)
                    w_mask = tf.reshape(w_mask, w.shape)
                    model.cost_reg += tf.nn.l2_loss((w - w_val) * w_mask)
                model.set_optimizer(var_list=my_var_list)

            step = 0
            run_ave_time = []
            while step * hp['batch_size_train'] <= max_steps:
                try:
                    # Validation
                    if step % display_step == 0:
                        grad_norm = tf.global_norm(model.clipped_gs)
                        grad_norm_np = sess.run(grad_norm)
                        # import pdb
                        # pdb.set_trace()
                        log['grad_norm'].append(grad_norm_np.item())
                        log['trials'].append(step * hp['batch_size_train'])
                        log['times'].append(time.time() - t_start)
                        log = do_eval(sess, model, log, hp['rule_trains'])
                        # if log['perf_avg'][-1] > model.hp['target_perf']:
                        # check if minimum performance is above target
                        if log['perf_min'][-1] > model.hp['target_perf']:
                            print('Perf reached the target: {:0.2f}'.format(
                                hp['target_perf']))
                            break

                        if rich_output:
                            display_rich_output(model, sess, step, log,
                                                model_dir)

                    # Training

                    dtStart = datetime.now()
                    sess.run(model.train_step)
                    dtEnd = datetime.now()

                    if len(run_ave_time) is 0:
                        run_ave_time = np.expand_dims(
                            (dtEnd - dtStart).total_seconds(), axis=0)
                    else:
                        run_ave_time = np.concatenate(
                            (run_ave_time,
                             np.expand_dims((dtEnd - dtStart).total_seconds(),
                                            axis=0)))

                    # print(np.mean(run_ave_time))
                    # print((dtEnd-dtStart).total_seconds())

                    step += 1

                    if step < 10:
                        model.save_ckpt(step)

                    if step < 1000:
                        if step % display_step / 10 == 0:
                            model.save_ckpt(step)

                    if step % display_step == 0:
                        model.save_ckpt(step)

                except KeyboardInterrupt:
                    print("Optimization interrupted by user")
                    break

            print("Optimization finished!")