Esempio n. 1
0
def train_tnn_alexnet():
    imagenet = setup.get_imagenet()
    images_plc = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 224, 224, 3])
    labels_plc = tf.placeholder(tf.int64, shape=[BATCH_SIZE])

    with tf.variable_scope('tconvnet'):
        G = main.graph_from_json('json/alexnet.json')
        main.init_nodes(G, batch_size=BATCH_SIZE)
        main.unroll(G, input_seq=images_plc)
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=G.node['fc8']['outputs'][-1], labels=labels_plc)
        loss = tf.reduce_mean(loss)

    optimizer = tf.train.MomentumOptimizer(learning_rate=.01,
                                           momentum=.9).minimize(loss)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    losses = []
    for step in range(1000):
        start = time.time()
        images_batch, labels_batch = imagenet.next()
        lo, _ = sess.run([loss, optimizer],
                         feed_dict={
                             images_plc: images_batch,
                             labels_plc: labels_batch
                         })
        end = time.time()
        losses.append(lo)
        print(step, '{:.4f}'.format(lo), '{:.3f}'.format(end - start))
    assert np.mean(losses[-20:]) < 6.8
Esempio n. 2
0
def test_alexnet():
    ims = np.random.standard_normal([BATCH_SIZE, 224, 224, 3])
    labels = np.random.randint(1000, size=[BATCH_SIZE])
    data = {
        'images': tf.constant(ims.astype(np.float32)),
        'labels': tf.constant(labels.astype(np.int32))
    }
    # initialize the benchmark model
    with tf.variable_scope('benchmark'):
        bench_targets = setup.alexnet(data['images'],
                                      data['labels'],
                                      'benchmark',
                                      train=False)
    bench_targets = {'loss': bench_targets['loss']}

    # initialize the tconvnet model
    with tf.variable_scope('tconvnet'):
        G = main.graph_from_json('json/alexnet.json')
        main.init_nodes(G, batch_size=BATCH_SIZE)
        main.unroll(G, input_seq=data['images'])
        tnn_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=G.node['fc8']['outputs'][-1], labels=data['labels'])
        tnn_targets = {'loss': tf.reduce_mean(tnn_loss)}

    run(bench_targets, tnn_targets, nsteps=10, n_initial=10)
Esempio n. 3
0
def memory_usage():
    ims = np.random.standard_normal([BATCH_SIZE, 224, 224, 3])
    labels = np.random.randint(1000, size=[BATCH_SIZE])
    data = {
        'images': tf.constant(ims.astype(np.float32)),
        'labels': tf.constant(labels.astype(np.int32))
    }
    # initialize the benchmark model
    # with tf.variable_scope('benchmark'):
    #     bench_targets = setup.alexnet(data['images'], data['labels'], 'benchmark', train=False)
    #     loss = bench_targets['loss']

    with tf.variable_scope('tconvnet'):
        G = main.graph_from_json('json/alexnet.json')
        main.init_nodes(G, batch_size=BATCH_SIZE)
        main.unroll(G, input_seq=data['images'])
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=G.node['fc8']['outputs'][-1], labels=data['labels'])

    init = tf.global_variables_initializer()
    sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
        allow_growth=True)))
    sess.run(init)
    sess.run(loss)
    import pdb
    pdb.set_trace()
Esempio n. 4
0
def test_feedback():
    images = tf.constant(np.random.standard_normal([BATCH_SIZE, 224, 224, 3]).astype(np.float32))
    # initialize the tconvnet model
    with tf.variable_scope('tconvnet'):
        json_path = os.path.join(json_dir, 'alexnet.json')
        G = main.graph_from_json(json_path)
        G.add_edges_from([('conv5', 'conv3'), ('conv5', 'conv4'), ('conv4', 'conv3')])
        main.init_nodes(G, input_nodes=['conv1'], batch_size=BATCH_SIZE)
        main.unroll(G, input_seq={'conv1': images})

    test_state_and_output_sizes(G)

    graph = tf.get_default_graph()

    # harbor output sizes
    harbor = graph.get_tensor_by_name('tconvnet/conv1/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 224, 224, 3]
    harbor = graph.get_tensor_by_name('tconvnet/conv2/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 27, 27, 96]
    harbor = graph.get_tensor_by_name('tconvnet/conv3/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 256+384+256]
    harbor = graph.get_tensor_by_name('tconvnet/conv4/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 384+256]
    harbor = graph.get_tensor_by_name('tconvnet/conv5/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 384]
    harbor = graph.get_tensor_by_name('tconvnet/fc6/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 7, 7, 256]
    harbor = graph.get_tensor_by_name('tconvnet/fc7/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 4096]
    harbor = graph.get_tensor_by_name('tconvnet/fc8/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 4096]

    # check if harbor outputs at t are equal to the concat of outputs
    # from incoming nodes at t-1

    # layer 4 gets inputs from 5 and 3
    conv4h = graph.get_tensor_by_name('tconvnet/conv4_5/harbor:0')
    conv3o = G.node['conv3']['outputs'][4]
    conv5o = G.node['conv5']['outputs'][4]
    conv5om = tf.image.resize_images(conv5o, conv4h.shape.as_list()[1:3])

    concat = tf.concat([conv3o, conv5om], axis=3)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        conv4hr, concatr = sess.run([conv4h, concat])
        assert np.array_equal(conv4hr, concatr)

    # layer 3 gets inputs from 2, 4, 5
    conv3h = graph.get_tensor_by_name('tconvnet/conv3_7/harbor:0')
    conv2o = G.node['conv2']['outputs'][6]
    conv5o = G.node['conv5']['outputs'][6]
    conv5om = tf.image.resize_images(conv5o, conv3h.shape.as_list()[1:3])
    conv4o = G.node['conv4']['outputs'][6]
    conv4om = tf.image.resize_images(conv4o, conv3h.shape.as_list()[1:3])

    concat = tf.concat([conv2o, conv4om, conv5om], axis=3)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        conv3hr, concatr = sess.run([conv3h, concat])
        assert np.array_equal(conv3hr, concatr)
Esempio n. 5
0
def test_feedback2():
    images = tf.constant(np.random.standard_normal([BATCH_SIZE, 224, 224, 3]).astype(np.float32))
    # initialize the tconvnet model
    with tf.variable_scope('tconvnet'):
        json_path = os.path.join(json_dir, 'alexnet.json')
        G = main.graph_from_json(json_path)
        G.add_edges_from([('fc7', 'conv5')])
        main.init_nodes(G, input_nodes=['conv1'], batch_size=BATCH_SIZE)
        main.unroll(G, input_seq={'conv1': images})

    test_state_and_output_sizes(G)

    graph = tf.get_default_graph()

    harbor = graph.get_tensor_by_name('tconvnet/conv1/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 224, 224, 3]
    harbor = graph.get_tensor_by_name('tconvnet/conv2/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 27, 27, 96]
    harbor = graph.get_tensor_by_name('tconvnet/conv3/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 256]
    harbor = graph.get_tensor_by_name('tconvnet/conv4/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 384]
    harbor = graph.get_tensor_by_name('tconvnet/conv5_1/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 384 + int(math.ceil(4096 / (14 * 14)))]
    harbor = graph.get_tensor_by_name('tconvnet/fc6/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 7, 7, 256]
    harbor = graph.get_tensor_by_name('tconvnet/fc7/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 4096]
    harbor = graph.get_tensor_by_name('tconvnet/fc8/harbor:0')
    assert harbor.shape.as_list() == [BATCH_SIZE, 4096]
Esempio n. 6
0
def model_func(input_images,
               train=True,
               ntimes=TOTAL_TIMESTEPS,
               batch_size=batch_size,
               edges_arr=[],
               base_name='alexnet_tnn',
               tau=0.0,
               trainable_flag=False,
               channel_op='concat'):
    with tf.variable_scope("alexnet_tnn"):
        base_name += '.json'
        print('Using base: ', base_name)
        # creates the feedforward network graph from json
        G = main.graph_from_json(base_name)

        for node, attr in G.nodes(data=True):
            if node in ['fc6', 'fc7']:
                if train:  # we add dropout to fc6 and fc7 during training
                    print('Using dropout for ' + node)
                    attr['kwargs']['post_memory'][1][1]['keep_prob'] = 0.5
                else:  # turn off dropout during training
                    print('Not using dropout')
                    attr['kwargs']['post_memory'][1][1]['keep_prob'] = 1.0

            memory_func, memory_param = attr['kwargs']['memory']
            if 'filter_size' in memory_param:
                attr['cell'] = tnn_ConvLSTMCell
            else:
                attr['kwargs']['memory'][1]['memory_decay'] = tau
                attr['kwargs']['memory'][1]['trainable'] = trainable_flag

        # add any non feedforward connections here: [('L2', 'L1')]
        G.add_edges_from(edges_arr)

        # initialize network to infer the shapes of all the parameters
        main.init_nodes(G,
                        input_nodes=['conv1'],
                        batch_size=batch_size,
                        channel_op=channel_op)
        # unroll the network through time
        main.unroll(G, input_seq={'conv1': input_images}, ntimes=ntimes)

        outputs = {}
        # start from the final output of the model and num timesteps beyond that
        # for t in range(ntimes-NUM_TIMESTEPS, ntimes):
        #     idx = t - (ntimes - NUM_TIMESTEPS) # keys start at timepoint 0
        #    outputs[idx] = G.node['fc8']['outputs'][t]

        # alternatively, we return the final output of the model at the last timestep
        outputs[0] = G.node['fc8']['outputs'][-1]
        return outputs
Esempio n. 7
0
def test_mnist(kind='conv'):
    data = {
        'images':
        np.random.standard_normal([BATCH_SIZE, 28 * 28]).astype(np.float32),
        'labels':
        np.random.randint(10, size=BATCH_SIZE).astype(np.int32)
    }
    if kind == 'conv':
        data['images'] = np.reshape(data['images'], [-1, 28, 28, 1])

    # initialize the benchmark model
    with tf.variable_scope('benchmark'):
        if kind == 'conv':
            bench_targets = setup.mnist_conv(**data)
        elif kind == 'fc':
            bench_targets = setup.mnist_fc(**data)
        else:
            raise ValueError

    bench_vars = {
        v.name[len('benchmark') + 1:]: v
        for v in tf.global_variables() if v.name.startswith('benchmark')
    }
    bench_targets.update(bench_vars)
    for name, var in bench_vars.items():
        bench_targets['grad_' + name] = tf.gradients(bench_targets['loss'],
                                                     var)

    # initialize the tconvnet model
    with tf.variable_scope('tconvnet'):
        G = main.graph_from_json('json/mnist_{}.json'.format(kind))
        main.init_nodes(G, batch_size=BATCH_SIZE)
        input_seq = tf.constant(data['images'])
        main.unroll(G, input_seq=input_seq)
        tnn_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=G.node['fc2']['outputs'][-1],
            labels=tf.constant(data['labels']))

    tnn_targets = {n: G.node[n]['outputs'][-1] for n in G}
    tnn_targets['loss'] = tf.reduce_mean(tnn_loss)
    tnn_vars = {
        v.name[len('tconvnet') + 1:]: v
        for v in tf.global_variables()
        if v.name.startswith('tconvnet') and 'memory_decay' not in v.name
    }
    tnn_targets.update(tnn_vars)
    for name, var in tnn_vars.items():
        tnn_targets['grad_' + name] = tf.gradients(tnn_targets['loss'], var)

    run(bench_targets, tnn_targets, nsteps=100)
Esempio n. 8
0
def model_func(input_images,
               ntimes=TOTAL_TIMESTEPS,
               batch_size=batch_size,
               edges_arr=[],
               base_name='../json/VanillaRNN',
               tau=0.0,
               trainable_flag=False):

    with tf.variable_scope("my_model"):
        # reshape the 784 dimension MNIST digits to be 28x28 images
        input_images = tf.reshape(input_images, [-1, 28, 28, 1])
        base_name += '.json'
        print('Using base: ', base_name)
        # creates the feedforward network graph from json
        G = main.graph_from_json(base_name)

        for node, attr in G.nodes(data=True):
            memory_func, memory_params = attr['kwargs']['memory']
            if any(p in memory_params for p in
                   ['filter_size', 'gate_filter_size', 'tau_filter_size']):
                # this is where you add your custom cell
                attr['cell'] = CUSTOM_CELL
            else:
                # default to not having a memory cell
                # tau = 0.0, trainable = False
                attr['kwargs']['memory'][1]['memory_decay'] = tau
                attr['kwargs']['memory'][1]['trainable'] = trainable_flag

        # add any non feedforward connections here: e.g. [('L2', 'L1')]
        G.add_edges_from(edges_arr)

        # initialize network to infer the shapes of all the parameters
        main.init_nodes(G, input_nodes=[INPUT_LAYER], batch_size=batch_size)
        # unroll the network through time
        main.unroll(G, input_seq={INPUT_LAYER: input_images}, ntimes=ntimes)

        outputs = {}
        # start from the final output of the model and 4 timesteps beyond that
        for t in range(ntimes - NUM_TIMESTEPS, ntimes):
            idx = t - (ntimes - NUM_TIMESTEPS)  # keys start at timepoint 0
            outputs[idx] = G.node[READOUT_LAYER]['outputs'][t]

        return outputs
Esempio n. 9
0
def test_memory():
    images = tf.constant(np.random.standard_normal([BATCH_SIZE, 28, 28, 1]).astype(np.float32))

    with tf.variable_scope('tconvnet'):
        json_path = os.path.join(json_dir, 'alexnet.json')
        G = main.graph_from_json(json_path)
        for node, attr in G.nodes(data=True):
            if node in ['conv1', 'conv2']:
                attr['kwargs']['memory'][1]['memory_decay'] = MEM
        main.init_nodes(G, input_nodes=['conv1'], batch_size=BATCH_SIZE)
        main.unroll(G, input_seq={'conv1': images}, ntimes=6)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    graph = tf.get_default_graph()

    conv1_state = np.zeros(G.node['conv1']['states'][0].get_shape().as_list())
    conv2_state = np.zeros(G.node['conv2']['states'][0].get_shape().as_list())
    state1, state2 = sess.run([G.node['conv1']['states'], G.node['conv2']['states']])
    for i, (s1, s2) in enumerate(zip(state1, state2)):
        if i == 0:
            state1_inp = graph.get_tensor_by_name('tconvnet/conv1/conv:0')
            state2_inp = graph.get_tensor_by_name('tconvnet/conv2/conv:0')
        else:
            state1_inp = graph.get_tensor_by_name('tconvnet/conv1_{}/conv:0'.format(i))
            state2_inp = graph.get_tensor_by_name('tconvnet/conv2_{}/conv:0'.format(i))
        state1_inp, state2_inp = sess.run([state1_inp, state2_inp])

        conv1_state = conv1_state * MEM + state1_inp
        assert np.allclose(s1, conv1_state)
        conv2_state = conv2_state * MEM + state2_inp
        assert np.allclose(s2, conv2_state)

    sess.close()
Esempio n. 10
0
from __future__ import absolute_import, division, print_function

import tensorflow as tf

from tnn import main


def loss(G):
    output_nodes = [n for n in G if len(G.successors(n)) == 0]

    for node in output_nodes:
        attr = G.node[node]
        assert len(G.predecessors(node)) == 2
        for pred in sorted(G.predecessors(node)):
            if pred == 'labels':
                labels = input_dict['labels']
            else:  # must be logits then
                logits = G.node[pred]['outputs'][t]

        with tf.variable_scope(attr['name']):
            loss = attr['function'](logits=logits, labels=labels)

        attr['outputs'].append(loss)


G = main.graph_from_json('../json/alexnet.json')
main.init_nodes(G, input_nodes=['conv1'], batch_size=256)
input_images = tf.zeros([256, 224, 224, 3])
main.unroll(G, input_seq={'conv1': input_images}, ntimes=9)
Esempio n. 11
0
def catenet_tnn(inputs,
                cfg_path,
                train=True,
                tnndecay=0.1,
                decaytrain=0,
                cfg_initial=None,
                cmu=0,
                fixweights=False,
                seed=0,
                **kwargs):
    m = model.ConvNet(fixweights=fixweights, seed=seed, **kwargs)

    params = {'input': inputs.name, 'type': 'fc'}

    dropout = 0.5 if train else None

    # Get inputs
    shape_list = inputs.get_shape().as_list()
    assert shape_list[2] == 35, 'Must set expand==1'
    sep_num = shape_list[1]
    if not cfg_initial is None and 'sep_num' in cfg_initial:
        sep_num = cfg_initial['sep_num']
    small_inputs = tf.split(inputs, sep_num, 1)
    for indx_time in xrange(len(small_inputs)):
        small_inputs[indx_time] = tf.transpose(small_inputs[indx_time],
                                               [0, 2, 1, 3])
        small_inputs[indx_time] = tf.reshape(small_inputs[indx_time],
                                             [shape_list[0], 5, 7, -1])

    G = main.graph_from_json(cfg_path)

    if 'all_conn' in cfg_initial:
        node_list = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'conv6']

        MASTER_EDGES = []
        for i in range(len(node_list)):
            for j in range(len(node_list)):
                if (j > i + 1 or i > j) and (
                        j > 0):  #since (i, j) w/ j > i is already an edge
                    MASTER_EDGES.append((node_list[i], node_list[j]))

        print(MASTER_EDGES)
        G.add_edges_from(MASTER_EDGES)

    for node, attr in G.nodes(data=True):

        memory_func, memory_param = attr['kwargs']['memory']
        if 'nunits' in memory_param:
            attr['cell'] = tnn_LSTMCell
        else:
            memory_param['memory_decay'] = tnndecay
            memory_param['trainable'] = decaytrain == 1
            attr['kwargs']['memory'] = (memory_func, memory_param)

        if fixweights:
            if node.startswith('conv'):
                _, prememory_param = attr['kwargs']['pre_memory'][0]
                attr['kwargs']['pre_memory'][0] = (model.conv_fix,
                                                   prememory_param)

            if node.startswith('fc'):
                _, prememory_param = attr['kwargs']['pre_memory'][0]
                attr['kwargs']['pre_memory'][0] = (model.fc_fix,
                                                   prememory_param)

        if not seed == 0:
            for sub_prememory in attr['kwargs']['pre_memory']:
                prememory_func, prememory_param = sub_prememory
                if 'kernel_init_kwargs' in prememory_param:
                    prememory_param['kernel_init_kwargs']['seed'] = seed

        if node in ['fc7', 'fc8']:
            attr['kwargs']['pre_memory'][0][1]['dropout'] = dropout

    main.init_nodes(G, batch_size=shape_list[0])
    main.unroll(G, input_seq={'conv1': small_inputs}, ntimes=len(small_inputs))

    if not 'retres' in cfg_initial:
        if cmu == 0:
            m.output = G.node['fc8']['outputs'][-1]
        else:
            m.output = tf.transpose(tf.stack(G.node['fc8']['outputs']),
                                    [1, 2, 0])
    else:
        m.output = tf.concat(
            [G.node['fc8']['outputs'][x] for x in cfg_initial['retres']], 1)

    print(len(G.node['fc8']['outputs']))
    m.params = params

    return m
Esempio n. 12
0
def rnn_fc(inputs,
           train=True,
           prefix=MODEL_PREFIX,
           devices=DEVICES,
           num_gpus=NUM_GPUS,
           ntimes=TOTAL_TIMESTEPS,
           edges_arr=[],
           base_name='retina_tnn_fc_lstm',
           tau=0.0,
           trainable_flag=False,
           channel_op='concat',
           seed=0,
           cfg_final=None):

    params = OrderedDict()
    batch_size = inputs['images'].get_shape().as_list()[0]

    params['stim_type'] = stim_type
    params['train'] = train
    params['batch_size'] = batch_size

    input_images = inputs['images']
    input_images = tf.reshape(input_images, [batch_size, 40, 50, 50, 1])

    # Accepts list of length 40. [40, batch size, 50, 50, 1]
    input_list = []
    for i in range(40):
        slice_val = tf.squeeze(tf.slice(input_images, [0, i, 0, 0, 0],
                                        [-1, 1, -1, -1, -1]),
                               axis=1)
        input_list.append(slice_val)
    input_list.append(slice_val)
    input_list.append(slice_val)
    input_list.append(slice_val)

    with tf.variable_scope("retina_tnn_fc_lstm"):
        base_name += '.json'
        print('Using base: ', base_name)
        # creates the feedforward network graph from json
        G = main.graph_from_json(base_name)

        for node, attr in G.nodes(data=True):
            if node in ['conv1', 'conv2']:
                if train:  # we add dropout to fc6 and fc7 during training
                    print('Using dropout for ' + node)
                    attr['kwargs']['post_memory'][1][1]['keep_prob'] = 0.75
                else:  # turn off dropout during training
                    print('Not using dropout')
                    attr['kwargs']['post_memory'][1][1]['keep_prob'] = 1.0


#            if node in ['fc3']:
#                if train: # we add dropout to fc3 during training
#                    print('Using dropout for ' + node)
#                    attr['kwargs']['post_memory'][1][1]['keep_prob'] = 0.5
#               else: # turn off dropout during training
#                    print('Not using dropout')
#                    attr['kwargs']['post_memory'][1][1]['keep_prob'] = 1.0

            memory_func, memory_param = attr['kwargs']['memory']
            if 'filter_size' in memory_param:
                assert (0)  # Should not be here in this specific arch
                attr['cell'] = tnn_ConvLSTMCell
            elif 'nunits' in memory_param:
                attr['cell'] = tnn_DenseRNNCell
            else:
                attr['kwargs']['memory'][1]['memory_decay'] = tau
                attr['kwargs']['memory'][1]['trainable'] = trainable_flag

        # add any non feedforward connections here: [('L2', 'L1')]
        G.add_edges_from(edges_arr)

        # initialize network to infer the shapes of all the parameters
        main.init_nodes(G,
                        input_nodes=['conv1'],
                        batch_size=batch_size,
                        channel_op=channel_op)
        # unroll the network through time
        main.unroll(G, input_seq={'conv1': input_list}, ntimes=ntimes)

        outputs = {}
        # start from the final output of the model and num timesteps beyond that
        # for t in range(ntimes-NUM_TIMESTEPS, ntimes):
        #     idx = t - (ntimes - NUM_TIMESTEPS) # keys start at timepoint 0
        #    outputs[idx] = G.node['fc8']['outputs'][t]

        # alternatively, we return the final output of the model at the last timestep
        outputs['pred'] = G.node['fc3']['outputs'][-1]
        return outputs, params
def tnn_base_edges(inputs,
                   train=True,
                   basenet_layers=['conv' + str(l) for l in range(1, 11)],
                   alter_layers=None,
                   unroll_tf=False,
                   const_pres=False,
                   out_layers='imnetds',
                   base_name='model_jsons/10Lv9_imnet128_res23_rrgctx',
                   times=range(18),
                   image_on=0,
                   image_off=11,
                   delay=10,
                   random_off=None,
                   dropout=dropout10L,
                   edges_arr=[],
                   convrnn_type='recipcell',
                   mem_val=0.0,
                   train_tau_fg=False,
                   apply_bn=False,
                   channel_op='concat',
                   seed=0,
                   min_duration=11,
                   use_legacy_cell=False,
                   layer_params={},
                   p_edge=1.0,
                   decoder_start=18,
                   decoder_end=26,
                   decoder_type='last',
                   ff_weight_decay=0.0,
                   ff_kernel_initializer_kwargs={},
                   final_max_pool=True,
                   tpu_name=None,
                   gcp_project=None,
                   tpu_zone=None,
                   num_shards=None,
                   iterations_per_loop=None,
                   **kwargs):

    mo_params = {}
    print("using multicell model!")
    # set ds dropout
    # dropout[out_layers] = ds_dropout

    # times may be a list or array, where t = 10t-10(t+1)ms.
    # if times is a list, it must be a subset of range(26).
    # input reaches convT at time t (no activations at t=0)
    image_off = int(image_off)
    decoder_start = int(decoder_start)
    decoder_end = int(decoder_end)
    if isinstance(times, (int, float)):
        ntimes = times
        times = range(ntimes)
    else:
        ntimes = times[-1] + 1

    if random_off is not None and train == True:
        print("max duration", random_off - image_on)
        print("min duration", min_duration)
        image_times = np.random.choice(
            range(min_duration, random_off - image_on + 1))
        image_off = image_on + image_times
        print("image times", image_times)
        times = range(image_on + delay, image_off + delay)
        readout_time = times[-1]
        print("model times", times)
        print("readout_time", readout_time)
    else:
        image_times = image_off - image_on

    # set up image presentation, note that inputs is a tensor now, not a dictionary
    ims = tf.identity(inputs, name='split')
    batch_size = ims.get_shape().as_list()[0]
    print('IM SHAPE', ims.shape)

    if const_pres:
        print('Using constant image presentation')
        pres = ims
    else:
        print('Making movie')
        blank = tf.constant(value=0.5,
                            shape=ims.get_shape().as_list(),
                            name='split')
        pres = ([blank] * image_on) + ([ims] *
                                       image_times) + ([blank] *
                                                       (ntimes - image_off))

    # graph building stage
    with tf.compat.v1.variable_scope('tnn_model'):
        if '.json' not in base_name:
            base_name += '.json'
        print('Using base: ', base_name)
        G = tnn_main.graph_from_json(base_name)
        print("graph build from JSON")

        # memory_cell_params = cell_params.copy()
        # print("CELL PARAMS:", cell_params)

        # dealing with dropout between training and validation
        for node, attr in G.nodes(data=True):
            if apply_bn:
                if 'conv' in node:
                    print('Applying batch norm to ', node)
                    # set train flag of batch norm for conv layers
                    attr['kwargs']['pre_memory'][0][1]['batch_norm'] = True
                    attr['kwargs']['pre_memory'][0][1]['is_training'] = train

            this_layer_params = layer_params[node]
            # set ksize, out depth, and training flag for batch_norm
            for func, kwargs in attr['kwargs']['pre_memory'] + attr['kwargs'][
                    'post_memory']:

                if func.__name__ in ['component_conv', 'conv']:
                    ksize_val = this_layer_params.get('ksize')
                    if ksize_val is not None:
                        kwargs['ksize'] = ksize_val
                    print("using ksize {} for {}".format(
                        kwargs['ksize'], node))
                    out_depth_val = this_layer_params.get('out_depth')
                    if out_depth_val is not None:
                        kwargs['out_depth'] = out_depth_val
                    print("using out depth {} for {}".format(
                        kwargs['out_depth'], node))
                    if ff_weight_decay is not None:  # otherwise uses json
                        kwargs['weight_decay'] = ff_weight_decay
                    if kwargs['kernel_init'] == "variance_scaling":
                        if ff_kernel_initializer_kwargs is not None:  # otherwise uses json
                            kwargs[
                                'kernel_init_kwargs'] = ff_kernel_initializer_kwargs

            # # optional max pooling at end of conv10
            if node == 'conv10':
                if final_max_pool:
                    attr['kwargs']['post_memory'][-1] = (tf.nn.max_pool2d, {
                        'ksize': [1, 2, 2, 1],
                        'strides': [1, 2, 2, 1],
                        'padding':
                        'SAME'
                    })
                    print("using a final max pool")
                else:
                    attr['kwargs']['post_memory'][-1] = (tf.identity, {})
                    print("not using a final max pool")

            # set memory params, including cell config
            memory_func, memory_param = attr['kwargs']['memory']

            if any(s in memory_param
                   for s in ('gate_filter_size', 'tau_filter_size')):
                if convrnn_type == 'recipcell':
                    print('using reciprocal gated cell for ', node)
                    if use_legacy_cell:
                        print(
                            'Using legacy cell to preserve scoping to load checkpoint'
                        )
                        attr['cell'] = legacy_tnn_ReciprocalGateCell
                    else:
                        attr['cell'] = tnn_ReciprocalGateCell

                    recip_cell_params = this_layer_params['cell_params'].copy()
                    assert recip_cell_params is not None
                    for k, v in recip_cell_params.items():
                        attr['kwargs']['memory'][1][k] = v

            else:
                if alter_layers is None:
                    alter_layers = basenet_layers
                if node in alter_layers:
                    attr['kwargs']['memory'][1]['memory_decay'] = mem_val
                    attr['kwargs']['memory'][1]['trainable'] = train_tau_fg
                if node in basenet_layers:
                    print(node, attr['kwargs']['memory'][1])

        # add non feedforward edges
        if len(edges_arr) > 0:
            edges = []
            for edge, p in edges_arr:
                if p <= p_edge:
                    edges.append(edge)
            print("applying edges,", edges)
            G.add_edges_from(edges)

        # initialize graph structure
        tnn_main.init_nodes(G,
                            input_nodes=['conv1'],
                            batch_size=batch_size,
                            channel_op=channel_op)

        # unroll graph
        if unroll_tf:
            print('Unroll tf way')
            tnn_main.unroll_tf(G, input_seq={'conv1': pres}, ntimes=ntimes)
        else:
            print('Unrolling tnn way')
            tnn_main.unroll(G, input_seq={'conv1': pres}, ntimes=ntimes)

        # collect last timestep output
        logits_list = [
            G.node['imnetds']['outputs'][t]
            for t in range(decoder_start, decoder_end)
        ]

        print("decoder_type", decoder_type, "from", decoder_start, "to",
              decoder_end)
        if decoder_type == 'last':
            logits = logits_list[-1]
        elif decoder_type == 'sum':
            logits = tf.add_n(logits_list)
        elif decoder_type == 'avg':
            logits = tf.add_n(logits_list) / len(logits_list)
        elif decoder_type == 'random':
            if train:
                logits = np.random.choice(logits_list)
            elif not train:  #eval -- use last timepoint with image on
                t_eval = image_off + delay - 1
                t_eval = t_eval - decoder_start
                logits = logits_list[t_eval]
        else:
            raise ValueError

        logits = tf.squeeze(logits)
        print("logits shape", logits.shape)

    outputs = {}
    outputs['logits'] = logits
    outputs['times'] = {}
    for t in times:
        outputs['times'][t] = tf.squeeze(G.node[out_layers]['outputs'][t])
    return outputs, mo_params