def train_tnn_alexnet(): imagenet = setup.get_imagenet() images_plc = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 224, 224, 3]) labels_plc = tf.placeholder(tf.int64, shape=[BATCH_SIZE]) with tf.variable_scope('tconvnet'): G = main.graph_from_json('json/alexnet.json') main.init_nodes(G, batch_size=BATCH_SIZE) main.unroll(G, input_seq=images_plc) loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=G.node['fc8']['outputs'][-1], labels=labels_plc) loss = tf.reduce_mean(loss) optimizer = tf.train.MomentumOptimizer(learning_rate=.01, momentum=.9).minimize(loss) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) losses = [] for step in range(1000): start = time.time() images_batch, labels_batch = imagenet.next() lo, _ = sess.run([loss, optimizer], feed_dict={ images_plc: images_batch, labels_plc: labels_batch }) end = time.time() losses.append(lo) print(step, '{:.4f}'.format(lo), '{:.3f}'.format(end - start)) assert np.mean(losses[-20:]) < 6.8
def test_alexnet(): ims = np.random.standard_normal([BATCH_SIZE, 224, 224, 3]) labels = np.random.randint(1000, size=[BATCH_SIZE]) data = { 'images': tf.constant(ims.astype(np.float32)), 'labels': tf.constant(labels.astype(np.int32)) } # initialize the benchmark model with tf.variable_scope('benchmark'): bench_targets = setup.alexnet(data['images'], data['labels'], 'benchmark', train=False) bench_targets = {'loss': bench_targets['loss']} # initialize the tconvnet model with tf.variable_scope('tconvnet'): G = main.graph_from_json('json/alexnet.json') main.init_nodes(G, batch_size=BATCH_SIZE) main.unroll(G, input_seq=data['images']) tnn_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=G.node['fc8']['outputs'][-1], labels=data['labels']) tnn_targets = {'loss': tf.reduce_mean(tnn_loss)} run(bench_targets, tnn_targets, nsteps=10, n_initial=10)
def memory_usage(): ims = np.random.standard_normal([BATCH_SIZE, 224, 224, 3]) labels = np.random.randint(1000, size=[BATCH_SIZE]) data = { 'images': tf.constant(ims.astype(np.float32)), 'labels': tf.constant(labels.astype(np.int32)) } # initialize the benchmark model # with tf.variable_scope('benchmark'): # bench_targets = setup.alexnet(data['images'], data['labels'], 'benchmark', train=False) # loss = bench_targets['loss'] with tf.variable_scope('tconvnet'): G = main.graph_from_json('json/alexnet.json') main.init_nodes(G, batch_size=BATCH_SIZE) main.unroll(G, input_seq=data['images']) loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=G.node['fc8']['outputs'][-1], labels=data['labels']) init = tf.global_variables_initializer() sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) sess.run(init) sess.run(loss) import pdb pdb.set_trace()
def test_feedback(): images = tf.constant(np.random.standard_normal([BATCH_SIZE, 224, 224, 3]).astype(np.float32)) # initialize the tconvnet model with tf.variable_scope('tconvnet'): json_path = os.path.join(json_dir, 'alexnet.json') G = main.graph_from_json(json_path) G.add_edges_from([('conv5', 'conv3'), ('conv5', 'conv4'), ('conv4', 'conv3')]) main.init_nodes(G, input_nodes=['conv1'], batch_size=BATCH_SIZE) main.unroll(G, input_seq={'conv1': images}) test_state_and_output_sizes(G) graph = tf.get_default_graph() # harbor output sizes harbor = graph.get_tensor_by_name('tconvnet/conv1/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 224, 224, 3] harbor = graph.get_tensor_by_name('tconvnet/conv2/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 27, 27, 96] harbor = graph.get_tensor_by_name('tconvnet/conv3/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 256+384+256] harbor = graph.get_tensor_by_name('tconvnet/conv4/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 384+256] harbor = graph.get_tensor_by_name('tconvnet/conv5/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 384] harbor = graph.get_tensor_by_name('tconvnet/fc6/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 7, 7, 256] harbor = graph.get_tensor_by_name('tconvnet/fc7/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 4096] harbor = graph.get_tensor_by_name('tconvnet/fc8/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 4096] # check if harbor outputs at t are equal to the concat of outputs # from incoming nodes at t-1 # layer 4 gets inputs from 5 and 3 conv4h = graph.get_tensor_by_name('tconvnet/conv4_5/harbor:0') conv3o = G.node['conv3']['outputs'][4] conv5o = G.node['conv5']['outputs'][4] conv5om = tf.image.resize_images(conv5o, conv4h.shape.as_list()[1:3]) concat = tf.concat([conv3o, conv5om], axis=3) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) conv4hr, concatr = sess.run([conv4h, concat]) assert np.array_equal(conv4hr, concatr) # layer 3 gets inputs from 2, 4, 5 conv3h = graph.get_tensor_by_name('tconvnet/conv3_7/harbor:0') conv2o = G.node['conv2']['outputs'][6] conv5o = G.node['conv5']['outputs'][6] conv5om = tf.image.resize_images(conv5o, conv3h.shape.as_list()[1:3]) conv4o = G.node['conv4']['outputs'][6] conv4om = tf.image.resize_images(conv4o, conv3h.shape.as_list()[1:3]) concat = tf.concat([conv2o, conv4om, conv5om], axis=3) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) conv3hr, concatr = sess.run([conv3h, concat]) assert np.array_equal(conv3hr, concatr)
def test_feedback2(): images = tf.constant(np.random.standard_normal([BATCH_SIZE, 224, 224, 3]).astype(np.float32)) # initialize the tconvnet model with tf.variable_scope('tconvnet'): json_path = os.path.join(json_dir, 'alexnet.json') G = main.graph_from_json(json_path) G.add_edges_from([('fc7', 'conv5')]) main.init_nodes(G, input_nodes=['conv1'], batch_size=BATCH_SIZE) main.unroll(G, input_seq={'conv1': images}) test_state_and_output_sizes(G) graph = tf.get_default_graph() harbor = graph.get_tensor_by_name('tconvnet/conv1/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 224, 224, 3] harbor = graph.get_tensor_by_name('tconvnet/conv2/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 27, 27, 96] harbor = graph.get_tensor_by_name('tconvnet/conv3/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 256] harbor = graph.get_tensor_by_name('tconvnet/conv4/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 384] harbor = graph.get_tensor_by_name('tconvnet/conv5_1/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 14, 14, 384 + int(math.ceil(4096 / (14 * 14)))] harbor = graph.get_tensor_by_name('tconvnet/fc6/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 7, 7, 256] harbor = graph.get_tensor_by_name('tconvnet/fc7/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 4096] harbor = graph.get_tensor_by_name('tconvnet/fc8/harbor:0') assert harbor.shape.as_list() == [BATCH_SIZE, 4096]
def model_func(input_images, train=True, ntimes=TOTAL_TIMESTEPS, batch_size=batch_size, edges_arr=[], base_name='alexnet_tnn', tau=0.0, trainable_flag=False, channel_op='concat'): with tf.variable_scope("alexnet_tnn"): base_name += '.json' print('Using base: ', base_name) # creates the feedforward network graph from json G = main.graph_from_json(base_name) for node, attr in G.nodes(data=True): if node in ['fc6', 'fc7']: if train: # we add dropout to fc6 and fc7 during training print('Using dropout for ' + node) attr['kwargs']['post_memory'][1][1]['keep_prob'] = 0.5 else: # turn off dropout during training print('Not using dropout') attr['kwargs']['post_memory'][1][1]['keep_prob'] = 1.0 memory_func, memory_param = attr['kwargs']['memory'] if 'filter_size' in memory_param: attr['cell'] = tnn_ConvLSTMCell else: attr['kwargs']['memory'][1]['memory_decay'] = tau attr['kwargs']['memory'][1]['trainable'] = trainable_flag # add any non feedforward connections here: [('L2', 'L1')] G.add_edges_from(edges_arr) # initialize network to infer the shapes of all the parameters main.init_nodes(G, input_nodes=['conv1'], batch_size=batch_size, channel_op=channel_op) # unroll the network through time main.unroll(G, input_seq={'conv1': input_images}, ntimes=ntimes) outputs = {} # start from the final output of the model and num timesteps beyond that # for t in range(ntimes-NUM_TIMESTEPS, ntimes): # idx = t - (ntimes - NUM_TIMESTEPS) # keys start at timepoint 0 # outputs[idx] = G.node['fc8']['outputs'][t] # alternatively, we return the final output of the model at the last timestep outputs[0] = G.node['fc8']['outputs'][-1] return outputs
def test_mnist(kind='conv'): data = { 'images': np.random.standard_normal([BATCH_SIZE, 28 * 28]).astype(np.float32), 'labels': np.random.randint(10, size=BATCH_SIZE).astype(np.int32) } if kind == 'conv': data['images'] = np.reshape(data['images'], [-1, 28, 28, 1]) # initialize the benchmark model with tf.variable_scope('benchmark'): if kind == 'conv': bench_targets = setup.mnist_conv(**data) elif kind == 'fc': bench_targets = setup.mnist_fc(**data) else: raise ValueError bench_vars = { v.name[len('benchmark') + 1:]: v for v in tf.global_variables() if v.name.startswith('benchmark') } bench_targets.update(bench_vars) for name, var in bench_vars.items(): bench_targets['grad_' + name] = tf.gradients(bench_targets['loss'], var) # initialize the tconvnet model with tf.variable_scope('tconvnet'): G = main.graph_from_json('json/mnist_{}.json'.format(kind)) main.init_nodes(G, batch_size=BATCH_SIZE) input_seq = tf.constant(data['images']) main.unroll(G, input_seq=input_seq) tnn_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=G.node['fc2']['outputs'][-1], labels=tf.constant(data['labels'])) tnn_targets = {n: G.node[n]['outputs'][-1] for n in G} tnn_targets['loss'] = tf.reduce_mean(tnn_loss) tnn_vars = { v.name[len('tconvnet') + 1:]: v for v in tf.global_variables() if v.name.startswith('tconvnet') and 'memory_decay' not in v.name } tnn_targets.update(tnn_vars) for name, var in tnn_vars.items(): tnn_targets['grad_' + name] = tf.gradients(tnn_targets['loss'], var) run(bench_targets, tnn_targets, nsteps=100)
def model_func(input_images, ntimes=TOTAL_TIMESTEPS, batch_size=batch_size, edges_arr=[], base_name='../json/VanillaRNN', tau=0.0, trainable_flag=False): with tf.variable_scope("my_model"): # reshape the 784 dimension MNIST digits to be 28x28 images input_images = tf.reshape(input_images, [-1, 28, 28, 1]) base_name += '.json' print('Using base: ', base_name) # creates the feedforward network graph from json G = main.graph_from_json(base_name) for node, attr in G.nodes(data=True): memory_func, memory_params = attr['kwargs']['memory'] if any(p in memory_params for p in ['filter_size', 'gate_filter_size', 'tau_filter_size']): # this is where you add your custom cell attr['cell'] = CUSTOM_CELL else: # default to not having a memory cell # tau = 0.0, trainable = False attr['kwargs']['memory'][1]['memory_decay'] = tau attr['kwargs']['memory'][1]['trainable'] = trainable_flag # add any non feedforward connections here: e.g. [('L2', 'L1')] G.add_edges_from(edges_arr) # initialize network to infer the shapes of all the parameters main.init_nodes(G, input_nodes=[INPUT_LAYER], batch_size=batch_size) # unroll the network through time main.unroll(G, input_seq={INPUT_LAYER: input_images}, ntimes=ntimes) outputs = {} # start from the final output of the model and 4 timesteps beyond that for t in range(ntimes - NUM_TIMESTEPS, ntimes): idx = t - (ntimes - NUM_TIMESTEPS) # keys start at timepoint 0 outputs[idx] = G.node[READOUT_LAYER]['outputs'][t] return outputs
def test_memory(): images = tf.constant(np.random.standard_normal([BATCH_SIZE, 28, 28, 1]).astype(np.float32)) with tf.variable_scope('tconvnet'): json_path = os.path.join(json_dir, 'alexnet.json') G = main.graph_from_json(json_path) for node, attr in G.nodes(data=True): if node in ['conv1', 'conv2']: attr['kwargs']['memory'][1]['memory_decay'] = MEM main.init_nodes(G, input_nodes=['conv1'], batch_size=BATCH_SIZE) main.unroll(G, input_seq={'conv1': images}, ntimes=6) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) graph = tf.get_default_graph() conv1_state = np.zeros(G.node['conv1']['states'][0].get_shape().as_list()) conv2_state = np.zeros(G.node['conv2']['states'][0].get_shape().as_list()) state1, state2 = sess.run([G.node['conv1']['states'], G.node['conv2']['states']]) for i, (s1, s2) in enumerate(zip(state1, state2)): if i == 0: state1_inp = graph.get_tensor_by_name('tconvnet/conv1/conv:0') state2_inp = graph.get_tensor_by_name('tconvnet/conv2/conv:0') else: state1_inp = graph.get_tensor_by_name('tconvnet/conv1_{}/conv:0'.format(i)) state2_inp = graph.get_tensor_by_name('tconvnet/conv2_{}/conv:0'.format(i)) state1_inp, state2_inp = sess.run([state1_inp, state2_inp]) conv1_state = conv1_state * MEM + state1_inp assert np.allclose(s1, conv1_state) conv2_state = conv2_state * MEM + state2_inp assert np.allclose(s2, conv2_state) sess.close()
from __future__ import absolute_import, division, print_function import tensorflow as tf from tnn import main def loss(G): output_nodes = [n for n in G if len(G.successors(n)) == 0] for node in output_nodes: attr = G.node[node] assert len(G.predecessors(node)) == 2 for pred in sorted(G.predecessors(node)): if pred == 'labels': labels = input_dict['labels'] else: # must be logits then logits = G.node[pred]['outputs'][t] with tf.variable_scope(attr['name']): loss = attr['function'](logits=logits, labels=labels) attr['outputs'].append(loss) G = main.graph_from_json('../json/alexnet.json') main.init_nodes(G, input_nodes=['conv1'], batch_size=256) input_images = tf.zeros([256, 224, 224, 3]) main.unroll(G, input_seq={'conv1': input_images}, ntimes=9)
def catenet_tnn(inputs, cfg_path, train=True, tnndecay=0.1, decaytrain=0, cfg_initial=None, cmu=0, fixweights=False, seed=0, **kwargs): m = model.ConvNet(fixweights=fixweights, seed=seed, **kwargs) params = {'input': inputs.name, 'type': 'fc'} dropout = 0.5 if train else None # Get inputs shape_list = inputs.get_shape().as_list() assert shape_list[2] == 35, 'Must set expand==1' sep_num = shape_list[1] if not cfg_initial is None and 'sep_num' in cfg_initial: sep_num = cfg_initial['sep_num'] small_inputs = tf.split(inputs, sep_num, 1) for indx_time in xrange(len(small_inputs)): small_inputs[indx_time] = tf.transpose(small_inputs[indx_time], [0, 2, 1, 3]) small_inputs[indx_time] = tf.reshape(small_inputs[indx_time], [shape_list[0], 5, 7, -1]) G = main.graph_from_json(cfg_path) if 'all_conn' in cfg_initial: node_list = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'conv6'] MASTER_EDGES = [] for i in range(len(node_list)): for j in range(len(node_list)): if (j > i + 1 or i > j) and ( j > 0): #since (i, j) w/ j > i is already an edge MASTER_EDGES.append((node_list[i], node_list[j])) print(MASTER_EDGES) G.add_edges_from(MASTER_EDGES) for node, attr in G.nodes(data=True): memory_func, memory_param = attr['kwargs']['memory'] if 'nunits' in memory_param: attr['cell'] = tnn_LSTMCell else: memory_param['memory_decay'] = tnndecay memory_param['trainable'] = decaytrain == 1 attr['kwargs']['memory'] = (memory_func, memory_param) if fixweights: if node.startswith('conv'): _, prememory_param = attr['kwargs']['pre_memory'][0] attr['kwargs']['pre_memory'][0] = (model.conv_fix, prememory_param) if node.startswith('fc'): _, prememory_param = attr['kwargs']['pre_memory'][0] attr['kwargs']['pre_memory'][0] = (model.fc_fix, prememory_param) if not seed == 0: for sub_prememory in attr['kwargs']['pre_memory']: prememory_func, prememory_param = sub_prememory if 'kernel_init_kwargs' in prememory_param: prememory_param['kernel_init_kwargs']['seed'] = seed if node in ['fc7', 'fc8']: attr['kwargs']['pre_memory'][0][1]['dropout'] = dropout main.init_nodes(G, batch_size=shape_list[0]) main.unroll(G, input_seq={'conv1': small_inputs}, ntimes=len(small_inputs)) if not 'retres' in cfg_initial: if cmu == 0: m.output = G.node['fc8']['outputs'][-1] else: m.output = tf.transpose(tf.stack(G.node['fc8']['outputs']), [1, 2, 0]) else: m.output = tf.concat( [G.node['fc8']['outputs'][x] for x in cfg_initial['retres']], 1) print(len(G.node['fc8']['outputs'])) m.params = params return m
def rnn_fc(inputs, train=True, prefix=MODEL_PREFIX, devices=DEVICES, num_gpus=NUM_GPUS, ntimes=TOTAL_TIMESTEPS, edges_arr=[], base_name='retina_tnn_fc_lstm', tau=0.0, trainable_flag=False, channel_op='concat', seed=0, cfg_final=None): params = OrderedDict() batch_size = inputs['images'].get_shape().as_list()[0] params['stim_type'] = stim_type params['train'] = train params['batch_size'] = batch_size input_images = inputs['images'] input_images = tf.reshape(input_images, [batch_size, 40, 50, 50, 1]) # Accepts list of length 40. [40, batch size, 50, 50, 1] input_list = [] for i in range(40): slice_val = tf.squeeze(tf.slice(input_images, [0, i, 0, 0, 0], [-1, 1, -1, -1, -1]), axis=1) input_list.append(slice_val) input_list.append(slice_val) input_list.append(slice_val) input_list.append(slice_val) with tf.variable_scope("retina_tnn_fc_lstm"): base_name += '.json' print('Using base: ', base_name) # creates the feedforward network graph from json G = main.graph_from_json(base_name) for node, attr in G.nodes(data=True): if node in ['conv1', 'conv2']: if train: # we add dropout to fc6 and fc7 during training print('Using dropout for ' + node) attr['kwargs']['post_memory'][1][1]['keep_prob'] = 0.75 else: # turn off dropout during training print('Not using dropout') attr['kwargs']['post_memory'][1][1]['keep_prob'] = 1.0 # if node in ['fc3']: # if train: # we add dropout to fc3 during training # print('Using dropout for ' + node) # attr['kwargs']['post_memory'][1][1]['keep_prob'] = 0.5 # else: # turn off dropout during training # print('Not using dropout') # attr['kwargs']['post_memory'][1][1]['keep_prob'] = 1.0 memory_func, memory_param = attr['kwargs']['memory'] if 'filter_size' in memory_param: assert (0) # Should not be here in this specific arch attr['cell'] = tnn_ConvLSTMCell elif 'nunits' in memory_param: attr['cell'] = tnn_DenseRNNCell else: attr['kwargs']['memory'][1]['memory_decay'] = tau attr['kwargs']['memory'][1]['trainable'] = trainable_flag # add any non feedforward connections here: [('L2', 'L1')] G.add_edges_from(edges_arr) # initialize network to infer the shapes of all the parameters main.init_nodes(G, input_nodes=['conv1'], batch_size=batch_size, channel_op=channel_op) # unroll the network through time main.unroll(G, input_seq={'conv1': input_list}, ntimes=ntimes) outputs = {} # start from the final output of the model and num timesteps beyond that # for t in range(ntimes-NUM_TIMESTEPS, ntimes): # idx = t - (ntimes - NUM_TIMESTEPS) # keys start at timepoint 0 # outputs[idx] = G.node['fc8']['outputs'][t] # alternatively, we return the final output of the model at the last timestep outputs['pred'] = G.node['fc3']['outputs'][-1] return outputs, params
def tnn_base_edges(inputs, train=True, basenet_layers=['conv' + str(l) for l in range(1, 11)], alter_layers=None, unroll_tf=False, const_pres=False, out_layers='imnetds', base_name='model_jsons/10Lv9_imnet128_res23_rrgctx', times=range(18), image_on=0, image_off=11, delay=10, random_off=None, dropout=dropout10L, edges_arr=[], convrnn_type='recipcell', mem_val=0.0, train_tau_fg=False, apply_bn=False, channel_op='concat', seed=0, min_duration=11, use_legacy_cell=False, layer_params={}, p_edge=1.0, decoder_start=18, decoder_end=26, decoder_type='last', ff_weight_decay=0.0, ff_kernel_initializer_kwargs={}, final_max_pool=True, tpu_name=None, gcp_project=None, tpu_zone=None, num_shards=None, iterations_per_loop=None, **kwargs): mo_params = {} print("using multicell model!") # set ds dropout # dropout[out_layers] = ds_dropout # times may be a list or array, where t = 10t-10(t+1)ms. # if times is a list, it must be a subset of range(26). # input reaches convT at time t (no activations at t=0) image_off = int(image_off) decoder_start = int(decoder_start) decoder_end = int(decoder_end) if isinstance(times, (int, float)): ntimes = times times = range(ntimes) else: ntimes = times[-1] + 1 if random_off is not None and train == True: print("max duration", random_off - image_on) print("min duration", min_duration) image_times = np.random.choice( range(min_duration, random_off - image_on + 1)) image_off = image_on + image_times print("image times", image_times) times = range(image_on + delay, image_off + delay) readout_time = times[-1] print("model times", times) print("readout_time", readout_time) else: image_times = image_off - image_on # set up image presentation, note that inputs is a tensor now, not a dictionary ims = tf.identity(inputs, name='split') batch_size = ims.get_shape().as_list()[0] print('IM SHAPE', ims.shape) if const_pres: print('Using constant image presentation') pres = ims else: print('Making movie') blank = tf.constant(value=0.5, shape=ims.get_shape().as_list(), name='split') pres = ([blank] * image_on) + ([ims] * image_times) + ([blank] * (ntimes - image_off)) # graph building stage with tf.compat.v1.variable_scope('tnn_model'): if '.json' not in base_name: base_name += '.json' print('Using base: ', base_name) G = tnn_main.graph_from_json(base_name) print("graph build from JSON") # memory_cell_params = cell_params.copy() # print("CELL PARAMS:", cell_params) # dealing with dropout between training and validation for node, attr in G.nodes(data=True): if apply_bn: if 'conv' in node: print('Applying batch norm to ', node) # set train flag of batch norm for conv layers attr['kwargs']['pre_memory'][0][1]['batch_norm'] = True attr['kwargs']['pre_memory'][0][1]['is_training'] = train this_layer_params = layer_params[node] # set ksize, out depth, and training flag for batch_norm for func, kwargs in attr['kwargs']['pre_memory'] + attr['kwargs'][ 'post_memory']: if func.__name__ in ['component_conv', 'conv']: ksize_val = this_layer_params.get('ksize') if ksize_val is not None: kwargs['ksize'] = ksize_val print("using ksize {} for {}".format( kwargs['ksize'], node)) out_depth_val = this_layer_params.get('out_depth') if out_depth_val is not None: kwargs['out_depth'] = out_depth_val print("using out depth {} for {}".format( kwargs['out_depth'], node)) if ff_weight_decay is not None: # otherwise uses json kwargs['weight_decay'] = ff_weight_decay if kwargs['kernel_init'] == "variance_scaling": if ff_kernel_initializer_kwargs is not None: # otherwise uses json kwargs[ 'kernel_init_kwargs'] = ff_kernel_initializer_kwargs # # optional max pooling at end of conv10 if node == 'conv10': if final_max_pool: attr['kwargs']['post_memory'][-1] = (tf.nn.max_pool2d, { 'ksize': [1, 2, 2, 1], 'strides': [1, 2, 2, 1], 'padding': 'SAME' }) print("using a final max pool") else: attr['kwargs']['post_memory'][-1] = (tf.identity, {}) print("not using a final max pool") # set memory params, including cell config memory_func, memory_param = attr['kwargs']['memory'] if any(s in memory_param for s in ('gate_filter_size', 'tau_filter_size')): if convrnn_type == 'recipcell': print('using reciprocal gated cell for ', node) if use_legacy_cell: print( 'Using legacy cell to preserve scoping to load checkpoint' ) attr['cell'] = legacy_tnn_ReciprocalGateCell else: attr['cell'] = tnn_ReciprocalGateCell recip_cell_params = this_layer_params['cell_params'].copy() assert recip_cell_params is not None for k, v in recip_cell_params.items(): attr['kwargs']['memory'][1][k] = v else: if alter_layers is None: alter_layers = basenet_layers if node in alter_layers: attr['kwargs']['memory'][1]['memory_decay'] = mem_val attr['kwargs']['memory'][1]['trainable'] = train_tau_fg if node in basenet_layers: print(node, attr['kwargs']['memory'][1]) # add non feedforward edges if len(edges_arr) > 0: edges = [] for edge, p in edges_arr: if p <= p_edge: edges.append(edge) print("applying edges,", edges) G.add_edges_from(edges) # initialize graph structure tnn_main.init_nodes(G, input_nodes=['conv1'], batch_size=batch_size, channel_op=channel_op) # unroll graph if unroll_tf: print('Unroll tf way') tnn_main.unroll_tf(G, input_seq={'conv1': pres}, ntimes=ntimes) else: print('Unrolling tnn way') tnn_main.unroll(G, input_seq={'conv1': pres}, ntimes=ntimes) # collect last timestep output logits_list = [ G.node['imnetds']['outputs'][t] for t in range(decoder_start, decoder_end) ] print("decoder_type", decoder_type, "from", decoder_start, "to", decoder_end) if decoder_type == 'last': logits = logits_list[-1] elif decoder_type == 'sum': logits = tf.add_n(logits_list) elif decoder_type == 'avg': logits = tf.add_n(logits_list) / len(logits_list) elif decoder_type == 'random': if train: logits = np.random.choice(logits_list) elif not train: #eval -- use last timepoint with image on t_eval = image_off + delay - 1 t_eval = t_eval - decoder_start logits = logits_list[t_eval] else: raise ValueError logits = tf.squeeze(logits) print("logits shape", logits.shape) outputs = {} outputs['logits'] = logits outputs['times'] = {} for t in times: outputs['times'][t] = tf.squeeze(G.node[out_layers]['outputs'][t]) return outputs, mo_params