Ejemplo n.º 1
0
def test_tf_simple_while_loop():
    graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename)
    assert graph.isValid()

    algorithmic_flops = graph.calcAlgFlops()

    while_iters = utils.getIntSymbolFromString('while/LoopCond_block::iters')
    wba_dim_0 = utils.getIntSymbolFromString('while/body/add_1:0::dim_0')
    correct_alg_flops = while_iters * (wba_dim_0 + 2)

    print('Loaded Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Initial alg flops incorrect!\n  Expecting: {}\n  Calculated: {}' \
        .format(correct_alg_flops, algorithmic_flops)

    # Now, bind tensor names in the graph and verify that the algorithmic
    # Flop counts reflect the new name bindings
    batch_size = utils.getIntSymbolFromString('batch_size')
    # NOTE: This also works: batch_size = 'batch_size'
    # Bind placeholders (a and b) output dimensions 0 to name batch_size
    bind_dict = {'a': [batch_size, 1]}
    graph.bindTensorShapeDimensions(bind_dict)

    algorithmic_flops = graph.calcAlgFlops()

    # Update the algorithmic Flops formula
    correct_alg_flops = correct_alg_flops.subs({wba_dim_0: batch_size})
    print('Bound Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Bound alg flops incorrect!\n  Expecting: {}\n  Calculated: {}' \
        .format(correct_alg_flops, algorithmic_flops)
Ejemplo n.º 2
0
def test_expanddims_op():
    ''' Specify graphs with ExpandDimsOps and make sure dimensions behave
    as desired.
    '''

    combos = [
        ([3], 0),
        ([None], 0),
        ([None, None], 0),
        ([None, None], 1),
        ([None, None], 2),
        ([None, None], -1),
        ([None, None], -2),
        ([None, None], -3),
    ]

    for combo in combos:
        graph = Graph()
        with graph.asDefault():
            ph_dims, expand_dim = combo
            if isinstance(ph_dims, list):
                ed_out_dims = list(ph_dims)
                insert_dim = expand_dim
                if insert_dim < 0:
                    insert_dim += len(ed_out_dims) + 1
                ed_out_dims.insert(insert_dim, 1)
            else:
                ed_out_dims = [ph_dims]
            print('Testing expand dims with in_dims {}, expand dim {} to {}'.
                  format(ph_dims, expand_dim, ed_out_dims))

            # Build model
            in_ph = placeholder('in', ph_dims)
            expanddims_out = expanddims('expanddims',
                                        ed_out_dims,
                                        in_ph,
                                        axis=expand_dim)

            assert graph.isValid()

            feed_dict = {}
            if isinstance(ph_dims, list):
                for idx in range(len(ph_dims)):
                    if ph_dims[idx] is None:
                        ph_dims[idx] = utils.getIntSymbolFromString(
                            'in::dim_{}'.format(idx))
            else:
                if ph_dims is None:
                    ph_dims = utils.getIntSymbolFromString('in::dim_0')
            feed_dict['in'] = ph_dims
            print('    Feed dict: {}'.format(feed_dict))

            graph.bindTensorShapeDimensions(feed_dict)
            assert expanddims_out.shape == TensorShape(ed_out_dims)
        reset_symbols()
Ejemplo n.º 3
0
def test_tf_dynamic_rnn():
    graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename)
    assert graph.isValid()

    algorithmic_flops = graph.calcAlgFlops()

    rwb_iters = utils.getIntSymbolFromString('rnn/while/LoopCond_block::iters')
    ba_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_rnn_cell/BiasAdd:0::dim_0')
    mm_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_rnn_cell/MatMul:0::dim_0')
    th_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_rnn_cell/Tanh:0::dim_0')
    correct_alg_flops = rwb_iters * \
                        (24 * ba_0 + 2304 * mm_0 + 144 * th_0 + 5) + \
                        2305

    print('Loaded Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Initial alg flops incorrect!\n  Expecting: {}\n  Calculated: {}' \
        .format(correct_alg_flops, algorithmic_flops)

    # Now, bind tensor names in the graph and verify that the algorithmic
    # Flop counts reflect the new name bindings
    batch_size = utils.getIntSymbolFromString('batch_size')
    # NOTE: This also works: batch_size = 'batch_size'
    # Bind placeholders (a and b) output dimensions 0 to name batch_size
    bind_dict = {
        'a': ['batch_size', 'seq_length', 'hidden_dim'],
        'init_state': ['batch_size', 'hidden_dim']
    }
    graph.bindTensorShapeDimensions(bind_dict)

    algorithmic_flops = graph.calcAlgFlops()

    # Update the algorithmic Flops formula
    correct_alg_flops = correct_alg_flops.subs({
        ba_0: batch_size,
        mm_0: batch_size,
        th_0: batch_size
    })
    print('Bound Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Bound alg flops incorrect!\n  Expecting: {}\n  Calculated: {}' \
        .format(correct_alg_flops, algorithmic_flops)
Ejemplo n.º 4
0
    def propagateShapes(self, make_symbolic=False):
        self.debugAssert(len(self._inputs) == 2)
        self.debugAssert(len(self._outputs) == 1)
        # Assume that there are multiple workers contributing to this
        # collective operation and their matrix sizes are the same as first
        # input tensor passed in here. Create a symbol to represent the number
        # of participating workers
        num_workers_str = '{}::num_workers'.format(self.name)
        num_workers_symbol = utils.getIntSymbolFromString(num_workers_str)

        # TODO (Joel): We could take another input tensor to specify the axis
        # on which to concatenate values. For now, axis = 0
        axis = 0
        final_shape = []
        for idx in range(len(self._inputs[0].shape.dims)):
            dim = self._inputs[0].shape.getDimension(idx)
            if idx == axis:
                # Manipulate the dimension make the value None (it is
                # necessarily symbolic), and set the symbol to reflect
                # multiple workers
                dim_val = dim.value
                new_dim = Dimension(None)
                new_symbol = dim.symbol * num_workers_symbol
                new_dim.setSymbolOrName(new_symbol)
                final_shape.append(new_dim)
            else:
                final_shape.append(dim)
        self._outputs[0].mergeShape(final_shape, make_symbolic=make_symbolic)
Ejemplo n.º 5
0
def test_tf_load_and_calculate():
    graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename)
    assert graph.isValid()

    algorithmic_flops = graph.calcAlgFlops()

    add_dim_0 = utils.getIntSymbolFromString('add:0::dim_0')
    matmul_dim_0 = utils.getIntSymbolFromString('matmul:0::dim_0')
    mul_dim_0 = utils.getIntSymbolFromString('mul:0::dim_0')
    correct_alg_flops = 256 * add_dim_0 + \
                        65536 * matmul_dim_0 + \
                        256 * mul_dim_0 + \
                        98307

    print('Loaded Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Initial alg flops incorrect!\n  Expecting: {}\n  Calculated: {}' \
        .format(correct_alg_flops, algorithmic_flops)

    # Now, bind tensor names in the graph and verify that the algorithmic
    # Flop counts reflect the new name bindings
    batch_size = utils.getIntSymbolFromString('batch_size')
    # NOTE: This also works: batch_size = 'batch_size'
    # Bind placeholders (a and b) output dimensions 0 to name batch_size
    bind_dict = {'a': [batch_size, None], 'b': [batch_size, None]}
    graph.bindTensorShapeDimensions(bind_dict)

    algorithmic_flops = graph.calcAlgFlops()

    # Update the algorithmic Flops formula
    correct_alg_flops = correct_alg_flops.subs({
        add_dim_0: batch_size,
        matmul_dim_0: batch_size,
        mul_dim_0: batch_size,
    })
    print('Bound Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Bound alg flops incorrect!\n  Expecting: {}\n  Calculated: {}' \
        .format(correct_alg_flops, algorithmic_flops)
Ejemplo n.º 6
0
 def __init__(self, name):
     super(CandidateSamplerOp, self).__init__(name)
     # TODO (Joel): Read these from compute graph op attributes
     self.setNumTrue(1)
     self.setNumSampled(None)
     # TODO: Depending on the generator, there should be some small number
     # of Flops per sampled element. Using (incorrect) 1 for now...
     self._flops_per_element = 1
     samps_name = '{}::rand_samps'.format(self.name)
     self._num_samples_symbol = \
         num_samples = utils.getIntSymbolFromString(samps_name)
Ejemplo n.º 7
0
def add_symbols(name, out_shape):
    # print('  Adding symbols for {}: out_shape: {}'.format(name, out_shape))
    global symbol_table
    global subs_table

    def add_symbol(symbol, dim):
        assert sym_name not in symbol_table.keys()
        symbol_table[sym_name] = symbol
        # print('Added symbol name {} with sym {}'.format(sym_name, symbol))
        if isinstance(dim, Dimension):
            dim = dim.value
        if dim is not None:
            subs_table[symbol] = dim

    if isinstance(out_shape, list):
        for idx, dim in enumerate(out_shape):
            sym_name = '{}::dim_{}'.format(name, idx)
            add_symbol(utils.getIntSymbolFromString(sym_name), dim)
    else:
        sym_name = '{}::unk'.format(name)
        add_symbol(utils.getIntSymbolFromString(sym_name), out_shape)
Ejemplo n.º 8
0
    def propagateShapes(self, make_symbolic=False):
        self.debugAssert(len(self._inputs) == 2)
        self.debugAssert(len(self._outputs) == 1)

        # Cannot propagate shapes if first input shape undefined
        if not self._inputs[0].shape.isFullySymbolic():
            return
        self.debugAssert(self._inputs[0].shape.rank == 2)
        num_samples = self._inputs[1].value
        if num_samples == None:
            samps_name = '{}::rand_samps'.format(self.name)
            num_samples = utils.getIntSymbolFromString(samps_name)
        out_shape = []
        out_shape.append(self._inputs[0].shape.getDimension(0))
        out_shape.append(num_samples)
        self._outputs[0].shape.mergeShape(out_shape,
                                          make_symbolic=make_symbolic)
Ejemplo n.º 9
0
 def calcAlgFlops(self):
     self.debugAssert(len(self._inputs) == 2)
     self.debugAssert(len(self._outputs) == 1)
     # Steps in multinomial sampling:
     # 1) Draw uniform random sample, "noises", of size
     #        [batch_size, num_samples, num_classes]
     num_samples = self._inputs[1].value
     if num_samples == None:
         samps_name = '{}::rand_samps'.format(self.name)
         num_samples = utils.getIntSymbolFromString(samps_name)
     in_0_shape = self._inputs[0].shape
     full_shape_elts = in_0_shape.numElements() * num_samples
     total_flops = full_shape_elts
     # 2) Calculate scores = logits - log(-log(noises)) with broadcasting
     total_flops += 3 * full_shape_elts
     # 3) Minimum reduction along classes dimension
     total_flops += full_shape_elts
     return total_flops
Ejemplo n.º 10
0
    def propagateShapes(self, make_symbolic=False):
        self.debugAssert(len(self._inputs) == 1)
        # First output (output[1]) is the true expected count and has shape
        # equal to the input tensor unless num_true attribute is changed
        self.debugAssert(len(self._outputs) == 3)

        if self._num_true != 1:
            self.notImplemented('CandidateSamplerOp propagateShapes ' \
                                'num_true != 1')
        self._outputs[1].shape.mergeShape(self._inputs[0].shape,
                                          make_symbolic=make_symbolic)
        num_samples = None
        if self._num_sampled is None:
            samps_name = '{}::rand_samps'.format(self.name)
            num_samples = utils.getIntSymbolFromString(samps_name)
        else:
            self.notImplemented('CandidateSamplerOp: propagateShapes '\
                                'num_sampled != None')
        self._outputs[0].shape.mergeShape([num_samples])
        self._outputs[2].shape.mergeShape([num_samples])
Ejemplo n.º 11
0
def test_concat_op():
    ''' Specify graphs with concat operations and make sure dimensions behave
    as desired.
    '''

    combos = [([[None, None], [None, None]], 0),
              ([[None, None], [None, None]], 1), ([[3, None], [3, None]], 0),
              ([[3, None], [3, None]], 1), ([[3, 7], [6, 7]], 0),
              ([[3, 15], [3, None]], 1),
              ([[3, None, 7, 15], [3, 15, 7, None]], 0),
              ([[3, None, 7, 15], [3, 15, 7, None]], 1),
              ([[3, None, 7, 15], [3, 15, 7, None]], 2),
              ([[3, None, 7, 15], [3, 15, 7, None]], 3),
              ([[3, None, 7, 15], [3, 15, 7, None], [None, 15, 7, 30]], 3)]

    for combo in combos:
        graph = Graph()
        with graph.asDefault():
            ph_dims, axis = combo
            print('Testing concat with in dims {}, axis {}'.format(
                ph_dims, axis))

            # Build model
            in_phs = []
            rank = None
            for idx, ph_dim in enumerate(ph_dims):
                ph_name = 'in_{}'.format(idx)
                in_phs.append(placeholder(ph_name, ph_dim))
                if rank is None:
                    rank = in_phs[idx].shape.rank
                else:
                    assert rank == in_phs[idx].shape.rank
            concat_out = concat('concat', [None] * rank, in_phs, axis=axis)

            assert graph.isValid()

            feed_dict = {}
            out_c_dim = Dimension(0)
            for in_ph, ph_dim in zip(in_phs, ph_dims):
                in_ph_dims = []
                for idx, dim in enumerate(ph_dim):
                    append_dim_sym = None
                    if dim is None:
                        dim_name = 'bind_{}_{}'.format(in_ph.name, idx)
                        append_dim_sym = utils.getIntSymbolFromString(dim_name)
                    else:
                        append_dim_sym = dim
                    in_ph_dims.append(append_dim_sym)
                    if idx == axis:
                        append_dim = Dimension(None)
                        append_dim.setSymbolOrName(append_dim_sym)
                        out_c_dim += append_dim
                feed_dict[in_ph.name] = in_ph_dims
            print('    Feed dict: {}'.format(feed_dict))

            graph.bindTensorShapeDimensions(feed_dict)

            out_dims = TensorShape(in_phs[-1].shape.dims)
            out_dims.dims[axis] = out_c_dim
            check_symbol_table = {}
            for idx in range(concat_out.shape.rank):
                c_out_dim = concat_out.shape.getDimension(idx).symbol
                out_dim = out_dims.getDimension(idx).symbol
                if isinstance(out_dim, sympy.Symbol) and \
                   isinstance(c_out_dim, int):
                    if out_dim not in check_symbol_table.keys():
                        check_symbol_table[out_dim] = c_out_dim
                        out_dim = c_out_dim
                    else:
                        assert c_out_dim == check_symbol_table[out_dim]
                print('    Catamount dim[{}]:   {}'.format(idx, c_out_dim))
                print('    Correct dim[{}]: {}'.format(idx, out_dim))
                assert (sympy.simplify(c_out_dim - out_dim) == 0), \
                    'Concat dim[{}] incorrect!\n  Expecting:  {}\n' \
                    '  Calculated: {}'.format(idx, out_dim, c_out_dim)
        reset_symbols()
Ejemplo n.º 12
0
 def __init__(self, name):
     super(AllreduceOp, self).__init__(name)
     num_workers_str = '{}::num_workers'.format(self.name)
     self._workers_symbol = utils.getIntSymbolFromString(num_workers_str)
def run_tf_speech_attention():
    global is_pytest_run
    graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/speech_attention/model.ckpt.meta'

    graph = catamount.frameworks.tensorflow.import_graph(graph_meta)
    assert graph.isValid()

    # HAX: NEED TO MANUALLY REMOVE SOME?! WHY?
    remove_ops = [
        'DevArgmaxWERChecker/Less', 'DevLossChecker/Less',
        'DevArgmaxWERChecker/best_dev', 'DevLossChecker/best_dev'
    ]
    for op_name in remove_ops:
        op = graph.opsByName[op_name]
        graph.removeOp(op)
    assert graph.isValid()

    # Remove ops that are not executed during a standard training step:
    graph_ops = list(graph._ops_by_name.values())
    for op in graph_ops:
        # Ops in attn_model_[1-3] are used for inference
        if 'attn_model_1' in op.name or \
           'attn_model_2' in op.name or \
           'attn_model_3' in op.name:
            graph.removeOp(op)
    assert graph.isValid()

    print('Initial graph:\n{}\n'.format(graph))
    init_params = graph.calcModelParameters()
    print('Initial parameters: {}'.format(init_params))
    print('Initial Flops: {}\n'.format(graph.calcAlgFlops()))

    print('Placeholders:')
    for op in graph.getPlaceholders():
        print(op.debugString())
    print('')

    # Set up symbols to name dimensions
    audio_features_symbol = utils.getPositiveIntSymbolFromString(
        'audio_features')
    encoder_steps_symbol = utils.getPositiveIntSymbolFromString(
        'encoder_steps')
    decoder_steps_symbol = utils.getPositiveIntSymbolFromString(
        'decoder_steps')
    subbatch_size_symbol = utils.getPositiveIntSymbolFromString(
        'subbatch_size')
    attn_dim_symbol = utils.getPositiveIntSymbolFromString('attn_dim')
    attn_hidden_dim_symbol = utils.getPositiveIntSymbolFromString(
        'attn_hidden_dim')
    dec_hidden_dim_symbol = utils.getPositiveIntSymbolFromString(
        'dec_hidden_dim')
    enc_hidden_dim_symbol = utils.getPositiveIntSymbolFromString(
        'enc_hidden_dim')
    graph_iters_symbol = utils.getIntSymbolFromString('graph::iters')
    output_vocab_symbol = utils.getPositiveIntSymbolFromString('output_vocab')
    conv_width_symbol = utils.getPositiveIntSymbolFromString('conv_width')
    num_conv_filters_symbol = utils.getPositiveIntSymbolFromString(
        'num_conv_filters')

    # Convert these constant dimensions to symbols
    base_encoder_steps = 300
    base_decoder_steps = 300
    base_subbatch_size = 32
    base_output_vocab = 31
    base_audio_features = 40
    base_conv_width = 53
    base_attn_dim = 137
    base_attn_hidden_dim = 509
    base_dec_hidden_dim = 571
    base_enc_hidden_dim = 1051
    base_enc_input_dim = 1091  # Input + recurrent state
    enc_input_dim_symbol = audio_features_symbol + enc_hidden_dim_symbol
    base_dec_attn_rec = 2133
    dec_attn_rec_symbol = 2 * enc_hidden_dim_symbol + output_vocab_symbol
    base_attn_cell_inputs = 2611
    attn_cell_inputs_symbol = 2 * enc_hidden_dim_symbol + attn_hidden_dim_symbol
    base_attn_cell_in_dim = 2642
    attn_cell_in_dim_symbol = 2 * enc_hidden_dim_symbol + output_vocab_symbol + \
                              attn_hidden_dim_symbol
    base_dec_attn_dim = 3182
    dec_attn_dim_symbol = attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + \
                          dec_hidden_dim_symbol

    bind_dict = { # Placeholders
                  'attn_model/input_seq': [encoder_steps_symbol, subbatch_size_symbol, audio_features_symbol],
                  'attn_model/input_len': [subbatch_size_symbol],
                  'attn_model/output_seq': [decoder_steps_symbol, subbatch_size_symbol],
                  'attn_model/output_mask': [decoder_steps_symbol, subbatch_size_symbol],

                  # Variables
                  'InputNormalizer/means': [audio_features_symbol],
                  'InputNormalizer/std': [audio_features_symbol],
                  'attn_model/AffineAttentionStateNN/W': [2 * enc_hidden_dim_symbol, attn_dim_symbol],
                  'attn_model/AffineAttentionStateNN/b': [attn_dim_symbol],
                  'attn_model/AffineOutputProjection/W': [dec_hidden_dim_symbol, output_vocab_symbol],
                  'attn_model/AffineOutputProjection/b': [output_vocab_symbol],
                  'attn_model/Decoder/attn_model/attention_cell/biases': [4 * attn_hidden_dim_symbol],
                  'attn_model/Decoder/attn_model/attention_cell/weights': [attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + output_vocab_symbol, 4 * attn_hidden_dim_symbol],
                  'attn_model/Decoder/attn_model/decoder_cell/biases': [4 * dec_hidden_dim_symbol],
                  'attn_model/Decoder/attn_model/decoder_cell/weights': [attn_hidden_dim_symbol + dec_hidden_dim_symbol + 2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol],
                  'attn_model/HybridAttentionContext/Q': [conv_width_symbol, 1, num_conv_filters_symbol],
                  'attn_model/HybridAttentionContext/U': [1, num_conv_filters_symbol, attn_dim_symbol],
                  'attn_model/HybridAttentionContext/W': [2 * attn_hidden_dim_symbol, attn_dim_symbol],
                  'attn_model/HybridAttentionContext/b': [attn_dim_symbol],
                  'attn_model/HybridAttentionContext/w': [attn_dim_symbol],
                  'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],

                  # Constants
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/MatMul/Enter_grad/b_acc': [dec_hidden_dim_symbol, output_vocab_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/add/Enter_grad/b_acc': [output_vocab_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/MatMul/Enter_grad/b_acc': [2 * attn_hidden_dim_symbol, attn_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/add_2/Enter_grad/b_acc': [attn_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/attention_cell/BiasAdd/Enter_grad/b_acc': [4 * attn_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/attention_cell/attention_cell/add/Enter_grad/b_acc': [attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + output_vocab_symbol, 4 * attn_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/ExpandDims_1/Enter_grad/b_acc': [conv_width_symbol, 1, 4],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/ExpandDims_1/Enter_grad/b_acc': [1, 4, attn_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/decoder_cell/BiasAdd/Enter_grad/b_acc': [4 * dec_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/decoder_cell/decoder_cell/add/Enter_grad/b_acc': [attn_hidden_dim_symbol + dec_hidden_dim_symbol + 2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/mul/Enter_grad/b_acc': [attn_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                }

    # Update constant values
    const_dict = {
        'attn_model/AffineAttentionStateNN/Reshape/shape':
        [-1, 2 * enc_hidden_dim_symbol],
        'attn_model/AffineAttentionStateNN/Reshape_1/shape/2':
        attn_dim_symbol,
        'attn_model/AttentionEncoderDecoder/Reshape/shape/1':
        output_vocab_symbol,
        'attn_model/AttentionModel/gradients/attn_model/AffineAttentionStateNN/add_grad/Shape_1':
        [attn_dim_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/add_grad/Shape_1':
        [output_vocab_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/add_2_grad/Shape_1':
        [attn_dim_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/Conv2D_grad/Const':
        [1, conv_width_symbol, 1, num_conv_filters_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/ExpandDims_1_grad/Shape':
        [conv_width_symbol, 1, num_conv_filters_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/Conv2D_grad/Const':
        [1, 1, num_conv_filters_symbol, attn_dim_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/ExpandDims_1_grad/Shape':
        [1, num_conv_filters_symbol, attn_dim_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/mul_grad/Shape_1':
        [attn_dim_symbol],
        'attn_model/Decoder/CustomLSTMCellZeroState/Const':
        [2 * attn_hidden_dim_symbol],
        'attn_model/Decoder/CustomLSTMCellZeroState/Const_1':
        [2 * attn_hidden_dim_symbol],
        'attn_model/Decoder/CustomLSTMCellZeroState_1/Const': [
            2 * dec_hidden_dim_symbol
        ],
        'attn_model/Decoder/CustomLSTMCellZeroState_1/Const_1': [
            2 * dec_hidden_dim_symbol
        ],
        'attn_model/Decoder/while/attn_model/attention_cell/attention_cell/Shape':
        [
            attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol +
            output_vocab_symbol, 4 * attn_hidden_dim_symbol
        ],
        'attn_model/Decoder/while/attn_model/decoder_cell/decoder_cell/Shape':
        [
            attn_hidden_dim_symbol + dec_hidden_dim_symbol +
            2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol
        ],
        'attn_model/Decoder/while/attn_model/one_hot/depth':
        output_vocab_symbol,
        'attn_model/Decoder/zeros/shape/1':
        2 * enc_hidden_dim_symbol,
        'attn_model/Decoder/zeros_2/shape/1':
        output_vocab_symbol,
        'attn_model/Reshape/shape': [1, 1, audio_features_symbol],
        'attn_model/Reshape_1/shape': [1, 1, audio_features_symbol],
        'attn_model/Reshape_2/shape/2':
        2 * enc_hidden_dim_symbol,
        'attn_model/StackedEncoder/Layer0/RNNEncoder/Reshape/shape/2':
        audio_features_symbol,
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/Const_4':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/Const_4':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/Reshape/shape/2':
        2 * enc_hidden_dim_symbol,
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/Const_4':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/Const_4':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/Reshape/shape/2':
        2 * enc_hidden_dim_symbol,
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/Const_4':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/Const_4':
        [enc_hidden_dim_symbol],
    }

    graph.bindConstantValues(const_dict)

    # TODO: Currently, Catamount doesn't automatically handle Tensorflow TensorArrays
    # or Stack ops. Here, manually set the dimensions of these ops' tensors.
    for op in graph._ops_by_name.values():
        op_name_suffix = op.name.split('/')[-1]
        if 'TensorArrayGather' in op_name_suffix:
            assert isinstance(op, UnknownOp)
            assert len(op._inputs) == 3
            assert len(op._outputs) == 1
            if op._outputs[0].shape.rank == 1 or op._outputs[0].shape.rank == 2:
                if len(op._outputs[0].consumers) > 0:
                    print(
                        'TODO: Unknown TensorArrayGather (rank {}): {}'.format(
                            op._outputs[0].shape.rank, op.debugString()))
            elif op._outputs[0].shape.isUnknown(
            ) or op._outputs[0].shape.rank == 3:
                if len(op._outputs[0].consumers) > 0:
                    # If output rank is 3, then appears to be:
                    # [seq_length, batch_size, enc_hid], where
                    # seq_length depends on layer
                    out_shape = None
                    if 'StackedEncoder/Layer0' in op.name:
                        out_shape = [
                            encoder_steps_symbol, subbatch_size_symbol,
                            enc_hidden_dim_symbol
                        ]
                    elif 'StackedEncoder/Layer2' in op.name:
                        if 'attn_model/AttentionModel/gradients' in op.name:
                            # Backprop stores concatenated state
                            out_shape = [
                                encoder_steps_symbol // 2,
                                subbatch_size_symbol, 2 * enc_hidden_dim_symbol
                            ]
                        else:
                            out_shape = [
                                encoder_steps_symbol // 2,
                                subbatch_size_symbol, enc_hidden_dim_symbol
                            ]
                    elif 'StackedEncoder/Layer4' in op.name:
                        if 'attn_model/AttentionModel/gradients' in op.name:
                            # Backprop stores concatenated state
                            out_shape = [(encoder_steps_symbol // 2) // 2,
                                         subbatch_size_symbol,
                                         2 * enc_hidden_dim_symbol]
                        else:
                            out_shape = [(encoder_steps_symbol // 2) // 2,
                                         subbatch_size_symbol,
                                         enc_hidden_dim_symbol]
                    elif 'Decoder' in op.name:
                        # HAXXXX: Manually specify a few
                        if op.name == 'attn_model/Decoder/TensorArrayStack/TensorArrayGatherV3':
                            out_shape = [
                                decoder_steps_symbol, subbatch_size_symbol,
                                output_vocab_symbol
                            ]
                        else:
                            out_shape = [
                                decoder_steps_symbol, subbatch_size_symbol,
                                dec_hidden_dim_symbol
                            ]
                    else:
                        print('TODO: Unknown TensorArrayGather {}'.format(
                            op.debugString()))
                    if out_shape is not None:
                        op._outputs[0].mergeShape(out_shape,
                                                  make_symbolic=True)
            else:
                print('TODO: Unknown TensorArrayGather {}'.format(
                    op.debugString()))
        elif 'TensorArraySize' in op_name_suffix:
            assert isinstance(op, UnknownOp)
            assert len(op._inputs) == 2
            assert len(op._outputs) == 1
            assert op._outputs[0].shape.rank == 0
            # NOTES:
            # StackedEncoder Layer0: enc_seq
            # StackedEncoder Layer2: enc_seq / 2 # Due to stride 2 in time
            # StackedEncoder Layer4: enc_seq / 4 # Due to stride 2 in time
            # Decoder: dec_seq
            if 'StackedEncoder/Layer0' in op.name:
                op._outputs[0].setValue(encoder_steps_symbol)
            elif 'StackedEncoder/Layer2' in op.name:
                op._outputs[0].setValue(encoder_steps_symbol // 2)
            elif 'StackedEncoder/Layer4' in op.name:
                op._outputs[0].setValue((encoder_steps_symbol // 2) // 2)
            elif 'Decoder' in op.name:
                op._outputs[0].setValue(decoder_steps_symbol)
            else:
                print('WARN: Unknown TensorArraySizeV3: {}'.format(
                    op.debugString()))
        elif 'TensorArrayRead' in op_name_suffix:
            assert isinstance(op, UnknownOp)
            assert len(op._inputs) == 3
            assert len(op._outputs) == 1
            assert op._outputs[0].shape.isUnknown() or \
                   op._outputs[0].shape.rank == 2, \
                   '{}'.format(op.name)
            if op._outputs[0].shape.isUnknown():
                if len(op._outputs[0].consumers) > 0:
                    out_shape = None
                    if 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer' in op.name and \
                       ('/RNNEncoder/bidirectional_rnn/fw/fw/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' in op.name or \
                        '/RNNEncoder/bidirectional_rnn/bw/bw/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' in op.name):
                        out_shape = [
                            subbatch_size_symbol, enc_hidden_dim_symbol
                        ]
                    elif op.name == 'attn_model/AttentionModel/gradients/attn_model/Decoder/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' or \
                         op.name == 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/TensorArrayWrite_1/TensorArrayWriteV3_grad/TensorArrayReadV3' or \
                         op.name == 'attn_model_2/Decoder/while/cond/TensorArrayReadV3' or \
                         op.name == 'attn_model/Decoder/while/cond/TensorArrayReadV3':
                        out_shape = [subbatch_size_symbol, output_vocab_symbol]
                    else:
                        print('WARN: Unknown TensorArrayReadV3 out shape: {}'.
                              format(op.debugString()))
                    if out_shape is not None:
                        op._outputs[0].mergeShape(out_shape,
                                                  make_symbolic=True)
            else:
                # NOTES: Many are (?, 40 "features"), (?, 1051 "enc_hid"), or (?, 2102 "2*enc_hid")
                dim_1_val = op._outputs[0].shape.getDimension(1).value
                assert dim_1_val == base_audio_features or \
                       dim_1_val == base_enc_hidden_dim or \
                       dim_1_val == 2 * base_enc_hidden_dim, \
                       'Op: {}\n   Dim 1 value: {}'.format(op.debugString(), dim_1_val)
                out_shape = None
                if dim_1_val == base_audio_features:
                    out_shape = [subbatch_size_symbol, audio_features_symbol]
                elif dim_1_val > 0 and dim_1_val % base_enc_hidden_dim == 0:
                    mult = dim_1_val // base_enc_hidden_dim
                    out_shape = [
                        subbatch_size_symbol, mult * enc_hidden_dim_symbol
                    ]
                else:
                    print('Unhandled TensorArrayRead: {}'.format(
                        op.debugString()))
                if out_shape is not None:
                    op._outputs[0].mergeShape(out_shape, make_symbolic=True)

    # Manually set a couple shapes for max ops that can't yet resolve
    # maximums of 1 vs. positive symbols:
    max_op = graph._ops_by_name[
        'attn_model/AttentionModel/gradients/attn_model/AttentionEncoderDecoder/Sum_grad/Maximum']
    max_op._outputs[0].mergeShape([2])
    max_op._outputs[0].setValue([1, subbatch_size_symbol])

    max_op = graph._ops_by_name[
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/Sum_grad/Maximum']
    max_op._outputs[0].mergeShape([3])
    # [floor(floor(encoder_steps/2)/2) subbatch_size 1]
    max_op._outputs[0].setValue([(encoder_steps_symbol // 2) // 2,
                                 subbatch_size_symbol, 1])

    max_op = graph._ops_by_name[
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/Sum_1_grad/Maximum']
    max_op._outputs[0].mergeShape([3])
    # [1 subbatch_size 2*enc_hidden_dim]
    max_op._outputs[0].setValue(
        [1, subbatch_size_symbol, 2 * enc_hidden_dim_symbol])

    print('Binding variables')

    graph.bindShapesAndPropagate(bind_dict,
                                 warn_if_ill_defined=(not is_pytest_run),
                                 make_symbolic=True)
    assert graph.isValid()

    print('\n\nCleaned Graph:\n{}'.format(graph))

    print('\n\nBound values')

    # Set base values to be subbed in:
    base_encoder_steps = 96
    base_decoder_steps = 24
    base_attn_dim = 128
    base_conv_width = 50
    base_attn_hidden_dim = 512
    base_dec_hidden_dim = 512
    base_enc_hidden_dim = 1024

    bind_subs = {
        audio_features_symbol: base_audio_features,
        encoder_steps_symbol: base_encoder_steps,
        decoder_steps_symbol: (encoder_steps_symbol // 2) // 2,
        subbatch_size_symbol: base_subbatch_size,
        attn_dim_symbol: base_attn_dim,
        attn_hidden_dim_symbol: enc_hidden_dim_symbol // 2,
        dec_hidden_dim_symbol: enc_hidden_dim_symbol // 2,
        output_vocab_symbol: base_output_vocab,
        conv_width_symbol: base_conv_width,
        enc_hidden_dim_symbol: base_enc_hidden_dim,
        num_conv_filters_symbol: 4,
        graph_iters_symbol: 1,
    }
    # Add loop iteration counts to bind_subs
    bind_str_subs = {
        'attn_model/AttentionModel/gradients/b_count_2_block::iters':
        decoder_steps_symbol,
        'attn_model/Decoder/while/LoopCond_block::iters':
        decoder_steps_symbol,
        'attn_model/AttentionModel/gradients/b_count_22_block::iters':
        encoder_steps_symbol,
        'attn_model/AttentionModel/gradients/b_count_26_block::iters':
        encoder_steps_symbol,
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters':
        encoder_steps_symbol,
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters':
        encoder_steps_symbol,
        'attn_model/AttentionModel/gradients/b_count_14_block::iters':
        encoder_steps_symbol // 2,
        'attn_model/AttentionModel/gradients/b_count_18_block::iters':
        encoder_steps_symbol // 2,
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters':
        encoder_steps_symbol // 2,
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters':
        encoder_steps_symbol // 2,
        'attn_model/AttentionModel/gradients/b_count_6_block::iters':
        (encoder_steps_symbol // 2) // 2,
        'attn_model/AttentionModel/gradients/b_count_10_block::iters':
        (encoder_steps_symbol // 2) // 2,
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters':
        (encoder_steps_symbol // 2) // 2,
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters':
        (encoder_steps_symbol // 2) // 2,
    }

    for var_name, sub_val in bind_str_subs.items():
        var_ref = utils.getIntSymbolFromString(var_name)
        assert var_name not in bind_subs.keys()
        bind_subs[var_ref] = sub_val

    # Calculate model parameter count
    parameters = graph.calcModelParameters()
    resolved_params = parameters.subs(bind_subs)
    try:
        resolved_params = int(resolved_params)
    except:
        print('ERROR: resolved_params should be int, but is {} = {}'.format(
            type(resolved_params), resolved_params))
    correct_params = 71084729
    assert resolved_params == correct_params, \
           'Incorrect model params: {}'.format(resolved_params)
    print('Parameters: {}\nWith specified dims: {}\n'.format(
        parameters, resolved_params))

    # Calculate algorithmic Flops
    alg_flops = graph.calcAlgFlops()
    resolved_flops = alg_flops.subs(bind_subs)
    try:
        resolved_flops = int(resolved_flops)
    except:
        print('ERROR: resolved_flops should be int, but is {} = {}'.format(
            type(resolved_flops), resolved_flops))
    correct_flops = 568878183032
    assert resolved_flops == correct_flops, \
           'Incorrect algorithmic flops: {}'.format(resolved_flops)
    print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format(
        alg_flops, resolved_flops))

    # Calculate algorthmic Bytes accessed
    alg_bytes = graph.calcAlgBytes()
    resolved_bytes = alg_bytes.subs(bind_subs)
    try:
        resolved_bytes = int(resolved_bytes)
    except:
        print('ERROR: resolved_bytes should be int, but is {} = {}'.format(
            type(resolved_bytes), resolved_bytes))
    correct_bytes = 92231419797
    assert resolved_bytes == correct_bytes, \
           'Incorrect algorithmic bytes: {}'.format(resolved_bytes)
    print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format(
        alg_bytes, resolved_bytes))

    # Calculate algorthmic Bytes accessed
    alg_footprint = graph.calcAlgFootprint()
    resolved_footprint = alg_footprint.subs(bind_subs)
    try:
        resolved_footprint = int(resolved_footprint)
    except:
        print('ERROR: resolved_footprint should be int, but is {} = {}'.format(
            type(resolved_footprint), resolved_footprint))
    correct_footprint = 32624988214
    assert resolved_footprint == correct_footprint, \
           'Incorrect algorithmic footprint: {}'.format(resolved_footprint)
    print('Alg mem footprint: {}\nWith specified dims: {}\n'.format(
        alg_footprint, resolved_footprint))

    # Calculate algorithmic IO per step
    total_io_footprint = 0
    for op in graph.getPlaceholders():
        total_io_footprint += op.calcAlgFootprint()
    resolved_io_footprint = total_io_footprint.subs(bind_subs)
    print('Alg IO footprint: {}\nWith specified dims: {}\n'.format(
        total_io_footprint, resolved_io_footprint))

    print('VERBOSE ALGORTHMIC FLOPS:')
    graph.calcAlgFlops(verbose=True)
    print('')

    print('VERBOSE ALGORTHMIC BYTES:')
    graph.calcAlgBytes(verbose=True)
    print('')

    print('VERBOSE ALGORTHMIC FOOTPRINT:')
    graph.calcAlgFootprint(verbose=True)
    print('')

    # HACKY WAY TO SAVE MODELS FOR NOW!
    pickle.dump(
        graph,
        open(
            'catamount/frameworks/example_graphs/tensorflow/full_models/speech_attention/graph_speech_attention.p',
            'wb'))

    if is_pytest_run:
        return

    print('\n\n======= Algorithmic graph-level analytics: =======')

    encoder_dims = [
        32, 64, 96, 128, 160, 192, 256, 320, 384, 448, 512, 640, 768, 892,
        1024, 1152, 1280, 1408, 1548, 1702, 1872, 2059, 2264, 2490, 2739, 3012,
        3289
    ]
    base_encoder_steps = 335
    base_subbatch_size = 32
    base_attn_dim = 128
    base_conv_width = 50
    base_attn_hidden_dim = 512
    base_dec_hidden_dim = 512
    base_enc_hidden_dim = 1024

    bind_subs[audio_features_symbol] = base_audio_features
    bind_subs[encoder_steps_symbol] = base_encoder_steps
    bind_subs[decoder_steps_symbol] = (encoder_steps_symbol // 2) // 2
    bind_subs[subbatch_size_symbol] = base_subbatch_size
    bind_subs[attn_dim_symbol] = base_attn_dim
    bind_subs[attn_hidden_dim_symbol] = enc_hidden_dim_symbol // 2
    bind_subs[dec_hidden_dim_symbol] = enc_hidden_dim_symbol // 2
    bind_subs[output_vocab_symbol] = base_output_vocab
    bind_subs[conv_width_symbol] = base_conv_width
    # bind_subs[enc_hidden_dim_symbol] = base_enc_hidden_dim
    bind_subs[num_conv_filters_symbol] = 4
    bind_subs[graph_iters_symbol] = 1

    bind_subs.pop(enc_hidden_dim_symbol)
    resolved_params = parameters.subs(bind_subs)

    print('Symbol associations: {}\n'.format(bind_subs))

    print(
        'Algorithmic Flops by hidden dimension, params, and per-batch-sample:')
    resolved_flops = alg_flops.subs(bind_subs)
    for enc_dim in encoder_dims:
        graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim})
        graph_flops = resolved_flops.subs({enc_hidden_dim_symbol: enc_dim})
        graph_flops_per_sample = float(graph_flops) / \
                                 bind_subs[subbatch_size_symbol]
        print('{}\t{}\t{}\t{}'.format(enc_dim, graph_params, graph_flops,
                                      int(graph_flops_per_sample)))

    print('\nAlgorithmic bytes accessed by hidden dimension, params:')
    resolved_bytes = alg_bytes.subs(bind_subs)
    for enc_dim in encoder_dims:
        graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim})
        graph_bytes = resolved_bytes.subs({enc_hidden_dim_symbol: enc_dim})
        print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_bytes))

    print('\nAlgorithmic memory footprint by hidden dimension, params:')
    resolved_footprint = alg_footprint.subs(bind_subs)
    for enc_dim in encoder_dims:
        graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim})
        graph_footprint = resolved_footprint.subs(
            {enc_hidden_dim_symbol: enc_dim})
        print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_footprint))

    print(
        '\nAlgorithmic minimal memory footprint by hidden dimension, params:')
    full_subs = dict(bind_subs)
    for enc_dim in encoder_dims:
        graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim})
        full_subs[enc_hidden_dim_symbol] = enc_dim
        graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs)
        print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_min_foot))
Ejemplo n.º 14
0
def test_tf_static_unroll_rnn():
    graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename)
    assert graph.isValid()

    algorithmic_flops = graph.calcAlgFlops()

    ba_0 = utils.getIntSymbolFromString('basic_rnn_cell/BiasAdd:0::dim_0')
    ba_1 = utils.getIntSymbolFromString('basic_rnn_cell/BiasAdd_1:0::dim_0')
    ba_2 = utils.getIntSymbolFromString('basic_rnn_cell/BiasAdd_2:0::dim_0')
    ba_3 = utils.getIntSymbolFromString('basic_rnn_cell/BiasAdd_3:0::dim_0')
    ba_4 = utils.getIntSymbolFromString('basic_rnn_cell/BiasAdd_4:0::dim_0')
    mm_0 = utils.getIntSymbolFromString('basic_rnn_cell/MatMul:0::dim_0')
    mm_1 = utils.getIntSymbolFromString('basic_rnn_cell/MatMul_1:0::dim_0')
    mm_2 = utils.getIntSymbolFromString('basic_rnn_cell/MatMul_2:0::dim_0')
    mm_3 = utils.getIntSymbolFromString('basic_rnn_cell/MatMul_3:0::dim_0')
    mm_4 = utils.getIntSymbolFromString('basic_rnn_cell/MatMul_4:0::dim_0')
    th_0 = utils.getIntSymbolFromString('basic_rnn_cell/Tanh:0::dim_0')
    th_1 = utils.getIntSymbolFromString('basic_rnn_cell/Tanh_1:0::dim_0')
    th_2 = utils.getIntSymbolFromString('basic_rnn_cell/Tanh_2:0::dim_0')
    th_3 = utils.getIntSymbolFromString('basic_rnn_cell/Tanh_3:0::dim_0')
    th_4 = utils.getIntSymbolFromString('basic_rnn_cell/Tanh_4:0::dim_0')
    correct_alg_flops = 24 * (ba_0 + ba_1 + ba_2 + ba_3 + ba_4) + \
                        2304 * (mm_0 + mm_1 + mm_2 + mm_3 + mm_4) + \
                        144 * (th_0 + th_1 + th_2 + th_3 + th_4) + 2305

    print('Loaded Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Initial alg flops incorrect!\n  Expecting: {}\n  Calculated: {}' \
        .format(correct_alg_flops, algorithmic_flops)

    # Now, bind tensor names in the graph and verify that the algorithmic
    # Flop counts reflect the new name bindings
    batch_size = utils.getIntSymbolFromString('batch_size')
    # NOTE: This also works: batch_size = 'batch_size'
    # Bind placeholders (a and b) output dimensions 0 to name batch_size
    bind_dict = {
        'a': ['seq_length', 'batch_size', 'hidden_dim'],
        'init_state': ['batch_size', 'hidden_dim']
    }
    graph.bindTensorShapeDimensions(bind_dict)

    algorithmic_flops = graph.calcAlgFlops()

    # Update the algorithmic Flops formula
    correct_alg_flops = correct_alg_flops.subs({
        ba_0: batch_size,
        ba_1: batch_size,
        ba_2: batch_size,
        ba_3: batch_size,
        ba_4: batch_size,
        mm_0: batch_size,
        mm_1: batch_size,
        mm_2: batch_size,
        mm_3: batch_size,
        mm_4: batch_size,
        th_0: batch_size,
        th_1: batch_size,
        th_2: batch_size,
        th_3: batch_size,
        th_4: batch_size
    })
    print('Bound Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Bound alg flops incorrect!\n  Expecting: {}\n  Calculated: {}' \
        .format(correct_alg_flops, algorithmic_flops)
def run_tf_language_model(domain=None, build_projection=False):
    global is_pytest_run

    if domain == 'wordlm':
        graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/word_lm_n2004_l2_sgd_lr0.2_nodrop_b128_v10k_d20_s80-best_model.meta'
    elif domain == 'charlm':
        graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/char_lm_n2004_l10_sgd_lr0.15_rhn_b128_vchar_d1.0_s150-latest_model.meta'
    elif domain == 'nmt':
        graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/nmt_el2_dl1_n1024_b128-translate.ckpt-1000.meta'
    else:
        raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain))

    graph = catamount.frameworks.tensorflow.import_graph(graph_meta)
    assert graph.isValid()

    # Next, remove ops that are not executed during a standard training step:
    # TODO: Implement feeds->fetches calcAlg*
    if domain == 'wordlm':
        graph_ops = list(graph._ops_by_name.values())
        for op in graph_ops:
            # Certain ops are only used for inference
            if 'Model/Recurrent_1_lstm_3/' in op.name or \
               'Model/Recurrent_2_lstm_3/' in op.name or \
               'Model/FullSoftmaxLoss_1_3/' in op.name or \
               'Model/Collapse_1/' in op.name or \
               'Model/Embedding_1_3/' in op.name or \
               'Model/Labels_1/' in op.name or \
               'Model/Mask_1/' in op.name:
                graph.removeOp(op)
            elif op.name == 'Model/Sum_1' or \
                 op.name == 'Model/Cast_3' or \
                 op.name == 'Model/Cast_2' or \
                 op.name == 'Model/Size_1' or \
                 op.name == 'Model/truediv_2' or \
                 op.name == 'Model/truediv_3' or \
                 op.name == 'Model/Exp_1':
                graph.removeOp(op)
    elif domain == 'charlm':
        graph_ops = list(graph._ops_by_name.values())
        for op in graph_ops:
            # Certain ops are only used for inference
            if 'Model/Recurrent_1_rhn_3/' in op.name or \
               'Model/FullSoftmaxLoss_1_3/' in op.name or \
               'Model/Collapse_1/' in op.name or \
               'Model/Embedding_1_3/' in op.name or \
               'Model/Labels_1/' in op.name or \
               'Model/Mask_1/' in op.name:
                graph.removeOp(op)
            elif op.name == 'Model/Cast_1' or \
                 op.name == 'Model/Sum_1' or \
                 op.name == 'Model/Size_1' or \
                 op.name == 'Model/truediv_2' or \
                 op.name == 'Model/truediv_3' or \
                 op.name == 'Model/Exp_1':
                graph.removeOp(op)
    elif domain == 'nmt':
        pass
    else:
        raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain))

    if not is_pytest_run:
        print('Initial graph:\n{}\n'.format(graph))
    init_params = graph.calcModelParameters()
    print('Initial parameters: {}'.format(init_params))
    print('Initial Flops: {}\n'.format(graph.calcAlgFlops()))

    print('Placeholders:')
    for op in graph.getPlaceholders():
        print(op.debugString())
    print('')

    # Set up symbols to name dimensions
    hidden_dim_symbol = utils.getIntSymbolFromString('hidden_dim')
    vocab_size_symbol = utils.getIntSymbolFromString('vocab_size')
    subbatch_size_symbol = utils.getIntSymbolFromString('subbatch_size')
    sequence_length_symbol = utils.getIntSymbolFromString('sequence_length')
    batch_times_seq_symbol = sequence_length_symbol * subbatch_size_symbol
    graph_iters_symbol = utils.getIntSymbolFromString('graph::iters')

    # Convert these constant dimensions to symbols
    base_subbatch_size = None
    base_sequence_length = None
    if domain == 'wordlm':
        base_hidden_dim = 2004
        base_vocab_size = 10004
    elif domain == 'charlm':
        base_hidden_dim = 2004
        base_vocab_size = 98
    elif domain == 'nmt':
        base_hidden_dim = 1024
        base_vocab_size = 36548
        base_sequence_length = 19
        base_subbatch_size = 128
    else:
        raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain))

    # HAXXX: Manually setting TensorArray shapes!
    if domain == 'wordlm' or domain == 'charlm' or domain == 'nmt':
        for op in graph._ops_by_name.values():
            op_name_suffix = op.name.split('/')[-1]
            if 'TensorArrayGather' in op_name_suffix:
                assert isinstance(op, UnknownOp)
                assert len(op._inputs) == 3
                assert len(op._outputs) == 1
                if domain == 'wordlm' or domain == 'charlm':
                    assert op._outputs[0].shape.isUnknown() or \
                           op._outputs[0].shape.rank == 3, \
                           '{}'.format(op.name)
                    gather_shape = [
                        sequence_length_symbol, subbatch_size_symbol,
                        hidden_dim_symbol
                    ]
                else:
                    assert domain == 'nmt'
                    assert op._outputs[0].shape.isUnknown() or \
                           op._outputs[0].shape.rank == 2 or \
                           op._outputs[0].shape.rank == 3, \
                           '{}'.format(op.name)
                    if not op._outputs[0].shape.isUnknown():
                        if op._outputs[0].shape.rank == 3:
                            out_shape = [
                                base_sequence_length, base_subbatch_size,
                                base_hidden_dim
                            ]
                            # Verify that the shape is clearly specified
                            op._outputs[0].mergeShape(out_shape,
                                                      make_symbolic=True)
                            gather_shape = [
                                sequence_length_symbol, subbatch_size_symbol,
                                hidden_dim_symbol
                            ]
                        else:
                            # This TAGather is known to be unused, so who cares?!
                            assert len(op._outputs[0].consumers) == 0
                            continue
                op._outputs[0].mergeShape(gather_shape, make_symbolic=True)
            elif 'TensorArraySize' in op_name_suffix:
                assert isinstance(op, UnknownOp)
                assert len(op._inputs) == 2
                assert len(op._outputs) == 1
                assert op._outputs[0].shape.rank == 0
                op._outputs[0].setValue(sequence_length_symbol)
            elif 'TensorArrayRead' in op_name_suffix:
                assert isinstance(op, UnknownOp)
                assert len(op._inputs) == 3
                assert len(op._outputs) == 1
                assert op._outputs[0].shape.isUnknown() or \
                       op._outputs[0].shape.rank == 2, \
                       '{}'.format(op.name)
                if not op._outputs[0].shape.isUnknown():
                    assert op._outputs[0].shape.dims[
                        1].value == base_hidden_dim
                read_shape = [subbatch_size_symbol, hidden_dim_symbol]
                op._outputs[0].mergeShape(read_shape, make_symbolic=True)
    else:
        raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain))

    assert graph.isValid()

    if domain == 'wordlm':
        const_dict = {
            'Model/Collapse/Reshape/shape': [-1, hidden_dim_symbol],
            'Model/Recurrent_\d_lstm_1/rnn/Const': [hidden_dim_symbol],
            'Model/Recurrent_\d_lstm_1/rnn/Const_1': [hidden_dim_symbol],
            'Model/Gradient/Compute/gradients/Model/Embedding_1_1/Gather_grad/Shape':
            [vocab_size_symbol, hidden_dim_symbol],
            'Model/Gradient/Compute/gradients/Model/FullSoftmaxLoss_1_1/add_grad/Shape_1':
            [1, vocab_size_symbol],
        }
    elif domain == 'charlm':
        const_dict = {
            'Model/Collapse/Reshape/shape': [-1, hidden_dim_symbol],
            'Model/Recurrent_1_rhn_1/rnn/Const': [hidden_dim_symbol],
            'Model/Recurrent_1_rhn_1/rnn/Const_1': [hidden_dim_symbol],
            'Model/Gradient/Compute/gradients/Model/FullSoftmaxLoss_1_1/add_grad/Shape_1':
            [1, vocab_size_symbol],
            'Model/Gradient/Compute/gradients/Model/Embedding_1_1/Gather_grad/Shape':
            [vocab_size_symbol, hidden_dim_symbol],
        }
    elif domain == 'nmt':
        const_dict = {
            'gradients/dynamic_seq2seq/decoder/output_projection/Tensordot/Reshape_1_grad/Shape':
            [hidden_dim_symbol, vocab_size_symbol],
            'gradients/dynamic_seq2seq/decoder/embedding_lookup_grad/Shape':
            [vocab_size_symbol, hidden_dim_symbol],
            'gradients/dynamic_seq2seq/encoder/embedding_lookup_grad/Shape':
            [vocab_size_symbol, hidden_dim_symbol],
            'gradients/dynamic_seq2seq/decoder/LuongAttention/memory_layer/Tensordot/Reshape_1_grad/Shape':
            [2 * hidden_dim_symbol, hidden_dim_symbol],
            'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/Const_1': [
                hidden_dim_symbol
            ],
            'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/Const_4': [
                hidden_dim_symbol
            ],
            'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/DeviceWrapperZeroState/DropoutWrapperZeroState/BasicLSTMCellZeroState/Const':
            [hidden_dim_symbol],
            'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/DeviceWrapperZeroState/DropoutWrapperZeroState/BasicLSTMCellZeroState/Const_\d':
            [hidden_dim_symbol],
            'dynamic_seq2seq/decoder/output_projection/Tensordot/Reshape_1/shape':
            [hidden_dim_symbol, vocab_size_symbol],
            'dynamic_seq2seq/decoder/output_projection/Tensordot/Const_2': [
                vocab_size_symbol
            ],
            'dynamic_seq2seq/decoder/decoder/Const': [hidden_dim_symbol],
            'dynamic_seq2seq/decoder/decoder/Const_1': [hidden_dim_symbol],
            'dynamic_seq2seq/decoder/LuongAttention/memory_layer/Tensordot/Const_2':
            [hidden_dim_symbol],
            'dynamic_seq2seq/decoder/LuongAttention/memory_layer/Tensordot/Reshape_1/shape':
            [2 * hidden_dim_symbol, hidden_dim_symbol],
            'dynamic_seq2seq/decoder/DeviceWrapperZeroState/AttentionWrapperZeroState/Const':
            [hidden_dim_symbol],
            'dynamic_seq2seq/decoder/DeviceWrapperZeroState/AttentionWrapperZeroState/Const_1':
            [hidden_dim_symbol],
            'dynamic_seq2seq/decoder/DeviceWrapperZeroState/AttentionWrapperZeroState/DeviceWrapperZeroState/DropoutWrapperZeroState/BasicLSTMCellZeroState/Const':
            [hidden_dim_symbol],
            'dynamic_seq2seq/decoder/DeviceWrapperZeroState/AttentionWrapperZeroState/DeviceWrapperZeroState/DropoutWrapperZeroState/BasicLSTMCellZeroState/Const_\d':
            [hidden_dim_symbol],
            'buffer_size':
            256 * hidden_dim_symbol,
            'buffer_size_1':
            256 * hidden_dim_symbol,
            'buffer_size_[2-8]':
            125 * hidden_dim_symbol,
        }
    else:
        raise NotImplementedError(
            'Manually set constant op values for domain {}'.format(domain))

    graph.bindConstantValues(const_dict)

    # Next, bind the constant, placeholder, and variable shapes and propagate
    if domain == 'wordlm':
        bind_dict = { # Constants
                      'Model/Gradient/Compute/gradients/Model/Recurrent_\d_lstm_1/rnn/while/rnn/BiasAdd/Enter_grad/b_acc': [4 * hidden_dim_symbol],
                      'Model/Gradient/Compute/gradients/Model/Recurrent_\d_lstm_1/rnn/while/rnn/MatMul/Enter_grad/b_acc': [2 * hidden_dim_symbol, 4 * hidden_dim_symbol],
                      # Placeholders
                      'Input/Input': [subbatch_size_symbol, sequence_length_symbol],
                      'Labels/Labels': [subbatch_size_symbol, sequence_length_symbol],
                      'Model/Placeholder': [subbatch_size_symbol, hidden_dim_symbol],
                      'Model/Placeholder_\d': [subbatch_size_symbol, hidden_dim_symbol],
                      # Variables
                      'Model/Embedding_1/EmbeddingWeights': [vocab_size_symbol, hidden_dim_symbol],
                      'Model/FullSoftmaxLoss_1/W_Softmax': [vocab_size_symbol, hidden_dim_symbol],
                      'Model/FullSoftmaxLoss_1/b_Softmax': [1, vocab_size_symbol],
                      'Model/Recurrent_\d_lstm/rnn/Bias': [4 * hidden_dim_symbol],
                      'Model/Recurrent_\d_lstm/rnn/Matrix': [2 * hidden_dim_symbol, 4 * hidden_dim_symbol],
                    }
    elif domain == 'charlm':
        bind_dict = { # Constants
                      'Model/Gradient/Compute/gradients/Model/Recurrent_1_rhn_1/rnn/while/[ht]_0/[ht]_0/BiasAdd/Enter_grad/b_acc': [hidden_dim_symbol],
                      'Model/Gradient/Compute/gradients/Model/Recurrent_1_rhn_1/rnn/while/[ht]_0/[ht]_0/MatMul/Enter_grad/b_acc': [2 * hidden_dim_symbol, hidden_dim_symbol],
                      'Model/Gradient/Compute/gradients/Model/Recurrent_1_rhn_1/rnn/while/[ht]_[1-9]/[ht]_[1-9]/BiasAdd/Enter_grad/b_acc': [hidden_dim_symbol],
                      'Model/Gradient/Compute/gradients/Model/Recurrent_1_rhn_1/rnn/while/[ht]_[1-9]/[ht]_[1-9]/MatMul/Enter_grad/b_acc': [hidden_dim_symbol, hidden_dim_symbol],
                      'Model/Recurrent_1_rhn/rnn/[ht]_[0-9]/Bias/Initializer/Const': [hidden_dim_symbol],
                      # Placeholders
                      'Input/Input': [subbatch_size_symbol, sequence_length_symbol],
                      'Labels/Labels': [subbatch_size_symbol, sequence_length_symbol],
                      'Model/Placeholder': [subbatch_size_symbol, hidden_dim_symbol],
                      'Model/Placeholder_1': [subbatch_size_symbol, hidden_dim_symbol],
                      # Variables
                      'Model/Embedding_1/EmbeddingWeights': [vocab_size_symbol, hidden_dim_symbol],
                      'Model/FullSoftmaxLoss_1/W_Softmax': [vocab_size_symbol, hidden_dim_symbol],
                      'Model/FullSoftmaxLoss_1/b_Softmax': [1, vocab_size_symbol],
                      'Model/Recurrent_1_rhn/rnn/h_0/Bias': [hidden_dim_symbol],
                      'Model/Recurrent_1_rhn/rnn/h_0/Matrix': [2 * hidden_dim_symbol, hidden_dim_symbol],
                      'Model/Recurrent_1_rhn/rnn/h_[1-9]/Bias': [hidden_dim_symbol],
                      'Model/Recurrent_1_rhn/rnn/h_[1-9]/Matrix': [hidden_dim_symbol, hidden_dim_symbol],
                      'Model/Recurrent_1_rhn/rnn/t_0/Bias': [hidden_dim_symbol],
                      'Model/Recurrent_1_rhn/rnn/t_0/Matrix': [2 * hidden_dim_symbol, hidden_dim_symbol],
                      'Model/Recurrent_1_rhn/rnn/t_[1-9]/Bias': [hidden_dim_symbol],
                      'Model/Recurrent_1_rhn/rnn/t_[1-9]/Matrix': [hidden_dim_symbol, hidden_dim_symbol],
                    }
    elif domain == 'nmt':
        # HAX: Manually hack the iterator op
        it_op = graph.opsByName['IteratorGetNext']
        it_op._outputs[0].mergeShape(
            [subbatch_size_symbol, sequence_length_symbol], make_symbolic=True)
        it_op._outputs[1].mergeShape(
            [subbatch_size_symbol, sequence_length_symbol], make_symbolic=True)
        it_op._outputs[2].mergeShape(
            [subbatch_size_symbol, sequence_length_symbol], make_symbolic=True)
        it_op._outputs[3].mergeShape([subbatch_size_symbol],
                                     make_symbolic=True)
        it_op._outputs[4].mergeShape([subbatch_size_symbol],
                                     make_symbolic=True)

        bind_dict = { # Constants
                      'gradients/dynamic_seq2seq/decoder/decoder/while/BasicDecoderStep/decoder/attention/attention_layer/MatMul/Enter_grad/b_acc': [3 * hidden_dim_symbol, hidden_dim_symbol],
                      'gradients/dynamic_seq2seq/decoder/decoder/while/BasicDecoderStep/decoder/attention/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * hidden_dim_symbol],
                      'gradients/dynamic_seq2seq/decoder/decoder/while/BasicDecoderStep/decoder/attention/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * hidden_dim_symbol, 4 * hidden_dim_symbol],
                      'gradients/dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * hidden_dim_symbol],
                      'gradients/dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/[bf]w/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [2 * hidden_dim_symbol, 4 * hidden_dim_symbol],
                      # Placeholders
                      # Variables
                      'dynamic_seq2seq/decoder/attention/attention_layer/kernel': [3 * hidden_dim_symbol, hidden_dim_symbol],
                      'dynamic_seq2seq/decoder/attention/basic_lstm_cell/bias': [4 * hidden_dim_symbol],
                      'dynamic_seq2seq/decoder/attention/basic_lstm_cell/kernel': [3 * hidden_dim_symbol, 4 * hidden_dim_symbol],
                      'dynamic_seq2seq/decoder/memory_layer/kernel': [2 * hidden_dim_symbol, hidden_dim_symbol],
                      'dynamic_seq2seq/decoder/output_projection/kernel': [hidden_dim_symbol, vocab_size_symbol],
                      'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/basic_lstm_cell/bias': [4 * hidden_dim_symbol],
                      'dynamic_seq2seq/encoder/bidirectional_rnn/[bf]w/basic_lstm_cell/kernel': [2 * hidden_dim_symbol, 4 * hidden_dim_symbol],
                      'embeddings/decoder/embedding_decoder': [vocab_size_symbol, hidden_dim_symbol],
                      'embeddings/encoder/embedding_encoder': [vocab_size_symbol, hidden_dim_symbol],
                    }
    else:
        raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain))

    print('Binding variables')

    graph.bindShapesAndPropagate(bind_dict,
                                 warn_if_ill_defined=(not is_pytest_run),
                                 make_symbolic=True)
    assert graph.isValid()

    num_sampled_vocab_symbol = subbatch_size_symbol * sequence_length_symbol
    if domain == 'wordlm':
        base_sequence_length = 80
        base_subbatch_size = 64
        base_num_sampled_vocab = base_subbatch_size * base_sequence_length
        bind_str_subs = {
            'Model/Collapse_1/boolean_mask/Reshape_1:0::num_true':
            sequence_length_symbol * subbatch_size_symbol,
            'Model/Collapse/boolean_mask/Reshape_1:0::num_true':
            sequence_length_symbol * subbatch_size_symbol,
            'Model/Labels_1/boolean_mask/Reshape_1:0::num_true':
            sequence_length_symbol * subbatch_size_symbol,
            'Model/Labels/boolean_mask/Reshape_1:0::num_true':
            sequence_length_symbol * subbatch_size_symbol,
            'Model/Gradient/Compute/gradients/b_count_2_block::iters':
            sequence_length_symbol,
            'Model/Gradient/Compute/gradients/b_count_6_block::iters':
            sequence_length_symbol,
            'Model/Recurrent_1_lstm_1/rnn/while/LoopCond_block::iters':
            sequence_length_symbol,
            'Model/Recurrent_1_lstm_3/rnn/while/LoopCond_block::iters':
            sequence_length_symbol,
            'Model/Recurrent_2_lstm_1/rnn/while/LoopCond_block::iters':
            sequence_length_symbol,
            'Model/Recurrent_2_lstm_3/rnn/while/LoopCond_block::iters':
            sequence_length_symbol,
        }
    elif domain == 'charlm':
        base_sequence_length = 150
        base_subbatch_size = 128
        bind_str_subs = {
            'Model/Collapse/boolean_mask/Reshape_1:0::num_true':
            sequence_length_symbol * subbatch_size_symbol,
            'Model/Labels/boolean_mask/Reshape_1:0::num_true':
            sequence_length_symbol * subbatch_size_symbol,
            'Model/Recurrent_1_rhn_1/rnn/while/LoopCond_block::iters':
            sequence_length_symbol,
            'Model/Gradient/Compute/gradients/b_count_2_block::iters':
            sequence_length_symbol,
        }
    elif domain == 'nmt':
        bind_str_subs = {
            'dynamic_seq2seq/decoder/decoder/while/LoopCond_block::iters':
            sequence_length_symbol,
            'dynamic_seq2seq/encoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters':
            sequence_length_symbol,
            'dynamic_seq2seq/encoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters':
            sequence_length_symbol,
            'gradients/b_count_10_block::iters':
            sequence_length_symbol,
            'gradients/b_count_2_block::iters':
            sequence_length_symbol,
            'gradients/b_count_6_block::iters':
            sequence_length_symbol,
        }
    else:
        raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain))

    if not is_pytest_run:
        print('\n\nCleaned Graph:\n{}'.format(graph))

    print('\n\nBound values')

    bind_subs = {
        graph_iters_symbol: 1,
        hidden_dim_symbol: base_hidden_dim,
        sequence_length_symbol: base_sequence_length,
        subbatch_size_symbol: base_subbatch_size,
        vocab_size_symbol: base_vocab_size,
    }
    var_refs_table = {}
    for var_name, sub_val in bind_str_subs.items():
        var_ref = utils.getIntSymbolFromString(var_name)
        assert var_name not in bind_subs.keys()
        bind_subs[var_ref] = sub_val
        var_refs_table[var_name] = var_ref

    # Verify parameter counts first
    parameters = graph.calcModelParameters()
    if domain == 'wordlm':
        correct_symbolic_params = 16 * hidden_dim_symbol**2 + \
                                  2 * hidden_dim_symbol * vocab_size_symbol + \
                                  8 * hidden_dim_symbol + \
                                  vocab_size_symbol + 2
        correct_params = 104378326
        correct_flops = 2597058084257
        correct_bytes = 143652774404
        correct_total_footprint = 49660373192
    elif domain == 'charlm':
        correct_symbolic_params = 22 * hidden_dim_symbol**2 + \
                                  2 * hidden_dim_symbol * vocab_size_symbol + \
                                  20 * hidden_dim_symbol + \
                                  vocab_size_symbol + 2
        correct_params = 88785316
        correct_flops = 10228050930711
        correct_bytes = 445302356084
        correct_total_footprint = 156135676796
    elif domain == 'nmt':
        correct_symbolic_params = 33 * hidden_dim_symbol**2 + \
                                  3 * hidden_dim_symbol * vocab_size_symbol + \
                                  12 * hidden_dim_symbol + 1
        correct_params = 146890753
        correct_flops = 1053984410589
        correct_bytes = 36901043741
        correct_total_footprint = 14551615608
    else:
        raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain))
    assert sympy.simplify(parameters - correct_symbolic_params) == 0, \
           'Param count incorrect!\n  Expecting: {}\n  Calculated: {}' \
           .format(correct_symbolic_params, parameters)

    print('Symbol associations: {}\n'.format(bind_subs))

    # Calculate model parameter count
    resolved_params = parameters.subs(bind_subs)
    try:
        resolved_params = int(resolved_params)
    except:
        print('ERROR: resolved_params should be int, but is {} = {}'.format(
            type(resolved_params), resolved_params))
    assert resolved_params == correct_params, \
           'Incorrect model params: {}'.format(resolved_params)
    print('Parameters: {}\nWith specified dims: {}\n'.format(
        parameters, resolved_params))

    # Calculate algorithmic Flops
    alg_flops = graph.calcAlgFlops()
    resolved_flops = alg_flops.subs(bind_subs)
    try:
        resolved_flops = int(resolved_flops)
    except:
        print('ERROR: resolved_flops should be int, but is {} = {}'.format(
            type(resolved_flops), resolved_flops))
    assert resolved_flops == correct_flops, \
           'Incorrect algorithmic flops: {}'.format(resolved_flops)
    print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format(
        alg_flops, resolved_flops))

    # Calculate algorthmic Bytes accessed
    alg_bytes = graph.calcAlgBytes()
    resolved_bytes = alg_bytes.subs(bind_subs)
    try:
        resolved_bytes = int(resolved_bytes)
    except:
        print('ERROR: resolved_bytes should be int, but is {} = {}'.format(
            type(resolved_bytes), resolved_bytes))
    assert resolved_bytes == correct_bytes, \
           'Incorrect algorithmic bytes: {}'.format(resolved_bytes)
    print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format(
        alg_bytes, resolved_bytes))

    # Calculate total memory footprint
    alg_footprint = graph.calcAlgFootprint()
    resolved_footprint = alg_footprint.subs(bind_subs)
    try:
        resolved_footprint = int(resolved_footprint)
    except:
        print('ERROR: resolved_footprint should be int, but is {} = {}'.format(
            type(resolved_footprint), resolved_footprint))
    assert resolved_footprint == correct_total_footprint, \
           'Incorrect algorithmic footprint: {}'.format(resolved_footprint)
    print('Alg mem footprint: {}\nWith specified dims: {}\n'.format(
        alg_footprint, resolved_footprint))

    # Calculate minimal memory footprint
    alg_min_footprint = graph.calcMinimalFootprint(symbol_subs=bind_subs)
    print('Alg minimal footprint (With specified dims): {}\n'.format(
        alg_min_footprint))

    # Calculate algorithmic IO per step
    total_io_footprint = 0
    for op in graph.getPlaceholders():
        total_io_footprint += op.calcAlgFootprint()
    if isinstance(total_io_footprint, int):
        resolved_io_footprint = total_io_footprint
    else:
        resolved_io_footprint = total_io_footprint.subs(bind_subs)
    print('Alg IO footprint: {}\nWith specified dims: {}\n'.format(
        total_io_footprint, resolved_io_footprint))

    if not is_pytest_run:
        print('VERBOSE ALGORTHMIC FLOPS:')
        graph.calcAlgFlops(verbose=True)
        print('')

        print('VERBOSE ALGORTHMIC BYTES:')
        graph.calcAlgBytes(verbose=True)
        print('')

        print('VERBOSE ALGORTHMIC FOOTPRINT:')
        graph.calcAlgFootprint(verbose=True)
        print('')

    # HACKY WAY TO SAVE MODELS FOR NOW!
    pickle.dump(
        graph,
        open(
            'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/graph_{}.p'
            .format(domain), 'wb'))

    if is_pytest_run:
        return

    print('\n\n======= Algorithmic graph-level analytics: =======')

    if domain == 'wordlm':
        hidden_dims = [
            1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 18, 20, 25, 28, 35, 40, 50, 56,
            69, 78, 86, 96, 108, 119, 123, 133, 148, 163, 182, 202, 221, 246,
            273, 297, 329, 330, 364, 396, 436, 437, 520, 572, 617, 676, 740,
            796, 869, 948, 1017, 1106, 1202, 1286, 1394, 1510, 1611, 1742,
            1882, 2004, 2161, 2476, 3040, 3714, 4520, 5478, 6628, 8019, 9702,
            11739, 14204, 17186, 20795, 25161, 30444, 36837, 38100
        ]
        bind_subs[subbatch_size_symbol] = 128
    elif domain == 'charlm':
        hidden_dims = [
            1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 18, 20, 25, 28, 35, 40, 50, 56,
            69, 78, 86, 96, 108, 119, 123, 133, 148, 163, 182, 202, 221, 246,
            273, 297, 329, 330, 364, 396, 436, 437, 520, 572, 617, 676, 740,
            796, 869, 948, 1017, 1106, 1202, 1286, 1394, 1510, 1611, 1742,
            1882, 2004, 2161, 2476, 3040, 3714, 5051, 6869, 9341, 12703, 17276,
            23495, 31953, 43456, 59100, 80376, 81400
        ]
        bind_subs[subbatch_size_symbol] = 96
    elif domain == 'nmt':
        hidden_dims = [
            32, 64, 96, 128, 192, 256, 384, 512, 768, 1024, 1280, 1536, 2048,
            2560, 3072, 3747, 4571, 5576, 6802, 8298, 10123, 12350, 15067,
            18381, 22350
        ]
        bind_subs[subbatch_size_symbol] = 96
        bind_subs[sequence_length_symbol] = 26
    else:
        raise NotImplementedError('ERROR: Unknown domain: {}'.format(domain))

    bind_subs.pop(hidden_dim_symbol)
    resolved_params = parameters.subs(bind_subs)

    print('Symbol associations: {}\n'.format(bind_subs))

    print(
        'Algorithmic Flops by hidden dimension, params, and per-batch-sample:')
    resolved_flops = alg_flops.subs(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        graph_flops = resolved_flops.subs({hidden_dim_symbol: hid_dim})
        graph_flops_per_sample = float(graph_flops) / \
                                 bind_subs[subbatch_size_symbol]
        print('{}\t{}\t{}\t{}'.format(hid_dim, graph_params, graph_flops,
                                      int(graph_flops_per_sample)))

    print('\nAlgorithmic bytes accessed by hidden dimension, params:')
    resolved_bytes = alg_bytes.subs(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        graph_bytes = resolved_bytes.subs({hidden_dim_symbol: hid_dim})
        print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_bytes))

    print('\nAlgorithmic total memory footprint by hidden dimension, params:')
    resolved_footprint = alg_footprint.subs(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        graph_footprint = resolved_footprint.subs({hidden_dim_symbol: hid_dim})
        print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_footprint))

    print(
        '\nAlgorithmic minimal memory footprint by hidden dimension, params:')
    full_subs = dict(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        full_subs[hidden_dim_symbol] = hid_dim
        graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs)
        print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_min_foot))
    full_subs = dict(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        full_subs[hidden_dim_symbol] = hid_dim
        graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs)
        print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_min_foot))


if False:
    if domain == 'wordlm':
        if args.build_projection:
            # This is hacky anyway... Import required parts here:
            from catamount.ops.optimizer_ops import *
            from catamount.tensors.tensor import *

            projection_dim_symbol = utils.getIntSymbolFromString(
                'projection_dim')
            # (1) Project output of the second recurrent layer. Save the
            # consumers of the output to send the projected values there
            proj_in_op = graph.opsByName['Model/Collapse/Reshape']
            proj_input = proj_in_op._outputs[0]
            proj_input_consumers = proj_input._consumers
            proj_input._consumers = {}
            # (1a) Create projection matrix
            proj_weights = catamount.variable(
                'Model/Collapse/projection/W',
                [hidden_dim_symbol, projection_dim_symbol], graph)
            # (1b) Create matrix multiply for projection
            proj_mm_out = catamount.matmul('Model/Collapse/projection/MatMul',
                                           [None, projection_dim_symbol],
                                           proj_input, proj_weights, graph)
            # (2) Feed projection to output consumers
Ejemplo n.º 17
0
def run_tf_image_resnet(depth, filter_scale=1.0):
    global is_pytest_run
    model_string = '_d{}_fs{}_'.format(depth, filter_scale)
    test_outputs_dir = 'catamount/frameworks/example_graphs/tensorflow/full_models/image_classification'
    graph_meta = None
    for root, dirs, files in os.walk(test_outputs_dir):
        for filename in files:
            if 'graph{}'.format(
                    model_string) in filename and '.meta' in filename:
                # Take the first graph that we find in the directory
                graph_meta = os.path.join(root, filename)
                break
        if graph_meta is not None:
            break

    if graph_meta is None:
        raise FileNotFoundError(
            'Unable to find model string {} in directory {}'.format(
                model_string, test_outputs_dir))

    graph = catamount.frameworks.tensorflow.import_graph(graph_meta)
    assert graph.isValid()

    # ============ TO REMOVE INITIALIZATION OPS! =============
    # NOTE: This code is pretty general and is likely to be migrated into
    # Catamount code for removing TF-specific initialization ops
    from catamount.ops import AssignOp
    from catamount.ops import VariableOp
    assign_ops = set()
    for op in graph.opsByName.values():
        if isinstance(op, AssignOp):
            assign_ops.add(op)
    for assign_op in assign_ops:
        my_ancestors = set()
        my_frontier = set()
        my_frontier.add(assign_op)
        while len(my_frontier) > 0:
            next_op = my_frontier.pop()
            for in_tensor in next_op.inputs:
                if not isinstance(in_tensor.producer, VariableOp):
                    my_frontier.add(in_tensor.producer)
            my_ancestors.add(next_op)
        for next_op in my_ancestors:
            graph.removeOp(next_op)
    assert graph.isValid()

    # Manually remove the inference parts of graph
    graph_ops = list(graph._ops_by_name.values())
    for op in graph_ops:
        # Certain ops are only used for inference
        if 'InferenceTower/' in op.name or \
           'InferenceRunner/' in op.name or \
           op.name == 'MergeAllSummariesRunWithOp/Merge/MergeSummary':
            graph.removeOp(op)
    assert graph.isValid()

    print('Initial graph:\n{}\n'.format(graph))
    init_params = graph.calcModelParameters()
    print('Initial parameters: {}'.format(init_params))
    print('Initial Flops: {}\n'.format(graph.calcAlgFlops()))

    print('Placeholders:')
    for op in graph.getPlaceholders():
        print(op.debugString())
    print('')

    # Set up symbols to name dimensions
    output_classes_symbol = utils.getPositiveIntSymbolFromString('out_classes')
    subbatch_size_symbol = utils.getPositiveIntSymbolFromString(
        'subbatch_size')
    image_height_symbol = utils.getPositiveIntSymbolFromString('image_height')
    image_width_symbol = utils.getPositiveIntSymbolFromString('image_width')
    num_in_channels_symbol = utils.getPositiveIntSymbolFromString(
        'num_in_channels')
    graph_iters_symbol = utils.getIntSymbolFromString('graph::iters')
    feature_channels_symbol = utils.getPositiveIntSymbolFromString(
        'feature_channels')

    # Find and replace convolution/pooling dimensions also:
    # Dimension(64 * 2^k): conv/pool feature channels
    base_output_classes = 1000
    base_num_in_channels = 3
    base_feature_channels = 64

    base_image_height = 224
    base_image_width = 224
    base_half_im_height = 112
    half_im_height_symbol = image_height_symbol // 2
    base_half_im_width = 112
    half_im_width_symbol = image_width_symbol // 2
    base_quart_im_height = 56
    quart_im_height_symbol = (image_height_symbol // 2) // 2
    base_quart_im_width = 56
    quart_im_width_symbol = (image_width_symbol // 2) // 2
    base_eighth_im_height = 28
    eighth_im_height_symbol = ((image_height_symbol // 2) // 2) // 2
    base_eighth_im_width = 28
    eighth_im_width_symbol = ((image_width_symbol // 2) // 2) // 2
    base_sixtnth_im_height = 14
    sixtnth_im_height_symbol = (((image_height_symbol // 2) // 2) // 2) // 2
    base_sixtnth_im_width = 14
    sixtnth_im_width_symbol = (((image_width_symbol // 2) // 2) // 2) // 2
    base_small_im_height = 7
    small_im_height_symbol = (((
        (image_height_symbol // 2) // 2) // 2) // 2) // 2
    base_small_im_width = 7
    small_im_width_symbol = ((((image_width_symbol // 2) // 2) // 2) // 2) // 2

    # Set up a dictionary of placeholders and variables for which we want
    # to make dimensions symbolic. Sift out their dimensions
    bind_dict = { # Placeholders
                      'label': [subbatch_size_symbol],
                      'input': [subbatch_size_symbol, image_height_symbol, image_width_symbol, num_in_channels_symbol],
                    }
    # Parameterize all variable tensor dimensions
    for op in graph._ops_by_name.values():
        if isinstance(op, VariableOp):
            op_name_suffix = op.name.split('/')[-1]
            if op_name_suffix == 'W':
                if op._outputs[0].shape.rank == 4:
                    assert 'conv' in op.name
                    new_shape = []
                    for i in range(op._outputs[0].shape.rank):
                        new_shape.append(
                            op._outputs[0].shape.getDimension(i).value)
                    if new_shape[2] % base_feature_channels == 0:
                        in_filters = (new_shape[2] // \
                                      base_feature_channels) * \
                                     feature_channels_symbol
                    elif new_shape[2] == 3:
                        # This is the first convolution on image channels (3)
                        assert op.name == 'conv0/W'
                        in_filters = num_in_channels_symbol
                    else:
                        print('FIX ME: base in filters {}'.format(
                            new_shape[2]))
                        assert 0
                    if new_shape[3] % base_feature_channels == 0:
                        out_filters = (new_shape[3] // \
                                       base_feature_channels) * \
                                       feature_channels_symbol
                    else:
                        print('FIX ME: base out filters {}'.format(
                            new_shape[3]))
                        assert 0
                    new_shape[2] = in_filters
                    new_shape[3] = out_filters
                else:
                    # This is the output layer with output_classes dimension
                    assert op.name == 'linear/W'
                    assert op._outputs[0].shape.rank == 2
                    in_dim = op._outputs[0].shape.getDimension(0).value
                    assert in_dim % base_feature_channels == 0
                    in_dim = (in_dim // base_feature_channels) * \
                             feature_channels_symbol
                    new_shape = [in_dim, output_classes_symbol]
                bind_dict[op.name] = new_shape
                momentum_op_name = '{}/Momentum'.format(op.name)
                momentum_op = graph._ops_by_name[momentum_op_name]
                bind_dict[momentum_op.name] = new_shape
            elif op_name_suffix == 'b':
                # This is the output layer with output_classes dimension
                assert op.name == 'linear/b'
                assert op._outputs[0].shape.rank == 1
                assert op._outputs[0].shape.getDimension(0).value == \
                       base_output_classes
                new_shape = [output_classes_symbol]
                bind_dict[op.name] = new_shape
                momentum_op_name = '{}/Momentum'.format(op.name)
                momentum_op = graph._ops_by_name[momentum_op_name]
                bind_dict[momentum_op.name] = new_shape
            elif op_name_suffix == 'beta' or op_name_suffix == 'gamma' or \
                 op_name_suffix == 'EMA':
                assert op._outputs[0].shape.rank == 1
                in_dim = op._outputs[0].shape.getDimension(0).value
                assert in_dim % base_feature_channels == 0
                in_dim = (in_dim // base_feature_channels) * \
                         feature_channels_symbol
                new_shape = [in_dim]
                bind_dict[op.name] = new_shape
                if op_name_suffix != 'EMA':
                    momentum_op_name = '{}/Momentum'.format(op.name)
                    momentum_op = graph._ops_by_name[momentum_op_name]
                    bind_dict[momentum_op.name] = new_shape

    # Now handle constant values in the graph
    const_dict = {}
    for op in graph._ops_by_name.values():
        if isinstance(op, ConstantOp):
            if op._outputs[0].value is None:
                continue
            if op._outputs[0].shape.rank == 0:
                print('{}'.format(op.debugString()))
                continue
            assert op._outputs[0].shape.rank == 1
            values = op._outputs[0].value.tolist()
            new_values = []
            changed = False
            for value in values:
                if value > 0 and value % base_feature_channels == 0:
                    value = (value //
                             base_feature_channels) * feature_channels_symbol
                    changed = True
                new_values.append(value)
            # HACKY SPECIAL CASE:
            if op.name == 'tower0/gradients/tower0/conv0/Conv2D_grad/Const':
                assert new_values[2] == base_num_in_channels
                new_values[2] = num_in_channels_symbol
            if changed:
                const_dict[op.name] = new_values
    for const_key, const_val in const_dict.items():
        try:
            if const_key in graph.opsByName.keys():
                const_op = graph.opsByName[const_key]
                assert isinstance(const_op, ConstantOp)
                const_op._outputs[0].setValue(const_val)
            else:
                print('WARN: ConstantOp not found: {}'.format(const_key))
        except Exception as exc:
            print('WARN: ConstantOp unknown problem: {}: {}'.format(
                const_key, exc))

    # TODO (Joel): Add InputQueue ops to avoid manually setting dimensions
    in_deque_op = graph.opsByName['QueueInput/input_deque']
    # print(in_deque_op.debugString())
    out_tensor = in_deque_op._outputs[0]
    for idx, sym in enumerate([
            subbatch_size_symbol, image_height_symbol, image_width_symbol,
            num_in_channels_symbol
    ]):
        out_tensor.shape.setDimension(idx, sym)
        out_tensor.shape.dims[idx]._value = None
    out_tensor = in_deque_op._outputs[1]
    out_tensor.shape.setDimension(0, subbatch_size_symbol)

    graph.bindTensorShapeDimensions(bind_dict,
                                    warn_if_ill_defined=(not is_pytest_run),
                                    make_symbolic=True)

    # Nice little hack to actually propagate MaximumOp values to outputs
    for op in graph._ops_by_name.values():
        if isinstance(op, MaximumOp):
            if op._inputs[0].value is not None and \
               op._inputs[1].value is not None:
                vmax = np.vectorize(lambda x, y: sympy.Max(x, y))
                out_val = vmax(op._inputs[0].value, op._inputs[1].value)
                op._outputs[0].setValue(out_val)

    graph.bindTensorShapeDimensions(bind_dict,
                                    warn_if_ill_defined=(not is_pytest_run),
                                    make_symbolic=True)

    print('Bound values')

    print(graph)

    bind_subs = {
        graph_iters_symbol: 1,
        output_classes_symbol: base_output_classes,
        subbatch_size_symbol: 32,
        image_height_symbol: base_image_height,
        image_width_symbol: base_image_width,
        num_in_channels_symbol: base_num_in_channels,
        feature_channels_symbol: base_feature_channels,
    }

    correct_params = -1
    correct_flops = -1
    correct_bytes = -1
    correct_footprint = -1
    if depth == 18:
        correct_params = 11689514
        correct_flops = 349684163360
        correct_bytes = 7186222676
        correct_footprint = 2802084304
    elif depth == 34:
        correct_params = 21797674
        correct_flops = 705506994208
        correct_bytes = 11162578644
        correct_footprint = 4368689744
    elif depth == 50:
        correct_params = 25557034
        correct_flops = 790954958112
        correct_bytes = 32896462028
        correct_footprint = 12909734408
    elif depth == 101:
        correct_params = 44549162
        correct_flops = 1506507229472
        correct_bytes = 50026672916
        correct_footprint = 19690293072
    elif depth == 152:
        correct_params = 60192810
        correct_flops = 2222688328992
        correct_bytes = 70967716188
        correct_footprint = 27971880088
    else:
        print('WARN: Tests not defined for depth {}'.format(depth))

    # Calculate parameters
    # NOTE: Need to remove Momentum optimizer parameters and moving average values
    momentum_params = 0
    parameters = 0
    for op_name in sorted(graph.opsByName.keys()):
        op = graph.opsByName[op_name]
        if isinstance(op, VariableOp):
            if "Momentum" in op.name or "EMA" in op.name:
                momentum_params += op.calcModelParameters()
            else:
                parameters += op.calcModelParameters()

    all_weights = graph.calcModelParameters()
    assert (all_weights - momentum_params - parameters) == 0

    # Calculate model parameter count
    resolved_params = parameters.subs(bind_subs)
    try:
        resolved_params = int(resolved_params)
    except:
        print('ERROR: resolved_params should be int, but is {} = {}'.format(
            type(resolved_params), resolved_params))
    assert correct_params < 0 or resolved_params == correct_params, \
           'Incorrect model params: {}'.format(resolved_params)
    print('Parameters: {}\nWith specified dims: {}\n'.format(
        parameters, resolved_params))

    # Calculate algorithmic Flops
    alg_flops = graph.calcAlgFlops()
    resolved_flops = alg_flops.subs(bind_subs)
    try:
        resolved_flops = int(resolved_flops)
    except:
        print('ERROR: resolved_flops should be int, but is {} = {}'.format(
            type(resolved_flops), resolved_flops))
    assert correct_flops < 0 or resolved_flops == correct_flops, \
           'Incorrect algorithmic flops: {}'.format(resolved_flops)
    print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format(
        alg_flops, resolved_flops))

    # Calculate algorthmic Bytes accessed
    alg_bytes = graph.calcAlgBytes()
    resolved_bytes = alg_bytes.subs(bind_subs)
    try:
        resolved_bytes = int(resolved_bytes)
    except:
        print('ERROR: resolved_bytes should be int, but is {} = {}'.format(
            type(resolved_bytes), resolved_bytes))
    assert correct_bytes < 0 or resolved_bytes == correct_bytes, \
           'Incorrect algorithmic bytes: {}'.format(resolved_bytes)
    print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format(
        alg_bytes, resolved_bytes))

    # Calculate total memory footprint
    alg_footprint = graph.calcAlgFootprint()
    resolved_footprint = alg_footprint.subs(bind_subs)
    try:
        resolved_footprint = int(resolved_footprint)
    except:
        print('ERROR: resolved_footprint should be int, but is {} = {}'.format(
            type(resolved_footprint), resolved_footprint))
    assert correct_footprint < 0 or resolved_footprint == correct_footprint, \
           'Incorrect algorithmic footprint: {}'.format(resolved_footprint)
    print('Alg mem footprint: {}\nWith specified dims: {}\n'.format(
        alg_footprint, resolved_footprint))

    # Calculate algorithmic IO per step
    total_io_footprint = 0
    for op in graph.getPlaceholders():
        total_io_footprint += op.calcAlgFootprint()
    resolved_io_footprint = total_io_footprint.subs(bind_subs)
    print('Alg IO footprint: {}\nWith specified dims: {}\n'.format(
        total_io_footprint, resolved_io_footprint))

    try:  # In case the footprint code is not complete
        # Calculate minimal memory footprint
        print('Alg min mem footprint {}'.format(
            graph.calcMinFootprint(symbol_subs=bind_subs)))
    except:
        pass

    if not is_pytest_run:
        print('VERBOSE ALGORTHMIC FLOPS:')
        graph.calcAlgFlops(verbose=True)
        print('')

        print('VERBOSE ALGORTHMIC BYTES:')
        graph.calcAlgBytes(verbose=True)
        print('')

        print('VERBOSE ALGORTHMIC FOOTPRINT:')
        graph.calcAlgFootprint(verbose=True)
        print('')

    # HACKY WAY TO SAVE MODELS FOR NOW!
    pickle.dump(
        graph,
        open(
            'catamount/frameworks/example_graphs/tensorflow/full_models/image_classification/graph_image_resnet_d{}_fs{}.p'
            .format(depth, filter_scale), 'wb'))

    if is_pytest_run:
        return

    print('\n\n======= Algorithmic graph-level analytics: =======')

    feature_channel_dims = [32, 48, 64, 96, 128]

    bind_subs.pop(feature_channels_symbol)
    resolved_params = parameters.subs(bind_subs)

    print('Symbol associations: {}\n'.format(bind_subs))

    print(
        'Algorithmic Flops by feature channels, params, and per-batch-sample:')
    resolved_flops = alg_flops.subs(bind_subs)
    for features_dim in feature_channel_dims:
        graph_params = resolved_params.subs(
            {feature_channels_symbol: features_dim})
        graph_flops = resolved_flops.subs(
            {feature_channels_symbol: features_dim})
        graph_flops_per_sample = float(graph_flops) / \
                                 bind_subs[subbatch_size_symbol]
        print('{}\t{}\t{}\t{}'.format(features_dim, graph_params, graph_flops,
                                      int(graph_flops_per_sample)))

    print('\nAlgorithmic bytes accessed by feature channels, params:')
    resolved_bytes = alg_bytes.subs(bind_subs)
    for features_dim in feature_channel_dims:
        graph_params = resolved_params.subs(
            {feature_channels_symbol: features_dim})
        graph_bytes = resolved_bytes.subs(
            {feature_channels_symbol: features_dim})
        print('{}\t{}\t{}'.format(features_dim, graph_params, graph_bytes))

    print('\nAlgorithmic total memory footprint by feature channels, params:')
    resolved_footprint = alg_footprint.subs(bind_subs)
    for features_dim in feature_channel_dims:
        graph_params = resolved_params.subs(
            {feature_channels_symbol: features_dim})
        graph_footprint = resolved_footprint.subs(
            {feature_channels_symbol: features_dim})
        print('{}\t{}\t{}'.format(features_dim, graph_params, graph_footprint))

    print(
        '\nAlgorithmic minimal memory footprint by feature channels, params:')
    full_subs = dict(bind_subs)
    for features_dim in feature_channel_dims:
        graph_params = resolved_params.subs(
            {feature_channels_symbol: features_dim})
        full_subs[feature_channels_symbol] = features_dim
        graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs)
        print('{}\t{}\t{}'.format(features_dim, graph_params, graph_min_foot))
Ejemplo n.º 18
0
def test_tf_dynamic_rnn():
    graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename)

    print('INITIAL GRAPH: {}\n\n'.format(graph))

    assert graph.isValid()

    algorithmic_flops = graph.calcAlgFlops()

    # Symbols we will use
    batch_size = utils.getIntSymbolFromString('batch_size')
    seq_length = utils.getIntSymbolFromString('seq_length')
    hidden_dim = utils.getIntSymbolFromString('hidden_dim')

    graph_iters = utils.getIntSymbolFromString('graph::iters')
    rwb_iters = utils.getIntSymbolFromString('rnn/while/LoopCond_block::iters')
    a_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Add:0::dim_0')
    a1_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Add_1:0::dim_0')
    ba_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/BiasAdd:0::dim_0')
    mm_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/MatMul:0::dim_0')
    m_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Mul:0::dim_0')
    m1_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Mul_1:0::dim_0')
    m2_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Mul_2:0::dim_0')
    s_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Sigmoid:0::dim_0')
    s1_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Sigmoid_1:0::dim_0')
    s2_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Sigmoid_2:0::dim_0')
    sm_r_0 = utils.getIntSymbolFromString('softmax/Reshape:0::dim_0')
    sm_r_1 = utils.getIntSymbolFromString('softmax/Reshape:0::dim_1')
    th_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Tanh:0::dim_0')
    th1_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Tanh_1:0::dim_0')
    forward_flops = rwb_iters * (hidden_dim * a_0 + hidden_dim * a1_0 + 4 * hidden_dim * ba_0 + 16 * hidden_dim**2 * mm_0 + \
                                 hidden_dim * m_0 + hidden_dim * m1_0 + hidden_dim * m2_0 + 4 * hidden_dim * s_0 + \
                                 4 * hidden_dim * s1_0 + 4 * hidden_dim * s2_0 + 6 * hidden_dim * th_0 + 6 * hidden_dim * th1_0 + 6) + \
                    3 * sm_r_0 * sm_r_1 + \
                    16 * hidden_dim**2 + 8 * hidden_dim + 3
    grad_iters = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/b_count_2_block::iters')
    g_an_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/AddN:0::dim_0')
    g_an1_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/AddN_1:0::dim_0')
    g_mm_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/MatMul_grad/MatMul:0::dim_0'
    )
    g_mms_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/MatMul_grad/MatMul_1/StackPopV2:0::dim_0'
    )
    g_ms_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul/StackPopV2:0::dim_0'
    )
    g_m_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul:0::dim_0'
    )
    g_ms1_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul_1/StackPopV2:0::dim_0'
    )
    g_m1_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul_1:0::dim_0'
    )
    g_ms2_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul/StackPopV2:0::dim_0'
    )
    g_m2_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul:0::dim_0'
    )
    g_ms21_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul_1/StackPopV2:0::dim_0'
    )
    g_m21_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul_1:0::dim_0'
    )
    g_mus_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul/StackPopV2:0::dim_0'
    )
    g_mu_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul:0::dim_0'
    )
    g_mu1_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul_1:0::dim_0'
    )
    g_s_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Sigmoid_grad/SigmoidGrad:0::dim_0'
    )
    g_c_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/split_grad/concat:0::dim_0'
    )
    g_sm_m_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/softmax/softmax_grad/mul:0::dim_0')
    g_sm_m_1 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/softmax/softmax_grad/mul:0::dim_1')
    backward_flops = grad_iters * (hidden_dim * g_an_0 + 3 * hidden_dim * g_an1_0 + 16 * hidden_dim**2 * g_mm_0 + 16 * hidden_dim**2 * g_mms_0 + \
                                   3 * hidden_dim * g_ms_0 + 2 * hidden_dim * g_m_0 + 3 * hidden_dim * g_ms1_0 + 2 * hidden_dim * g_m1_0 + \
                                   3 * hidden_dim * g_ms2_0 + 2 * hidden_dim * g_m2_0 + 3 * hidden_dim * g_ms21_0 + 2 * hidden_dim * g_m21_0 + \
                                   3 * hidden_dim * g_mus_0 + 2 * hidden_dim * g_mu_0 + 2 * hidden_dim * g_mu1_0 + 2 * hidden_dim * g_s_0 + \
                                   4 * hidden_dim * g_c_0 + 8 * hidden_dim**2 + 4 * hidden_dim + 2) + \
                     g_sm_m_0 * g_sm_m_1
    general_correct_alg_flops = forward_flops + backward_flops
    correct_alg_flops = general_correct_alg_flops.subs({hidden_dim: 24})

    # Now, bind tensor names in the graph and verify that the algorithmic
    # Flop counts reflect the new name bindings
    print('Loaded Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Initial alg flops incorrect!\n  Expecting: {}\n  Calculated: {}' \
        .format(correct_alg_flops, algorithmic_flops)

    # Manually set some variables
    # TODO (Joel): Fix this up when all tensor arrays work!
    ta_op = graph.opsByName['rnn/TensorArrayStack/TensorArraySizeV3']
    ta_op._outputs[0].setValue(seq_length)
    ta_op = graph.opsByName['rnn/TensorArrayStack/TensorArrayGatherV3']
    ta_op._outputs[0].mergeShape([seq_length, batch_size, hidden_dim],
                                 make_symbolic=True)
    ta_op = graph.opsByName['rnn/while/TensorArrayReadV3']
    ta_op._outputs[0].mergeShape([batch_size, hidden_dim], make_symbolic=True)

    ta_op = graph.opsByName[
        'Gradient/Compute/gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3']
    ta_op._outputs[0].mergeShape([batch_size, hidden_dim], make_symbolic=True)

    # Bind constant values first
    const_dict = { # Store the shapes of certain tensors as constants
                   'rnn/Const': [hidden_dim],
                   'rnn/Const_1': [hidden_dim],
                 }
    graph.bindConstantValues(const_dict)

    # NOTE: This also works: batch_size = 'batch_size'
    # Bind placeholders (a and b) output dimensions 0 to name batch_size
    bind_dict = { # Variables
                  'rnn/basic_lstm_cell/kernel': [2 * hidden_dim, 4 * hidden_dim],
                  'rnn/basic_lstm_cell/bias': [4 * hidden_dim],
                  # Placeholders
                  'a': [batch_size, seq_length, hidden_dim],
                  'c_init_state': [batch_size, hidden_dim],
                  'h_init_state': [batch_size, hidden_dim],
                  'out_correct': [batch_size, seq_length],
                  # Constants
                  'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [2 * hidden_dim, 4 * hidden_dim],
                  'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * hidden_dim],
                }
    graph.bindShapesAndPropagate(bind_dict,
                                 make_symbolic=True,
                                 warn_if_ill_defined=True)

    algorithmic_flops = graph.calcAlgFlops()

    # Update the algorithmic Flops formula
    # Sub the forward prop values
    correct_alg_flops = general_correct_alg_flops.subs({
        ba_0: batch_size,
        a_0: batch_size,
        a1_0: batch_size,
        mm_0: batch_size,
        m_0: batch_size,
        m1_0: batch_size,
        m2_0: batch_size,
        s_0: batch_size,
        s1_0: batch_size,
        s2_0: batch_size,
        th_0: batch_size,
        th1_0: batch_size,
        sm_r_0: batch_size * seq_length,
        sm_r_1: hidden_dim,
    })

    # Sub the backward prop values
    # TODO (Joel): Fix this up when all backprop works!
    correct_alg_flops = correct_alg_flops.subs({
        g_mm_0: batch_size,
        g_mms_0: batch_size,
        g_ms_0: batch_size,
        g_m_0: batch_size,
        g_ms1_0: batch_size,
        g_m1_0: batch_size,
        g_ms2_0: batch_size,
        g_m2_0: batch_size,
        g_ms21_0: batch_size,
        g_m21_0: batch_size,
        g_mus_0: batch_size,
        g_mu_0: batch_size,
        g_mu1_0: batch_size,
        g_s_0: batch_size,
        g_c_0: batch_size,
        g_an_0: batch_size,
        g_an1_0: batch_size,
        g_sm_m_0: batch_size * seq_length,
        g_sm_m_1: hidden_dim,
    })

    assert graph.isValid()

    print('BOUND GRAPH:\n{}\n\n'.format(graph))

    print('Bound Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Bound alg flops incorrect!\n  Expecting: {}\n  Calculated: {}\n  Difference: {}' \
        .format(correct_alg_flops, algorithmic_flops, algorithmic_flops - correct_alg_flops)

    bind_subs = { # Symbols/names to bind in next tests:
                  graph_iters: 1,
                  rwb_iters: seq_length,
                  grad_iters: seq_length,
                  batch_size: 128,
                  seq_length: 32,
                  hidden_dim: 256,
                }

    print('\n\nBound values')
    print('Symbol associations: {}\n'.format(bind_subs))

    # Calculate model parameter count
    parameters = graph.calcModelParameters()
    resolved_params = parameters.subs(bind_subs)
    try:
        resolved_params = int(resolved_params)
    except:
        print('ERROR: resolved_params should be int, but is {} = {}'.format(
            type(resolved_params), resolved_params))
    correct_params = 525312
    assert resolved_params == correct_params, \
           'Incorrect model params: {}'.format(resolved_params)
    print('Parameters: {}\nWith specified dims: {}\n'.format(
        parameters, resolved_params))

    # Calculate algorithmic Flops
    alg_flops = graph.calcAlgFlops()
    resolved_flops = alg_flops.subs(bind_subs)
    try:
        resolved_flops = int(resolved_flops)
    except:
        print('ERROR: resolved_flops should be int, but is {} = {}'.format(
            type(resolved_flops), resolved_flops))
    correct_flops = 12980357379
    assert resolved_flops == correct_flops, \
           'Incorrect algorithmic flops: {}'.format(resolved_flops)
    print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format(
        alg_flops, resolved_flops))

    # Calculate algorthmic Bytes accessed
    alg_bytes = graph.calcAlgBytes()
    resolved_bytes = alg_bytes.subs(bind_subs)
    try:
        resolved_bytes = int(resolved_bytes)
    except:
        print('ERROR: resolved_bytes should be int, but is {} = {}'.format(
            type(resolved_bytes), resolved_bytes))
    correct_bytes = 1134017252
    assert resolved_bytes == correct_bytes, \
           'Incorrect algorithmic bytes: {}'.format(resolved_bytes)
    print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format(
        alg_bytes, resolved_bytes))

    # Calculate algorthmic Bytes accessed
    alg_footprint = graph.calcAlgFootprint()
    resolved_footprint = alg_footprint.subs(bind_subs)
    try:
        resolved_footprint = int(resolved_footprint)
    except:
        print('ERROR: resolved_footprint should be int, but is {} = {}'.format(
            type(resolved_footprint), resolved_footprint))
    correct_footprint = 441447784
    assert resolved_footprint == correct_footprint, \
           'Incorrect algorithmic footprint: {}'.format(resolved_footprint)
    print('Alg mem footprint: {}\nWith specified dims: {}\n'.format(
        alg_footprint, resolved_footprint))

    # Calculate the minimal memory footprint for a step
    alg_min_footprint = graph.calcMinimalFootprint(symbol_subs=bind_subs)
    resolved_min_footprint = alg_min_footprint
    try:
        resolved_min_footprint = int(resolved_min_footprint)
    except:
        print('ERROR: resolved_min_footprint should be int, but is {} = {}'.
              format(type(resolved_min_footprint), resolved_min_footprint))
    correct_min_footprint = 38153640
    error_percent = abs(correct_min_footprint -
                        resolved_min_footprint) / correct_min_footprint
    if error_percent > 0.15:
        print('Incorrect algorithmic footprint: {} (err: {})!'.format(
            resolved_min_footprint, error_percent))
    print('Alg minimal footprint: {}\nWith specified dims: {} (err: {})\n'.
          format(alg_footprint, resolved_min_footprint, error_percent))

    # Calculate algorithmic IO per step
    total_io_footprint = 0
    for op in graph.getPlaceholders():
        total_io_footprint += op.calcAlgFootprint()
    resolved_io_footprint = total_io_footprint.subs(bind_subs)
    print('Alg IO footprint: {}\nWith specified dims: {}\n'.format(
        total_io_footprint, resolved_io_footprint))
Ejemplo n.º 19
0
 def __init__(self, name):
     super(MultinomialOp, self).__init__(name)
     samps_name = '{}::rand_samps'.format(self.name)
     self._num_samples_symbol = \
         num_samples = utils.getIntSymbolFromString(samps_name)
Ejemplo n.º 20
0
def run_tf_w2v_model():
    global is_pytest_run

    graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/word2vec_n200-latest_model.meta'

    graph = catamount.frameworks.tensorflow.import_graph(graph_meta)
    assert graph.isValid()

    # Next, remove ops that are not executed during a standard training step:
    graph_ops = list(graph._ops_by_name.values())
    for op in graph_ops:
        # Certain ops are only used for inference
        if 'Model/NceLoss_1_3/' in op.name or \
           'Model/Collapse_1/' in op.name or \
           'Model/Embedding_1_3/' in op.name or \
           'Model/Labels_1/' in op.name or \
           'Model/SkipGramSampler_1/' in op.name or \
           'Model/Mask_1/' in op.name:
            graph.removeOp(op)
        elif \
             op.name == 'Model/Cast_1' or \
             op.name == 'Model/Sum_1' or \
             op.name == 'Model/Size_1' or \
             op.name == 'Model/Exp_1' or \
             op.name == 'Model/truediv_2' or \
             op.name == 'Model/truediv_3':
            graph.removeOp(op)

    if not is_pytest_run:
        print('Initial graph:\n{}\n'.format(graph))
    init_params = graph.calcModelParameters()
    print('Initial parameters: {}'.format(init_params))
    print('Initial Flops: {}\n'.format(graph.calcAlgFlops()))

    print('Placeholders:')
    for op in graph.getPlaceholders():
        print(op.debugString())
    print('')

    # Set up symbols to name dimensions
    skip_window_symbol = utils.getPositiveIntSymbolFromString('skip_window')
    num_skips_symbol = utils.getPositiveIntSymbolFromString('num_skips')
    nce_samples_symbol = utils.getPositiveIntSymbolFromString('nce_samples')
    hidden_dim_symbol = utils.getIntSymbolFromString('hidden_dim')
    vocab_size_symbol = utils.getIntSymbolFromString('vocab_size')
    subbatch_size_symbol = utils.getIntSymbolFromString('subbatch_size')
    sequence_length_symbol = utils.getIntSymbolFromString('sequence_length')
    batch_times_seq_symbol = sequence_length_symbol * subbatch_size_symbol
    graph_iters_symbol = utils.getIntSymbolFromString('graph::iters')
    # For simplicity, assign samples symbol in the op
    nce_samp_op = graph.opsByName[
        'Model/NceLoss_1_1/nce_loss/LogUniformCandidateSampler']
    nce_samp_op._num_samples_symbol = nce_samples_symbol

    # Convert these constant dimensions to symbols
    base_skip_window = 8
    base_num_skips = 8
    base_nce_samples = 64
    base_hidden_dim = 400
    base_vocab_size = 40004
    base_sequence_length = 32
    base_subbatch_size = 1

    # Find and set constants that contain model hyperparameters
    const_dict = {
        'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/sub_1_grad/Shape_1':
        [nce_samples_symbol],
        'Model/SkipGramSampler/Const':
        2 * skip_window_symbol,
        'Model/SkipGramSampler/strided_slice/stack': [0, skip_window_symbol],
        'Model/SkipGramSampler/strided_slice/stack_1':
        [0, -skip_window_symbol],
        'Model/Collapse/Reshape/shape': [-1, hidden_dim_symbol],
        'Model/Gradient/Compute/gradients/Model/Embedding_1_1/Gather_grad/Shape':
        [vocab_size_symbol, hidden_dim_symbol],
        'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/embedding_lookup_1_grad/Shape':
        [vocab_size_symbol],
        'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/embedding_lookup_grad/Shape':
        [vocab_size_symbol, hidden_dim_symbol],
        'Model/Mask/NotEqual/y':
        vocab_size_symbol - 3,
        'Model/SkipGramSampler/Const_2':
        num_skips_symbol,
        'Model/SkipGramSampler/Tile_1/multiples': [1, num_skips_symbol],
        'Model/SkipGramSampler/Tile/multiples': [1, num_skips_symbol],
    }
    graph.bindConstantValues(const_dict)

    # Next, bind the constant, placeholder, and variable shapes and propagate
    bind_dict = { # Constants
                  # Placeholders
                  'Input/Input': [subbatch_size_symbol, sequence_length_symbol],
                  'Labels/Labels': [subbatch_size_symbol, sequence_length_symbol],
                  # Variables
                  'Model/NceLoss_1/b_Softmax': [vocab_size_symbol],
                  'Model/NceLoss_1/W_Softmax': [vocab_size_symbol, hidden_dim_symbol],
                  'Model/Embedding_1/EmbeddingWeights': [vocab_size_symbol, hidden_dim_symbol],
                }

    print('Binding variables')

    # HACK: For now, manually set GatherNd op shapes. Later, implement GatherNd
    gnd_op = graph.opsByName['Model/SkipGramSampler/GatherNd']
    gnd_op.outputs[0].mergeShape([
        subbatch_size_symbol,
        num_skips_symbol * (sequence_length_symbol - 2 * skip_window_symbol)
    ])

    graph.bindShapesAndPropagate(bind_dict,
                                 warn_if_ill_defined=(not is_pytest_run),
                                 make_symbolic=True)
    assert graph.isValid()

    if not is_pytest_run:
        print('\n\nCleaned Graph:\n{}'.format(graph))

    print('\n\nBound values')

    bind_subs = {
        graph_iters_symbol: 1,
        hidden_dim_symbol: base_hidden_dim,
        sequence_length_symbol: base_sequence_length,
        subbatch_size_symbol: base_subbatch_size,
        vocab_size_symbol: base_vocab_size,
        skip_window_symbol: base_skip_window,
        num_skips_symbol: base_num_skips,
        nce_samples_symbol: base_nce_samples,
    }

    # Verify parameter counts first
    parameters = graph.calcModelParameters()

    correct_params = 32043205
    correct_flops = 21148823
    correct_bytes = 23762537
    correct_total_footprint = 137949925

    print('Symbol associations: {}\n'.format(bind_subs))

    # Calculate model parameter count
    resolved_params = parameters.subs(bind_subs)
    try:
        resolved_params = int(resolved_params)
    except:
        print('ERROR: resolved_params should be int, but is {} = {}'.format(
            type(resolved_params), resolved_params))
    assert resolved_params == correct_params, \
           'Incorrect model params: {}'.format(resolved_params)
    print('Parameters: {}\nWith specified dims: {}\n'.format(
        parameters, resolved_params))

    # Calculate algorithmic Flops
    alg_flops = graph.calcAlgFlops()
    resolved_flops = alg_flops.subs(bind_subs)
    try:
        resolved_flops = int(resolved_flops)
    except:
        print('ERROR: resolved_flops should be int, but is {} = {}'.format(
            type(resolved_flops), resolved_flops))
    assert resolved_flops == correct_flops, \
           'Incorrect algorithmic flops: {}'.format(resolved_flops)
    print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format(
        alg_flops, resolved_flops))

    # Calculate algorthmic Bytes accessed
    alg_bytes = graph.calcAlgBytes()
    resolved_bytes = alg_bytes.subs(bind_subs)
    try:
        resolved_bytes = int(resolved_bytes)
    except:
        print('ERROR: resolved_bytes should be int, but is {} = {}'.format(
            type(resolved_bytes), resolved_bytes))
    assert resolved_bytes == correct_bytes, \
           'Incorrect algorithmic bytes: {}'.format(resolved_bytes)
    print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format(
        alg_bytes, resolved_bytes))

    # Calculate total memory footprint
    alg_footprint = graph.calcAlgFootprint()
    resolved_footprint = alg_footprint.subs(bind_subs)
    try:
        resolved_footprint = int(resolved_footprint)
    except:
        print('ERROR: resolved_footprint should be int, but is {} = {}'.format(
            type(resolved_footprint), resolved_footprint))
    assert resolved_footprint == correct_total_footprint, \
           'Incorrect algorithmic footprint: {}'.format(resolved_footprint)
    print('Alg mem footprint: {}\nWith specified dims: {}\n'.format(
        alg_footprint, resolved_footprint))

    # Calculate minimal memory footprint
    alg_min_footprint = graph.calcMinimalFootprint(symbol_subs=bind_subs)
    print('Alg minimal footprint (With specified dims): {}\n'.format(
        alg_min_footprint))

    # Calculate algorithmic IO per step
    total_io_footprint = 0
    for op in graph.getPlaceholders():
        total_io_footprint += op.calcAlgFootprint()
    if isinstance(total_io_footprint, int):
        resolved_io_footprint = total_io_footprint
    else:
        resolved_io_footprint = total_io_footprint.subs(bind_subs)
    print('Alg IO footprint: {}\nWith specified dims: {}\n'.format(
        total_io_footprint, resolved_io_footprint))

    if not is_pytest_run:
        print('VERBOSE ALGORTHMIC FLOPS:')
        graph.calcAlgFlops(verbose=True)
        print('')

        print('VERBOSE ALGORTHMIC BYTES:')
        graph.calcAlgBytes(verbose=True)
        print('')

        print('VERBOSE ALGORTHMIC FOOTPRINT:')
        graph.calcAlgFootprint(verbose=True)
        print('')

    # HACKY WAY TO SAVE MODELS FOR NOW!
    pickle.dump(
        graph,
        open(
            'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/graph_word2vec.p',
            'wb'))

    if is_pytest_run:
        return

    print('\n\n======= Algorithmic graph-level analytics: =======')

    hidden_dims = [
        1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 18, 20, 25, 28, 35, 40, 50, 56, 69,
        78, 86, 96, 108, 119, 123, 133, 148, 163, 182, 202, 221, 246, 273, 297,
        329, 330, 364, 396, 436, 437, 520, 572, 617, 676, 740, 796, 869, 948,
        1017, 1106, 1202, 1286, 1394, 1510, 1611, 1742, 1882, 2004, 2161, 2476,
        3040, 3714, 4520, 5478, 6628, 8019, 9702, 11739, 14204, 17186, 20795,
        25161, 30444, 36837, 38100
    ]

    bind_subs.pop(hidden_dim_symbol)
    resolved_params = parameters.subs(bind_subs)

    print('Symbol associations: {}\n'.format(bind_subs))

    print(
        'Algorithmic Flops by hidden dimension, params, and per-batch-sample:')
    resolved_flops = alg_flops.subs(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        graph_flops = resolved_flops.subs({hidden_dim_symbol: hid_dim})
        graph_flops_per_sample = float(graph_flops) / \
                                 bind_subs[subbatch_size_symbol]
        print('{}\t{}\t{}\t{}'.format(hid_dim, graph_params, graph_flops,
                                      int(graph_flops_per_sample)))

    print('\nAlgorithmic bytes accessed by hidden dimension, params:')
    resolved_bytes = alg_bytes.subs(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        graph_bytes = resolved_bytes.subs({hidden_dim_symbol: hid_dim})
        print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_bytes))

    print('\nAlgorithmic total memory footprint by hidden dimension, params:')
    resolved_footprint = alg_footprint.subs(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        graph_footprint = resolved_footprint.subs({hidden_dim_symbol: hid_dim})
        print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_footprint))

    print(
        '\nAlgorithmic minimal memory footprint by hidden dimension, params:')
    full_subs = dict(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        full_subs[hidden_dim_symbol] = hid_dim
        graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs)
        print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_min_foot))
Ejemplo n.º 21
0
def test_tf_dynamic_rnn():
    graph = catamount.frameworks.tensorflow.import_graph(tf_example_filename)

    print('INITIAL GRAPH: {}\n\n'.format(graph))

    assert graph.isValid()

    algorithmic_flops = graph.calcAlgFlops()

    rwb_iters = utils.getIntSymbolFromString('rnn/while/LoopCond_block::iters')
    a_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Add:0::dim_0')
    a1_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Add_1:0::dim_0')
    ba_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/BiasAdd:0::dim_0')
    mm_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/MatMul:0::dim_0')
    m_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Mul:0::dim_0')
    m1_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Mul_1:0::dim_0')
    m2_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Mul_2:0::dim_0')
    s_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Sigmoid:0::dim_0')
    s1_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Sigmoid_1:0::dim_0')
    s2_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Sigmoid_2:0::dim_0')
    sm_r_0 = utils.getIntSymbolFromString('softmax/Reshape:0::dim_0')
    sm_r_1 = utils.getIntSymbolFromString('softmax/Reshape:0::dim_1')
    th_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Tanh:0::dim_0')
    th1_0 = utils.getIntSymbolFromString(
        'rnn/while/basic_lstm_cell/Tanh_1:0::dim_0')
    forward_flops = rwb_iters * (24 * a_0 + 24 * a1_0 + 96 * ba_0 + 9216 * mm_0 + \
                                 24 * m_0 + 24 * m1_0 + 24 * m2_0 + 96 * s_0 + \
                                 96 * s1_0 + 96 * s2_0 + 144 * th_0 + 144 * th1_0 + 6) + \
                    3 * sm_r_0 * sm_r_1 + \
                    18628
    grad_iters = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/b_count_2_block::iters')
    g_an_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/AddN:0::dim_0')
    g_an1_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/AddN_1:0::dim_0')
    g_mm_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/MatMul_grad/MatMul:0::dim_0'
    )
    g_mms_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/MatMul_grad/MatMul_1/StackPopV2:0::dim_0'
    )
    g_ms_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul/StackPopV2:0::dim_0'
    )
    g_m_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul:0::dim_0'
    )
    g_ms1_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul_1/StackPopV2:0::dim_0'
    )
    g_m1_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_1_grad/mul_1:0::dim_0'
    )
    g_ms2_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul/StackPopV2:0::dim_0'
    )
    g_m2_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul:0::dim_0'
    )
    g_ms21_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul_1/StackPopV2:0::dim_0'
    )
    g_m21_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_2_grad/mul_1:0::dim_0'
    )
    g_mus_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul/StackPopV2:0::dim_0'
    )
    g_mu_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul:0::dim_0'
    )
    g_mu1_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Mul_grad/mul_1:0::dim_0'
    )
    g_s_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/Sigmoid_grad/SigmoidGrad:0::dim_0'
    )
    g_c_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/rnn/while/basic_lstm_cell/split_grad/concat:0::dim_0'
    )
    g_sm_m_0 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/softmax/softmax_grad/mul:0::dim_0')
    g_sm_m_1 = utils.getIntSymbolFromString(
        'Gradient/Compute/gradients/softmax/softmax_grad/mul:0::dim_1')
    backward_flops = grad_iters * (24 * g_an_0 + 72 * g_an1_0 + 9216 * g_mm_0 + 9216 * g_mms_0 + \
                                   72 * g_ms_0 + 48 * g_m_0 + 72 * g_ms1_0 + 48 * g_m1_0 + \
                                   72 * g_ms2_0 + 48 * g_m2_0 + 72 * g_ms21_0 + 48 * g_m21_0 + \
                                   72 * g_mus_0 + 48 * g_mu_0 + 48 * g_mu1_0 + 48 * g_s_0 + \
                                   96 * g_c_0 + 4706) + \
                     g_sm_m_0 * g_sm_m_1
    correct_alg_flops = forward_flops + backward_flops

    print('Loaded Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Initial alg flops incorrect!\n  Expecting: {}\n  Calculated: {}' \
        .format(correct_alg_flops, algorithmic_flops)

    # Now, bind tensor names in the graph and verify that the algorithmic
    # Flop counts reflect the new name bindings
    batch_size = utils.getIntSymbolFromString('batch_size')
    seq_length = utils.getIntSymbolFromString('seq_length')
    hidden_dim = utils.getIntSymbolFromString('hidden_dim')

    # Manually set some variables
    # TODO (Joel): Fix this up when all tensor arrays work!
    ta_op = graph.opsByName['rnn/TensorArrayStack/TensorArraySizeV3']
    ta_op._outputs[0].setValue(seq_length)
    ta_op = graph.opsByName['rnn/TensorArrayStack/TensorArrayGatherV3']
    ta_op._outputs[0].shape.mergeShape([seq_length, batch_size, hidden_dim])
    ta_op = graph.opsByName['rnn/while/TensorArrayReadV3']
    ta_op._outputs[0].shape.mergeShape([batch_size, hidden_dim])

    # TODO (Joel): Fix this up when all stack ops work!
    find_stack_shape = TensorShape([None, 24])
    find_stack_shape_2 = TensorShape([None, 48])
    for op in graph.opsByName.values():
        op_name_suffix = op.name.split('/')[-1]
        if 'StackPopV2' in op_name_suffix:
            if op._outputs[0].shape == find_stack_shape:
                op._outputs[0].shape.mergeShape([batch_size, hidden_dim])
            elif op._outputs[0].shape == find_stack_shape_2:
                op._outputs[0].shape.mergeShape([batch_size, 2 * hidden_dim])

    # NOTE: This also works: batch_size = 'batch_size'
    # Bind placeholders (a and b) output dimensions 0 to name batch_size
    bind_dict = {
        'a': [batch_size, seq_length, hidden_dim],
        'c_init_state': [batch_size, hidden_dim],
        'h_init_state': [batch_size, hidden_dim],
        'out_correct': [batch_size, seq_length]
    }
    graph.bindTensorShapeDimensions(bind_dict, warn_if_ill_defined=True)

    algorithmic_flops = graph.calcAlgFlops()

    # Update the algorithmic Flops formula
    # Sub the forward prop values
    correct_alg_flops = correct_alg_flops.subs({
        ba_0: batch_size,
        a_0: batch_size,
        a1_0: batch_size,
        mm_0: batch_size,
        m_0: batch_size,
        m1_0: batch_size,
        m2_0: batch_size,
        s_0: batch_size,
        s1_0: batch_size,
        s2_0: batch_size,
        th_0: batch_size,
        th1_0: batch_size,
        sm_r_0: batch_size * seq_length,
        sm_r_1: 24,
    })

    # Sub the backward prop values
    # TODO (Joel): Fix this up when all backprop works!
    correct_alg_flops = correct_alg_flops.subs({
        g_mm_0: batch_size,
        g_mms_0: batch_size,
        g_ms_0: batch_size,
        g_m_0: batch_size,
        g_ms1_0: batch_size,
        g_m1_0: batch_size,
        g_ms2_0: batch_size,
        g_m2_0: batch_size,
        g_ms21_0: batch_size,
        g_m21_0: batch_size,
        g_mus_0: batch_size,
        g_mu_0: batch_size,
        g_mu1_0: batch_size,
        g_s_0: batch_size,
        g_c_0: batch_size,
        g_an_0: batch_size,
        g_an1_0: batch_size,
        g_sm_m_0: batch_size * seq_length,
        g_sm_m_1: 24,
    })

    assert graph.isValid()

    print('BOUND GRAPH: {}\n\n'.format(graph))

    # HHHHAAAAAAXXXXXX: FIX THIS! DUE TO SHAPEOP SYMBOL PROPAGATION!
    algorithmic_flops = algorithmic_flops.subs({hidden_dim: 24})

    print('Bound Flops test:')
    print('    Catamount:   {}'.format(algorithmic_flops))
    print('    Correct: {}'.format(correct_alg_flops))
    assert sympy.simplify(algorithmic_flops - correct_alg_flops) == 0, \
        'Bound alg flops incorrect!\n  Expecting: {}\n  Calculated: {}\n  Difference: {}' \
        .format(correct_alg_flops, algorithmic_flops, algorithmic_flops - correct_alg_flops)