def test_tf_simple_while_loop(): graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename) assert graph.isValid() algorithmic_flops = graph.calcAlgFlops() while_iters = utils.getIntSymbolFromString('while/LoopCond_block::iters') wba_dim_0 = utils.getIntSymbolFromString('while/body/add_1:0::dim_0') correct_alg_flops = while_iters * (wba_dim_0 + 2) print('Loaded Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Initial alg flops incorrect!\n Expecting: {}\n Calculated: {}' \ .format(correct_alg_flops, algorithmic_flops) # Now, bind tensor names in the graph and verify that the algorithmic # Flop counts reflect the new name bindings batch_size = utils.getIntSymbolFromString('batch_size') # NOTE: This also works: batch_size = 'batch_size' # Bind placeholders (a and b) output dimensions 0 to name batch_size bind_dict = {'a': [batch_size, 1]} graph.bindTensorShapeDimensions(bind_dict) algorithmic_flops = graph.calcAlgFlops() # Update the algorithmic Flops formula correct_alg_flops = correct_alg_flops.subs({wba_dim_0: batch_size}) print('Bound Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Bound alg flops incorrect!\n Expecting: {}\n Calculated: {}' \ .format(correct_alg_flops, algorithmic_flops)
def test_expanddims_op(): ''' Specify graphs with ExpandDimsOps and make sure dimensions behave as desired. ''' combos = [ ([3], 0), ([None], 0), ([None, None], 0), ([None, None], 1), ([None, None], 2), ([None, None], -1), ([None, None], -2), ([None, None], -3), ] for combo in combos: graph = Graph() with graph.asDefault(): ph_dims, expand_dim = combo if isinstance(ph_dims, list): ed_out_dims = list(ph_dims) insert_dim = expand_dim if insert_dim < 0: insert_dim += len(ed_out_dims) + 1 ed_out_dims.insert(insert_dim, 1) else: ed_out_dims = [ph_dims] print('Testing expand dims with in_dims {}, expand dim {} to {}'. format(ph_dims, expand_dim, ed_out_dims)) # Build model in_ph = placeholder('in', ph_dims) expanddims_out = expanddims('expanddims', ed_out_dims, in_ph, axis=expand_dim) assert graph.isValid() feed_dict = {} if isinstance(ph_dims, list): for idx in range(len(ph_dims)): if ph_dims[idx] is None: ph_dims[idx] = utils.getIntSymbolFromString( 'in::dim_{}'.format(idx)) else: if ph_dims is None: ph_dims = utils.getIntSymbolFromString('in::dim_0') feed_dict['in'] = ph_dims print(' Feed dict: {}'.format(feed_dict)) graph.bindTensorShapeDimensions(feed_dict) assert expanddims_out.shape == TensorShape(ed_out_dims) reset_symbols()
def test_tf_dynamic_rnn(): graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename) assert graph.isValid() algorithmic_flops = graph.calcAlgFlops() rwb_iters = utils.getIntSymbolFromString('rnn/while/LoopCond_block::iters') ba_0 = utils.getIntSymbolFromString( 'rnn/while/basic_rnn_cell/BiasAdd:0::dim_0') mm_0 = utils.getIntSymbolFromString( 'rnn/while/basic_rnn_cell/MatMul:0::dim_0') th_0 = utils.getIntSymbolFromString( 'rnn/while/basic_rnn_cell/Tanh:0::dim_0') correct_alg_flops = rwb_iters * \ (24 * ba_0 + 2304 * mm_0 + 144 * th_0 + 5) + \ 2305 print('Loaded Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Initial alg flops incorrect!\n Expecting: {}\n Calculated: {}' \ .format(correct_alg_flops, algorithmic_flops) # Now, bind tensor names in the graph and verify that the algorithmic # Flop counts reflect the new name bindings batch_size = utils.getIntSymbolFromString('batch_size') # NOTE: This also works: batch_size = 'batch_size' # Bind placeholders (a and b) output dimensions 0 to name batch_size bind_dict = { 'a': ['batch_size', 'seq_length', 'hidden_dim'], 'init_state': ['batch_size', 'hidden_dim'] } graph.bindTensorShapeDimensions(bind_dict) algorithmic_flops = graph.calcAlgFlops() # Update the algorithmic Flops formula correct_alg_flops = correct_alg_flops.subs({ ba_0: batch_size, mm_0: batch_size, th_0: batch_size }) print('Bound Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Bound alg flops incorrect!\n Expecting: {}\n Calculated: {}' \ .format(correct_alg_flops, algorithmic_flops)
def propagateShapes(self, make_symbolic=False): self.debugAssert(len(self._inputs) == 2) self.debugAssert(len(self._outputs) == 1) # Assume that there are multiple workers contributing to this # collective operation and their matrix sizes are the same as first # input tensor passed in here. Create a symbol to represent the number # of participating workers num_workers_str = '{}::num_workers'.format(self.name) num_workers_symbol = utils.getIntSymbolFromString(num_workers_str) # TODO (Joel): We could take another input tensor to specify the axis # on which to concatenate values. For now, axis = 0 axis = 0 final_shape = [] for idx in range(len(self._inputs[0].shape.dims)): dim = self._inputs[0].shape.getDimension(idx) if idx == axis: # Manipulate the dimension make the value None (it is # necessarily symbolic), and set the symbol to reflect # multiple workers dim_val = dim.value new_dim = Dimension(None) new_symbol = dim.symbol * num_workers_symbol new_dim.setSymbolOrName(new_symbol) final_shape.append(new_dim) else: final_shape.append(dim) self._outputs[0].mergeShape(final_shape, make_symbolic=make_symbolic)
def test_tf_load_and_calculate(): graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename) assert graph.isValid() algorithmic_flops = graph.calcAlgFlops() add_dim_0 = utils.getIntSymbolFromString('add:0::dim_0') matmul_dim_0 = utils.getIntSymbolFromString('matmul:0::dim_0') mul_dim_0 = utils.getIntSymbolFromString('mul:0::dim_0') correct_alg_flops = 256 * add_dim_0 + \ 65536 * matmul_dim_0 + \ 256 * mul_dim_0 + \ 98307 print('Loaded Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Initial alg flops incorrect!\n Expecting: {}\n Calculated: {}' \ .format(correct_alg_flops, algorithmic_flops) # Now, bind tensor names in the graph and verify that the algorithmic # Flop counts reflect the new name bindings batch_size = utils.getIntSymbolFromString('batch_size') # NOTE: This also works: batch_size = 'batch_size' # Bind placeholders (a and b) output dimensions 0 to name batch_size bind_dict = {'a': [batch_size, None], 'b': [batch_size, None]} graph.bindTensorShapeDimensions(bind_dict) algorithmic_flops = graph.calcAlgFlops() # Update the algorithmic Flops formula correct_alg_flops = correct_alg_flops.subs({ add_dim_0: batch_size, matmul_dim_0: batch_size, mul_dim_0: batch_size, }) print('Bound Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Bound alg flops incorrect!\n Expecting: {}\n Calculated: {}' \ .format(correct_alg_flops, algorithmic_flops)
def __init__(self, name): super(CandidateSamplerOp, self).__init__(name) # TODO (Joel): Read these from compute graph op attributes self.setNumTrue(1) self.setNumSampled(None) # TODO: Depending on the generator, there should be some small number # of Flops per sampled element. Using (incorrect) 1 for now... self._flops_per_element = 1 samps_name = '{}::rand_samps'.format(self.name) self._num_samples_symbol = \ num_samples = utils.getIntSymbolFromString(samps_name)
def add_symbols(name, out_shape): # print(' Adding symbols for {}: out_shape: {}'.format(name, out_shape)) global symbol_table global subs_table def add_symbol(symbol, dim): assert sym_name not in symbol_table.keys() symbol_table[sym_name] = symbol # print('Added symbol name {} with sym {}'.format(sym_name, symbol)) if isinstance(dim, Dimension): dim = dim.value if dim is not None: subs_table[symbol] = dim if isinstance(out_shape, list): for idx, dim in enumerate(out_shape): sym_name = '{}::dim_{}'.format(name, idx) add_symbol(utils.getIntSymbolFromString(sym_name), dim) else: sym_name = '{}::unk'.format(name) add_symbol(utils.getIntSymbolFromString(sym_name), out_shape)
def propagateShapes(self, make_symbolic=False): self.debugAssert(len(self._inputs) == 2) self.debugAssert(len(self._outputs) == 1) # Cannot propagate shapes if first input shape undefined if not self._inputs[0].shape.isFullySymbolic(): return self.debugAssert(self._inputs[0].shape.rank == 2) num_samples = self._inputs[1].value if num_samples == None: samps_name = '{}::rand_samps'.format(self.name) num_samples = utils.getIntSymbolFromString(samps_name) out_shape = [] out_shape.append(self._inputs[0].shape.getDimension(0)) out_shape.append(num_samples) self._outputs[0].shape.mergeShape(out_shape, make_symbolic=make_symbolic)
def calcAlgFlops(self): self.debugAssert(len(self._inputs) == 2) self.debugAssert(len(self._outputs) == 1) # Steps in multinomial sampling: # 1) Draw uniform random sample, "noises", of size # [batch_size, num_samples, num_classes] num_samples = self._inputs[1].value if num_samples == None: samps_name = '{}::rand_samps'.format(self.name) num_samples = utils.getIntSymbolFromString(samps_name) in_0_shape = self._inputs[0].shape full_shape_elts = in_0_shape.numElements() * num_samples total_flops = full_shape_elts # 2) Calculate scores = logits - log(-log(noises)) with broadcasting total_flops += 3 * full_shape_elts # 3) Minimum reduction along classes dimension total_flops += full_shape_elts return total_flops
def propagateShapes(self, make_symbolic=False): self.debugAssert(len(self._inputs) == 1) # First output (output[1]) is the true expected count and has shape # equal to the input tensor unless num_true attribute is changed self.debugAssert(len(self._outputs) == 3) if self._num_true != 1: self.notImplemented('CandidateSamplerOp propagateShapes ' \ 'num_true != 1') self._outputs[1].shape.mergeShape(self._inputs[0].shape, make_symbolic=make_symbolic) num_samples = None if self._num_sampled is None: samps_name = '{}::rand_samps'.format(self.name) num_samples = utils.getIntSymbolFromString(samps_name) else: self.notImplemented('CandidateSamplerOp: propagateShapes '\ 'num_sampled != None') self._outputs[0].shape.mergeShape([num_samples]) self._outputs[2].shape.mergeShape([num_samples])
def test_concat_op(): ''' Specify graphs with concat operations and make sure dimensions behave as desired. ''' combos = [([[None, None], [None, None]], 0), ([[None, None], [None, None]], 1), ([[3, None], [3, None]], 0), ([[3, None], [3, None]], 1), ([[3, 7], [6, 7]], 0), ([[3, 15], [3, None]], 1), ([[3, None, 7, 15], [3, 15, 7, None]], 0), ([[3, None, 7, 15], [3, 15, 7, None]], 1), ([[3, None, 7, 15], [3, 15, 7, None]], 2), ([[3, None, 7, 15], [3, 15, 7, None]], 3), ([[3, None, 7, 15], [3, 15, 7, None], [None, 15, 7, 30]], 3)] for combo in combos: graph = Graph() with graph.asDefault(): ph_dims, axis = combo print('Testing concat with in dims {}, axis {}'.format( ph_dims, axis)) # Build model in_phs = [] rank = None for idx, ph_dim in enumerate(ph_dims): ph_name = 'in_{}'.format(idx) in_phs.append(placeholder(ph_name, ph_dim)) if rank is None: rank = in_phs[idx].shape.rank else: assert rank == in_phs[idx].shape.rank concat_out = concat('concat', [None] * rank, in_phs, axis=axis) assert graph.isValid() feed_dict = {} out_c_dim = Dimension(0) for in_ph, ph_dim in zip(in_phs, ph_dims): in_ph_dims = [] for idx, dim in enumerate(ph_dim): append_dim_sym = None if dim is None: dim_name = 'bind_{}_{}'.format(in_ph.name, idx) append_dim_sym = utils.getIntSymbolFromString(dim_name) else: append_dim_sym = dim in_ph_dims.append(append_dim_sym) if idx == axis: append_dim = Dimension(None) append_dim.setSymbolOrName(append_dim_sym) out_c_dim += append_dim feed_dict[in_ph.name] = in_ph_dims print(' Feed dict: {}'.format(feed_dict)) graph.bindTensorShapeDimensions(feed_dict) out_dims = TensorShape(in_phs[-1].shape.dims) out_dims.dims[axis] = out_c_dim check_symbol_table = {} for idx in range(concat_out.shape.rank): c_out_dim = concat_out.shape.getDimension(idx).symbol out_dim = out_dims.getDimension(idx).symbol if isinstance(out_dim, sympy.Symbol) and \ isinstance(c_out_dim, int): if out_dim not in check_symbol_table.keys(): check_symbol_table[out_dim] = c_out_dim out_dim = c_out_dim else: assert c_out_dim == check_symbol_table[out_dim] print(' Catamount dim[{}]: {}'.format(idx, c_out_dim)) print(' Correct dim[{}]: {}'.format(idx, out_dim)) assert (sympy.simplify(c_out_dim - out_dim) == 0), \ 'Concat dim[{}] incorrect!\n Expecting: {}\n' \ ' Calculated: {}'.format(idx, out_dim, c_out_dim) reset_symbols()
def __init__(self, name): super(AllreduceOp, self).__init__(name) num_workers_str = '{}::num_workers'.format(self.name) self._workers_symbol = utils.getIntSymbolFromString(num_workers_str)
def run_tf_speech_attention(): global is_pytest_run graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/speech_attention/model.ckpt.meta' graph = catamount.frameworks.tensorflow.import_graph(graph_meta) assert graph.isValid() # HAX: NEED TO MANUALLY REMOVE SOME?! WHY? remove_ops = [ 'DevArgmaxWERChecker/Less', 'DevLossChecker/Less', 'DevArgmaxWERChecker/best_dev', 'DevLossChecker/best_dev' ] for op_name in remove_ops: op = graph.opsByName[op_name] graph.removeOp(op) assert graph.isValid() # Remove ops that are not executed during a standard training step: graph_ops = list(graph._ops_by_name.values()) for op in graph_ops: # Ops in attn_model_[1-3] are used for inference if 'attn_model_1' in op.name or \ 'attn_model_2' in op.name or \ 'attn_model_3' in op.name: graph.removeOp(op) assert graph.isValid() print('Initial graph:\n{}\n'.format(graph)) init_params = graph.calcModelParameters() print('Initial parameters: {}'.format(init_params)) print('Initial Flops: {}\n'.format(graph.calcAlgFlops())) print('Placeholders:') for op in graph.getPlaceholders(): print(op.debugString()) print('') # Set up symbols to name dimensions audio_features_symbol = utils.getPositiveIntSymbolFromString( 'audio_features') encoder_steps_symbol = utils.getPositiveIntSymbolFromString( 'encoder_steps') decoder_steps_symbol = utils.getPositiveIntSymbolFromString( 'decoder_steps') subbatch_size_symbol = utils.getPositiveIntSymbolFromString( 'subbatch_size') attn_dim_symbol = utils.getPositiveIntSymbolFromString('attn_dim') attn_hidden_dim_symbol = utils.getPositiveIntSymbolFromString( 'attn_hidden_dim') dec_hidden_dim_symbol = utils.getPositiveIntSymbolFromString( 'dec_hidden_dim') enc_hidden_dim_symbol = utils.getPositiveIntSymbolFromString( 'enc_hidden_dim') graph_iters_symbol = utils.getIntSymbolFromString('graph::iters') output_vocab_symbol = utils.getPositiveIntSymbolFromString('output_vocab') conv_width_symbol = utils.getPositiveIntSymbolFromString('conv_width') num_conv_filters_symbol = utils.getPositiveIntSymbolFromString( 'num_conv_filters') # Convert these constant dimensions to symbols base_encoder_steps = 300 base_decoder_steps = 300 base_subbatch_size = 32 base_output_vocab = 31 base_audio_features = 40 base_conv_width = 53 base_attn_dim = 137 base_attn_hidden_dim = 509 base_dec_hidden_dim = 571 base_enc_hidden_dim = 1051 base_enc_input_dim = 1091 # Input + recurrent state enc_input_dim_symbol = audio_features_symbol + enc_hidden_dim_symbol base_dec_attn_rec = 2133 dec_attn_rec_symbol = 2 * enc_hidden_dim_symbol + output_vocab_symbol base_attn_cell_inputs = 2611 attn_cell_inputs_symbol = 2 * enc_hidden_dim_symbol + attn_hidden_dim_symbol base_attn_cell_in_dim = 2642 attn_cell_in_dim_symbol = 2 * enc_hidden_dim_symbol + output_vocab_symbol + \ attn_hidden_dim_symbol base_dec_attn_dim = 3182 dec_attn_dim_symbol = attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + \ dec_hidden_dim_symbol bind_dict = { # Placeholders 'attn_model/input_seq': [encoder_steps_symbol, subbatch_size_symbol, audio_features_symbol], 'attn_model/input_len': [subbatch_size_symbol], 'attn_model/output_seq': [decoder_steps_symbol, subbatch_size_symbol], 'attn_model/output_mask': [decoder_steps_symbol, subbatch_size_symbol], # Variables 'InputNormalizer/means': [audio_features_symbol], 'InputNormalizer/std': [audio_features_symbol], 'attn_model/AffineAttentionStateNN/W': [2 * enc_hidden_dim_symbol, attn_dim_symbol], 'attn_model/AffineAttentionStateNN/b': [attn_dim_symbol], 'attn_model/AffineOutputProjection/W': [dec_hidden_dim_symbol, output_vocab_symbol], 'attn_model/AffineOutputProjection/b': [output_vocab_symbol], 'attn_model/Decoder/attn_model/attention_cell/biases': [4 * attn_hidden_dim_symbol], 'attn_model/Decoder/attn_model/attention_cell/weights': [attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + output_vocab_symbol, 4 * attn_hidden_dim_symbol], 'attn_model/Decoder/attn_model/decoder_cell/biases': [4 * dec_hidden_dim_symbol], 'attn_model/Decoder/attn_model/decoder_cell/weights': [attn_hidden_dim_symbol + dec_hidden_dim_symbol + 2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol], 'attn_model/HybridAttentionContext/Q': [conv_width_symbol, 1, num_conv_filters_symbol], 'attn_model/HybridAttentionContext/U': [1, num_conv_filters_symbol, attn_dim_symbol], 'attn_model/HybridAttentionContext/W': [2 * attn_hidden_dim_symbol, attn_dim_symbol], 'attn_model/HybridAttentionContext/b': [attn_dim_symbol], 'attn_model/HybridAttentionContext/w': [attn_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], # Constants 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/MatMul/Enter_grad/b_acc': [dec_hidden_dim_symbol, output_vocab_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/add/Enter_grad/b_acc': [output_vocab_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/MatMul/Enter_grad/b_acc': [2 * attn_hidden_dim_symbol, attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/add_2/Enter_grad/b_acc': [attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/attention_cell/BiasAdd/Enter_grad/b_acc': [4 * attn_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/attention_cell/attention_cell/add/Enter_grad/b_acc': [attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + output_vocab_symbol, 4 * attn_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/ExpandDims_1/Enter_grad/b_acc': [conv_width_symbol, 1, 4], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/ExpandDims_1/Enter_grad/b_acc': [1, 4, attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/decoder_cell/BiasAdd/Enter_grad/b_acc': [4 * dec_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/decoder_cell/decoder_cell/add/Enter_grad/b_acc': [attn_hidden_dim_symbol + dec_hidden_dim_symbol + 2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/mul/Enter_grad/b_acc': [attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol], } # Update constant values const_dict = { 'attn_model/AffineAttentionStateNN/Reshape/shape': [-1, 2 * enc_hidden_dim_symbol], 'attn_model/AffineAttentionStateNN/Reshape_1/shape/2': attn_dim_symbol, 'attn_model/AttentionEncoderDecoder/Reshape/shape/1': output_vocab_symbol, 'attn_model/AttentionModel/gradients/attn_model/AffineAttentionStateNN/add_grad/Shape_1': [attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/add_grad/Shape_1': [output_vocab_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/add_2_grad/Shape_1': [attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/Conv2D_grad/Const': [1, conv_width_symbol, 1, num_conv_filters_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/ExpandDims_1_grad/Shape': [conv_width_symbol, 1, num_conv_filters_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/Conv2D_grad/Const': [1, 1, num_conv_filters_symbol, attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/ExpandDims_1_grad/Shape': [1, num_conv_filters_symbol, attn_dim_symbol], 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/mul_grad/Shape_1': [attn_dim_symbol], 'attn_model/Decoder/CustomLSTMCellZeroState/Const': [2 * attn_hidden_dim_symbol], 'attn_model/Decoder/CustomLSTMCellZeroState/Const_1': [2 * attn_hidden_dim_symbol], 'attn_model/Decoder/CustomLSTMCellZeroState_1/Const': [ 2 * dec_hidden_dim_symbol ], 'attn_model/Decoder/CustomLSTMCellZeroState_1/Const_1': [ 2 * dec_hidden_dim_symbol ], 'attn_model/Decoder/while/attn_model/attention_cell/attention_cell/Shape': [ attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + output_vocab_symbol, 4 * attn_hidden_dim_symbol ], 'attn_model/Decoder/while/attn_model/decoder_cell/decoder_cell/Shape': [ attn_hidden_dim_symbol + dec_hidden_dim_symbol + 2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol ], 'attn_model/Decoder/while/attn_model/one_hot/depth': output_vocab_symbol, 'attn_model/Decoder/zeros/shape/1': 2 * enc_hidden_dim_symbol, 'attn_model/Decoder/zeros_2/shape/1': output_vocab_symbol, 'attn_model/Reshape/shape': [1, 1, audio_features_symbol], 'attn_model/Reshape_1/shape': [1, 1, audio_features_symbol], 'attn_model/Reshape_2/shape/2': 2 * enc_hidden_dim_symbol, 'attn_model/StackedEncoder/Layer0/RNNEncoder/Reshape/shape/2': audio_features_symbol, 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/Const_4': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/Const_4': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/Reshape/shape/2': 2 * enc_hidden_dim_symbol, 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/Const_4': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/Const_4': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/Reshape/shape/2': 2 * enc_hidden_dim_symbol, 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/Const_4': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1': [2 * enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/Const_1': [enc_hidden_dim_symbol], 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/Const_4': [enc_hidden_dim_symbol], } graph.bindConstantValues(const_dict) # TODO: Currently, Catamount doesn't automatically handle Tensorflow TensorArrays # or Stack ops. Here, manually set the dimensions of these ops' tensors. for op in graph._ops_by_name.values(): op_name_suffix = op.name.split('/')[-1] if 'TensorArrayGather' in op_name_suffix: assert isinstance(op, UnknownOp) assert len(op._inputs) == 3 assert len(op._outputs) == 1 if op._outputs[0].shape.rank == 1 or op._outputs[0].shape.rank == 2: if len(op._outputs[0].consumers) > 0: print( 'TODO: Unknown TensorArrayGather (rank {}): {}'.format( op._outputs[0].shape.rank, op.debugString())) elif op._outputs[0].shape.isUnknown( ) or op._outputs[0].shape.rank == 3: if len(op._outputs[0].consumers) > 0: # If output rank is 3, then appears to be: # [seq_length, batch_size, enc_hid], where # seq_length depends on layer out_shape = None if 'StackedEncoder/Layer0' in op.name: out_shape = [ encoder_steps_symbol, subbatch_size_symbol, enc_hidden_dim_symbol ] elif 'StackedEncoder/Layer2' in op.name: if 'attn_model/AttentionModel/gradients' in op.name: # Backprop stores concatenated state out_shape = [ encoder_steps_symbol // 2, subbatch_size_symbol, 2 * enc_hidden_dim_symbol ] else: out_shape = [ encoder_steps_symbol // 2, subbatch_size_symbol, enc_hidden_dim_symbol ] elif 'StackedEncoder/Layer4' in op.name: if 'attn_model/AttentionModel/gradients' in op.name: # Backprop stores concatenated state out_shape = [(encoder_steps_symbol // 2) // 2, subbatch_size_symbol, 2 * enc_hidden_dim_symbol] else: out_shape = [(encoder_steps_symbol // 2) // 2, subbatch_size_symbol, enc_hidden_dim_symbol] elif 'Decoder' in op.name: # HAXXXX: Manually specify a few if op.name == 'attn_model/Decoder/TensorArrayStack/TensorArrayGatherV3': out_shape = [ decoder_steps_symbol, subbatch_size_symbol, output_vocab_symbol ] else: out_shape = [ decoder_steps_symbol, subbatch_size_symbol, dec_hidden_dim_symbol ] else: print('TODO: Unknown TensorArrayGather {}'.format( op.debugString())) if out_shape is not None: op._outputs[0].mergeShape(out_shape, make_symbolic=True) else: print('TODO: Unknown TensorArrayGather {}'.format( op.debugString())) elif 'TensorArraySize' in op_name_suffix: assert isinstance(op, UnknownOp) assert len(op._inputs) == 2 assert len(op._outputs) == 1 assert op._outputs[0].shape.rank == 0 # NOTES: # StackedEncoder Layer0: enc_seq # StackedEncoder Layer2: enc_seq / 2 # Due to stride 2 in time # StackedEncoder Layer4: enc_seq / 4 # Due to stride 2 in time # Decoder: dec_seq if 'StackedEncoder/Layer0' in op.name: op._outputs[0].setValue(encoder_steps_symbol) elif 'StackedEncoder/Layer2' in op.name: op._outputs[0].setValue(encoder_steps_symbol // 2) elif 'StackedEncoder/Layer4' in op.name: op._outputs[0].setValue((encoder_steps_symbol // 2) // 2) elif 'Decoder' in op.name: op._outputs[0].setValue(decoder_steps_symbol) else: print('WARN: Unknown TensorArraySizeV3: {}'.format( op.debugString())) elif 'TensorArrayRead' in op_name_suffix: assert isinstance(op, UnknownOp) assert len(op._inputs) == 3 assert len(op._outputs) == 1 assert op._outputs[0].shape.isUnknown() or \ op._outputs[0].shape.rank == 2, \ '{}'.format(op.name) if op._outputs[0].shape.isUnknown(): if len(op._outputs[0].consumers) > 0: out_shape = None if 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer' in op.name and \ ('/RNNEncoder/bidirectional_rnn/fw/fw/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' in op.name or \ '/RNNEncoder/bidirectional_rnn/bw/bw/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' in op.name): out_shape = [ subbatch_size_symbol, enc_hidden_dim_symbol ] elif op.name == 'attn_model/AttentionModel/gradients/attn_model/Decoder/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' or \ op.name == 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/TensorArrayWrite_1/TensorArrayWriteV3_grad/TensorArrayReadV3' or \ op.name == 'attn_model_2/Decoder/while/cond/TensorArrayReadV3' or \ op.name == 'attn_model/Decoder/while/cond/TensorArrayReadV3': out_shape = [subbatch_size_symbol, output_vocab_symbol] else: print('WARN: Unknown TensorArrayReadV3 out shape: {}'. format(op.debugString())) if out_shape is not None: op._outputs[0].mergeShape(out_shape, make_symbolic=True) else: # NOTES: Many are (?, 40 "features"), (?, 1051 "enc_hid"), or (?, 2102 "2*enc_hid") dim_1_val = op._outputs[0].shape.getDimension(1).value assert dim_1_val == base_audio_features or \ dim_1_val == base_enc_hidden_dim or \ dim_1_val == 2 * base_enc_hidden_dim, \ 'Op: {}\n Dim 1 value: {}'.format(op.debugString(), dim_1_val) out_shape = None if dim_1_val == base_audio_features: out_shape = [subbatch_size_symbol, audio_features_symbol] elif dim_1_val > 0 and dim_1_val % base_enc_hidden_dim == 0: mult = dim_1_val // base_enc_hidden_dim out_shape = [ subbatch_size_symbol, mult * enc_hidden_dim_symbol ] else: print('Unhandled TensorArrayRead: {}'.format( op.debugString())) if out_shape is not None: op._outputs[0].mergeShape(out_shape, make_symbolic=True) # Manually set a couple shapes for max ops that can't yet resolve # maximums of 1 vs. positive symbols: max_op = graph._ops_by_name[ 'attn_model/AttentionModel/gradients/attn_model/AttentionEncoderDecoder/Sum_grad/Maximum'] max_op._outputs[0].mergeShape([2]) max_op._outputs[0].setValue([1, subbatch_size_symbol]) max_op = graph._ops_by_name[ 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/Sum_grad/Maximum'] max_op._outputs[0].mergeShape([3]) # [floor(floor(encoder_steps/2)/2) subbatch_size 1] max_op._outputs[0].setValue([(encoder_steps_symbol // 2) // 2, subbatch_size_symbol, 1]) max_op = graph._ops_by_name[ 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/Sum_1_grad/Maximum'] max_op._outputs[0].mergeShape([3]) # [1 subbatch_size 2*enc_hidden_dim] max_op._outputs[0].setValue( [1, subbatch_size_symbol, 2 * enc_hidden_dim_symbol]) print('Binding variables') graph.bindShapesAndPropagate(bind_dict, warn_if_ill_defined=(not is_pytest_run), make_symbolic=True) assert graph.isValid() print('\n\nCleaned Graph:\n{}'.format(graph)) print('\n\nBound values') # Set base values to be subbed in: base_encoder_steps = 96 base_decoder_steps = 24 base_attn_dim = 128 base_conv_width = 50 base_attn_hidden_dim = 512 base_dec_hidden_dim = 512 base_enc_hidden_dim = 1024 bind_subs = { audio_features_symbol: base_audio_features, encoder_steps_symbol: base_encoder_steps, decoder_steps_symbol: (encoder_steps_symbol // 2) // 2, subbatch_size_symbol: base_subbatch_size, attn_dim_symbol: base_attn_dim, attn_hidden_dim_symbol: enc_hidden_dim_symbol // 2, dec_hidden_dim_symbol: enc_hidden_dim_symbol // 2, output_vocab_symbol: base_output_vocab, conv_width_symbol: base_conv_width, enc_hidden_dim_symbol: base_enc_hidden_dim, num_conv_filters_symbol: 4, graph_iters_symbol: 1, } # Add loop iteration counts to bind_subs bind_str_subs = { 'attn_model/AttentionModel/gradients/b_count_2_block::iters': decoder_steps_symbol, 'attn_model/Decoder/while/LoopCond_block::iters': decoder_steps_symbol, 'attn_model/AttentionModel/gradients/b_count_22_block::iters': encoder_steps_symbol, 'attn_model/AttentionModel/gradients/b_count_26_block::iters': encoder_steps_symbol, 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters': encoder_steps_symbol, 'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters': encoder_steps_symbol, 'attn_model/AttentionModel/gradients/b_count_14_block::iters': encoder_steps_symbol // 2, 'attn_model/AttentionModel/gradients/b_count_18_block::iters': encoder_steps_symbol // 2, 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters': encoder_steps_symbol // 2, 'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters': encoder_steps_symbol // 2, 'attn_model/AttentionModel/gradients/b_count_6_block::iters': (encoder_steps_symbol // 2) // 2, 'attn_model/AttentionModel/gradients/b_count_10_block::iters': (encoder_steps_symbol // 2) // 2, 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters': (encoder_steps_symbol // 2) // 2, 'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters': (encoder_steps_symbol // 2) // 2, } for var_name, sub_val in bind_str_subs.items(): var_ref = utils.getIntSymbolFromString(var_name) assert var_name not in bind_subs.keys() bind_subs[var_ref] = sub_val # Calculate model parameter count parameters = graph.calcModelParameters() resolved_params = parameters.subs(bind_subs) try: resolved_params = int(resolved_params) except: print('ERROR: resolved_params should be int, but is {} = {}'.format( type(resolved_params), resolved_params)) correct_params = 71084729 assert resolved_params == correct_params, \ 'Incorrect model params: {}'.format(resolved_params) print('Parameters: {}\nWith specified dims: {}\n'.format( parameters, resolved_params)) # Calculate algorithmic Flops alg_flops = graph.calcAlgFlops() resolved_flops = alg_flops.subs(bind_subs) try: resolved_flops = int(resolved_flops) except: print('ERROR: resolved_flops should be int, but is {} = {}'.format( type(resolved_flops), resolved_flops)) correct_flops = 568878183032 assert resolved_flops == correct_flops, \ 'Incorrect algorithmic flops: {}'.format(resolved_flops) print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format( alg_flops, resolved_flops)) # Calculate algorthmic Bytes accessed alg_bytes = graph.calcAlgBytes() resolved_bytes = alg_bytes.subs(bind_subs) try: resolved_bytes = int(resolved_bytes) except: print('ERROR: resolved_bytes should be int, but is {} = {}'.format( type(resolved_bytes), resolved_bytes)) correct_bytes = 92231419797 assert resolved_bytes == correct_bytes, \ 'Incorrect algorithmic bytes: {}'.format(resolved_bytes) print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format( alg_bytes, resolved_bytes)) # Calculate algorthmic Bytes accessed alg_footprint = graph.calcAlgFootprint() resolved_footprint = alg_footprint.subs(bind_subs) try: resolved_footprint = int(resolved_footprint) except: print('ERROR: resolved_footprint should be int, but is {} = {}'.format( type(resolved_footprint), resolved_footprint)) correct_footprint = 32624988214 assert resolved_footprint == correct_footprint, \ 'Incorrect algorithmic footprint: {}'.format(resolved_footprint) print('Alg mem footprint: {}\nWith specified dims: {}\n'.format( alg_footprint, resolved_footprint)) # Calculate algorithmic IO per step total_io_footprint = 0 for op in graph.getPlaceholders(): total_io_footprint += op.calcAlgFootprint() resolved_io_footprint = total_io_footprint.subs(bind_subs) print('Alg IO footprint: {}\nWith specified dims: {}\n'.format( total_io_footprint, resolved_io_footprint)) print('VERBOSE ALGORTHMIC FLOPS:') graph.calcAlgFlops(verbose=True) print('') print('VERBOSE ALGORTHMIC BYTES:') graph.calcAlgBytes(verbose=True) print('') print('VERBOSE ALGORTHMIC FOOTPRINT:') graph.calcAlgFootprint(verbose=True) print('') # HACKY WAY TO SAVE MODELS FOR NOW! pickle.dump( graph, open( 'catamount/frameworks/example_graphs/tensorflow/full_models/speech_attention/graph_speech_attention.p', 'wb')) if is_pytest_run: return print('\n\n======= Algorithmic graph-level analytics: =======') encoder_dims = [ 32, 64, 96, 128, 160, 192, 256, 320, 384, 448, 512, 640, 768, 892, 1024, 1152, 1280, 1408, 1548, 1702, 1872, 2059, 2264, 2490, 2739, 3012, 3289 ] base_encoder_steps = 335 base_subbatch_size = 32 base_attn_dim = 128 base_conv_width = 50 base_attn_hidden_dim = 512 base_dec_hidden_dim = 512 base_enc_hidden_dim = 1024 bind_subs[audio_features_symbol] = base_audio_features bind_subs[encoder_steps_symbol] = base_encoder_steps bind_subs[decoder_steps_symbol] = (encoder_steps_symbol // 2) // 2 bind_subs[subbatch_size_symbol] = base_subbatch_size bind_subs[attn_dim_symbol] = base_attn_dim bind_subs[attn_hidden_dim_symbol] = enc_hidden_dim_symbol // 2 bind_subs[dec_hidden_dim_symbol] = enc_hidden_dim_symbol // 2 bind_subs[output_vocab_symbol] = base_output_vocab bind_subs[conv_width_symbol] = base_conv_width # bind_subs[enc_hidden_dim_symbol] = base_enc_hidden_dim bind_subs[num_conv_filters_symbol] = 4 bind_subs[graph_iters_symbol] = 1 bind_subs.pop(enc_hidden_dim_symbol) resolved_params = parameters.subs(bind_subs) print('Symbol associations: {}\n'.format(bind_subs)) print( 'Algorithmic Flops by hidden dimension, params, and per-batch-sample:') resolved_flops = alg_flops.subs(bind_subs) for enc_dim in encoder_dims: graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim}) graph_flops = resolved_flops.subs({enc_hidden_dim_symbol: enc_dim}) graph_flops_per_sample = float(graph_flops) / \ bind_subs[subbatch_size_symbol] print('{}\t{}\t{}\t{}'.format(enc_dim, graph_params, graph_flops, int(graph_flops_per_sample))) print('\nAlgorithmic bytes accessed by hidden dimension, params:') resolved_bytes = alg_bytes.subs(bind_subs) for enc_dim in encoder_dims: graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim}) graph_bytes = resolved_bytes.subs({enc_hidden_dim_symbol: enc_dim}) print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_bytes)) print('\nAlgorithmic memory footprint by hidden dimension, params:') resolved_footprint = alg_footprint.subs(bind_subs) for enc_dim in encoder_dims: graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim}) graph_footprint = resolved_footprint.subs( {enc_hidden_dim_symbol: enc_dim}) print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_footprint)) print( '\nAlgorithmic minimal memory footprint by hidden dimension, params:') full_subs = dict(bind_subs) for enc_dim in encoder_dims: graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim}) full_subs[enc_hidden_dim_symbol] = enc_dim graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs) print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_min_foot))
def test_tf_static_unroll_rnn(): graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename) assert graph.isValid() algorithmic_flops = graph.calcAlgFlops() ba_0 = utils.getIntSymbolFromString('basic_rnn_cell/BiasAdd:0::dim_0') ba_1 = utils.getIntSymbolFromString('basic_rnn_cell/BiasAdd_1:0::dim_0') ba_2 = utils.getIntSymbolFromString('basic_rnn_cell/BiasAdd_2:0::dim_0') ba_3 = utils.getIntSymbolFromString('basic_rnn_cell/BiasAdd_3:0::dim_0') ba_4 = utils.getIntSymbolFromString('basic_rnn_cell/BiasAdd_4:0::dim_0') mm_0 = utils.getIntSymbolFromString('basic_rnn_cell/MatMul:0::dim_0') mm_1 = utils.getIntSymbolFromString('basic_rnn_cell/MatMul_1:0::dim_0') mm_2 = utils.getIntSymbolFromString('basic_rnn_cell/MatMul_2:0::dim_0') mm_3 = utils.getIntSymbolFromString('basic_rnn_cell/MatMul_3:0::dim_0') mm_4 = utils.getIntSymbolFromString('basic_rnn_cell/MatMul_4:0::dim_0') th_0 = utils.getIntSymbolFromString('basic_rnn_cell/Tanh:0::dim_0') th_1 = utils.getIntSymbolFromString('basic_rnn_cell/Tanh_1:0::dim_0') th_2 = utils.getIntSymbolFromString('basic_rnn_cell/Tanh_2:0::dim_0') th_3 = utils.getIntSymbolFromString('basic_rnn_cell/Tanh_3:0::dim_0') th_4 = utils.getIntSymbolFromString('basic_rnn_cell/Tanh_4:0::dim_0') correct_alg_flops = 24 * (ba_0 + ba_1 + ba_2 + ba_3 + ba_4) + \ 2304 * (mm_0 + mm_1 + mm_2 + mm_3 + mm_4) + \ 144 * (th_0 + th_1 + th_2 + th_3 + th_4) + 2305 print('Loaded Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Initial alg flops incorrect!\n Expecting: {}\n Calculated: {}' \ .format(correct_alg_flops, algorithmic_flops) # Now, bind tensor names in the graph and verify that the algorithmic # Flop counts reflect the new name bindings batch_size = utils.getIntSymbolFromString('batch_size') # NOTE: This also works: batch_size = 'batch_size' # Bind placeholders (a and b) output dimensions 0 to name batch_size bind_dict = { 'a': ['seq_length', 'batch_size', 'hidden_dim'], 'init_state': ['batch_size', 'hidden_dim'] } graph.bindTensorShapeDimensions(bind_dict) algorithmic_flops = graph.calcAlgFlops() # Update the algorithmic Flops formula correct_alg_flops = correct_alg_flops.subs({ ba_0: batch_size, ba_1: batch_size, ba_2: batch_size, ba_3: batch_size, ba_4: batch_size, mm_0: batch_size, mm_1: batch_size, mm_2: batch_size, mm_3: batch_size, mm_4: batch_size, th_0: batch_size, th_1: batch_size, th_2: batch_size, th_3: batch_size, th_4: batch_size }) print('Bound Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Bound alg flops incorrect!\n Expecting: {}\n Calculated: {}' \ .format(correct_alg_flops, algorithmic_flops)
def run_tf_language_model(domain=None, build_projection=False): global is_pytest_run if domain == 'wordlm': graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/word_lm_n2004_l2_sgd_lr0.2_nodrop_b128_v10k_d20_s80-best_model.meta' elif domain == 'charlm': graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/char_lm_n2004_l10_sgd_lr0.15_rhn_b128_vchar_d1.0_s150-latest_model.meta' elif domain == 'nmt': graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/nmt_el2_dl1_n1024_b128-translate.ckpt-1000.meta' else: raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain)) graph = catamount.frameworks.tensorflow.import_graph(graph_meta) assert graph.isValid() # Next, remove ops that are not executed during a standard training step: # TODO: Implement feeds->fetches calcAlg* if domain == 'wordlm': graph_ops = list(graph._ops_by_name.values()) for op in graph_ops: # Certain ops are only used for inference if 'Model/Recurrent_1_lstm_3/' in op.name or \ 'Model/Recurrent_2_lstm_3/' in op.name or \ 'Model/FullSoftmaxLoss_1_3/' in op.name or \ 'Model/Collapse_1/' in op.name or \ 'Model/Embedding_1_3/' in op.name or \ 'Model/Labels_1/' in op.name or \ 'Model/Mask_1/' in op.name: graph.removeOp(op) elif op.name == 'Model/Sum_1' or \ op.name == 'Model/Cast_3' or \ op.name == 'Model/Cast_2' or \ op.name == 'Model/Size_1' or \ op.name == 'Model/truediv_2' or \ op.name == 'Model/truediv_3' or \ op.name == 'Model/Exp_1': graph.removeOp(op) elif domain == 'charlm': graph_ops = list(graph._ops_by_name.values()) for op in graph_ops: # Certain ops are only used for inference if 'Model/Recurrent_1_rhn_3/' in op.name or \ 'Model/FullSoftmaxLoss_1_3/' in op.name or \ 'Model/Collapse_1/' in op.name or \ 'Model/Embedding_1_3/' in op.name or \ 'Model/Labels_1/' in op.name or \ 'Model/Mask_1/' in op.name: graph.removeOp(op) elif op.name == 'Model/Cast_1' or \ op.name == 'Model/Sum_1' or \ op.name == 'Model/Size_1' or \ op.name == 'Model/truediv_2' or \ op.name == 'Model/truediv_3' or \ op.name == 'Model/Exp_1': graph.removeOp(op) elif domain == 'nmt': pass else: raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain)) if not is_pytest_run: print('Initial graph:\n{}\n'.format(graph)) init_params = graph.calcModelParameters() print('Initial parameters: {}'.format(init_params)) print('Initial Flops: {}\n'.format(graph.calcAlgFlops())) print('Placeholders:') for op in graph.getPlaceholders(): print(op.debugString()) print('') # Set up symbols to name dimensions hidden_dim_symbol = utils.getIntSymbolFromString('hidden_dim') vocab_size_symbol = utils.getIntSymbolFromString('vocab_size') subbatch_size_symbol = utils.getIntSymbolFromString('subbatch_size') sequence_length_symbol = utils.getIntSymbolFromString('sequence_length') batch_times_seq_symbol = sequence_length_symbol * subbatch_size_symbol graph_iters_symbol = utils.getIntSymbolFromString('graph::iters') # Convert these constant dimensions to symbols base_subbatch_size = None base_sequence_length = None if domain == 'wordlm': base_hidden_dim = 2004 base_vocab_size = 10004 elif domain == 'charlm': base_hidden_dim = 2004 base_vocab_size = 98 elif domain == 'nmt': base_hidden_dim = 1024 base_vocab_size = 36548 base_sequence_length = 19 base_subbatch_size = 128 else: raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain)) # HAXXX: Manually setting TensorArray shapes! if domain == 'wordlm' or domain == 'charlm' or domain == 'nmt': for op in graph._ops_by_name.values(): op_name_suffix = op.name.split('/')[-1] if 'TensorArrayGather' in op_name_suffix: assert isinstance(op, UnknownOp) assert len(op._inputs) == 3 assert len(op._outputs) == 1 if domain == 'wordlm' or domain == 'charlm': assert op._outputs[0].shape.isUnknown() or \ op._outputs[0].shape.rank == 3, \ '{}'.format(op.name) gather_shape = [ sequence_length_symbol, subbatch_size_symbol, hidden_dim_symbol ] else: assert domain == 'nmt' assert op._outputs[0].shape.isUnknown() or \ op._outputs[0].shape.rank == 2 or \ op._outputs[0].shape.rank == 3, \ '{}'.format(op.name) if not op._outputs[0].shape.isUnknown(): if op._outputs[0].shape.rank == 3: out_shape = [ base_sequence_length, base_subbatch_size, base_hidden_dim ] # Verify that the shape is clearly specified op._outputs[0].mergeShape(out_shape, make_symbolic=True) gather_shape = [ sequence_length_symbol, subbatch_size_symbol, hidden_dim_symbol ] else: # This TAGather is known to be unused, so who cares?! assert len(op._outputs[0].consumers) == 0 continue op._outputs[0].mergeShape(gather_shape, make_symbolic=True) elif 'TensorArraySize' in op_name_suffix: assert isinstance(op, UnknownOp) assert len(op._inputs) == 2 assert len(op._outputs) == 1 assert op._outputs[0].shape.rank == 0 op._outputs[0].setValue(sequence_length_symbol) elif 'TensorArrayRead' in op_name_suffix: assert isinstance(op, UnknownOp) assert len(op._inputs) == 3 assert len(op._outputs) == 1 assert op._outputs[0].shape.isUnknown() or \ op._outputs[0].shape.rank == 2, \ '{}'.format(op.name) if not op._outputs[0].shape.isUnknown(): assert op._outputs[0].shape.dims[ 1].value == base_hidden_dim read_shape = [subbatch_size_symbol, hidden_dim_symbol] op._outputs[0].mergeShape(read_shape, make_symbolic=True) else: raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain)) assert graph.isValid() if domain == 'wordlm': const_dict = { 'Model/Collapse/Reshape/shape': [-1, hidden_dim_symbol], 'Model/Recurrent_\d_lstm_1/rnn/Const': [hidden_dim_symbol], 'Model/Recurrent_\d_lstm_1/rnn/Const_1': [hidden_dim_symbol], 'Model/Gradient/Compute/gradients/Model/Embedding_1_1/Gather_grad/Shape': [vocab_size_symbol, hidden_dim_symbol], 'Model/Gradient/Compute/gradients/Model/FullSoftmaxLoss_1_1/add_grad/Shape_1': [1, vocab_size_symbol], } elif domain == 'charlm': const_dict = { 'Model/Collapse/Reshape/shape': [-1, hidden_dim_symbol], 'Model/Recurrent_1_rhn_1/rnn/Const': [hidden_dim_symbol], 'Model/Recurrent_1_rhn_1/rnn/Const_1': [hidden_dim_symbol], 'Model/Gradient/Compute/gradients/Model/FullSoftmaxLoss_1_1/add_grad/Shape_1': [1, vocab_size_symbol], 'Model/Gradient/Compute/gradients/Model/Embedding_1_1/Gather_grad/Shape': [vocab_size_symbol, hidden_dim_symbol], } elif domain == 'nmt': const_dict = { 'gradients/dynamic_seq2seq/decoder/output_projection/Tensordot/Reshape_1_grad/Shape': [hidden_dim_symbol, vocab_size_symbol], 'gradients/dynamic_seq2seq/decoder/embedding_lookup_grad/Shape': [vocab_size_symbol, hidden_dim_symbol], 'gradients/dynamic_seq2seq/encoder/embedding_lookup_grad/Shape': [vocab_size_symbol, hidden_dim_symbol], 'gradients/dynamic_seq2seq/decoder/LuongAttention/memory_layer/Tensordot/Reshape_1_grad/Shape': [2 * hidden_dim_symbol, hidden_dim_symbol], 'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/Const_1': [ hidden_dim_symbol ], 'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/Const_4': [ hidden_dim_symbol ], 'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/DeviceWrapperZeroState/DropoutWrapperZeroState/BasicLSTMCellZeroState/Const': [hidden_dim_symbol], 'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/DeviceWrapperZeroState/DropoutWrapperZeroState/BasicLSTMCellZeroState/Const_\d': [hidden_dim_symbol], 'dynamic_seq2seq/decoder/output_projection/Tensordot/Reshape_1/shape': [hidden_dim_symbol, vocab_size_symbol], 'dynamic_seq2seq/decoder/output_projection/Tensordot/Const_2': [ vocab_size_symbol ], 'dynamic_seq2seq/decoder/decoder/Const': [hidden_dim_symbol], 'dynamic_seq2seq/decoder/decoder/Const_1': [hidden_dim_symbol], 'dynamic_seq2seq/decoder/LuongAttention/memory_layer/Tensordot/Const_2': [hidden_dim_symbol], 'dynamic_seq2seq/decoder/LuongAttention/memory_layer/Tensordot/Reshape_1/shape': [2 * hidden_dim_symbol, hidden_dim_symbol], 'dynamic_seq2seq/decoder/DeviceWrapperZeroState/AttentionWrapperZeroState/Const': [hidden_dim_symbol], 'dynamic_seq2seq/decoder/DeviceWrapperZeroState/AttentionWrapperZeroState/Const_1': [hidden_dim_symbol], 'dynamic_seq2seq/decoder/DeviceWrapperZeroState/AttentionWrapperZeroState/DeviceWrapperZeroState/DropoutWrapperZeroState/BasicLSTMCellZeroState/Const': [hidden_dim_symbol], 'dynamic_seq2seq/decoder/DeviceWrapperZeroState/AttentionWrapperZeroState/DeviceWrapperZeroState/DropoutWrapperZeroState/BasicLSTMCellZeroState/Const_\d': [hidden_dim_symbol], 'buffer_size': 256 * hidden_dim_symbol, 'buffer_size_1': 256 * hidden_dim_symbol, 'buffer_size_[2-8]': 125 * hidden_dim_symbol, } else: raise NotImplementedError( 'Manually set constant op values for domain {}'.format(domain)) graph.bindConstantValues(const_dict) # Next, bind the constant, placeholder, and variable shapes and propagate if domain == 'wordlm': bind_dict = { # Constants 'Model/Gradient/Compute/gradients/Model/Recurrent_\d_lstm_1/rnn/while/rnn/BiasAdd/Enter_grad/b_acc': [4 * hidden_dim_symbol], 'Model/Gradient/Compute/gradients/Model/Recurrent_\d_lstm_1/rnn/while/rnn/MatMul/Enter_grad/b_acc': [2 * hidden_dim_symbol, 4 * hidden_dim_symbol], # Placeholders 'Input/Input': [subbatch_size_symbol, sequence_length_symbol], 'Labels/Labels': [subbatch_size_symbol, sequence_length_symbol], 'Model/Placeholder': [subbatch_size_symbol, hidden_dim_symbol], 'Model/Placeholder_\d': [subbatch_size_symbol, hidden_dim_symbol], # Variables 'Model/Embedding_1/EmbeddingWeights': [vocab_size_symbol, hidden_dim_symbol], 'Model/FullSoftmaxLoss_1/W_Softmax': [vocab_size_symbol, hidden_dim_symbol], 'Model/FullSoftmaxLoss_1/b_Softmax': [1, vocab_size_symbol], 'Model/Recurrent_\d_lstm/rnn/Bias': [4 * hidden_dim_symbol], 'Model/Recurrent_\d_lstm/rnn/Matrix': [2 * hidden_dim_symbol, 4 * hidden_dim_symbol], } elif domain == 'charlm': bind_dict = { # Constants 'Model/Gradient/Compute/gradients/Model/Recurrent_1_rhn_1/rnn/while/[ht]_0/[ht]_0/BiasAdd/Enter_grad/b_acc': [hidden_dim_symbol], 'Model/Gradient/Compute/gradients/Model/Recurrent_1_rhn_1/rnn/while/[ht]_0/[ht]_0/MatMul/Enter_grad/b_acc': [2 * hidden_dim_symbol, hidden_dim_symbol], 'Model/Gradient/Compute/gradients/Model/Recurrent_1_rhn_1/rnn/while/[ht]_[1-9]/[ht]_[1-9]/BiasAdd/Enter_grad/b_acc': [hidden_dim_symbol], 'Model/Gradient/Compute/gradients/Model/Recurrent_1_rhn_1/rnn/while/[ht]_[1-9]/[ht]_[1-9]/MatMul/Enter_grad/b_acc': [hidden_dim_symbol, hidden_dim_symbol], 'Model/Recurrent_1_rhn/rnn/[ht]_[0-9]/Bias/Initializer/Const': [hidden_dim_symbol], # Placeholders 'Input/Input': [subbatch_size_symbol, sequence_length_symbol], 'Labels/Labels': [subbatch_size_symbol, sequence_length_symbol], 'Model/Placeholder': [subbatch_size_symbol, hidden_dim_symbol], 'Model/Placeholder_1': [subbatch_size_symbol, hidden_dim_symbol], # Variables 'Model/Embedding_1/EmbeddingWeights': [vocab_size_symbol, hidden_dim_symbol], 'Model/FullSoftmaxLoss_1/W_Softmax': [vocab_size_symbol, hidden_dim_symbol], 'Model/FullSoftmaxLoss_1/b_Softmax': [1, vocab_size_symbol], 'Model/Recurrent_1_rhn/rnn/h_0/Bias': [hidden_dim_symbol], 'Model/Recurrent_1_rhn/rnn/h_0/Matrix': [2 * hidden_dim_symbol, hidden_dim_symbol], 'Model/Recurrent_1_rhn/rnn/h_[1-9]/Bias': [hidden_dim_symbol], 'Model/Recurrent_1_rhn/rnn/h_[1-9]/Matrix': [hidden_dim_symbol, hidden_dim_symbol], 'Model/Recurrent_1_rhn/rnn/t_0/Bias': [hidden_dim_symbol], 'Model/Recurrent_1_rhn/rnn/t_0/Matrix': [2 * hidden_dim_symbol, hidden_dim_symbol], 'Model/Recurrent_1_rhn/rnn/t_[1-9]/Bias': [hidden_dim_symbol], 'Model/Recurrent_1_rhn/rnn/t_[1-9]/Matrix': [hidden_dim_symbol, hidden_dim_symbol], } elif domain == 'nmt': # HAX: Manually hack the iterator op it_op = graph.opsByName['IteratorGetNext'] it_op._outputs[0].mergeShape( [subbatch_size_symbol, sequence_length_symbol], make_symbolic=True) it_op._outputs[1].mergeShape( [subbatch_size_symbol, sequence_length_symbol], make_symbolic=True) it_op._outputs[2].mergeShape( [subbatch_size_symbol, sequence_length_symbol], make_symbolic=True) it_op._outputs[3].mergeShape([subbatch_size_symbol], make_symbolic=True) it_op._outputs[4].mergeShape([subbatch_size_symbol], make_symbolic=True) bind_dict = { # Constants 'gradients/dynamic_seq2seq/decoder/decoder/while/BasicDecoderStep/decoder/attention/attention_layer/MatMul/Enter_grad/b_acc': [3 * hidden_dim_symbol, hidden_dim_symbol], 'gradients/dynamic_seq2seq/decoder/decoder/while/BasicDecoderStep/decoder/attention/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * hidden_dim_symbol], 'gradients/dynamic_seq2seq/decoder/decoder/while/BasicDecoderStep/decoder/attention/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * hidden_dim_symbol, 4 * hidden_dim_symbol], 'gradients/dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * hidden_dim_symbol], 'gradients/dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [2 * hidden_dim_symbol, 4 * hidden_dim_symbol], # Placeholders # Variables 'dynamic_seq2seq/decoder/attention/attention_layer/kernel': [3 * hidden_dim_symbol, hidden_dim_symbol], 'dynamic_seq2seq/decoder/attention/basic_lstm_cell/bias': [4 * hidden_dim_symbol], 'dynamic_seq2seq/decoder/attention/basic_lstm_cell/kernel': [3 * hidden_dim_symbol, 4 * hidden_dim_symbol], 'dynamic_seq2seq/decoder/memory_layer/kernel': [2 * hidden_dim_symbol, hidden_dim_symbol], 'dynamic_seq2seq/decoder/output_projection/kernel': [hidden_dim_symbol, vocab_size_symbol], 'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/basic_lstm_cell/bias': [4 * hidden_dim_symbol], 'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/basic_lstm_cell/kernel': [2 * hidden_dim_symbol, 4 * hidden_dim_symbol], 'embeddings/decoder/embedding_decoder': [vocab_size_symbol, hidden_dim_symbol], 'embeddings/encoder/embedding_encoder': [vocab_size_symbol, hidden_dim_symbol], } else: raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain)) print('Binding variables') graph.bindShapesAndPropagate(bind_dict, warn_if_ill_defined=(not is_pytest_run), make_symbolic=True) assert graph.isValid() num_sampled_vocab_symbol = subbatch_size_symbol * sequence_length_symbol if domain == 'wordlm': base_sequence_length = 80 base_subbatch_size = 64 base_num_sampled_vocab = base_subbatch_size * base_sequence_length bind_str_subs = { 'Model/Collapse_1/boolean_mask/Reshape_1:0::num_true': sequence_length_symbol * subbatch_size_symbol, 'Model/Collapse/boolean_mask/Reshape_1:0::num_true': sequence_length_symbol * subbatch_size_symbol, 'Model/Labels_1/boolean_mask/Reshape_1:0::num_true': sequence_length_symbol * subbatch_size_symbol, 'Model/Labels/boolean_mask/Reshape_1:0::num_true': sequence_length_symbol * subbatch_size_symbol, 'Model/Gradient/Compute/gradients/b_count_2_block::iters': sequence_length_symbol, 'Model/Gradient/Compute/gradients/b_count_6_block::iters': sequence_length_symbol, 'Model/Recurrent_1_lstm_1/rnn/while/LoopCond_block::iters': sequence_length_symbol, 'Model/Recurrent_1_lstm_3/rnn/while/LoopCond_block::iters': sequence_length_symbol, 'Model/Recurrent_2_lstm_1/rnn/while/LoopCond_block::iters': sequence_length_symbol, 'Model/Recurrent_2_lstm_3/rnn/while/LoopCond_block::iters': sequence_length_symbol, } elif domain == 'charlm': base_sequence_length = 150 base_subbatch_size = 128 bind_str_subs = { 'Model/Collapse/boolean_mask/Reshape_1:0::num_true': sequence_length_symbol * subbatch_size_symbol, 'Model/Labels/boolean_mask/Reshape_1:0::num_true': sequence_length_symbol * subbatch_size_symbol, 'Model/Recurrent_1_rhn_1/rnn/while/LoopCond_block::iters': sequence_length_symbol, 'Model/Gradient/Compute/gradients/b_count_2_block::iters': sequence_length_symbol, } elif domain == 'nmt': bind_str_subs = { 'dynamic_seq2seq/decoder/decoder/while/LoopCond_block::iters': sequence_length_symbol, 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters': sequence_length_symbol, 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters': sequence_length_symbol, 'gradients/b_count_10_block::iters': sequence_length_symbol, 'gradients/b_count_2_block::iters': sequence_length_symbol, 'gradients/b_count_6_block::iters': sequence_length_symbol, } else: raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain)) if not is_pytest_run: print('\n\nCleaned Graph:\n{}'.format(graph)) print('\n\nBound values') bind_subs = { graph_iters_symbol: 1, hidden_dim_symbol: base_hidden_dim, sequence_length_symbol: base_sequence_length, subbatch_size_symbol: base_subbatch_size, vocab_size_symbol: base_vocab_size, } var_refs_table = {} for var_name, sub_val in bind_str_subs.items(): var_ref = utils.getIntSymbolFromString(var_name) assert var_name not in bind_subs.keys() bind_subs[var_ref] = sub_val var_refs_table[var_name] = var_ref # Verify parameter counts first parameters = graph.calcModelParameters() if domain == 'wordlm': correct_symbolic_params = 16 * hidden_dim_symbol**2 + \ 2 * hidden_dim_symbol * vocab_size_symbol + \ 8 * hidden_dim_symbol + \ vocab_size_symbol + 2 correct_params = 104378326 correct_flops = 2597058084257 correct_bytes = 143652774404 correct_total_footprint = 49660373192 elif domain == 'charlm': correct_symbolic_params = 22 * hidden_dim_symbol**2 + \ 2 * hidden_dim_symbol * vocab_size_symbol + \ 20 * hidden_dim_symbol + \ vocab_size_symbol + 2 correct_params = 88785316 correct_flops = 10228050930711 correct_bytes = 445302356084 correct_total_footprint = 156135676796 elif domain == 'nmt': correct_symbolic_params = 33 * hidden_dim_symbol**2 + \ 3 * hidden_dim_symbol * vocab_size_symbol + \ 12 * hidden_dim_symbol + 1 correct_params = 146890753 correct_flops = 1053984410589 correct_bytes = 36901043741 correct_total_footprint = 14551615608 else: raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain)) assert sympy.simplify(parameters - correct_symbolic_params) == 0, \ 'Param count incorrect!\n Expecting: {}\n Calculated: {}' \ .format(correct_symbolic_params, parameters) print('Symbol associations: {}\n'.format(bind_subs)) # Calculate model parameter count resolved_params = parameters.subs(bind_subs) try: resolved_params = int(resolved_params) except: print('ERROR: resolved_params should be int, but is {} = {}'.format( type(resolved_params), resolved_params)) assert resolved_params == correct_params, \ 'Incorrect model params: {}'.format(resolved_params) print('Parameters: {}\nWith specified dims: {}\n'.format( parameters, resolved_params)) # Calculate algorithmic Flops alg_flops = graph.calcAlgFlops() resolved_flops = alg_flops.subs(bind_subs) try: resolved_flops = int(resolved_flops) except: print('ERROR: resolved_flops should be int, but is {} = {}'.format( type(resolved_flops), resolved_flops)) assert resolved_flops == correct_flops, \ 'Incorrect algorithmic flops: {}'.format(resolved_flops) print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format( alg_flops, resolved_flops)) # Calculate algorthmic Bytes accessed alg_bytes = graph.calcAlgBytes() resolved_bytes = alg_bytes.subs(bind_subs) try: resolved_bytes = int(resolved_bytes) except: print('ERROR: resolved_bytes should be int, but is {} = {}'.format( type(resolved_bytes), resolved_bytes)) assert resolved_bytes == correct_bytes, \ 'Incorrect algorithmic bytes: {}'.format(resolved_bytes) print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format( alg_bytes, resolved_bytes)) # Calculate total memory footprint alg_footprint = graph.calcAlgFootprint() resolved_footprint = alg_footprint.subs(bind_subs) try: resolved_footprint = int(resolved_footprint) except: print('ERROR: resolved_footprint should be int, but is {} = {}'.format( type(resolved_footprint), resolved_footprint)) assert resolved_footprint == correct_total_footprint, \ 'Incorrect algorithmic footprint: {}'.format(resolved_footprint) print('Alg mem footprint: {}\nWith specified dims: {}\n'.format( alg_footprint, resolved_footprint)) # Calculate minimal memory footprint alg_min_footprint = graph.calcMinimalFootprint(symbol_subs=bind_subs) print('Alg minimal footprint (With specified dims): {}\n'.format( alg_min_footprint)) # Calculate algorithmic IO per step total_io_footprint = 0 for op in graph.getPlaceholders(): total_io_footprint += op.calcAlgFootprint() if isinstance(total_io_footprint, int): resolved_io_footprint = total_io_footprint else: resolved_io_footprint = total_io_footprint.subs(bind_subs) print('Alg IO footprint: {}\nWith specified dims: {}\n'.format( total_io_footprint, resolved_io_footprint)) if not is_pytest_run: print('VERBOSE ALGORTHMIC FLOPS:') graph.calcAlgFlops(verbose=True) print('') print('VERBOSE ALGORTHMIC BYTES:') graph.calcAlgBytes(verbose=True) print('') print('VERBOSE ALGORTHMIC FOOTPRINT:') graph.calcAlgFootprint(verbose=True) print('') # HACKY WAY TO SAVE MODELS FOR NOW! pickle.dump( graph, open( 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/graph_{}.p' .format(domain), 'wb')) if is_pytest_run: return print('\n\n======= Algorithmic graph-level analytics: =======') if domain == 'wordlm': hidden_dims = [ 1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 18, 20, 25, 28, 35, 40, 50, 56, 69, 78, 86, 96, 108, 119, 123, 133, 148, 163, 182, 202, 221, 246, 273, 297, 329, 330, 364, 396, 436, 437, 520, 572, 617, 676, 740, 796, 869, 948, 1017, 1106, 1202, 1286, 1394, 1510, 1611, 1742, 1882, 2004, 2161, 2476, 3040, 3714, 4520, 5478, 6628, 8019, 9702, 11739, 14204, 17186, 20795, 25161, 30444, 36837, 38100 ] bind_subs[subbatch_size_symbol] = 128 elif domain == 'charlm': hidden_dims = [ 1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 18, 20, 25, 28, 35, 40, 50, 56, 69, 78, 86, 96, 108, 119, 123, 133, 148, 163, 182, 202, 221, 246, 273, 297, 329, 330, 364, 396, 436, 437, 520, 572, 617, 676, 740, 796, 869, 948, 1017, 1106, 1202, 1286, 1394, 1510, 1611, 1742, 1882, 2004, 2161, 2476, 3040, 3714, 5051, 6869, 9341, 12703, 17276, 23495, 31953, 43456, 59100, 80376, 81400 ] bind_subs[subbatch_size_symbol] = 96 elif domain == 'nmt': hidden_dims = [ 32, 64, 96, 128, 192, 256, 384, 512, 768, 1024, 1280, 1536, 2048, 2560, 3072, 3747, 4571, 5576, 6802, 8298, 10123, 12350, 15067, 18381, 22350 ] bind_subs[subbatch_size_symbol] = 96 bind_subs[sequence_length_symbol] = 26 else: raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain)) bind_subs.pop(hidden_dim_symbol) resolved_params = parameters.subs(bind_subs) print('Symbol associations: {}\n'.format(bind_subs)) print( 'Algorithmic Flops by hidden dimension, params, and per-batch-sample:') resolved_flops = alg_flops.subs(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) graph_flops = resolved_flops.subs({hidden_dim_symbol: hid_dim}) graph_flops_per_sample = float(graph_flops) / \ bind_subs[subbatch_size_symbol] print('{}\t{}\t{}\t{}'.format(hid_dim, graph_params, graph_flops, int(graph_flops_per_sample))) print('\nAlgorithmic bytes accessed by hidden dimension, params:') resolved_bytes = alg_bytes.subs(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) graph_bytes = resolved_bytes.subs({hidden_dim_symbol: hid_dim}) print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_bytes)) print('\nAlgorithmic total memory footprint by hidden dimension, params:') resolved_footprint = alg_footprint.subs(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) graph_footprint = resolved_footprint.subs({hidden_dim_symbol: hid_dim}) print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_footprint)) print( '\nAlgorithmic minimal memory footprint by hidden dimension, params:') full_subs = dict(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) full_subs[hidden_dim_symbol] = hid_dim graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs) print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_min_foot))
full_subs = dict(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) full_subs[hidden_dim_symbol] = hid_dim graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs) print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_min_foot)) if False: if domain == 'wordlm': if args.build_projection: # This is hacky anyway... Import required parts here: from catamount.ops.optimizer_ops import * from catamount.tensors.tensor import * projection_dim_symbol = utils.getIntSymbolFromString( 'projection_dim') # (1) Project output of the second recurrent layer. Save the # consumers of the output to send the projected values there proj_in_op = graph.opsByName['Model/Collapse/Reshape'] proj_input = proj_in_op._outputs[0] proj_input_consumers = proj_input._consumers proj_input._consumers = {} # (1a) Create projection matrix proj_weights = catamount.variable( 'Model/Collapse/projection/W', [hidden_dim_symbol, projection_dim_symbol], graph) # (1b) Create matrix multiply for projection proj_mm_out = catamount.matmul('Model/Collapse/projection/MatMul', [None, projection_dim_symbol], proj_input, proj_weights, graph) # (2) Feed projection to output consumers
def run_tf_image_resnet(depth, filter_scale=1.0): global is_pytest_run model_string = '_d{}_fs{}_'.format(depth, filter_scale) test_outputs_dir = 'catamount/frameworks/example_graphs/tensorflow/full_models/image_classification' graph_meta = None for root, dirs, files in os.walk(test_outputs_dir): for filename in files: if 'graph{}'.format( model_string) in filename and '.meta' in filename: # Take the first graph that we find in the directory graph_meta = os.path.join(root, filename) break if graph_meta is not None: break if graph_meta is None: raise FileNotFoundError( 'Unable to find model string {} in directory {}'.format( model_string, test_outputs_dir)) graph = catamount.frameworks.tensorflow.import_graph(graph_meta) assert graph.isValid() # ============ TO REMOVE INITIALIZATION OPS! ============= # NOTE: This code is pretty general and is likely to be migrated into # Catamount code for removing TF-specific initialization ops from catamount.ops import AssignOp from catamount.ops import VariableOp assign_ops = set() for op in graph.opsByName.values(): if isinstance(op, AssignOp): assign_ops.add(op) for assign_op in assign_ops: my_ancestors = set() my_frontier = set() my_frontier.add(assign_op) while len(my_frontier) > 0: next_op = my_frontier.pop() for in_tensor in next_op.inputs: if not isinstance(in_tensor.producer, VariableOp): my_frontier.add(in_tensor.producer) my_ancestors.add(next_op) for next_op in my_ancestors: graph.removeOp(next_op) assert graph.isValid() # Manually remove the inference parts of graph graph_ops = list(graph._ops_by_name.values()) for op in graph_ops: # Certain ops are only used for inference if 'InferenceTower/' in op.name or \ 'InferenceRunner/' in op.name or \ op.name == 'MergeAllSummariesRunWithOp/Merge/MergeSummary': graph.removeOp(op) assert graph.isValid() print('Initial graph:\n{}\n'.format(graph)) init_params = graph.calcModelParameters() print('Initial parameters: {}'.format(init_params)) print('Initial Flops: {}\n'.format(graph.calcAlgFlops())) print('Placeholders:') for op in graph.getPlaceholders(): print(op.debugString()) print('') # Set up symbols to name dimensions output_classes_symbol = utils.getPositiveIntSymbolFromString('out_classes') subbatch_size_symbol = utils.getPositiveIntSymbolFromString( 'subbatch_size') image_height_symbol = utils.getPositiveIntSymbolFromString('image_height') image_width_symbol = utils.getPositiveIntSymbolFromString('image_width') num_in_channels_symbol = utils.getPositiveIntSymbolFromString( 'num_in_channels') graph_iters_symbol = utils.getIntSymbolFromString('graph::iters') feature_channels_symbol = utils.getPositiveIntSymbolFromString( 'feature_channels') # Find and replace convolution/pooling dimensions also: # Dimension(64 * 2^k): conv/pool feature channels base_output_classes = 1000 base_num_in_channels = 3 base_feature_channels = 64 base_image_height = 224 base_image_width = 224 base_half_im_height = 112 half_im_height_symbol = image_height_symbol // 2 base_half_im_width = 112 half_im_width_symbol = image_width_symbol // 2 base_quart_im_height = 56 quart_im_height_symbol = (image_height_symbol // 2) // 2 base_quart_im_width = 56 quart_im_width_symbol = (image_width_symbol // 2) // 2 base_eighth_im_height = 28 eighth_im_height_symbol = ((image_height_symbol // 2) // 2) // 2 base_eighth_im_width = 28 eighth_im_width_symbol = ((image_width_symbol // 2) // 2) // 2 base_sixtnth_im_height = 14 sixtnth_im_height_symbol = (((image_height_symbol // 2) // 2) // 2) // 2 base_sixtnth_im_width = 14 sixtnth_im_width_symbol = (((image_width_symbol // 2) // 2) // 2) // 2 base_small_im_height = 7 small_im_height_symbol = ((( (image_height_symbol // 2) // 2) // 2) // 2) // 2 base_small_im_width = 7 small_im_width_symbol = ((((image_width_symbol // 2) // 2) // 2) // 2) // 2 # Set up a dictionary of placeholders and variables for which we want # to make dimensions symbolic. Sift out their dimensions bind_dict = { # Placeholders 'label': [subbatch_size_symbol], 'input': [subbatch_size_symbol, image_height_symbol, image_width_symbol, num_in_channels_symbol], } # Parameterize all variable tensor dimensions for op in graph._ops_by_name.values(): if isinstance(op, VariableOp): op_name_suffix = op.name.split('/')[-1] if op_name_suffix == 'W': if op._outputs[0].shape.rank == 4: assert 'conv' in op.name new_shape = [] for i in range(op._outputs[0].shape.rank): new_shape.append( op._outputs[0].shape.getDimension(i).value) if new_shape[2] % base_feature_channels == 0: in_filters = (new_shape[2] // \ base_feature_channels) * \ feature_channels_symbol elif new_shape[2] == 3: # This is the first convolution on image channels (3) assert op.name == 'conv0/W' in_filters = num_in_channels_symbol else: print('FIX ME: base in filters {}'.format( new_shape[2])) assert 0 if new_shape[3] % base_feature_channels == 0: out_filters = (new_shape[3] // \ base_feature_channels) * \ feature_channels_symbol else: print('FIX ME: base out filters {}'.format( new_shape[3])) assert 0 new_shape[2] = in_filters new_shape[3] = out_filters else: # This is the output layer with output_classes dimension assert op.name == 'linear/W' assert op._outputs[0].shape.rank == 2 in_dim = op._outputs[0].shape.getDimension(0).value assert in_dim % base_feature_channels == 0 in_dim = (in_dim // base_feature_channels) * \ feature_channels_symbol new_shape = [in_dim, output_classes_symbol] bind_dict[op.name] = new_shape momentum_op_name = '{}/Momentum'.format(op.name) momentum_op = graph._ops_by_name[momentum_op_name] bind_dict[momentum_op.name] = new_shape elif op_name_suffix == 'b': # This is the output layer with output_classes dimension assert op.name == 'linear/b' assert op._outputs[0].shape.rank == 1 assert op._outputs[0].shape.getDimension(0).value == \ base_output_classes new_shape = [output_classes_symbol] bind_dict[op.name] = new_shape momentum_op_name = '{}/Momentum'.format(op.name) momentum_op = graph._ops_by_name[momentum_op_name] bind_dict[momentum_op.name] = new_shape elif op_name_suffix == 'beta' or op_name_suffix == 'gamma' or \ op_name_suffix == 'EMA': assert op._outputs[0].shape.rank == 1 in_dim = op._outputs[0].shape.getDimension(0).value assert in_dim % base_feature_channels == 0 in_dim = (in_dim // base_feature_channels) * \ feature_channels_symbol new_shape = [in_dim] bind_dict[op.name] = new_shape if op_name_suffix != 'EMA': momentum_op_name = '{}/Momentum'.format(op.name) momentum_op = graph._ops_by_name[momentum_op_name] bind_dict[momentum_op.name] = new_shape # Now handle constant values in the graph const_dict = {} for op in graph._ops_by_name.values(): if isinstance(op, ConstantOp): if op._outputs[0].value is None: continue if op._outputs[0].shape.rank == 0: print('{}'.format(op.debugString())) continue assert op._outputs[0].shape.rank == 1 values = op._outputs[0].value.tolist() new_values = [] changed = False for value in values: if value > 0 and value % base_feature_channels == 0: value = (value // base_feature_channels) * feature_channels_symbol changed = True new_values.append(value) # HACKY SPECIAL CASE: if op.name == 'tower0/gradients/tower0/conv0/Conv2D_grad/Const': assert new_values[2] == base_num_in_channels new_values[2] = num_in_channels_symbol if changed: const_dict[op.name] = new_values for const_key, const_val in const_dict.items(): try: if const_key in graph.opsByName.keys(): const_op = graph.opsByName[const_key] assert isinstance(const_op, ConstantOp) const_op._outputs[0].setValue(const_val) else: print('WARN: ConstantOp not found: {}'.format(const_key)) except Exception as exc: print('WARN: ConstantOp unknown problem: {}: {}'.format( const_key, exc)) # TODO (Joel): Add InputQueue ops to avoid manually setting dimensions in_deque_op = graph.opsByName['QueueInput/input_deque'] # print(in_deque_op.debugString()) out_tensor = in_deque_op._outputs[0] for idx, sym in enumerate([ subbatch_size_symbol, image_height_symbol, image_width_symbol, num_in_channels_symbol ]): out_tensor.shape.setDimension(idx, sym) out_tensor.shape.dims[idx]._value = None out_tensor = in_deque_op._outputs[1] out_tensor.shape.setDimension(0, subbatch_size_symbol) graph.bindTensorShapeDimensions(bind_dict, warn_if_ill_defined=(not is_pytest_run), make_symbolic=True) # Nice little hack to actually propagate MaximumOp values to outputs for op in graph._ops_by_name.values(): if isinstance(op, MaximumOp): if op._inputs[0].value is not None and \ op._inputs[1].value is not None: vmax = np.vectorize(lambda x, y: sympy.Max(x, y)) out_val = vmax(op._inputs[0].value, op._inputs[1].value) op._outputs[0].setValue(out_val) graph.bindTensorShapeDimensions(bind_dict, warn_if_ill_defined=(not is_pytest_run), make_symbolic=True) print('Bound values') print(graph) bind_subs = { graph_iters_symbol: 1, output_classes_symbol: base_output_classes, subbatch_size_symbol: 32, image_height_symbol: base_image_height, image_width_symbol: base_image_width, num_in_channels_symbol: base_num_in_channels, feature_channels_symbol: base_feature_channels, } correct_params = -1 correct_flops = -1 correct_bytes = -1 correct_footprint = -1 if depth == 18: correct_params = 11689514 correct_flops = 349684163360 correct_bytes = 7186222676 correct_footprint = 2802084304 elif depth == 34: correct_params = 21797674 correct_flops = 705506994208 correct_bytes = 11162578644 correct_footprint = 4368689744 elif depth == 50: correct_params = 25557034 correct_flops = 790954958112 correct_bytes = 32896462028 correct_footprint = 12909734408 elif depth == 101: correct_params = 44549162 correct_flops = 1506507229472 correct_bytes = 50026672916 correct_footprint = 19690293072 elif depth == 152: correct_params = 60192810 correct_flops = 2222688328992 correct_bytes = 70967716188 correct_footprint = 27971880088 else: print('WARN: Tests not defined for depth {}'.format(depth)) # Calculate parameters # NOTE: Need to remove Momentum optimizer parameters and moving average values momentum_params = 0 parameters = 0 for op_name in sorted(graph.opsByName.keys()): op = graph.opsByName[op_name] if isinstance(op, VariableOp): if "Momentum" in op.name or "EMA" in op.name: momentum_params += op.calcModelParameters() else: parameters += op.calcModelParameters() all_weights = graph.calcModelParameters() assert (all_weights - momentum_params - parameters) == 0 # Calculate model parameter count resolved_params = parameters.subs(bind_subs) try: resolved_params = int(resolved_params) except: print('ERROR: resolved_params should be int, but is {} = {}'.format( type(resolved_params), resolved_params)) assert correct_params < 0 or resolved_params == correct_params, \ 'Incorrect model params: {}'.format(resolved_params) print('Parameters: {}\nWith specified dims: {}\n'.format( parameters, resolved_params)) # Calculate algorithmic Flops alg_flops = graph.calcAlgFlops() resolved_flops = alg_flops.subs(bind_subs) try: resolved_flops = int(resolved_flops) except: print('ERROR: resolved_flops should be int, but is {} = {}'.format( type(resolved_flops), resolved_flops)) assert correct_flops < 0 or resolved_flops == correct_flops, \ 'Incorrect algorithmic flops: {}'.format(resolved_flops) print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format( alg_flops, resolved_flops)) # Calculate algorthmic Bytes accessed alg_bytes = graph.calcAlgBytes() resolved_bytes = alg_bytes.subs(bind_subs) try: resolved_bytes = int(resolved_bytes) except: print('ERROR: resolved_bytes should be int, but is {} = {}'.format( type(resolved_bytes), resolved_bytes)) assert correct_bytes < 0 or resolved_bytes == correct_bytes, \ 'Incorrect algorithmic bytes: {}'.format(resolved_bytes) print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format( alg_bytes, resolved_bytes)) # Calculate total memory footprint alg_footprint = graph.calcAlgFootprint() resolved_footprint = alg_footprint.subs(bind_subs) try: resolved_footprint = int(resolved_footprint) except: print('ERROR: resolved_footprint should be int, but is {} = {}'.format( type(resolved_footprint), resolved_footprint)) assert correct_footprint < 0 or resolved_footprint == correct_footprint, \ 'Incorrect algorithmic footprint: {}'.format(resolved_footprint) print('Alg mem footprint: {}\nWith specified dims: {}\n'.format( alg_footprint, resolved_footprint)) # Calculate algorithmic IO per step total_io_footprint = 0 for op in graph.getPlaceholders(): total_io_footprint += op.calcAlgFootprint() resolved_io_footprint = total_io_footprint.subs(bind_subs) print('Alg IO footprint: {}\nWith specified dims: {}\n'.format( total_io_footprint, resolved_io_footprint)) try: # In case the footprint code is not complete # Calculate minimal memory footprint print('Alg min mem footprint {}'.format( graph.calcMinFootprint(symbol_subs=bind_subs))) except: pass if not is_pytest_run: print('VERBOSE ALGORTHMIC FLOPS:') graph.calcAlgFlops(verbose=True) print('') print('VERBOSE ALGORTHMIC BYTES:') graph.calcAlgBytes(verbose=True) print('') print('VERBOSE ALGORTHMIC FOOTPRINT:') graph.calcAlgFootprint(verbose=True) print('') # HACKY WAY TO SAVE MODELS FOR NOW! pickle.dump( graph, open( 'catamount/frameworks/example_graphs/tensorflow/full_models/image_classification/graph_image_resnet_d{}_fs{}.p' .format(depth, filter_scale), 'wb')) if is_pytest_run: return print('\n\n======= Algorithmic graph-level analytics: =======') feature_channel_dims = [32, 48, 64, 96, 128] bind_subs.pop(feature_channels_symbol) resolved_params = parameters.subs(bind_subs) print('Symbol associations: {}\n'.format(bind_subs)) print( 'Algorithmic Flops by feature channels, params, and per-batch-sample:') resolved_flops = alg_flops.subs(bind_subs) for features_dim in feature_channel_dims: graph_params = resolved_params.subs( {feature_channels_symbol: features_dim}) graph_flops = resolved_flops.subs( {feature_channels_symbol: features_dim}) graph_flops_per_sample = float(graph_flops) / \ bind_subs[subbatch_size_symbol] print('{}\t{}\t{}\t{}'.format(features_dim, graph_params, graph_flops, int(graph_flops_per_sample))) print('\nAlgorithmic bytes accessed by feature channels, params:') resolved_bytes = alg_bytes.subs(bind_subs) for features_dim in feature_channel_dims: graph_params = resolved_params.subs( {feature_channels_symbol: features_dim}) graph_bytes = resolved_bytes.subs( {feature_channels_symbol: features_dim}) print('{}\t{}\t{}'.format(features_dim, graph_params, graph_bytes)) print('\nAlgorithmic total memory footprint by feature channels, params:') resolved_footprint = alg_footprint.subs(bind_subs) for features_dim in feature_channel_dims: graph_params = resolved_params.subs( {feature_channels_symbol: features_dim}) graph_footprint = resolved_footprint.subs( {feature_channels_symbol: features_dim}) print('{}\t{}\t{}'.format(features_dim, graph_params, graph_footprint)) print( '\nAlgorithmic minimal memory footprint by feature channels, params:') full_subs = dict(bind_subs) for features_dim in feature_channel_dims: graph_params = resolved_params.subs( {feature_channels_symbol: features_dim}) full_subs[feature_channels_symbol] = features_dim graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs) print('{}\t{}\t{}'.format(features_dim, graph_params, graph_min_foot))
def test_tf_dynamic_rnn(): graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename) print('INITIAL GRAPH: {}\n\n'.format(graph)) assert graph.isValid() algorithmic_flops = graph.calcAlgFlops() # Symbols we will use batch_size = utils.getIntSymbolFromString('batch_size') seq_length = utils.getIntSymbolFromString('seq_length') hidden_dim = utils.getIntSymbolFromString('hidden_dim') graph_iters = utils.getIntSymbolFromString('graph::iters') rwb_iters = utils.getIntSymbolFromString('rnn/while/LoopCond_block::iters') a_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Add:0::dim_0') a1_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Add_1:0::dim_0') ba_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/BiasAdd:0::dim_0') mm_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/MatMul:0::dim_0') m_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Mul:0::dim_0') m1_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Mul_1:0::dim_0') m2_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Mul_2:0::dim_0') s_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Sigmoid:0::dim_0') s1_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Sigmoid_1:0::dim_0') s2_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Sigmoid_2:0::dim_0') sm_r_0 = utils.getIntSymbolFromString('softmax/Reshape:0::dim_0') sm_r_1 = utils.getIntSymbolFromString('softmax/Reshape:0::dim_1') th_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Tanh:0::dim_0') th1_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Tanh_1:0::dim_0') forward_flops = rwb_iters * (hidden_dim * a_0 + hidden_dim * a1_0 + 4 * hidden_dim * ba_0 + 16 * hidden_dim**2 * mm_0 + \ hidden_dim * m_0 + hidden_dim * m1_0 + hidden_dim * m2_0 + 4 * hidden_dim * s_0 + \ 4 * hidden_dim * s1_0 + 4 * hidden_dim * s2_0 + 6 * hidden_dim * th_0 + 6 * hidden_dim * th1_0 + 6) + \ 3 * sm_r_0 * sm_r_1 + \ 16 * hidden_dim**2 + 8 * hidden_dim + 3 grad_iters = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/b_count_2_block::iters') g_an_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/AddN:0::dim_0') g_an1_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/AddN_1:0::dim_0') g_mm_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/MatMul_grad/MatMul:0::dim_0' ) g_mms_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/MatMul_grad/MatMul_1/StackPopV2:0::dim_0' ) g_ms_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul/StackPopV2:0::dim_0' ) g_m_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul:0::dim_0' ) g_ms1_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul_1/StackPopV2:0::dim_0' ) g_m1_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul_1:0::dim_0' ) g_ms2_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul/StackPopV2:0::dim_0' ) g_m2_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul:0::dim_0' ) g_ms21_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul_1/StackPopV2:0::dim_0' ) g_m21_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul_1:0::dim_0' ) g_mus_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul/StackPopV2:0::dim_0' ) g_mu_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul:0::dim_0' ) g_mu1_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul_1:0::dim_0' ) g_s_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Sigmoid_grad/SigmoidGrad:0::dim_0' ) g_c_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/split_grad/concat:0::dim_0' ) g_sm_m_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/softmax/softmax_grad/mul:0::dim_0') g_sm_m_1 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/softmax/softmax_grad/mul:0::dim_1') backward_flops = grad_iters * (hidden_dim * g_an_0 + 3 * hidden_dim * g_an1_0 + 16 * hidden_dim**2 * g_mm_0 + 16 * hidden_dim**2 * g_mms_0 + \ 3 * hidden_dim * g_ms_0 + 2 * hidden_dim * g_m_0 + 3 * hidden_dim * g_ms1_0 + 2 * hidden_dim * g_m1_0 + \ 3 * hidden_dim * g_ms2_0 + 2 * hidden_dim * g_m2_0 + 3 * hidden_dim * g_ms21_0 + 2 * hidden_dim * g_m21_0 + \ 3 * hidden_dim * g_mus_0 + 2 * hidden_dim * g_mu_0 + 2 * hidden_dim * g_mu1_0 + 2 * hidden_dim * g_s_0 + \ 4 * hidden_dim * g_c_0 + 8 * hidden_dim**2 + 4 * hidden_dim + 2) + \ g_sm_m_0 * g_sm_m_1 general_correct_alg_flops = forward_flops + backward_flops correct_alg_flops = general_correct_alg_flops.subs({hidden_dim: 24}) # Now, bind tensor names in the graph and verify that the algorithmic # Flop counts reflect the new name bindings print('Loaded Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Initial alg flops incorrect!\n Expecting: {}\n Calculated: {}' \ .format(correct_alg_flops, algorithmic_flops) # Manually set some variables # TODO (Joel): Fix this up when all tensor arrays work! ta_op = graph.opsByName['rnn/TensorArrayStack/TensorArraySizeV3'] ta_op._outputs[0].setValue(seq_length) ta_op = graph.opsByName['rnn/TensorArrayStack/TensorArrayGatherV3'] ta_op._outputs[0].mergeShape([seq_length, batch_size, hidden_dim], make_symbolic=True) ta_op = graph.opsByName['rnn/while/TensorArrayReadV3'] ta_op._outputs[0].mergeShape([batch_size, hidden_dim], make_symbolic=True) ta_op = graph.opsByName[ 'Gradient/Compute/gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3'] ta_op._outputs[0].mergeShape([batch_size, hidden_dim], make_symbolic=True) # Bind constant values first const_dict = { # Store the shapes of certain tensors as constants 'rnn/Const': [hidden_dim], 'rnn/Const_1': [hidden_dim], } graph.bindConstantValues(const_dict) # NOTE: This also works: batch_size = 'batch_size' # Bind placeholders (a and b) output dimensions 0 to name batch_size bind_dict = { # Variables 'rnn/basic_lstm_cell/kernel': [2 * hidden_dim, 4 * hidden_dim], 'rnn/basic_lstm_cell/bias': [4 * hidden_dim], # Placeholders 'a': [batch_size, seq_length, hidden_dim], 'c_init_state': [batch_size, hidden_dim], 'h_init_state': [batch_size, hidden_dim], 'out_correct': [batch_size, seq_length], # Constants 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [2 * hidden_dim, 4 * hidden_dim], 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * hidden_dim], } graph.bindShapesAndPropagate(bind_dict, make_symbolic=True, warn_if_ill_defined=True) algorithmic_flops = graph.calcAlgFlops() # Update the algorithmic Flops formula # Sub the forward prop values correct_alg_flops = general_correct_alg_flops.subs({ ba_0: batch_size, a_0: batch_size, a1_0: batch_size, mm_0: batch_size, m_0: batch_size, m1_0: batch_size, m2_0: batch_size, s_0: batch_size, s1_0: batch_size, s2_0: batch_size, th_0: batch_size, th1_0: batch_size, sm_r_0: batch_size * seq_length, sm_r_1: hidden_dim, }) # Sub the backward prop values # TODO (Joel): Fix this up when all backprop works! correct_alg_flops = correct_alg_flops.subs({ g_mm_0: batch_size, g_mms_0: batch_size, g_ms_0: batch_size, g_m_0: batch_size, g_ms1_0: batch_size, g_m1_0: batch_size, g_ms2_0: batch_size, g_m2_0: batch_size, g_ms21_0: batch_size, g_m21_0: batch_size, g_mus_0: batch_size, g_mu_0: batch_size, g_mu1_0: batch_size, g_s_0: batch_size, g_c_0: batch_size, g_an_0: batch_size, g_an1_0: batch_size, g_sm_m_0: batch_size * seq_length, g_sm_m_1: hidden_dim, }) assert graph.isValid() print('BOUND GRAPH:\n{}\n\n'.format(graph)) print('Bound Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Bound alg flops incorrect!\n Expecting: {}\n Calculated: {}\n Difference: {}' \ .format(correct_alg_flops, algorithmic_flops, algorithmic_flops - correct_alg_flops) bind_subs = { # Symbols/names to bind in next tests: graph_iters: 1, rwb_iters: seq_length, grad_iters: seq_length, batch_size: 128, seq_length: 32, hidden_dim: 256, } print('\n\nBound values') print('Symbol associations: {}\n'.format(bind_subs)) # Calculate model parameter count parameters = graph.calcModelParameters() resolved_params = parameters.subs(bind_subs) try: resolved_params = int(resolved_params) except: print('ERROR: resolved_params should be int, but is {} = {}'.format( type(resolved_params), resolved_params)) correct_params = 525312 assert resolved_params == correct_params, \ 'Incorrect model params: {}'.format(resolved_params) print('Parameters: {}\nWith specified dims: {}\n'.format( parameters, resolved_params)) # Calculate algorithmic Flops alg_flops = graph.calcAlgFlops() resolved_flops = alg_flops.subs(bind_subs) try: resolved_flops = int(resolved_flops) except: print('ERROR: resolved_flops should be int, but is {} = {}'.format( type(resolved_flops), resolved_flops)) correct_flops = 12980357379 assert resolved_flops == correct_flops, \ 'Incorrect algorithmic flops: {}'.format(resolved_flops) print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format( alg_flops, resolved_flops)) # Calculate algorthmic Bytes accessed alg_bytes = graph.calcAlgBytes() resolved_bytes = alg_bytes.subs(bind_subs) try: resolved_bytes = int(resolved_bytes) except: print('ERROR: resolved_bytes should be int, but is {} = {}'.format( type(resolved_bytes), resolved_bytes)) correct_bytes = 1134017252 assert resolved_bytes == correct_bytes, \ 'Incorrect algorithmic bytes: {}'.format(resolved_bytes) print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format( alg_bytes, resolved_bytes)) # Calculate algorthmic Bytes accessed alg_footprint = graph.calcAlgFootprint() resolved_footprint = alg_footprint.subs(bind_subs) try: resolved_footprint = int(resolved_footprint) except: print('ERROR: resolved_footprint should be int, but is {} = {}'.format( type(resolved_footprint), resolved_footprint)) correct_footprint = 441447784 assert resolved_footprint == correct_footprint, \ 'Incorrect algorithmic footprint: {}'.format(resolved_footprint) print('Alg mem footprint: {}\nWith specified dims: {}\n'.format( alg_footprint, resolved_footprint)) # Calculate the minimal memory footprint for a step alg_min_footprint = graph.calcMinimalFootprint(symbol_subs=bind_subs) resolved_min_footprint = alg_min_footprint try: resolved_min_footprint = int(resolved_min_footprint) except: print('ERROR: resolved_min_footprint should be int, but is {} = {}'. format(type(resolved_min_footprint), resolved_min_footprint)) correct_min_footprint = 38153640 error_percent = abs(correct_min_footprint - resolved_min_footprint) / correct_min_footprint if error_percent > 0.15: print('Incorrect algorithmic footprint: {} (err: {})!'.format( resolved_min_footprint, error_percent)) print('Alg minimal footprint: {}\nWith specified dims: {} (err: {})\n'. format(alg_footprint, resolved_min_footprint, error_percent)) # Calculate algorithmic IO per step total_io_footprint = 0 for op in graph.getPlaceholders(): total_io_footprint += op.calcAlgFootprint() resolved_io_footprint = total_io_footprint.subs(bind_subs) print('Alg IO footprint: {}\nWith specified dims: {}\n'.format( total_io_footprint, resolved_io_footprint))
def __init__(self, name): super(MultinomialOp, self).__init__(name) samps_name = '{}::rand_samps'.format(self.name) self._num_samples_symbol = \ num_samples = utils.getIntSymbolFromString(samps_name)
def run_tf_w2v_model(): global is_pytest_run graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/word2vec_n200-latest_model.meta' graph = catamount.frameworks.tensorflow.import_graph(graph_meta) assert graph.isValid() # Next, remove ops that are not executed during a standard training step: graph_ops = list(graph._ops_by_name.values()) for op in graph_ops: # Certain ops are only used for inference if 'Model/NceLoss_1_3/' in op.name or \ 'Model/Collapse_1/' in op.name or \ 'Model/Embedding_1_3/' in op.name or \ 'Model/Labels_1/' in op.name or \ 'Model/SkipGramSampler_1/' in op.name or \ 'Model/Mask_1/' in op.name: graph.removeOp(op) elif \ op.name == 'Model/Cast_1' or \ op.name == 'Model/Sum_1' or \ op.name == 'Model/Size_1' or \ op.name == 'Model/Exp_1' or \ op.name == 'Model/truediv_2' or \ op.name == 'Model/truediv_3': graph.removeOp(op) if not is_pytest_run: print('Initial graph:\n{}\n'.format(graph)) init_params = graph.calcModelParameters() print('Initial parameters: {}'.format(init_params)) print('Initial Flops: {}\n'.format(graph.calcAlgFlops())) print('Placeholders:') for op in graph.getPlaceholders(): print(op.debugString()) print('') # Set up symbols to name dimensions skip_window_symbol = utils.getPositiveIntSymbolFromString('skip_window') num_skips_symbol = utils.getPositiveIntSymbolFromString('num_skips') nce_samples_symbol = utils.getPositiveIntSymbolFromString('nce_samples') hidden_dim_symbol = utils.getIntSymbolFromString('hidden_dim') vocab_size_symbol = utils.getIntSymbolFromString('vocab_size') subbatch_size_symbol = utils.getIntSymbolFromString('subbatch_size') sequence_length_symbol = utils.getIntSymbolFromString('sequence_length') batch_times_seq_symbol = sequence_length_symbol * subbatch_size_symbol graph_iters_symbol = utils.getIntSymbolFromString('graph::iters') # For simplicity, assign samples symbol in the op nce_samp_op = graph.opsByName[ 'Model/NceLoss_1_1/nce_loss/LogUniformCandidateSampler'] nce_samp_op._num_samples_symbol = nce_samples_symbol # Convert these constant dimensions to symbols base_skip_window = 8 base_num_skips = 8 base_nce_samples = 64 base_hidden_dim = 400 base_vocab_size = 40004 base_sequence_length = 32 base_subbatch_size = 1 # Find and set constants that contain model hyperparameters const_dict = { 'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/sub_1_grad/Shape_1': [nce_samples_symbol], 'Model/SkipGramSampler/Const': 2 * skip_window_symbol, 'Model/SkipGramSampler/strided_slice/stack': [0, skip_window_symbol], 'Model/SkipGramSampler/strided_slice/stack_1': [0, -skip_window_symbol], 'Model/Collapse/Reshape/shape': [-1, hidden_dim_symbol], 'Model/Gradient/Compute/gradients/Model/Embedding_1_1/Gather_grad/Shape': [vocab_size_symbol, hidden_dim_symbol], 'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/embedding_lookup_1_grad/Shape': [vocab_size_symbol], 'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/embedding_lookup_grad/Shape': [vocab_size_symbol, hidden_dim_symbol], 'Model/Mask/NotEqual/y': vocab_size_symbol - 3, 'Model/SkipGramSampler/Const_2': num_skips_symbol, 'Model/SkipGramSampler/Tile_1/multiples': [1, num_skips_symbol], 'Model/SkipGramSampler/Tile/multiples': [1, num_skips_symbol], } graph.bindConstantValues(const_dict) # Next, bind the constant, placeholder, and variable shapes and propagate bind_dict = { # Constants # Placeholders 'Input/Input': [subbatch_size_symbol, sequence_length_symbol], 'Labels/Labels': [subbatch_size_symbol, sequence_length_symbol], # Variables 'Model/NceLoss_1/b_Softmax': [vocab_size_symbol], 'Model/NceLoss_1/W_Softmax': [vocab_size_symbol, hidden_dim_symbol], 'Model/Embedding_1/EmbeddingWeights': [vocab_size_symbol, hidden_dim_symbol], } print('Binding variables') # HACK: For now, manually set GatherNd op shapes. Later, implement GatherNd gnd_op = graph.opsByName['Model/SkipGramSampler/GatherNd'] gnd_op.outputs[0].mergeShape([ subbatch_size_symbol, num_skips_symbol * (sequence_length_symbol - 2 * skip_window_symbol) ]) graph.bindShapesAndPropagate(bind_dict, warn_if_ill_defined=(not is_pytest_run), make_symbolic=True) assert graph.isValid() if not is_pytest_run: print('\n\nCleaned Graph:\n{}'.format(graph)) print('\n\nBound values') bind_subs = { graph_iters_symbol: 1, hidden_dim_symbol: base_hidden_dim, sequence_length_symbol: base_sequence_length, subbatch_size_symbol: base_subbatch_size, vocab_size_symbol: base_vocab_size, skip_window_symbol: base_skip_window, num_skips_symbol: base_num_skips, nce_samples_symbol: base_nce_samples, } # Verify parameter counts first parameters = graph.calcModelParameters() correct_params = 32043205 correct_flops = 21148823 correct_bytes = 23762537 correct_total_footprint = 137949925 print('Symbol associations: {}\n'.format(bind_subs)) # Calculate model parameter count resolved_params = parameters.subs(bind_subs) try: resolved_params = int(resolved_params) except: print('ERROR: resolved_params should be int, but is {} = {}'.format( type(resolved_params), resolved_params)) assert resolved_params == correct_params, \ 'Incorrect model params: {}'.format(resolved_params) print('Parameters: {}\nWith specified dims: {}\n'.format( parameters, resolved_params)) # Calculate algorithmic Flops alg_flops = graph.calcAlgFlops() resolved_flops = alg_flops.subs(bind_subs) try: resolved_flops = int(resolved_flops) except: print('ERROR: resolved_flops should be int, but is {} = {}'.format( type(resolved_flops), resolved_flops)) assert resolved_flops == correct_flops, \ 'Incorrect algorithmic flops: {}'.format(resolved_flops) print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format( alg_flops, resolved_flops)) # Calculate algorthmic Bytes accessed alg_bytes = graph.calcAlgBytes() resolved_bytes = alg_bytes.subs(bind_subs) try: resolved_bytes = int(resolved_bytes) except: print('ERROR: resolved_bytes should be int, but is {} = {}'.format( type(resolved_bytes), resolved_bytes)) assert resolved_bytes == correct_bytes, \ 'Incorrect algorithmic bytes: {}'.format(resolved_bytes) print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format( alg_bytes, resolved_bytes)) # Calculate total memory footprint alg_footprint = graph.calcAlgFootprint() resolved_footprint = alg_footprint.subs(bind_subs) try: resolved_footprint = int(resolved_footprint) except: print('ERROR: resolved_footprint should be int, but is {} = {}'.format( type(resolved_footprint), resolved_footprint)) assert resolved_footprint == correct_total_footprint, \ 'Incorrect algorithmic footprint: {}'.format(resolved_footprint) print('Alg mem footprint: {}\nWith specified dims: {}\n'.format( alg_footprint, resolved_footprint)) # Calculate minimal memory footprint alg_min_footprint = graph.calcMinimalFootprint(symbol_subs=bind_subs) print('Alg minimal footprint (With specified dims): {}\n'.format( alg_min_footprint)) # Calculate algorithmic IO per step total_io_footprint = 0 for op in graph.getPlaceholders(): total_io_footprint += op.calcAlgFootprint() if isinstance(total_io_footprint, int): resolved_io_footprint = total_io_footprint else: resolved_io_footprint = total_io_footprint.subs(bind_subs) print('Alg IO footprint: {}\nWith specified dims: {}\n'.format( total_io_footprint, resolved_io_footprint)) if not is_pytest_run: print('VERBOSE ALGORTHMIC FLOPS:') graph.calcAlgFlops(verbose=True) print('') print('VERBOSE ALGORTHMIC BYTES:') graph.calcAlgBytes(verbose=True) print('') print('VERBOSE ALGORTHMIC FOOTPRINT:') graph.calcAlgFootprint(verbose=True) print('') # HACKY WAY TO SAVE MODELS FOR NOW! pickle.dump( graph, open( 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/graph_word2vec.p', 'wb')) if is_pytest_run: return print('\n\n======= Algorithmic graph-level analytics: =======') hidden_dims = [ 1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 18, 20, 25, 28, 35, 40, 50, 56, 69, 78, 86, 96, 108, 119, 123, 133, 148, 163, 182, 202, 221, 246, 273, 297, 329, 330, 364, 396, 436, 437, 520, 572, 617, 676, 740, 796, 869, 948, 1017, 1106, 1202, 1286, 1394, 1510, 1611, 1742, 1882, 2004, 2161, 2476, 3040, 3714, 4520, 5478, 6628, 8019, 9702, 11739, 14204, 17186, 20795, 25161, 30444, 36837, 38100 ] bind_subs.pop(hidden_dim_symbol) resolved_params = parameters.subs(bind_subs) print('Symbol associations: {}\n'.format(bind_subs)) print( 'Algorithmic Flops by hidden dimension, params, and per-batch-sample:') resolved_flops = alg_flops.subs(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) graph_flops = resolved_flops.subs({hidden_dim_symbol: hid_dim}) graph_flops_per_sample = float(graph_flops) / \ bind_subs[subbatch_size_symbol] print('{}\t{}\t{}\t{}'.format(hid_dim, graph_params, graph_flops, int(graph_flops_per_sample))) print('\nAlgorithmic bytes accessed by hidden dimension, params:') resolved_bytes = alg_bytes.subs(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) graph_bytes = resolved_bytes.subs({hidden_dim_symbol: hid_dim}) print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_bytes)) print('\nAlgorithmic total memory footprint by hidden dimension, params:') resolved_footprint = alg_footprint.subs(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) graph_footprint = resolved_footprint.subs({hidden_dim_symbol: hid_dim}) print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_footprint)) print( '\nAlgorithmic minimal memory footprint by hidden dimension, params:') full_subs = dict(bind_subs) for hid_dim in hidden_dims: graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim}) full_subs[hidden_dim_symbol] = hid_dim graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs) print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_min_foot))
def test_tf_dynamic_rnn(): graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename) print('INITIAL GRAPH: {}\n\n'.format(graph)) assert graph.isValid() algorithmic_flops = graph.calcAlgFlops() rwb_iters = utils.getIntSymbolFromString('rnn/while/LoopCond_block::iters') a_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Add:0::dim_0') a1_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Add_1:0::dim_0') ba_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/BiasAdd:0::dim_0') mm_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/MatMul:0::dim_0') m_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Mul:0::dim_0') m1_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Mul_1:0::dim_0') m2_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Mul_2:0::dim_0') s_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Sigmoid:0::dim_0') s1_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Sigmoid_1:0::dim_0') s2_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Sigmoid_2:0::dim_0') sm_r_0 = utils.getIntSymbolFromString('softmax/Reshape:0::dim_0') sm_r_1 = utils.getIntSymbolFromString('softmax/Reshape:0::dim_1') th_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Tanh:0::dim_0') th1_0 = utils.getIntSymbolFromString( 'rnn/while/basic_lstm_cell/Tanh_1:0::dim_0') forward_flops = rwb_iters * (24 * a_0 + 24 * a1_0 + 96 * ba_0 + 9216 * mm_0 + \ 24 * m_0 + 24 * m1_0 + 24 * m2_0 + 96 * s_0 + \ 96 * s1_0 + 96 * s2_0 + 144 * th_0 + 144 * th1_0 + 6) + \ 3 * sm_r_0 * sm_r_1 + \ 18628 grad_iters = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/b_count_2_block::iters') g_an_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/AddN:0::dim_0') g_an1_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/AddN_1:0::dim_0') g_mm_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/MatMul_grad/MatMul:0::dim_0' ) g_mms_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/MatMul_grad/MatMul_1/StackPopV2:0::dim_0' ) g_ms_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul/StackPopV2:0::dim_0' ) g_m_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul:0::dim_0' ) g_ms1_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul_1/StackPopV2:0::dim_0' ) g_m1_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul_1:0::dim_0' ) g_ms2_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul/StackPopV2:0::dim_0' ) g_m2_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul:0::dim_0' ) g_ms21_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul_1/StackPopV2:0::dim_0' ) g_m21_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul_1:0::dim_0' ) g_mus_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul/StackPopV2:0::dim_0' ) g_mu_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul:0::dim_0' ) g_mu1_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul_1:0::dim_0' ) g_s_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Sigmoid_grad/SigmoidGrad:0::dim_0' ) g_c_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/split_grad/concat:0::dim_0' ) g_sm_m_0 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/softmax/softmax_grad/mul:0::dim_0') g_sm_m_1 = utils.getIntSymbolFromString( 'Gradient/Compute/gradients/softmax/softmax_grad/mul:0::dim_1') backward_flops = grad_iters * (24 * g_an_0 + 72 * g_an1_0 + 9216 * g_mm_0 + 9216 * g_mms_0 + \ 72 * g_ms_0 + 48 * g_m_0 + 72 * g_ms1_0 + 48 * g_m1_0 + \ 72 * g_ms2_0 + 48 * g_m2_0 + 72 * g_ms21_0 + 48 * g_m21_0 + \ 72 * g_mus_0 + 48 * g_mu_0 + 48 * g_mu1_0 + 48 * g_s_0 + \ 96 * g_c_0 + 4706) + \ g_sm_m_0 * g_sm_m_1 correct_alg_flops = forward_flops + backward_flops print('Loaded Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Initial alg flops incorrect!\n Expecting: {}\n Calculated: {}' \ .format(correct_alg_flops, algorithmic_flops) # Now, bind tensor names in the graph and verify that the algorithmic # Flop counts reflect the new name bindings batch_size = utils.getIntSymbolFromString('batch_size') seq_length = utils.getIntSymbolFromString('seq_length') hidden_dim = utils.getIntSymbolFromString('hidden_dim') # Manually set some variables # TODO (Joel): Fix this up when all tensor arrays work! ta_op = graph.opsByName['rnn/TensorArrayStack/TensorArraySizeV3'] ta_op._outputs[0].setValue(seq_length) ta_op = graph.opsByName['rnn/TensorArrayStack/TensorArrayGatherV3'] ta_op._outputs[0].shape.mergeShape([seq_length, batch_size, hidden_dim]) ta_op = graph.opsByName['rnn/while/TensorArrayReadV3'] ta_op._outputs[0].shape.mergeShape([batch_size, hidden_dim]) # TODO (Joel): Fix this up when all stack ops work! find_stack_shape = TensorShape([None, 24]) find_stack_shape_2 = TensorShape([None, 48]) for op in graph.opsByName.values(): op_name_suffix = op.name.split('/')[-1] if 'StackPopV2' in op_name_suffix: if op._outputs[0].shape == find_stack_shape: op._outputs[0].shape.mergeShape([batch_size, hidden_dim]) elif op._outputs[0].shape == find_stack_shape_2: op._outputs[0].shape.mergeShape([batch_size, 2 * hidden_dim]) # NOTE: This also works: batch_size = 'batch_size' # Bind placeholders (a and b) output dimensions 0 to name batch_size bind_dict = { 'a': [batch_size, seq_length, hidden_dim], 'c_init_state': [batch_size, hidden_dim], 'h_init_state': [batch_size, hidden_dim], 'out_correct': [batch_size, seq_length] } graph.bindTensorShapeDimensions(bind_dict, warn_if_ill_defined=True) algorithmic_flops = graph.calcAlgFlops() # Update the algorithmic Flops formula # Sub the forward prop values correct_alg_flops = correct_alg_flops.subs({ ba_0: batch_size, a_0: batch_size, a1_0: batch_size, mm_0: batch_size, m_0: batch_size, m1_0: batch_size, m2_0: batch_size, s_0: batch_size, s1_0: batch_size, s2_0: batch_size, th_0: batch_size, th1_0: batch_size, sm_r_0: batch_size * seq_length, sm_r_1: 24, }) # Sub the backward prop values # TODO (Joel): Fix this up when all backprop works! correct_alg_flops = correct_alg_flops.subs({ g_mm_0: batch_size, g_mms_0: batch_size, g_ms_0: batch_size, g_m_0: batch_size, g_ms1_0: batch_size, g_m1_0: batch_size, g_ms2_0: batch_size, g_m2_0: batch_size, g_ms21_0: batch_size, g_m21_0: batch_size, g_mus_0: batch_size, g_mu_0: batch_size, g_mu1_0: batch_size, g_s_0: batch_size, g_c_0: batch_size, g_an_0: batch_size, g_an1_0: batch_size, g_sm_m_0: batch_size * seq_length, g_sm_m_1: 24, }) assert graph.isValid() print('BOUND GRAPH: {}\n\n'.format(graph)) # HHHHAAAAAAXXXXXX: FIX THIS! DUE TO SHAPEOP SYMBOL PROPAGATION! algorithmic_flops = algorithmic_flops.subs({hidden_dim: 24}) print('Bound Flops test:') print(' Catamount: {}'.format(algorithmic_flops)) print(' Correct: {}'.format(correct_alg_flops)) assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \ 'Bound alg flops incorrect!\n Expecting: {}\n Calculated: {}\n Difference: {}' \ .format(correct_alg_flops, algorithmic_flops, algorithmic_flops - correct_alg_flops)