コード例 #1
0
def main(_):
    # Configuration
    # FLAGS.train = False
    # FLAGS.smart_train = True

    FLAGS.overwrite = True
    FLAGS.summary = True
    FLAGS.save_model = False
    FLAGS.snapshot = False

    MEMORY_DEPTH = 1
    EPOCH = 2

    # Start
    console.start('rnn_task')

    # Initiate model
    model = rnn_models.vanilla_RNN('rnn00')

    # Load data
    train_set, val_set, test_set = load_wiener_hammerstein(
        r'../data/wiener_hammerstein/whb.tfd', depth=MEMORY_DEPTH)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    # Train or evaluate
    if FLAGS.train:
        pass
    else:
        console.show_status('Evaluating ...')

    # End
    console.end()
コード例 #2
0
def main(_):
  console.start('RNN task')

  # Configurations
  th = NlsHub(as_global=True)
  th.memory_depth = 10
  th.num_blocks = 1
  th.multiplier = 8
  th.hidden_dim = th.memory_depth * th.multiplier
  th.num_steps = 32

  th.epoch = 100000
  th.batch_size = 32
  th.learning_rate = 1e-4
  th.validation_per_round = 20
  th.print_cycle = 0

  # th.train = False
  th.smart_train = True
  th.max_bad_apples = 4
  th.lr_decay = 0.6

  th.early_stop = True
  th.idle_tol = 20
  th.save_mode = SaveMode.ON_RECORD
  th.warm_up_thres = 1
  th.at_most_save_once_per_round = True

  th.overwrite = True                        # Default: False
  th.export_note = True
  th.summary = True
  th.monitor_preact = False
  th.save_model = True

  th.allow_growth = False
  th.gpu_memory_fraction = 0.4

  description = '0'
  th.mark = 'rnn-{}x({}x{})-{}steps-{}'.format(
    th.num_blocks, th.memory_depth, th.multiplier, th.num_steps, description)
  # Get model
  model = model_lib.rnn0(th)
  # Load data
  train_set, val_set, test_set = load_wiener_hammerstein(
    th.data_dir, depth=th.memory_depth, validation_size=2000)
  assert isinstance(train_set, DataSet)
  assert isinstance(val_set, DataSet)
  assert isinstance(test_set, DataSet)

  # Train or evaluate
  if th.train:
    model.nn.train(train_set, validation_set=val_set, trainer_hub=th)
  else:
    console.show_status('Evaluating ...')
    model.evaluate(train_set, start_at=th.memory_depth)
    model.evaluate(val_set, start_at=th.memory_depth)
    model.evaluate(test_set, start_at=th.memory_depth)

  # End
  console.end()
コード例 #3
0
def main(_):
    console.start('trainer.task')

    EPOCH = 1000
    # FLAGS.train = False
    FLAGS.overwrite = True
    # FLAGS.save_best = True
    FLAGS.smart_train = True

    # Hyper parameters
    LEARNING_RATE = 0.001
    LAYER_NUM = 4
    BATCH_SIZE = 32
    MEMORY_DEPTH = 80
    LAYER_DIM = MEMORY_DEPTH * 2
    ACTIVATION = 'relu'

    # Set default flags
    FLAGS.progress_bar = True

    FLAGS.save_model = True
    FLAGS.summary = False
    FLAGS.snapshot = False

    PRINT_CYCLE = 100

    WH_PATH = os.path.join(nls_root, 'data/wiener_hammerstein/whb.tfd')
    MARK = 'mlp00'

    # Get model
    model = model_lib.mlp_00(MARK,
                             MEMORY_DEPTH,
                             LAYER_DIM,
                             LAYER_NUM,
                             LEARNING_RATE,
                             activation=ACTIVATION)

    # Load data set
    train_set, val_set, test_set = load_wiener_hammerstein(WH_PATH,
                                                           depth=MEMORY_DEPTH)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    # Train or evaluate
    if FLAGS.train:
        model.identify(train_set,
                       val_set,
                       batch_size=BATCH_SIZE,
                       print_cycle=PRINT_CYCLE,
                       epoch=EPOCH)
    else:
        model.evaluate(train_set, start_at=MEMORY_DEPTH, plot=False)
        model.evaluate(val_set, start_at=MEMORY_DEPTH, plot=False)
        model.evaluate(test_set, start_at=MEMORY_DEPTH, plot=False)

    console.end()
コード例 #4
0
def main(_):
    console.start('BResNet task')

    description = '0'
    # Configurations
    th = NlsHub(as_global=True)
    th.memory_depth = 80
    th.num_blocks = 3
    th.multiplier = 1
    th.hidden_dim = th.memory_depth * th.multiplier

    th.mark = 'bres-{}x({}x{})-{}'.format(th.num_blocks, th.memory_depth,
                                          th.multiplier, description)
    th.epoch = 50000
    th.batch_size = 64
    th.learning_rate = 0.0001
    th.start_at = 0
    th.reg_strength = 0.000
    th.validation_per_round = 30

    th.train = True
    th.smart_train = True
    th.idle_tol = 30
    th.max_bad_apples = 5
    th.lr_decay = 0.6
    th.early_stop = True
    th.save_mode = SaveMode.ON_RECORD
    th.warm_up_rounds = 50
    th.overwrite = True
    th.export_note = True
    th.summary = False
    th.save_model = False
    # Smoothen
    th.overwrite = th.overwrite and th.start_at == 0

    # Get model
    model = model_lib.bres_net_res0(th)
    # Load data
    train_set, val_set, test_set = load_wiener_hammerstein(
        th.data_dir, depth=th.memory_depth)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    # Train or evaluate
    if th.train:
        model.nn.train(train_set,
                       validation_set=val_set,
                       trainer_hub=th,
                       start_at=th.start_at)
    else:
        model.evaluate(train_set, start_at=th.memory_depth)
        model.evaluate(val_set, start_at=th.memory_depth)
        model.evaluate(test_set, start_at=th.memory_depth)

    # End
    console.end()
コード例 #5
0
def main(_):
    console.start('mlp task')

    description = 'm'
    # Configurations
    th = NlsHub(as_global=True)
    th.memory_depth = 40
    th.num_blocks = 2
    th.multiplier = 2
    th.hidden_dim = th.memory_depth * th.multiplier

    th.mark = 'mlp-{}x({}x{})-{}'.format(th.num_blocks, th.memory_depth,
                                         th.multiplier, description)
    th.epoch = 50000
    th.batch_size = 64
    th.learning_rate = 0.001
    th.validation_per_round = 20

    th.train = True
    th.smart_train = False
    th.idle_tol = 20
    th.max_bad_apples = 4
    th.lr_decay = 0.6
    th.early_stop = True
    th.save_mode = SaveMode.ON_RECORD
    th.warm_up_rounds = 50
    th.overwrite = True
    th.export_note = True
    th.summary = True
    th.monitor = True
    th.save_model = False

    th.allow_growth = False
    th.gpu_memory_fraction = 0.4

    # Get model
    model = model_lib.mlp_00(th)
    # Load data
    train_set, val_set, test_set = load_wiener_hammerstein(
        th.data_dir, depth=th.memory_depth)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    # Train or evaluate
    if th.train:
        model.nn.train(train_set, validation_set=val_set, trainer_hub=th)
    else:
        console.show_status('Evaluating ...')
        model.evaluate(train_set, start_at=th.memory_depth)
        model.evaluate(val_set, start_at=th.memory_depth)
        model.evaluate(test_set, start_at=th.memory_depth, plot=True)

    # End
    console.end()
コード例 #6
0
ファイル: hpt_task_lottery.py プロジェクト: zkmartin/nls
def main(_):
    console.start('trainer.task')

    # Set default flags
    FLAGS.train = True
    if FLAGS.use_default:
        FLAGS.overwrite = True
        FLAGS.smart_train = False
        FLAGS.save_best = False

    FLAGS.smart_train = True
    FLAGS.save_best = False

    WH_PATH = FLAGS.data_dir

    MARK = 'lottery02'
    MEMORY_DEPTH = 80
    PRINT_CYCLE = 50
    EPOCH = 1000
    LR = 0.000088

    LAYER_DIM = MEMORY_DEPTH * FLAGS.coe
    # ACTIVATION = FLAGS.activation
    ACTIVATION = 'relu'
    # BRANCHES = FLAGS.branches
    BRANCHES = 6
    LR_LIST = [FLAGS.lr1] * (BRANCHES + 1)
    FLAGS.smart_train = True

    # Get model
    model = model_lib.mlp02(MARK,
                            MEMORY_DEPTH,
                            BRANCHES,
                            LAYER_DIM,
                            LR,
                            ACTIVATION,
                            identity_init=True)

    # Load data set
    train_set, val_set, test_set = load_wiener_hammerstein(WH_PATH,
                                                           depth=MEMORY_DEPTH)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    # Train
    if FLAGS.train:
        model.identify(train_set,
                       val_set,
                       batch_size=64,
                       print_cycle=PRINT_CYCLE,
                       epoch=EPOCH,
                       lr_list=LR_LIST)

    console.end()
コード例 #7
0
ファイル: task.py プロジェクト: zkmartin/nls
def main(_):
    console.start('trainer.task')

    # Set default flags
    if FLAGS.use_default:
        FLAGS.train = True
        FLAGS.overwrite = True
        FLAGS.smart_train = False
        FLAGS.save_best = False
        FLAGS.progress_bar = False

    if FLAGS.data_dir == "":
        WH_PATH = os.path.join(nls_root, 'data/wiener_hammerstein/whb.tfd')
    else:
        WH_PATH = FLAGS.data_dir
    MARK = 'mlp00'
    MEMORY_DEPTH = 40
    PRINT_CYCLE = 100
    EPOCH = 2

    LAYER_DIM = MEMORY_DEPTH * 2
    LAYER_NUM = 2
    LEARNING_RATE = 0.001
    BATCH_SIZE = 64

    # Get model
    model = model_lib.mlp_00(MARK, MEMORY_DEPTH, LAYER_DIM, LAYER_NUM,
                             LEARNING_RATE)

    # Load data set
    train_set, val_set, test_set = load_wiener_hammerstein(WH_PATH,
                                                           depth=MEMORY_DEPTH)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    # Train or evaluate
    if FLAGS.train:
        model.identify(train_set,
                       val_set,
                       batch_size=BATCH_SIZE,
                       print_cycle=PRINT_CYCLE,
                       epoch=EPOCH)
    else:
        pass

    console.end()
コード例 #8
0
def main(_):
    # =============================================================================
    # Global configuration
    WH_PATH = './data/wiener_hammerstein/whb.tfd'

    NN_LEARNING_RATE = 0.001
    BATCH_SIZE = 32

    MEMORY_DEPTH = 40
    NN_EPOCH = 5
    PRINT_CYCLE = 100

    FLAGS.train = True
    # FLAGS.train = False
    FLAGS.overwrite = True
    # FLAGS.overwrite = False
    FLAGS.save_best = False
    # FLAGS.save_best = True

    FLAGS.smart_train = False
    FLAGS.epoch_tol = 20

    # Turn off overwrite while in save best mode
    FLAGS.overwrite = FLAGS.overwrite and not FLAGS.save_best and FLAGS.train
    # =============================================================================

    console.start('NN Demo')

    # Load data set
    train_set, val_set, test_set = load_wiener_hammerstein(WH_PATH,
                                                           depth=MEMORY_DEPTH)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    model = nn_models.mlp_00(NN_LEARNING_RATE, MEMORY_DEPTH)

    # Define model and identify
    if FLAGS.train:
        model.identify(train_set,
                       val_set,
                       batch_size=BATCH_SIZE,
                       print_cycle=PRINT_CYCLE,
                       epoch=NN_EPOCH)

    console.end()
コード例 #9
0
ファイル: hpt_task_resnet.py プロジェクト: zkmartin/nls
def main(_):
  console.start('trainer.task')

  # Set default flags
  FLAGS.train = True
  if FLAGS.use_default:
    FLAGS.overwrite = True
    FLAGS.smart_train = False
    FLAGS.save_best = False

  FLAGS.smart_train = True
  FLAGS.save_best = True

  WH_PATH = FLAGS.data_dir

  MARK = 'resnet01'
  MEMORY_DEPTH = 80
  PRINT_CYCLE = 50
  EPOCH = 100
  NN_BLOCKS = FLAGS.blocks
  order1 = FLAGS.order1
  order2 = FLAGS.order2


  LAYER_DIM = MEMORY_DEPTH * FLAGS.coe
  LEARNING_RATE = FLAGS.lr
  ACTIVATION = FLAGS.activation


  # Get model
  model = model_lib.res_00(memory=MEMORY_DEPTH, blocks=NN_BLOCKS, order1=order1, order2=order2,
                           activation=ACTIVATION, learning_rate=LEARNING_RATE)

  # Load data set
  train_set, val_set, test_set = load_wiener_hammerstein(
    WH_PATH, depth=MEMORY_DEPTH)
  assert isinstance(train_set, DataSet)
  assert isinstance(val_set, DataSet)
  assert isinstance(test_set, DataSet)

  # Train
  if FLAGS.train:
    model.identify(train_set, val_set, batch_size=64,
                   print_cycle=PRINT_CYCLE, epoch=EPOCH)

  console.end()
コード例 #10
0
def main(_):
    console.start('trainer.task')

    # Set default flags
    FLAGS.train = True
    if FLAGS.use_default:
        FLAGS.overwrite = True
        FLAGS.smart_train = False
        FLAGS.save_best = False

    FLAGS.smart_train = True
    FLAGS.save_best = True

    WH_PATH = FLAGS.data_dir

    MARK = 'svn00'
    MEMORY_DEPTH = FLAGS.memory
    PRINT_CYCLE = 50
    EPOCH = 100

    LAYER_DIM = MEMORY_DEPTH * 2
    LEARNING_RATE = FLAGS.lr
    BATCH_SIZE = 64
    ORDER1 = FLAGS.order1
    ORDER2 = FLAGS.order2
    ORDER3 = FLAGS.order3

    # Get model
    model = model_lib.svn_00(MEMORY_DEPTH, MARK, LAYER_DIM, ORDER1, ORDER2,
                             ORDER3, LEARNING_RATE)
    # Load data set
    train_set, val_set, test_set = load_wiener_hammerstein(WH_PATH,
                                                           depth=MEMORY_DEPTH)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    # Train
    if FLAGS.train:
        model.identify(train_set,
                       val_set,
                       batch_size=BATCH_SIZE,
                       print_cycle=PRINT_CYCLE,
                       epoch=EPOCH)

    console.end()
コード例 #11
0
def main(_):
    console.start('trainer.task')

    # Set default flags
    FLAGS.train = True
    if FLAGS.use_default:
        FLAGS.overwrite = True
        FLAGS.smart_train = False
        FLAGS.save_best = False

    WH_PATH = FLAGS.data_dir
    FLAGS.smart_train = True
    FLAGS.save_best = True

    MARK = 'mlp00'
    MEMORY_DEPTH = 80
    PRINT_CYCLE = 50
    EPOCH = 300

    LAYER_DIM = MEMORY_DEPTH * FLAGS.coe
    LAYER_NUM = FLAGS.layer_num
    LEARNING_RATE = FLAGS.lr
    BATCH_SIZE = 32
    ACTIVATION = FLAGS.activation

    # Get model
    model = model_lib.mlp_00(MARK, MEMORY_DEPTH, LAYER_DIM, LAYER_NUM,
                             LEARNING_RATE, ACTIVATION)

    # Load data set
    train_set, val_set, test_set = load_wiener_hammerstein(WH_PATH,
                                                           depth=MEMORY_DEPTH)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    # Train
    if FLAGS.train:
        model.identify(train_set,
                       val_set,
                       batch_size=BATCH_SIZE,
                       print_cycle=PRINT_CYCLE,
                       epoch=EPOCH)

    console.end()
コード例 #12
0
# FLAGS.save_best = True

FLAGS.smart_train = True
FLAGS.epoch_tol = 50

FLAGS.summary = False
FLAGS.snapshot = False

# Turn off overwrite while in save best mode
FLAGS.overwrite = FLAGS.overwrite and not FLAGS.save_best and FLAGS.train

EVALUATION = not FLAGS.train
PLOT = EVALUATION
# =============================================================================
# Load data set
train_set, val_set, test_set = load_wiener_hammerstein(WH_PATH,
                                                       depth=MEMORY_DEPTH)
assert isinstance(train_set, DataSet)
assert isinstance(val_set, DataSet)
assert isinstance(test_set, DataSet)

model = wh_model_lib.svn_02(MEMORY_DEPTH, 'svn_poly_relu', D * 2, 3, 3, 3,
                            NN_LEARNING_RATE)

# Define model and identify
if FLAGS.train:
    model.identify(train_set,
                   val_set,
                   batch_size=32,
                   print_cycle=100,
                   epoch=NN_EPOCH)
コード例 #13
0
import numpy as np
from signals.utils.dataset import load_wiener_hammerstein, DataSet

# Configurations
memory_depth = 80
data_dir = '../data/wiener_hammerstein/whb.tfd'
err_amplitude = 1e-5

# Load data
train_set, val_set, test_set = load_wiener_hammerstein(data_dir,
                                                       depth=memory_depth)
assert isinstance(train_set, DataSet)
assert isinstance(val_set, DataSet)
assert isinstance(test_set, DataSet)
u, y = test_set.signls[0], test_set.responses[0]

# Define error function
f_ratio = lambda val: 100 * val / y.rms


def pseud_evaluate(err_bound):
    print('-' * 79)
    pred_y = y + np.random.random_sample(y.shape) * err_bound
    # Error
    err = y - pred_y
    # Mean value
    val = err.average
    print('E[err] = {:.4f} mv ({:.3f}%)'.format(val * 1000, f_ratio(val)))
    # Standard deviation
    val = float(np.std(err))
    print('STD[err] = {:.4f} mv ({:.3f}%)'.format(val * 1000, f_ratio(val)))
コード例 #14
0
def main(_):
    console.start('Lottery')

    # Configurations
    MARK = 'mlp_block_test'
    MEMORY_DEPTH = 80
    coe = 2
    HIDDEN_DIM = MEMORY_DEPTH * coe
    BRANCH_NUM = 6
    T_BRANCH_INDEX_S = 0
    T_BRANCH_INDEX_E = 0

    EPOCH = 5000
    LR = 0.000088
    LR_LIST = [0.000088] * (BRANCH_NUM + 1)
    BATCH_SIZE = 32
    PRINT_CYCLE = 10
    BRANCH_INDEX = 2
    # FIX_PRE_WEIGHT = True
    freeze_index = 1
    LAYER_TRAIN = True
    BRANCH_TRAIN = False
    ACTIVATION = 'relu'

    # FLAGS.train = False
    # FLAGS.overwrite = True and BRANCH_INDEX == 0
    FLAGS.overwrite = False
    FLAGS.smart_train = True
    FLAGS.save_best = True and BRANCH_INDEX > 0
    # FLAGS.save_best = False
    FLAGS.summary = True
    # FLAGS.save_model = False
    FLAGS.snapshot = False
    FLAGS.epoch_tol = 30

    # Load data
    train_set, val_set, test_set = load_wiener_hammerstein(
        r'../data/wiener_hammerstein/whb.tfd', depth=MEMORY_DEPTH)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    # Get model
    model = lott_lib.mlp01(MARK, MEMORY_DEPTH, BRANCH_NUM, HIDDEN_DIM, LR,
                           ACTIVATION)

    model.nn._branches_variables_assign(BRANCH_INDEX)

    # Train or evaluate
    if FLAGS.train:
        model.identify(train_set,
                       val_set,
                       batch_size=BATCH_SIZE,
                       print_cycle=PRINT_CYCLE,
                       epoch=EPOCH,
                       branch_index=BRANCH_INDEX,
                       lr_list=LR_LIST,
                       freeze_index=freeze_index,
                       t_branch_s_index=T_BRANCH_INDEX_S,
                       t_branch_e_index=T_BRANCH_INDEX_E,
                       layer_train=LAYER_TRAIN,
                       branch_train=BRANCH_TRAIN)

    else:
        BRANCH_INDEX = 2
        model.evaluate(train_set,
                       start_at=MEMORY_DEPTH,
                       branch_index=BRANCH_INDEX)
        model.evaluate(val_set,
                       start_at=MEMORY_DEPTH,
                       branch_index=BRANCH_INDEX)
        model.evaluate(test_set,
                       start_at=MEMORY_DEPTH,
                       branch_index=BRANCH_INDEX)

    console.end()
コード例 #15
0
ファイル: lott_fortune_script.py プロジェクト: zkmartin/nls
def main(_):
    console.start('Lottery')

    # Configurations
    MARK = 'mlp02_broad'
    BRANCH_NUM = 3
    MEMORY_DEPTH = 80
    coe = 8
    HIDDEN_DIM = MEMORY_DEPTH * coe

    EPOCH = 500
    LR = 0.00088
    BATCH_SIZE = 32
    PRINT_CYCLE = 10
    BRANCH_INDEX = 1
    FIX_PRE_WEIGHT = True
    ACTIVATION = 'relu'

    # FLAGS.train = False
    FLAGS.overwrite = True and BRANCH_INDEX == 0
    FLAGS.smart_train = True
    FLAGS.save_best = True and BRANCH_INDEX > 0
    # FLAGS.save_best = False
    FLAGS.summary = True
    # FLAGS.save_model = False
    FLAGS.snapshot = False
    FLAGS.epoch_tol = 100

    # Load data
    train_set, val_set, test_set = load_wiener_hammerstein(
        r'../data/wiener_hammerstein/whb.tfd', depth=MEMORY_DEPTH)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    # Get model
    model = lott_lib.mlp02(MARK, MEMORY_DEPTH, BRANCH_NUM, HIDDEN_DIM, LR,
                           ACTIVATION)

    # Train or evaluate
    if FLAGS.train:
        model.identify(train_set,
                       val_set,
                       batch_size=BATCH_SIZE,
                       print_cycle=PRINT_CYCLE,
                       epoch=EPOCH,
                       branch_index=BRANCH_INDEX,
                       freeze=FIX_PRE_WEIGHT)
    else:
        BRANCH_INDEX = 1
        model.evaluate(train_set,
                       start_at=MEMORY_DEPTH,
                       branch_index=BRANCH_INDEX)
        model.evaluate(val_set,
                       start_at=MEMORY_DEPTH,
                       branch_index=BRANCH_INDEX)
        model.evaluate(test_set,
                       start_at=MEMORY_DEPTH,
                       branch_index=BRANCH_INDEX)

    console.end()
コード例 #16
0
ファイル: lott_script.py プロジェクト: zkmartin/nls
def main(_):
    console.start('Lottery')

    # Configurations
    MARK = 'mlp00'
    MEMORY_DEPTH = 80
    coe = 8
    HIDDEN_DIM = MEMORY_DEPTH * coe

    EPOCH = 500
    LR = 0.000058
    BATCH_SIZE = 32
    PRINT_CYCLE = 10
    BRANCH_INDEX = 1
    FIX_PRE_WEIGHT = True
    ACTIVATION = 'relu'

    # FLAGS.train = False
    FLAGS.overwrite = True and BRANCH_INDEX == 0
    FLAGS.smart_train = True
    FLAGS.save_best = True and BRANCH_INDEX > 0
    FLAGS.summary = True
    # FLAGS.save_model = False
    FLAGS.snapshot = False
    FLAGS.epoch_tol = 50

    # Load data
    train_set, val_set, test_set = load_wiener_hammerstein(
        r'../data/wiener_hammerstein/whb.tfd', depth=MEMORY_DEPTH)
    assert isinstance(train_set, DataSet)
    assert isinstance(val_set, DataSet)
    assert isinstance(test_set, DataSet)

    # Get model
    model = lott_lib.mlp00(MARK, MEMORY_DEPTH, HIDDEN_DIM, LR, ACTIVATION)

    branch_1_weights = 'FeedforwardNet/branch/linear/weights:0'
    branch_1_bias = 'FeedforwardNet/branch/linear/biases:0'
    branch_2_weights = 'FeedforwardNet/branch2/linear/weights:0'
    branch_2_bias = 'FeedforwardNet/branch2/linear/biases:0'
    # model.nn.variable_assign(branch_1_weights, branch_2_weights)
    # model.nn.variable_assign(branch_1_bias, branch_2_bias)
    with model.nn._graph.as_default():
        variables = tf.trainable_variables()
        b = 1
    #   print(model.nn._session.run(variables[2]))
    #   print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    #   print(model.nn._session.run(variables[6]))
    #   print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    #   print(model.nn._session.run(variables[3]))
    #   print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    #   print(model.nn._session.run(variables[7]))
    #   print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    #   print(model.nn._session.run(variables[4]))
    #   print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    #   print(model.nn._session.run(variables[5]))
    #   a = 1

    # Train or evaluate
    if FLAGS.train:
        model.identify(train_set,
                       val_set,
                       batch_size=BATCH_SIZE,
                       print_cycle=PRINT_CYCLE,
                       epoch=EPOCH,
                       branch_index=BRANCH_INDEX,
                       freeze=FIX_PRE_WEIGHT)
    else:
        BRANCH_INDEX = 1
        model.evaluate(train_set,
                       start_at=MEMORY_DEPTH,
                       branch_index=BRANCH_INDEX)
        model.evaluate(val_set,
                       start_at=MEMORY_DEPTH,
                       branch_index=BRANCH_INDEX)
        model.evaluate(test_set,
                       start_at=MEMORY_DEPTH,
                       branch_index=BRANCH_INDEX)

    console.end()