def run_tf_image_resnet(depth, filter_scale=1.0): global is_pytest_run model_string = '_d{}_fs{}_'.format(depth, filter_scale) test_outputs_dir = 'catamount/frameworks/example_graphs/tensorflow/full_models/image_classification' graph_meta = None for root, dirs, files in os.walk(test_outputs_dir): for filename in files: if 'graph{}'.format( model_string) in filename and '.meta' in filename: # Take the first graph that we find in the directory graph_meta = os.path.join(root, filename) break if graph_meta is not None: break if graph_meta is None: raise FileNotFoundError( 'Unable to find model string {} in directory {}'.format( model_string, test_outputs_dir)) graph = catamount.frameworks.tensorflow.import_graph(graph_meta) assert graph.isValid() # ============ TO REMOVE INITIALIZATION OPS! ============= # NOTE: This code is pretty general and is likely to be migrated into # Catamount code for removing TF-specific initialization ops from catamount.ops import AssignOp from catamount.ops import VariableOp assign_ops = set() for op in graph.opsByName.values(): if isinstance(op, AssignOp): assign_ops.add(op) for assign_op in assign_ops: my_ancestors = set() my_frontier = set() my_frontier.add(assign_op) while len(my_frontier) > 0: next_op = my_frontier.pop() for in_tensor in next_op.inputs: if not isinstance(in_tensor.producer, VariableOp): my_frontier.add(in_tensor.producer) my_ancestors.add(next_op) for next_op in my_ancestors: graph.removeOp(next_op) assert graph.isValid() # Manually remove the inference parts of graph graph_ops = list(graph._ops_by_name.values()) for op in graph_ops: # Certain ops are only used for inference if 'InferenceTower/' in op.name or \ 'InferenceRunner/' in op.name or \ op.name == 'MergeAllSummariesRunWithOp/Merge/MergeSummary': graph.removeOp(op) assert graph.isValid() print('Initial graph:\n{}\n'.format(graph)) init_params = graph.calcModelParameters() print('Initial parameters: {}'.format(init_params)) print('Initial Flops: {}\n'.format(graph.calcAlgFlops())) print('Placeholders:') for op in graph.getPlaceholders(): print(op.debugString()) print('') # Set up symbols to name dimensions output_classes_symbol = utils.getPositiveIntSymbolFromString('out_classes') subbatch_size_symbol = utils.getPositiveIntSymbolFromString( 'subbatch_size') image_height_symbol = utils.getPositiveIntSymbolFromString('image_height') image_width_symbol = utils.getPositiveIntSymbolFromString('image_width') num_in_channels_symbol = utils.getPositiveIntSymbolFromString( 'num_in_channels') graph_iters_symbol = utils.getIntSymbolFromString('graph::iters') feature_channels_symbol = utils.getPositiveIntSymbolFromString( 'feature_channels') # Find and replace convolution/pooling dimensions also: # Dimension(64 * 2^k): conv/pool feature channels base_output_classes = 1000 base_num_in_channels = 3 base_feature_channels = 64 base_image_height = 224 base_image_width = 224 base_half_im_height = 112 half_im_height_symbol = image_height_symbol // 2 base_half_im_width = 112 half_im_width_symbol = image_width_symbol // 2 base_quart_im_height = 56 quart_im_height_symbol = (image_height_symbol // 2) // 2 base_quart_im_width = 56 quart_im_width_symbol = (image_width_symbol // 2) // 2 base_eighth_im_height = 28 eighth_im_height_symbol = ((image_height_symbol // 2) // 2) // 2 base_eighth_im_width = 28 eighth_im_width_symbol = ((image_width_symbol // 2) // 2) // 2 base_sixtnth_im_height = 14 sixtnth_im_height_symbol = (((image_height_symbol // 2) // 2) // 2) // 2 base_sixtnth_im_width = 14 sixtnth_im_width_symbol = (((image_width_symbol // 2) // 2) // 2) // 2 base_small_im_height = 7 small_im_height_symbol = ((( (image_height_symbol // 2) // 2) // 2) // 2) // 2 base_small_im_width = 7 small_im_width_symbol = ((((image_width_symbol // 2) // 2) // 2) // 2) // 2 # Set up a dictionary of placeholders and variables for which we want # to make dimensions symbolic. Sift out their dimensions bind_dict = { # Placeholders 'label': [subbatch_size_symbol], 'input': [subbatch_size_symbol, image_height_symbol, image_width_symbol, num_in_channels_symbol], } # Parameterize all variable tensor dimensions for op in graph._ops_by_name.values(): if isinstance(op, VariableOp): op_name_suffix = op.name.split('/')[-1] if op_name_suffix == 'W': if op._outputs[0].shape.rank == 4: assert 'conv' in op.name new_shape = [] for i in range(op._outputs[0].shape.rank): new_shape.append( op._outputs[0].shape.getDimension(i).value) if new_shape[2] % base_feature_channels == 0: in_filters = (new_shape[2] // \ base_feature_channels) * \ feature_channels_symbol elif new_shape[2] == 3: # This is the first convolution on image channels (3) assert op.name == 'conv0/W' in_filters = num_in_channels_symbol else: print('FIX ME: base in filters {}'.format( new_shape[2])) assert 0 if new_shape[3] % base_feature_channels == 0: out_filters = (new_shape[3] // \ base_feature_channels) * \ feature_channels_symbol else: print('FIX ME: base out filters {}'.format( new_shape[3])) assert 0 new_shape[2] = in_filters new_shape[3] = out_filters else: # This is the output layer with output_classes dimension assert op.name == 'linear/W' assert op._outputs[0].shape.rank == 2 in_dim = op._outputs[0].shape.getDimension(0).value assert in_dim % base_feature_channels == 0 in_dim = (in_dim // base_feature_channels) * \ feature_channels_symbol new_shape = [in_dim, output_classes_symbol] bind_dict[op.name] = new_shape momentum_op_name = '{}/Momentum'.format(op.name) momentum_op = graph._ops_by_name[momentum_op_name] bind_dict[momentum_op.name] = new_shape elif op_name_suffix == 'b': # This is the output layer with output_classes dimension assert op.name == 'linear/b' assert op._outputs[0].shape.rank == 1 assert op._outputs[0].shape.getDimension(0).value == \ base_output_classes new_shape = [output_classes_symbol] bind_dict[op.name] = new_shape momentum_op_name = '{}/Momentum'.format(op.name) momentum_op = graph._ops_by_name[momentum_op_name] bind_dict[momentum_op.name] = new_shape elif op_name_suffix == 'beta' or op_name_suffix == 'gamma' or \ op_name_suffix == 'EMA': assert op._outputs[0].shape.rank == 1 in_dim = op._outputs[0].shape.getDimension(0).value assert in_dim % base_feature_channels == 0 in_dim = (in_dim // base_feature_channels) * \ feature_channels_symbol new_shape = [in_dim] bind_dict[op.name] = new_shape if op_name_suffix != 'EMA': momentum_op_name = '{}/Momentum'.format(op.name) momentum_op = graph._ops_by_name[momentum_op_name] bind_dict[momentum_op.name] = new_shape # Now handle constant values in the graph const_dict = {} for op in graph._ops_by_name.values(): if isinstance(op, ConstantOp): if op._outputs[0].value is None: continue if op._outputs[0].shape.rank == 0: print('{}'.format(op.debugString())) continue assert op._outputs[0].shape.rank == 1 values = op._outputs[0].value.tolist() new_values = [] changed = False for value in values: if value > 0 and value % base_feature_channels == 0: value = (value // base_feature_channels) * feature_channels_symbol changed = True new_values.append(value) # HACKY SPECIAL CASE: if op.name == 'tower0/gradients/tower0/conv0/Conv2D_grad/Const': assert new_values[2] == base_num_in_channels new_values[2] = num_in_channels_symbol if changed: const_dict[op.name] = new_values for const_key, const_val in const_dict.items(): try: if const_key in graph.opsByName.keys(): const_op = graph.opsByName[const_key] assert isinstance(const_op, ConstantOp) const_op._outputs[0].setValue(const_val) else: print('WARN: ConstantOp not found: {}'.format(const_key)) except Exception as exc: print('WARN: ConstantOp unknown problem: {}: {}'.format( const_key, exc)) # TODO (Joel): Add InputQueue ops to avoid manually setting dimensions in_deque_op = graph.opsByName['QueueInput/input_deque'] # print(in_deque_op.debugString()) out_tensor = in_deque_op._outputs[0] for idx, sym in enumerate([ subbatch_size_symbol, image_height_symbol, image_width_symbol, num_in_channels_symbol ]): out_tensor.shape.setDimension(idx, sym) out_tensor.shape.dims[idx]._value = None out_tensor = in_deque_op._outputs[1] out_tensor.shape.setDimension(0, subbatch_size_symbol) graph.bindTensorShapeDimensions(bind_dict, warn_if_ill_defined=(not is_pytest_run), make_symbolic=True) # Nice little hack to actually propagate MaximumOp values to outputs for op in graph._ops_by_name.values(): if isinstance(op, MaximumOp): if op._inputs[0].value is not None and \ op._inputs[1].value is not None: vmax = np.vectorize(lambda x, y: sympy.Max(x, y)) out_val = vmax(op._inputs[0].value, op._inputs[1].value) op._outputs[0].setValue(out_val) graph.bindTensorShapeDimensions(bind_dict, warn_if_ill_defined=(not is_pytest_run), make_symbolic=True) print('Bound values') print(graph) bind_subs = { graph_iters_symbol: 1, output_classes_symbol: base_output_classes, subbatch_size_symbol: 32, image_height_symbol: base_image_height, image_width_symbol: base_image_width, num_in_channels_symbol: base_num_in_channels, feature_channels_symbol: base_feature_channels, } correct_params = -1 correct_flops = -1 correct_bytes = -1 correct_footprint = -1 if depth == 18: correct_params = 11689514 correct_flops = 349684163360 correct_bytes = 7186222676 correct_footprint = 2802084304 elif depth == 34: correct_params = 21797674 correct_flops = 705506994208 correct_bytes = 11162578644 correct_footprint = 4368689744 elif depth == 50: correct_params = 25557034 correct_flops = 790954958112 correct_bytes = 32896462028 correct_footprint = 12909734408 elif depth == 101: correct_params = 44549162 correct_flops = 1506507229472 correct_bytes = 50026672916 correct_footprint = 19690293072 elif depth == 152: correct_params = 60192810 correct_flops = 2222688328992 correct_bytes = 70967716188 correct_footprint = 27971880088 else: print('WARN: Tests not defined for depth {}'.format(depth)) # Calculate parameters # NOTE: Need to remove Momentum optimizer parameters and moving average values momentum_params = 0 parameters = 0 for op_name in sorted(graph.opsByName.keys()): op = graph.opsByName[op_name] if isinstance(op, VariableOp): if "Momentum" in op.name or "EMA" in op.name: momentum_params += op.calcModelParameters() else: parameters += op.calcModelParameters() all_weights = graph.calcModelParameters() assert (all_weights - momentum_params - parameters) == 0 # Calculate model parameter count resolved_params = parameters.subs(bind_subs) try: resolved_params = int(resolved_params) except: print('ERROR: resolved_params should be int, but is {} = {}'.format( type(resolved_params), resolved_params)) assert correct_params < 0 or resolved_params == correct_params, \ 'Incorrect model params: {}'.format(resolved_params) print('Parameters: {}\nWith specified dims: {}\n'.format( parameters, resolved_params)) # Calculate algorithmic Flops alg_flops = graph.calcAlgFlops() resolved_flops = alg_flops.subs(bind_subs) try: resolved_flops = int(resolved_flops) except: print('ERROR: resolved_flops should be int, but is {} = {}'.format( type(resolved_flops), resolved_flops)) assert correct_flops < 0 or resolved_flops == correct_flops, \ 'Incorrect algorithmic flops: {}'.format(resolved_flops) print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format( alg_flops, resolved_flops)) # Calculate algorthmic Bytes accessed alg_bytes = graph.calcAlgBytes() resolved_bytes = alg_bytes.subs(bind_subs) try: resolved_bytes = int(resolved_bytes) except: print('ERROR: resolved_bytes should be int, but is {} = {}'.format( type(resolved_bytes), resolved_bytes)) assert correct_bytes < 0 or resolved_bytes == correct_bytes, \ 'Incorrect algorithmic bytes: {}'.format(resolved_bytes) print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format( alg_bytes, resolved_bytes)) # Calculate total memory footprint alg_footprint = graph.calcAlgFootprint() resolved_footprint = alg_footprint.subs(bind_subs) try: resolved_footprint = int(resolved_footprint) except: print('ERROR: resolved_footprint should be int, but is {} = {}'.format( type(resolved_footprint), resolved_footprint)) assert correct_footprint < 0 or resolved_footprint == correct_footprint, \ 'Incorrect algorithmic footprint: {}'.format(resolved_footprint) print('Alg mem footprint: {}\nWith specified dims: {}\n'.format( alg_footprint, resolved_footprint)) # Calculate algorithmic IO per step total_io_footprint = 0 for op in graph.getPlaceholders(): total_io_footprint += op.calcAlgFootprint() resolved_io_footprint = total_io_footprint.subs(bind_subs) print('Alg IO footprint: {}\nWith specified dims: {}\n'.format( total_io_footprint, resolved_io_footprint)) try: # In case the footprint code is not complete # Calculate minimal memory footprint print('Alg min mem footprint {}'.format( graph.calcMinFootprint(symbol_subs=bind_subs))) except: pass if not is_pytest_run: print('VERBOSE ALGORTHMIC FLOPS:') graph.calcAlgFlops(verbose=True) print('') print('VERBOSE ALGORTHMIC BYTES:') graph.calcAlgBytes(verbose=True) print('') print('VERBOSE ALGORTHMIC FOOTPRINT:') graph.calcAlgFootprint(verbose=True) print('') # HACKY WAY TO SAVE MODELS FOR NOW! pickle.dump( graph, open( 'catamount/frameworks/example_graphs/tensorflow/full_models/image_classification/graph_image_resnet_d{}_fs{}.p' .format(depth, filter_scale), 'wb')) if is_pytest_run: return print('\n\n======= Algorithmic graph-level analytics: =======') feature_channel_dims = [32, 48, 64, 96, 128] bind_subs.pop(feature_channels_symbol) resolved_params = parameters.subs(bind_subs) print('Symbol associations: {}\n'.format(bind_subs)) print( 'Algorithmic Flops by feature channels, params, and per-batch-sample:') resolved_flops = alg_flops.subs(bind_subs) for features_dim in feature_channel_dims: graph_params = resolved_params.subs( {feature_channels_symbol: features_dim}) graph_flops = resolved_flops.subs( {feature_channels_symbol: features_dim}) graph_flops_per_sample = float(graph_flops) / \ bind_subs[subbatch_size_symbol] print('{}\t{}\t{}\t{}'.format(features_dim, graph_params, graph_flops, int(graph_flops_per_sample))) print('\nAlgorithmic bytes accessed by feature channels, params:') resolved_bytes = alg_bytes.subs(bind_subs) for features_dim in feature_channel_dims: graph_params = resolved_params.subs( {feature_channels_symbol: features_dim}) graph_bytes = resolved_bytes.subs( {feature_channels_symbol: features_dim}) print('{}\t{}\t{}'.format(features_dim, graph_params, graph_bytes)) print('\nAlgorithmic total memory footprint by feature channels, params:') resolved_footprint = alg_footprint.subs(bind_subs) for features_dim in feature_channel_dims: graph_params = resolved_params.subs( {feature_channels_symbol: features_dim}) graph_footprint = resolved_footprint.subs( {feature_channels_symbol: features_dim}) print('{}\t{}\t{}'.format(features_dim, graph_params, graph_footprint)) print( '\nAlgorithmic minimal memory footprint by feature channels, params:') full_subs = dict(bind_subs) for features_dim in feature_channel_dims: graph_params = resolved_params.subs( {feature_channels_symbol: features_dim}) full_subs[feature_channels_symbol] = features_dim graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs) print('{}\t{}\t{}'.format(features_dim, graph_params, graph_min_foot))
def run_tf_speech_attention(): global is_pytest_run graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/speech_attention/model.ckpt.meta' graph = catamount.frameworks.tensorflow.import_graph(graph_meta) assert graph.isValid() # HAX: NEED TO MANUALLY REMOVE SOME?! WHY? remove_ops = [ 'DevArgmaxWERChecker/Less', 'DevLossChecker/Less', 'DevArgmaxWERChecker/best_dev', 'DevLossChecker/best_dev' ] for op_name in remove_ops: op = graph.opsByName[op_name] graph.removeOp(op) assert graph.isValid() # Remove ops that are not executed during a standard training step: graph_ops = list(graph._ops_by_name.values()) for op in graph_ops: # Ops in attn_model_[1-3] are used for inference if 'attn_model_1' in op.name or \ 'attn_model_2' in op.name or \ 'attn_model_3' in op.name: graph.removeOp(op) assert graph.isValid() print('Initial graph:\n{}\n'.format(graph)) init_params = graph.calcModelParameters() print('Initial parameters: {}'.format(init_params)) print('Initial Flops: {}\n'.format(graph.calcAlgFlops())) print('Placeholders:') for op in graph.getPlaceholders(): print(op.debugString()) print('') # Set up symbols to name dimensions audio_features_symbol = utils.getPositiveIntSymbolFromString( 'audio_features') encoder_steps_symbol = utils.getPositiveIntSymbolFromString( 'encoder_steps') decoder_steps_symbol = utils.getPositiveIntSymbolFromString( 'decoder_steps') subbatch_size_symbol = utils.getPositiveIntSymbolFromString( 'subbatch_size') attn_dim_symbol = utils.getPositiveIntSymbolFromString('attn_dim') attn_hidden_dim_symbol = utils.getPositiveIntSymbolFromString( 'attn_hidden_dim') dec_hidden_dim_symbol = utils.getPositiveIntSymbolFromString( 'dec_hidden_dim') enc_hidden_dim_symbol = utils.getPositiveIntSymbolFromString( 'enc_hidden_dim') graph_iters_symbol = utils.getIntSymbolFromString('graph::iters') output_vocab_symbol = utils.getPositiveIntSymbolFromString('output_vocab') conv_width_symbol = utils.getPositiveIntSymbolFromString('conv_width') num_conv_filters_symbol = utils.getPositiveIntSymbolFromString( 'num_conv_filters') # Convert these constant dimensions to symbols base_encoder_steps = 300 base_decoder_steps = 300 base_subbatch_size = 32 base_output_vocab = 31 base_audio_features = 40 base_conv_width = 53 base_attn_dim = 137 base_attn_hidden_dim = 509 base_dec_hidden_dim = 571 base_enc_hidden_dim = 1051 base_enc_input_dim = 1091 # Input + recurrent state enc_input_dim_symbol = audio_features_symbol + enc_hidden_dim_symbol base_dec_attn_rec = 2133 dec_attn_rec_symbol = 2 * enc_hidden_dim_symbol + output_vocab_symbol base_attn_cell_inputs = 2611 attn_cell_inputs_symbol = 2 * enc_hidden_dim_symbol + attn_hidden_dim_symbol base_attn_cell_in_dim = 2642 attn_cell_in_dim_symbol = 2 * enc_hidden_dim_symbol + output_vocab_symbol + \ attn_hidden_dim_symbol base_dec_attn_dim = 3182 dec_attn_dim_symbol = attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + \ dec_hidden_dim_symbol bind_dict = { # Placeholders 'attn_model/input_seq': [encoder_steps_symbol, subbatch_size_symbol, audio_features_symbol], 'attn_model/input_len': [subbatch_size_symbol], 'attn_model/output_seq': [decoder_steps_symbol, subbatch_size_symbol], 'attn_model/output_mask': [decoder_steps_symbol, subbatch_size_symbol], # Variables 'InputNormalizer/means': [audio_features_symbol], 'InputNormalizer/std': [audio_features_symbol], 'attn_model/AffineAttentionStateNN/W': [2 * enc_hidden_dim_symbol, attn_dim_symbol], 'attn_model/AffineAttentionStateNN/b': [attn_dim_symbol], 'attn_model/AffineOutputProjection/W': [dec_hidden_dim_symbol, output_vocab_symbol], 'attn_model/AffineOutputProjection/b': [output_vocab_symbol], 'attn_model/Decoder/attn_model/attention_cell/biases': [4 * attn_hidden_dim_symbol], 'attn_model/Decoder/attn_model/attention_cell/weights': [attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + output_vocab_symbol, 4 * attn_hidden_dim_symbol], 'attn_model/Decoder/attn_model/decoder_cell/biases': [4 * dec_hidden_dim_symbol], 'attn_model/Decoder/attn_model/decoder_cell/weights': [attn_hidden_dim_symbol + dec_hidden_dim_symbol + 2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol], 'attn_model/HybridAttentionContext/Q': [conv_width_symbol, 1, num_conv_filters_symbol], 'attn_model/HybridAttentionContext/U': [1, num_conv_filters_symbol, attn_dim_symbol], 'attn_model/HybridAttentionContext/W': [2 * attn_hidden_dim_symbol, attn_dim_symbol], 'attn_model/HybridAttentionContext/b': [attn_dim_symbol], 'attn_model/HybridAttentionContext/w': [attn_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], # Constants 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/MatMul/Enter_grad/b_acc': [dec_hidden_dim_symbol, output_vocab_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/add/Enter_grad/b_acc': [output_vocab_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/MatMul/Enter_grad/b_acc': [2 * attn_hidden_dim_symbol, attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/add_2/Enter_grad/b_acc': [attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/attention_cell/BiasAdd/Enter_grad/b_acc': [4 * attn_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/attention_cell/attention_cell/add/Enter_grad/b_acc': [attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + output_vocab_symbol, 4 * attn_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/ExpandDims_1/Enter_grad/b_acc': [conv_width_symbol, 1, 4], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/ExpandDims_1/Enter_grad/b_acc': [1, 4, attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/decoder_cell/BiasAdd/Enter_grad/b_acc': [4 * dec_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/decoder_cell/decoder_cell/add/Enter_grad/b_acc': [attn_hidden_dim_symbol + dec_hidden_dim_symbol + 2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/mul/Enter_grad/b_acc': [attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], } # Update constant values const_dict = { 'attn_model/AffineAttentionStateNN/Reshape/shape': [-1, 2 * enc_hidden_dim_symbol], 'attn_model/AffineAttentionStateNN/Reshape_1/shape/2': attn_dim_symbol, 'attn_model/AttentionEncoderDecoder/Reshape/shape/1': output_vocab_symbol, 'attn_model/AttentionModel/gradients/attn_model/AffineAttentionStateNN/add_grad/Shape_1': [attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/add_grad/Shape_1': [output_vocab_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/add_2_grad/Shape_1': [attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/Conv2D_grad/Const': [1, conv_width_symbol, 1, num_conv_filters_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/ExpandDims_1_grad/Shape': [conv_width_symbol, 1, num_conv_filters_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/Conv2D_grad/Const': [1, 1, num_conv_filters_symbol, attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/ExpandDims_1_grad/Shape': [1, num_conv_filters_symbol, attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/mul_grad/Shape_1': [attn_dim_symbol], 'attn_model/Decoder/CustomLSTMCellZeroState/Const': [2 * attn_hidden_dim_symbol], 'attn_model/Decoder/CustomLSTMCellZeroState/Const_1': [2 * attn_hidden_dim_symbol], 'attn_model/Decoder/CustomLSTMCellZeroState_1/Const': [ 2 * dec_hidden_dim_symbol ], 'attn_model/Decoder/CustomLSTMCellZeroState_1/Const_1': [ 2 * dec_hidden_dim_symbol ], 'attn_model/Decoder/while/attn_model/attention_cell/attention_cell/Shape': [ attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + output_vocab_symbol, 4 * attn_hidden_dim_symbol ], 'attn_model/Decoder/while/attn_model/decoder_cell/decoder_cell/Shape': [ attn_hidden_dim_symbol + dec_hidden_dim_symbol + 2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol ], 'attn_model/Decoder/while/attn_model/one_hot/depth': output_vocab_symbol, 'attn_model/Decoder/zeros/shape/1': 2 * enc_hidden_dim_symbol, 'attn_model/Decoder/zeros_2/shape/1': output_vocab_symbol, 'attn_model/Reshape/shape': [1, 1, audio_features_symbol], 'attn_model/Reshape_1/shape': [1, 1, audio_features_symbol], 'attn_model/Reshape_2/shape/2': 2 * enc_hidden_dim_symbol, 'attn_model/StackedEncoder/Layer0/RNNEncoder/Reshape/shape/2': audio_features_symbol, 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/Const_4': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/Const_4': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/Reshape/shape/2': 2 * enc_hidden_dim_symbol, 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/Const_4': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/Const_4': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/Reshape/shape/2': 2 * enc_hidden_dim_symbol, 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/Const_4': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/Const_4': [enc_hidden_dim_symbol], } graph.bindConstantValues(const_dict) # TODO: Currently, Catamount doesn't automatically handle Tensorflow TensorArrays # or Stack ops. Here, manually set the dimensions of these ops' tensors. for op in graph._ops_by_name.values(): op_name_suffix = op.name.split('/')[-1] if 'TensorArrayGather' in op_name_suffix: assert isinstance(op, UnknownOp) assert len(op._inputs) == 3 assert len(op._outputs) == 1 if op._outputs[0].shape.rank == 1 or op._outputs[0].shape.rank == 2: if len(op._outputs[0].consumers) > 0: print( 'TODO: Unknown TensorArrayGather (rank {}): {}'.format( op._outputs[0].shape.rank, op.debugString())) elif op._outputs[0].shape.isUnknown( ) or op._outputs[0].shape.rank == 3: if len(op._outputs[0].consumers) > 0: # If output rank is 3, then appears to be: # [seq_length, batch_size, enc_hid], where # seq_length depends on layer out_shape = None if 'StackedEncoder/Layer0' in op.name: out_shape = [ encoder_steps_symbol, subbatch_size_symbol, enc_hidden_dim_symbol ] elif 'StackedEncoder/Layer2' in op.name: if 'attn_model/AttentionModel/gradients' in op.name: # Backprop stores concatenated state out_shape = [ encoder_steps_symbol // 2, subbatch_size_symbol, 2 * enc_hidden_dim_symbol ] else: out_shape = [ encoder_steps_symbol // 2, subbatch_size_symbol, enc_hidden_dim_symbol ] elif 'StackedEncoder/Layer4' in op.name: if 'attn_model/AttentionModel/gradients' in op.name: # Backprop stores concatenated state out_shape = [(encoder_steps_symbol // 2) // 2, subbatch_size_symbol, 2 * enc_hidden_dim_symbol] else: out_shape = [(encoder_steps_symbol // 2) // 2, subbatch_size_symbol, enc_hidden_dim_symbol] elif 'Decoder' in op.name: # HAXXXX: Manually specify a few if op.name == 'attn_model/Decoder/TensorArrayStack/TensorArrayGatherV3': out_shape = [ decoder_steps_symbol, subbatch_size_symbol, output_vocab_symbol ] else: out_shape = [ decoder_steps_symbol, subbatch_size_symbol, dec_hidden_dim_symbol ] else: print('TODO: Unknown TensorArrayGather {}'.format( op.debugString())) if out_shape is not None: op._outputs[0].mergeShape(out_shape, make_symbolic=True) else: print('TODO: Unknown TensorArrayGather {}'.format( op.debugString())) elif 'TensorArraySize' in op_name_suffix: assert isinstance(op, UnknownOp) assert len(op._inputs) == 2 assert len(op._outputs) == 1 assert op._outputs[0].shape.rank == 0 # NOTES: # StackedEncoder Layer0: enc_seq # StackedEncoder Layer2: enc_seq / 2 # Due to stride 2 in time # StackedEncoder Layer4: enc_seq / 4 # Due to stride 2 in time # Decoder: dec_seq if 'StackedEncoder/Layer0' in op.name: op._outputs[0].setValue(encoder_steps_symbol) elif 'StackedEncoder/Layer2' in op.name: op._outputs[0].setValue(encoder_steps_symbol // 2) elif 'StackedEncoder/Layer4' in op.name: op._outputs[0].setValue((encoder_steps_symbol // 2) // 2) elif 'Decoder' in op.name: op._outputs[0].setValue(decoder_steps_symbol) else: print('WARN: Unknown TensorArraySizeV3: {}'.format( op.debugString())) elif 'TensorArrayRead' in op_name_suffix: assert isinstance(op, UnknownOp) assert len(op._inputs) == 3 assert len(op._outputs) == 1 assert op._outputs[0].shape.isUnknown() or \ op._outputs[0].shape.rank == 2, \ '{}'.format(op.name) if op._outputs[0].shape.isUnknown(): if len(op._outputs[0].consumers) > 0: out_shape = None if 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer' in op.name and \ ('/RNNEncoder/bidirectional_rnn/fw/fw/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' in op.name or \ '/RNNEncoder/bidirectional_rnn/bw/bw/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' in op.name): out_shape = [ subbatch_size_symbol, enc_hidden_dim_symbol ] elif op.name == 'attn_model/AttentionModel/gradients/attn_model/Decoder/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' or \ op.name == 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/TensorArrayWrite_1/TensorArrayWriteV3_grad/TensorArrayReadV3' or \ op.name == 'attn_model_2/Decoder/while/cond/TensorArrayReadV3' or \ op.name == 'attn_model/Decoder/while/cond/TensorArrayReadV3': out_shape = [subbatch_size_symbol, output_vocab_symbol] else: print('WARN: Unknown TensorArrayReadV3 out shape: {}'. format(op.debugString())) if out_shape is not None: op._outputs[0].mergeShape(out_shape, make_symbolic=True) else: # NOTES: Many are (?, 40 "features"), (?, 1051 "enc_hid"), or (?, 2102 "2*enc_hid") dim_1_val = op._outputs[0].shape.getDimension(1).value assert dim_1_val == base_audio_features or \ dim_1_val == base_enc_hidden_dim or \ dim_1_val == 2 * base_enc_hidden_dim, \ 'Op: {}\n Dim 1 value: {}'.format(op.debugString(), dim_1_val) out_shape = None if dim_1_val == base_audio_features: out_shape = [subbatch_size_symbol, audio_features_symbol] elif dim_1_val > 0 and dim_1_val % base_enc_hidden_dim == 0: mult = dim_1_val // base_enc_hidden_dim out_shape = [ subbatch_size_symbol, mult * enc_hidden_dim_symbol ] else: print('Unhandled TensorArrayRead: {}'.format( op.debugString())) if out_shape is not None: op._outputs[0].mergeShape(out_shape, make_symbolic=True) # Manually set a couple shapes for max ops that can't yet resolve # maximums of 1 vs. positive symbols: max_op = graph._ops_by_name[ 'attn_model/AttentionModel/gradients/attn_model/AttentionEncoderDecoder/Sum_grad/Maximum'] max_op._outputs[0].mergeShape([2]) max_op._outputs[0].setValue([1, subbatch_size_symbol]) max_op = graph._ops_by_name[ 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/Sum_grad/Maximum'] max_op._outputs[0].mergeShape([3]) # [floor(floor(encoder_steps/2)/2) subbatch_size 1] max_op._outputs[0].setValue([(encoder_steps_symbol // 2) // 2, subbatch_size_symbol, 1]) max_op = graph._ops_by_name[ 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/Sum_1_grad/Maximum'] max_op._outputs[0].mergeShape([3]) # [1 subbatch_size 2*enc_hidden_dim] max_op._outputs[0].setValue( [1, subbatch_size_symbol, 2 * enc_hidden_dim_symbol]) print('Binding variables') graph.bindShapesAndPropagate(bind_dict, warn_if_ill_defined=(not is_pytest_run), make_symbolic=True) assert graph.isValid() print('\n\nCleaned Graph:\n{}'.format(graph)) print('\n\nBound values') # Set base values to be subbed in: base_encoder_steps = 96 base_decoder_steps = 24 base_attn_dim = 128 base_conv_width = 50 base_attn_hidden_dim = 512 base_dec_hidden_dim = 512 base_enc_hidden_dim = 1024 bind_subs = { audio_features_symbol: base_audio_features, encoder_steps_symbol: base_encoder_steps, decoder_steps_symbol: (encoder_steps_symbol // 2) // 2, subbatch_size_symbol: base_subbatch_size, attn_dim_symbol: base_attn_dim, attn_hidden_dim_symbol: enc_hidden_dim_symbol // 2, dec_hidden_dim_symbol: enc_hidden_dim_symbol // 2, output_vocab_symbol: base_output_vocab, conv_width_symbol: base_conv_width, enc_hidden_dim_symbol: base_enc_hidden_dim, num_conv_filters_symbol: 4, graph_iters_symbol: 1, } # Add loop iteration counts to bind_subs bind_str_subs = { 'attn_model/AttentionModel/gradients/b_count_2_block::iters': decoder_steps_symbol, 'attn_model/Decoder/while/LoopCond_block::iters': decoder_steps_symbol, 'attn_model/AttentionModel/gradients/b_count_22_block::iters': encoder_steps_symbol, 'attn_model/AttentionModel/gradients/b_count_26_block::iters': encoder_steps_symbol, 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters': encoder_steps_symbol, 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters': encoder_steps_symbol, 'attn_model/AttentionModel/gradients/b_count_14_block::iters': encoder_steps_symbol // 2, 'attn_model/AttentionModel/gradients/b_count_18_block::iters': encoder_steps_symbol // 2, 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters': encoder_steps_symbol // 2, 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters': encoder_steps_symbol // 2, 'attn_model/AttentionModel/gradients/b_count_6_block::iters': (encoder_steps_symbol // 2) // 2, 'attn_model/AttentionModel/gradients/b_count_10_block::iters': (encoder_steps_symbol // 2) // 2, 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters': (encoder_steps_symbol // 2) // 2, 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters': (encoder_steps_symbol // 2) // 2, } for var_name, sub_val in bind_str_subs.items(): var_ref = utils.getIntSymbolFromString(var_name) assert var_name not in bind_subs.keys() bind_subs[var_ref] = sub_val # Calculate model parameter count parameters = graph.calcModelParameters() resolved_params = parameters.subs(bind_subs) try: resolved_params = int(resolved_params) except: print('ERROR: resolved_params should be int, but is {} = {}'.format( type(resolved_params), resolved_params)) correct_params = 71084729 assert resolved_params == correct_params, \ 'Incorrect model params: {}'.format(resolved_params) print('Parameters: {}\nWith specified dims: {}\n'.format( parameters, resolved_params)) # Calculate algorithmic Flops alg_flops = graph.calcAlgFlops() resolved_flops = alg_flops.subs(bind_subs) try: resolved_flops = int(resolved_flops) except: print('ERROR: resolved_flops should be int, but is {} = {}'.format( type(resolved_flops), resolved_flops)) correct_flops = 568878183032 assert resolved_flops == correct_flops, \ 'Incorrect algorithmic flops: {}'.format(resolved_flops) print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format( alg_flops, resolved_flops)) # Calculate algorthmic Bytes accessed alg_bytes = graph.calcAlgBytes() resolved_bytes = alg_bytes.subs(bind_subs) try: resolved_bytes = int(resolved_bytes) except: print('ERROR: resolved_bytes should be int, but is {} = {}'.format( type(resolved_bytes), resolved_bytes)) correct_bytes = 92231419797 assert resolved_bytes == correct_bytes, \ 'Incorrect algorithmic bytes: {}'.format(resolved_bytes) print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format( alg_bytes, resolved_bytes)) # Calculate algorthmic Bytes accessed alg_footprint = graph.calcAlgFootprint() resolved_footprint = alg_footprint.subs(bind_subs) try: resolved_footprint = int(resolved_footprint) except: print('ERROR: resolved_footprint should be int, but is {} = {}'.format( type(resolved_footprint), resolved_footprint)) correct_footprint = 32624988214 assert resolved_footprint == correct_footprint, \ 'Incorrect algorithmic footprint: {}'.format(resolved_footprint) print('Alg mem footprint: {}\nWith specified dims: {}\n'.format( alg_footprint, resolved_footprint)) # Calculate algorithmic IO per step total_io_footprint = 0 for op in graph.getPlaceholders(): total_io_footprint += op.calcAlgFootprint() resolved_io_footprint = total_io_footprint.subs(bind_subs) print('Alg IO footprint: {}\nWith specified dims: {}\n'.format( total_io_footprint, resolved_io_footprint)) print('VERBOSE ALGORTHMIC FLOPS:') graph.calcAlgFlops(verbose=True) print('') print('VERBOSE ALGORTHMIC BYTES:') graph.calcAlgBytes(verbose=True) print('') print('VERBOSE ALGORTHMIC FOOTPRINT:') graph.calcAlgFootprint(verbose=True) print('') # HACKY WAY TO SAVE MODELS FOR NOW! pickle.dump( graph, open( 'catamount/frameworks/example_graphs/tensorflow/full_models/speech_attention/graph_speech_attention.p', 'wb')) if is_pytest_run: return print('\n\n======= Algorithmic graph-level analytics: =======') encoder_dims = [ 32, 64, 96, 128, 160, 192, 256, 320, 384, 448, 512, 640, 768, 892, 1024, 1152, 1280, 1408, 1548, 1702, 1872, 2059, 2264, 2490, 2739, 3012, 3289 ] base_encoder_steps = 335 base_subbatch_size = 32 base_attn_dim = 128 base_conv_width = 50 base_attn_hidden_dim = 512 base_dec_hidden_dim = 512 base_enc_hidden_dim = 1024 bind_subs[audio_features_symbol] = base_audio_features bind_subs[encoder_steps_symbol] = base_encoder_steps bind_subs[decoder_steps_symbol] = (encoder_steps_symbol // 2) // 2 bind_subs[subbatch_size_symbol] = base_subbatch_size bind_subs[attn_dim_symbol] = base_attn_dim bind_subs[attn_hidden_dim_symbol] = enc_hidden_dim_symbol // 2 bind_subs[dec_hidden_dim_symbol] = enc_hidden_dim_symbol // 2 bind_subs[output_vocab_symbol] = base_output_vocab bind_subs[conv_width_symbol] = base_conv_width # bind_subs[enc_hidden_dim_symbol] = base_enc_hidden_dim bind_subs[num_conv_filters_symbol] = 4 bind_subs[graph_iters_symbol] = 1 bind_subs.pop(enc_hidden_dim_symbol) resolved_params = parameters.subs(bind_subs) print('Symbol associations: {}\n'.format(bind_subs)) print( 'Algorithmic Flops by hidden dimension, params, and per-batch-sample:') resolved_flops = alg_flops.subs(bind_subs) for enc_dim in encoder_dims: graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim}) graph_flops = resolved_flops.subs({enc_hidden_dim_symbol: enc_dim}) graph_flops_per_sample = float(graph_flops) / \ bind_subs[subbatch_size_symbol] print('{}\t{}\t{}\t{}'.format(enc_dim, graph_params, graph_flops, int(graph_flops_per_sample))) print('\nAlgorithmic bytes accessed by hidden dimension, params:') resolved_bytes = alg_bytes.subs(bind_subs) for enc_dim in encoder_dims: graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim}) graph_bytes = resolved_bytes.subs({enc_hidden_dim_symbol: enc_dim}) print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_bytes)) print('\nAlgorithmic memory footprint by hidden dimension, params:') resolved_footprint = alg_footprint.subs(bind_subs) for enc_dim in encoder_dims: graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim}) graph_footprint = resolved_footprint.subs( {enc_hidden_dim_symbol: enc_dim}) print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_footprint)) print( '\nAlgorithmic minimal memory footprint by hidden dimension, params:') full_subs = dict(bind_subs) for enc_dim in encoder_dims: graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim}) full_subs[enc_hidden_dim_symbol] = enc_dim graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs) print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_min_foot))
def run_tf_w2v_model(): global is_pytest_run graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/word2vec_n200-latest_model.meta' graph = catamount.frameworks.tensorflow.import_graph(graph_meta) assert graph.isValid() # Next, remove ops that are not executed during a standard training step: graph_ops = list(graph._ops_by_name.values()) for op in graph_ops: # Certain ops are only used for inference if 'Model/NceLoss_1_3/' in op.name or \ 'Model/Collapse_1/' in op.name or \ 'Model/Embedding_1_3/' in op.name or \ 'Model/Labels_1/' in op.name or \ 'Model/SkipGramSampler_1/' in op.name or \ 'Model/Mask_1/' in op.name: graph.removeOp(op) elif \ op.name == 'Model/Cast_1' or \ op.name == 'Model/Sum_1' or \ op.name == 'Model/Size_1' or \ op.name == 'Model/Exp_1' or \ op.name == 'Model/truediv_2' or \ op.name == 'Model/truediv_3': graph.removeOp(op) if not is_pytest_run: print('Initial graph:\n{}\n'.format(graph)) init_params = graph.calcModelParameters() print('Initial parameters: {}'.format(init_params)) print('Initial Flops: {}\n'.format(graph.calcAlgFlops())) print('Placeholders:') for op in graph.getPlaceholders(): print(op.debugString()) print('') # Set up symbols to name dimensions skip_window_symbol = utils.getPositiveIntSymbolFromString('skip_window') num_skips_symbol = utils.getPositiveIntSymbolFromString('num_skips') nce_samples_symbol = utils.getPositiveIntSymbolFromString('nce_samples') hidden_dim_symbol = utils.getIntSymbolFromString('hidden_dim') vocab_size_symbol = utils.getIntSymbolFromString('vocab_size') subbatch_size_symbol = utils.getIntSymbolFromString('subbatch_size') sequence_length_symbol = utils.getIntSymbolFromString('sequence_length') batch_times_seq_symbol = sequence_length_symbol * subbatch_size_symbol graph_iters_symbol = utils.getIntSymbolFromString('graph::iters') # For simplicity, assign samples symbol in the op nce_samp_op = graph.opsByName[ 'Model/NceLoss_1_1/nce_loss/LogUniformCandidateSampler'] nce_samp_op._num_samples_symbol = nce_samples_symbol # Convert these constant dimensions to symbols base_skip_window = 8 base_num_skips = 8 base_nce_samples = 64 base_hidden_dim = 400 base_vocab_size = 40004 base_sequence_length = 32 base_subbatch_size = 1 # Find and set constants that contain model hyperparameters const_dict = { 'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/sub_1_grad/Shape_1': [nce_samples_symbol], 'Model/SkipGramSampler/Const': 2 * skip_window_symbol, 'Model/SkipGramSampler/strided_slice/stack': [0, skip_window_symbol], 'Model/SkipGramSampler/strided_slice/stack_1': [0, -skip_window_symbol], 'Model/Collapse/Reshape/shape': [-1, hidden_dim_symbol], 'Model/Gradient/Compute/gradients/Model/Embedding_1_1/Gather_grad/Shape': [vocab_size_symbol, hidden_dim_symbol], 'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/embedding_lookup_1_grad/Shape': [vocab_size_symbol], 'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/embedding_lookup_grad/Shape': [vocab_size_symbol, hidden_dim_symbol], 'Model/Mask/NotEqual/y': vocab_size_symbol - 3, 'Model/SkipGramSampler/Const_2': num_skips_symbol, 'Model/SkipGramSampler/Tile_1/multiples': [1, num_skips_symbol], 'Model/SkipGramSampler/Tile/multiples': [1, num_skips_symbol], } graph.bindConstantValues(const_dict) # Next, bind the constant, placeholder, and variable shapes and propagate bind_dict = { # Constants # Placeholders 'Input/Input': [subbatch_size_symbol, sequence_length_symbol], 'Labels/Labels': [subbatch_size_symbol, sequence_length_symbol], # Variables 'Model/NceLoss_1/b_Softmax': [vocab_size_symbol], 'Model/NceLoss_1/W_Softmax': [vocab_size_symbol, hidden_dim_symbol], 'Model/Embedding_1/EmbeddingWeights': [vocab_size_symbol, hidden_dim_symbol], } print('Binding variables') # HACK: For now, manually set GatherNd op shapes. Later, implement GatherNd gnd_op = graph.opsByName['Model/SkipGramSampler/GatherNd'] gnd_op.outputs[0].mergeShape([ subbatch_size_symbol, num_skips_symbol * (sequence_length_symbol - 2 * skip_window_symbol) ]) graph.bindShapesAndPropagate(bind_dict, warn_if_ill_defined=(not is_pytest_run), make_symbolic=True) assert graph.isValid() if not is_pytest_run: print('\n\nCleaned Graph:\n{}'.format(graph)) print('\n\nBound values') bind_subs = { graph_iters_symbol: 1, hidden_dim_symbol: base_hidden_dim, sequence_length_symbol: base_sequence_length, subbatch_size_symbol: base_subbatch_size, vocab_size_symbol: base_vocab_size, skip_window_symbol: base_skip_window, num_skips_symbol: base_num_skips, nce_samples_symbol: base_nce_samples, } # Verify parameter counts first parameters = graph.calcModelParameters() correct_params = 32043205 correct_flops = 21148823 correct_bytes = 23762537 correct_total_footprint = 137949925 print('Symbol associations: {}\n'.format(bind_subs)) # Calculate model parameter count resolved_params = parameters.subs(bind_subs) try: resolved_params = int(resolved_params) except: print('ERROR: resolved_params should be int, but is {} = {}'.format( type(resolved_params), resolved_params)) assert resolved_params == correct_params, \ 'Incorrect model params: {}'.format(resolved_params) print('Parameters: {}\nWith specified dims: {}\n'.format( parameters, resolved_params)) # Calculate algorithmic Flops alg_flops = graph.calcAlgFlops() resolved_flops = alg_flops.subs(bind_subs) try: resolved_flops = int(resolved_flops) except: print('ERROR: resolved_flops should be int, but is {} = {}'.format( type(resolved_flops), resolved_flops)) assert resolved_flops == correct_flops, \ 'Incorrect algorithmic flops: {}'.format(resolved_flops) print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format( alg_flops, resolved_flops)) # Calculate algorthmic Bytes accessed alg_bytes = graph.calcAlgBytes() resolved_bytes = alg_bytes.subs(bind_subs) try: resolved_bytes = int(resolved_bytes) except: print('ERROR: resolved_bytes should be int, but is {} = {}'.format( type(resolved_bytes), resolved_bytes)) assert resolved_bytes == correct_bytes, \ 'Incorrect algorithmic bytes: {}'.format(resolved_bytes) print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format( alg_bytes, resolved_bytes)) # Calculate total memory footprint alg_footprint = graph.calcAlgFootprint() resolved_footprint = alg_footprint.subs(bind_subs) try: resolved_footprint = int(resolved_footprint) except: print('ERROR: resolved_footprint should be int, but is {} = {}'.format( type(resolved_footprint), resolved_footprint)) assert resolved_footprint == correct_total_footprint, \ 'Incorrect algorithmic footprint: {}'.format(resolved_footprint) print('Alg mem footprint: {}\nWith specified dims: {}\n'.format( alg_footprint, resolved_footprint)) # Calculate minimal memory footprint alg_min_footprint = graph.calcMinimalFootprint(symbol_subs=bind_subs) print('Alg minimal footprint (With specified dims): {}\n'.format( alg_min_footprint)) # Calculate algorithmic IO per step total_io_footprint = 0 for op in graph.getPlaceholders(): total_io_footprint += op.calcAlgFootprint() if isinstance(total_io_footprint, int): resolved_io_footprint = total_io_footprint else: resolved_io_footprint = total_io_footprint.subs(bind_subs) print('Alg IO footprint: {}\nWith specified dims: {}\n'.format( total_io_footprint, resolved_io_footprint)) if not is_pytest_run: print('VERBOSE ALGORTHMIC FLOPS:') graph.calcAlgFlops(verbose=True) print('') print('VERBOSE ALGORTHMIC BYTES:') graph.calcAlgBytes(verbose=True) print('') print('VERBOSE ALGORTHMIC FOOTPRINT:') graph.calcAlgFootprint(verbose=True) print('') # HACKY WAY TO SAVE MODELS FOR NOW! pickle.dump( graph, open( 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/graph_word2vec.p', 'wb')) if is_pytest_run: return print('\n\n======= Algorithmic graph-level analytics: =======') hidden_dims = [ 1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 18, 20, 25, 28, 35, 40, 50, 56, 69, 78, 86, 96, 108, 119, 123, 133, 148, 163, 182, 202, 221, 246, 273, 297, 329, 330, 364, 396, 436, 437, 520, 572, 617, 676, 740, 796, 869, 948, 1017, 1106, 1202, 1286, 1394, 1510, 1611, 1742, 1882, 2004, 2161, 2476, 3040, 3714, 4520, 5478, 6628, 8019, 9702, 11739, 14204, 17186, 20795, 25161, 30444, 36837, 38100 ] bind_subs.pop(hidden_dim_symbol) resolved_params = parameters.subs(bind_subs) print('Symbol associations: {}\n'.format(bind_subs)) print( 'Algorithmic Flops by hidden dimension, params, and per-batch-sample:') resolved_flops = alg_flops.subs(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) graph_flops = resolved_flops.subs({hidden_dim_symbol: hid_dim}) graph_flops_per_sample = float(graph_flops) / \ bind_subs[subbatch_size_symbol] print('{}\t{}\t{}\t{}'.format(hid_dim, graph_params, graph_flops, int(graph_flops_per_sample))) print('\nAlgorithmic bytes accessed by hidden dimension, params:') resolved_bytes = alg_bytes.subs(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) graph_bytes = resolved_bytes.subs({hidden_dim_symbol: hid_dim}) print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_bytes)) print('\nAlgorithmic total memory footprint by hidden dimension, params:') resolved_footprint = alg_footprint.subs(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) graph_footprint = resolved_footprint.subs({hidden_dim_symbol: hid_dim}) print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_footprint)) print( '\nAlgorithmic minimal memory footprint by hidden dimension, params:') full_subs = dict(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) full_subs[hidden_dim_symbol] = hid_dim graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs) print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_min_foot))