Пример #1
0
def run_tf_image_resnet(depth, filter_scale=1.0):
    global is_pytest_run
    model_string = '_d{}_fs{}_'.format(depth, filter_scale)
    test_outputs_dir = 'catamount/frameworks/example_graphs/tensorflow/full_models/image_classification'
    graph_meta = None
    for root, dirs, files in os.walk(test_outputs_dir):
        for filename in files:
            if 'graph{}'.format(
                    model_string) in filename and '.meta' in filename:
                # Take the first graph that we find in the directory
                graph_meta = os.path.join(root, filename)
                break
        if graph_meta is not None:
            break

    if graph_meta is None:
        raise FileNotFoundError(
            'Unable to find model string {} in directory {}'.format(
                model_string, test_outputs_dir))

    graph = catamount.frameworks.tensorflow.import_graph(graph_meta)
    assert graph.isValid()

    # ============ TO REMOVE INITIALIZATION OPS! =============
    # NOTE: This code is pretty general and is likely to be migrated into
    # Catamount code for removing TF-specific initialization ops
    from catamount.ops import AssignOp
    from catamount.ops import VariableOp
    assign_ops = set()
    for op in graph.opsByName.values():
        if isinstance(op, AssignOp):
            assign_ops.add(op)
    for assign_op in assign_ops:
        my_ancestors = set()
        my_frontier = set()
        my_frontier.add(assign_op)
        while len(my_frontier) > 0:
            next_op = my_frontier.pop()
            for in_tensor in next_op.inputs:
                if not isinstance(in_tensor.producer, VariableOp):
                    my_frontier.add(in_tensor.producer)
            my_ancestors.add(next_op)
        for next_op in my_ancestors:
            graph.removeOp(next_op)
    assert graph.isValid()

    # Manually remove the inference parts of graph
    graph_ops = list(graph._ops_by_name.values())
    for op in graph_ops:
        # Certain ops are only used for inference
        if 'InferenceTower/' in op.name or \
           'InferenceRunner/' in op.name or \
           op.name == 'MergeAllSummariesRunWithOp/Merge/MergeSummary':
            graph.removeOp(op)
    assert graph.isValid()

    print('Initial graph:\n{}\n'.format(graph))
    init_params = graph.calcModelParameters()
    print('Initial parameters: {}'.format(init_params))
    print('Initial Flops: {}\n'.format(graph.calcAlgFlops()))

    print('Placeholders:')
    for op in graph.getPlaceholders():
        print(op.debugString())
    print('')

    # Set up symbols to name dimensions
    output_classes_symbol = utils.getPositiveIntSymbolFromString('out_classes')
    subbatch_size_symbol = utils.getPositiveIntSymbolFromString(
        'subbatch_size')
    image_height_symbol = utils.getPositiveIntSymbolFromString('image_height')
    image_width_symbol = utils.getPositiveIntSymbolFromString('image_width')
    num_in_channels_symbol = utils.getPositiveIntSymbolFromString(
        'num_in_channels')
    graph_iters_symbol = utils.getIntSymbolFromString('graph::iters')
    feature_channels_symbol = utils.getPositiveIntSymbolFromString(
        'feature_channels')

    # Find and replace convolution/pooling dimensions also:
    # Dimension(64 * 2^k): conv/pool feature channels
    base_output_classes = 1000
    base_num_in_channels = 3
    base_feature_channels = 64

    base_image_height = 224
    base_image_width = 224
    base_half_im_height = 112
    half_im_height_symbol = image_height_symbol // 2
    base_half_im_width = 112
    half_im_width_symbol = image_width_symbol // 2
    base_quart_im_height = 56
    quart_im_height_symbol = (image_height_symbol // 2) // 2
    base_quart_im_width = 56
    quart_im_width_symbol = (image_width_symbol // 2) // 2
    base_eighth_im_height = 28
    eighth_im_height_symbol = ((image_height_symbol // 2) // 2) // 2
    base_eighth_im_width = 28
    eighth_im_width_symbol = ((image_width_symbol // 2) // 2) // 2
    base_sixtnth_im_height = 14
    sixtnth_im_height_symbol = (((image_height_symbol // 2) // 2) // 2) // 2
    base_sixtnth_im_width = 14
    sixtnth_im_width_symbol = (((image_width_symbol // 2) // 2) // 2) // 2
    base_small_im_height = 7
    small_im_height_symbol = (((
        (image_height_symbol // 2) // 2) // 2) // 2) // 2
    base_small_im_width = 7
    small_im_width_symbol = ((((image_width_symbol // 2) // 2) // 2) // 2) // 2

    # Set up a dictionary of placeholders and variables for which we want
    # to make dimensions symbolic. Sift out their dimensions
    bind_dict = { # Placeholders
                      'label': [subbatch_size_symbol],
                      'input': [subbatch_size_symbol, image_height_symbol, image_width_symbol, num_in_channels_symbol],
                    }
    # Parameterize all variable tensor dimensions
    for op in graph._ops_by_name.values():
        if isinstance(op, VariableOp):
            op_name_suffix = op.name.split('/')[-1]
            if op_name_suffix == 'W':
                if op._outputs[0].shape.rank == 4:
                    assert 'conv' in op.name
                    new_shape = []
                    for i in range(op._outputs[0].shape.rank):
                        new_shape.append(
                            op._outputs[0].shape.getDimension(i).value)
                    if new_shape[2] % base_feature_channels == 0:
                        in_filters = (new_shape[2] // \
                                      base_feature_channels) * \
                                     feature_channels_symbol
                    elif new_shape[2] == 3:
                        # This is the first convolution on image channels (3)
                        assert op.name == 'conv0/W'
                        in_filters = num_in_channels_symbol
                    else:
                        print('FIX ME: base in filters {}'.format(
                            new_shape[2]))
                        assert 0
                    if new_shape[3] % base_feature_channels == 0:
                        out_filters = (new_shape[3] // \
                                       base_feature_channels) * \
                                       feature_channels_symbol
                    else:
                        print('FIX ME: base out filters {}'.format(
                            new_shape[3]))
                        assert 0
                    new_shape[2] = in_filters
                    new_shape[3] = out_filters
                else:
                    # This is the output layer with output_classes dimension
                    assert op.name == 'linear/W'
                    assert op._outputs[0].shape.rank == 2
                    in_dim = op._outputs[0].shape.getDimension(0).value
                    assert in_dim % base_feature_channels == 0
                    in_dim = (in_dim // base_feature_channels) * \
                             feature_channels_symbol
                    new_shape = [in_dim, output_classes_symbol]
                bind_dict[op.name] = new_shape
                momentum_op_name = '{}/Momentum'.format(op.name)
                momentum_op = graph._ops_by_name[momentum_op_name]
                bind_dict[momentum_op.name] = new_shape
            elif op_name_suffix == 'b':
                # This is the output layer with output_classes dimension
                assert op.name == 'linear/b'
                assert op._outputs[0].shape.rank == 1
                assert op._outputs[0].shape.getDimension(0).value == \
                       base_output_classes
                new_shape = [output_classes_symbol]
                bind_dict[op.name] = new_shape
                momentum_op_name = '{}/Momentum'.format(op.name)
                momentum_op = graph._ops_by_name[momentum_op_name]
                bind_dict[momentum_op.name] = new_shape
            elif op_name_suffix == 'beta' or op_name_suffix == 'gamma' or \
                 op_name_suffix == 'EMA':
                assert op._outputs[0].shape.rank == 1
                in_dim = op._outputs[0].shape.getDimension(0).value
                assert in_dim % base_feature_channels == 0
                in_dim = (in_dim // base_feature_channels) * \
                         feature_channels_symbol
                new_shape = [in_dim]
                bind_dict[op.name] = new_shape
                if op_name_suffix != 'EMA':
                    momentum_op_name = '{}/Momentum'.format(op.name)
                    momentum_op = graph._ops_by_name[momentum_op_name]
                    bind_dict[momentum_op.name] = new_shape

    # Now handle constant values in the graph
    const_dict = {}
    for op in graph._ops_by_name.values():
        if isinstance(op, ConstantOp):
            if op._outputs[0].value is None:
                continue
            if op._outputs[0].shape.rank == 0:
                print('{}'.format(op.debugString()))
                continue
            assert op._outputs[0].shape.rank == 1
            values = op._outputs[0].value.tolist()
            new_values = []
            changed = False
            for value in values:
                if value > 0 and value % base_feature_channels == 0:
                    value = (value //
                             base_feature_channels) * feature_channels_symbol
                    changed = True
                new_values.append(value)
            # HACKY SPECIAL CASE:
            if op.name == 'tower0/gradients/tower0/conv0/Conv2D_grad/Const':
                assert new_values[2] == base_num_in_channels
                new_values[2] = num_in_channels_symbol
            if changed:
                const_dict[op.name] = new_values
    for const_key, const_val in const_dict.items():
        try:
            if const_key in graph.opsByName.keys():
                const_op = graph.opsByName[const_key]
                assert isinstance(const_op, ConstantOp)
                const_op._outputs[0].setValue(const_val)
            else:
                print('WARN: ConstantOp not found: {}'.format(const_key))
        except Exception as exc:
            print('WARN: ConstantOp unknown problem: {}: {}'.format(
                const_key, exc))

    # TODO (Joel): Add InputQueue ops to avoid manually setting dimensions
    in_deque_op = graph.opsByName['QueueInput/input_deque']
    # print(in_deque_op.debugString())
    out_tensor = in_deque_op._outputs[0]
    for idx, sym in enumerate([
            subbatch_size_symbol, image_height_symbol, image_width_symbol,
            num_in_channels_symbol
    ]):
        out_tensor.shape.setDimension(idx, sym)
        out_tensor.shape.dims[idx]._value = None
    out_tensor = in_deque_op._outputs[1]
    out_tensor.shape.setDimension(0, subbatch_size_symbol)

    graph.bindTensorShapeDimensions(bind_dict,
                                    warn_if_ill_defined=(not is_pytest_run),
                                    make_symbolic=True)

    # Nice little hack to actually propagate MaximumOp values to outputs
    for op in graph._ops_by_name.values():
        if isinstance(op, MaximumOp):
            if op._inputs[0].value is not None and \
               op._inputs[1].value is not None:
                vmax = np.vectorize(lambda x, y: sympy.Max(x, y))
                out_val = vmax(op._inputs[0].value, op._inputs[1].value)
                op._outputs[0].setValue(out_val)

    graph.bindTensorShapeDimensions(bind_dict,
                                    warn_if_ill_defined=(not is_pytest_run),
                                    make_symbolic=True)

    print('Bound values')

    print(graph)

    bind_subs = {
        graph_iters_symbol: 1,
        output_classes_symbol: base_output_classes,
        subbatch_size_symbol: 32,
        image_height_symbol: base_image_height,
        image_width_symbol: base_image_width,
        num_in_channels_symbol: base_num_in_channels,
        feature_channels_symbol: base_feature_channels,
    }

    correct_params = -1
    correct_flops = -1
    correct_bytes = -1
    correct_footprint = -1
    if depth == 18:
        correct_params = 11689514
        correct_flops = 349684163360
        correct_bytes = 7186222676
        correct_footprint = 2802084304
    elif depth == 34:
        correct_params = 21797674
        correct_flops = 705506994208
        correct_bytes = 11162578644
        correct_footprint = 4368689744
    elif depth == 50:
        correct_params = 25557034
        correct_flops = 790954958112
        correct_bytes = 32896462028
        correct_footprint = 12909734408
    elif depth == 101:
        correct_params = 44549162
        correct_flops = 1506507229472
        correct_bytes = 50026672916
        correct_footprint = 19690293072
    elif depth == 152:
        correct_params = 60192810
        correct_flops = 2222688328992
        correct_bytes = 70967716188
        correct_footprint = 27971880088
    else:
        print('WARN: Tests not defined for depth {}'.format(depth))

    # Calculate parameters
    # NOTE: Need to remove Momentum optimizer parameters and moving average values
    momentum_params = 0
    parameters = 0
    for op_name in sorted(graph.opsByName.keys()):
        op = graph.opsByName[op_name]
        if isinstance(op, VariableOp):
            if "Momentum" in op.name or "EMA" in op.name:
                momentum_params += op.calcModelParameters()
            else:
                parameters += op.calcModelParameters()

    all_weights = graph.calcModelParameters()
    assert (all_weights - momentum_params - parameters) == 0

    # Calculate model parameter count
    resolved_params = parameters.subs(bind_subs)
    try:
        resolved_params = int(resolved_params)
    except:
        print('ERROR: resolved_params should be int, but is {} = {}'.format(
            type(resolved_params), resolved_params))
    assert correct_params < 0 or resolved_params == correct_params, \
           'Incorrect model params: {}'.format(resolved_params)
    print('Parameters: {}\nWith specified dims: {}\n'.format(
        parameters, resolved_params))

    # Calculate algorithmic Flops
    alg_flops = graph.calcAlgFlops()
    resolved_flops = alg_flops.subs(bind_subs)
    try:
        resolved_flops = int(resolved_flops)
    except:
        print('ERROR: resolved_flops should be int, but is {} = {}'.format(
            type(resolved_flops), resolved_flops))
    assert correct_flops < 0 or resolved_flops == correct_flops, \
           'Incorrect algorithmic flops: {}'.format(resolved_flops)
    print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format(
        alg_flops, resolved_flops))

    # Calculate algorthmic Bytes accessed
    alg_bytes = graph.calcAlgBytes()
    resolved_bytes = alg_bytes.subs(bind_subs)
    try:
        resolved_bytes = int(resolved_bytes)
    except:
        print('ERROR: resolved_bytes should be int, but is {} = {}'.format(
            type(resolved_bytes), resolved_bytes))
    assert correct_bytes < 0 or resolved_bytes == correct_bytes, \
           'Incorrect algorithmic bytes: {}'.format(resolved_bytes)
    print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format(
        alg_bytes, resolved_bytes))

    # Calculate total memory footprint
    alg_footprint = graph.calcAlgFootprint()
    resolved_footprint = alg_footprint.subs(bind_subs)
    try:
        resolved_footprint = int(resolved_footprint)
    except:
        print('ERROR: resolved_footprint should be int, but is {} = {}'.format(
            type(resolved_footprint), resolved_footprint))
    assert correct_footprint < 0 or resolved_footprint == correct_footprint, \
           'Incorrect algorithmic footprint: {}'.format(resolved_footprint)
    print('Alg mem footprint: {}\nWith specified dims: {}\n'.format(
        alg_footprint, resolved_footprint))

    # Calculate algorithmic IO per step
    total_io_footprint = 0
    for op in graph.getPlaceholders():
        total_io_footprint += op.calcAlgFootprint()
    resolved_io_footprint = total_io_footprint.subs(bind_subs)
    print('Alg IO footprint: {}\nWith specified dims: {}\n'.format(
        total_io_footprint, resolved_io_footprint))

    try:  # In case the footprint code is not complete
        # Calculate minimal memory footprint
        print('Alg min mem footprint {}'.format(
            graph.calcMinFootprint(symbol_subs=bind_subs)))
    except:
        pass

    if not is_pytest_run:
        print('VERBOSE ALGORTHMIC FLOPS:')
        graph.calcAlgFlops(verbose=True)
        print('')

        print('VERBOSE ALGORTHMIC BYTES:')
        graph.calcAlgBytes(verbose=True)
        print('')

        print('VERBOSE ALGORTHMIC FOOTPRINT:')
        graph.calcAlgFootprint(verbose=True)
        print('')

    # HACKY WAY TO SAVE MODELS FOR NOW!
    pickle.dump(
        graph,
        open(
            'catamount/frameworks/example_graphs/tensorflow/full_models/image_classification/graph_image_resnet_d{}_fs{}.p'
            .format(depth, filter_scale), 'wb'))

    if is_pytest_run:
        return

    print('\n\n======= Algorithmic graph-level analytics: =======')

    feature_channel_dims = [32, 48, 64, 96, 128]

    bind_subs.pop(feature_channels_symbol)
    resolved_params = parameters.subs(bind_subs)

    print('Symbol associations: {}\n'.format(bind_subs))

    print(
        'Algorithmic Flops by feature channels, params, and per-batch-sample:')
    resolved_flops = alg_flops.subs(bind_subs)
    for features_dim in feature_channel_dims:
        graph_params = resolved_params.subs(
            {feature_channels_symbol: features_dim})
        graph_flops = resolved_flops.subs(
            {feature_channels_symbol: features_dim})
        graph_flops_per_sample = float(graph_flops) / \
                                 bind_subs[subbatch_size_symbol]
        print('{}\t{}\t{}\t{}'.format(features_dim, graph_params, graph_flops,
                                      int(graph_flops_per_sample)))

    print('\nAlgorithmic bytes accessed by feature channels, params:')
    resolved_bytes = alg_bytes.subs(bind_subs)
    for features_dim in feature_channel_dims:
        graph_params = resolved_params.subs(
            {feature_channels_symbol: features_dim})
        graph_bytes = resolved_bytes.subs(
            {feature_channels_symbol: features_dim})
        print('{}\t{}\t{}'.format(features_dim, graph_params, graph_bytes))

    print('\nAlgorithmic total memory footprint by feature channels, params:')
    resolved_footprint = alg_footprint.subs(bind_subs)
    for features_dim in feature_channel_dims:
        graph_params = resolved_params.subs(
            {feature_channels_symbol: features_dim})
        graph_footprint = resolved_footprint.subs(
            {feature_channels_symbol: features_dim})
        print('{}\t{}\t{}'.format(features_dim, graph_params, graph_footprint))

    print(
        '\nAlgorithmic minimal memory footprint by feature channels, params:')
    full_subs = dict(bind_subs)
    for features_dim in feature_channel_dims:
        graph_params = resolved_params.subs(
            {feature_channels_symbol: features_dim})
        full_subs[feature_channels_symbol] = features_dim
        graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs)
        print('{}\t{}\t{}'.format(features_dim, graph_params, graph_min_foot))
def run_tf_speech_attention():
    global is_pytest_run
    graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/speech_attention/model.ckpt.meta'

    graph = catamount.frameworks.tensorflow.import_graph(graph_meta)
    assert graph.isValid()

    # HAX: NEED TO MANUALLY REMOVE SOME?! WHY?
    remove_ops = [
        'DevArgmaxWERChecker/Less', 'DevLossChecker/Less',
        'DevArgmaxWERChecker/best_dev', 'DevLossChecker/best_dev'
    ]
    for op_name in remove_ops:
        op = graph.opsByName[op_name]
        graph.removeOp(op)
    assert graph.isValid()

    # Remove ops that are not executed during a standard training step:
    graph_ops = list(graph._ops_by_name.values())
    for op in graph_ops:
        # Ops in attn_model_[1-3] are used for inference
        if 'attn_model_1' in op.name or \
           'attn_model_2' in op.name or \
           'attn_model_3' in op.name:
            graph.removeOp(op)
    assert graph.isValid()

    print('Initial graph:\n{}\n'.format(graph))
    init_params = graph.calcModelParameters()
    print('Initial parameters: {}'.format(init_params))
    print('Initial Flops: {}\n'.format(graph.calcAlgFlops()))

    print('Placeholders:')
    for op in graph.getPlaceholders():
        print(op.debugString())
    print('')

    # Set up symbols to name dimensions
    audio_features_symbol = utils.getPositiveIntSymbolFromString(
        'audio_features')
    encoder_steps_symbol = utils.getPositiveIntSymbolFromString(
        'encoder_steps')
    decoder_steps_symbol = utils.getPositiveIntSymbolFromString(
        'decoder_steps')
    subbatch_size_symbol = utils.getPositiveIntSymbolFromString(
        'subbatch_size')
    attn_dim_symbol = utils.getPositiveIntSymbolFromString('attn_dim')
    attn_hidden_dim_symbol = utils.getPositiveIntSymbolFromString(
        'attn_hidden_dim')
    dec_hidden_dim_symbol = utils.getPositiveIntSymbolFromString(
        'dec_hidden_dim')
    enc_hidden_dim_symbol = utils.getPositiveIntSymbolFromString(
        'enc_hidden_dim')
    graph_iters_symbol = utils.getIntSymbolFromString('graph::iters')
    output_vocab_symbol = utils.getPositiveIntSymbolFromString('output_vocab')
    conv_width_symbol = utils.getPositiveIntSymbolFromString('conv_width')
    num_conv_filters_symbol = utils.getPositiveIntSymbolFromString(
        'num_conv_filters')

    # Convert these constant dimensions to symbols
    base_encoder_steps = 300
    base_decoder_steps = 300
    base_subbatch_size = 32
    base_output_vocab = 31
    base_audio_features = 40
    base_conv_width = 53
    base_attn_dim = 137
    base_attn_hidden_dim = 509
    base_dec_hidden_dim = 571
    base_enc_hidden_dim = 1051
    base_enc_input_dim = 1091  # Input + recurrent state
    enc_input_dim_symbol = audio_features_symbol + enc_hidden_dim_symbol
    base_dec_attn_rec = 2133
    dec_attn_rec_symbol = 2 * enc_hidden_dim_symbol + output_vocab_symbol
    base_attn_cell_inputs = 2611
    attn_cell_inputs_symbol = 2 * enc_hidden_dim_symbol + attn_hidden_dim_symbol
    base_attn_cell_in_dim = 2642
    attn_cell_in_dim_symbol = 2 * enc_hidden_dim_symbol + output_vocab_symbol + \
                              attn_hidden_dim_symbol
    base_dec_attn_dim = 3182
    dec_attn_dim_symbol = attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + \
                          dec_hidden_dim_symbol

    bind_dict = { # Placeholders
                  'attn_model/input_seq': [encoder_steps_symbol, subbatch_size_symbol, audio_features_symbol],
                  'attn_model/input_len': [subbatch_size_symbol],
                  'attn_model/output_seq': [decoder_steps_symbol, subbatch_size_symbol],
                  'attn_model/output_mask': [decoder_steps_symbol, subbatch_size_symbol],

                  # Variables
                  'InputNormalizer/means': [audio_features_symbol],
                  'InputNormalizer/std': [audio_features_symbol],
                  'attn_model/AffineAttentionStateNN/W': [2 * enc_hidden_dim_symbol, attn_dim_symbol],
                  'attn_model/AffineAttentionStateNN/b': [attn_dim_symbol],
                  'attn_model/AffineOutputProjection/W': [dec_hidden_dim_symbol, output_vocab_symbol],
                  'attn_model/AffineOutputProjection/b': [output_vocab_symbol],
                  'attn_model/Decoder/attn_model/attention_cell/biases': [4 * attn_hidden_dim_symbol],
                  'attn_model/Decoder/attn_model/attention_cell/weights': [attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + output_vocab_symbol, 4 * attn_hidden_dim_symbol],
                  'attn_model/Decoder/attn_model/decoder_cell/biases': [4 * dec_hidden_dim_symbol],
                  'attn_model/Decoder/attn_model/decoder_cell/weights': [attn_hidden_dim_symbol + dec_hidden_dim_symbol + 2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol],
                  'attn_model/HybridAttentionContext/Q': [conv_width_symbol, 1, num_conv_filters_symbol],
                  'attn_model/HybridAttentionContext/U': [1, num_conv_filters_symbol, attn_dim_symbol],
                  'attn_model/HybridAttentionContext/W': [2 * attn_hidden_dim_symbol, attn_dim_symbol],
                  'attn_model/HybridAttentionContext/b': [attn_dim_symbol],
                  'attn_model/HybridAttentionContext/w': [attn_dim_symbol],
                  'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/bias': [4 * enc_hidden_dim_symbol],
                  'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/basic_lstm_cell/kernel': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],

                  # Constants
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/MatMul/Enter_grad/b_acc': [dec_hidden_dim_symbol, output_vocab_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/add/Enter_grad/b_acc': [output_vocab_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/MatMul/Enter_grad/b_acc': [2 * attn_hidden_dim_symbol, attn_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/add_2/Enter_grad/b_acc': [attn_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/attention_cell/BiasAdd/Enter_grad/b_acc': [4 * attn_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/attention_cell/attention_cell/add/Enter_grad/b_acc': [attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol + output_vocab_symbol, 4 * attn_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/ExpandDims_1/Enter_grad/b_acc': [conv_width_symbol, 1, 4],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/ExpandDims_1/Enter_grad/b_acc': [1, 4, attn_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/decoder_cell/BiasAdd/Enter_grad/b_acc': [4 * dec_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/decoder_cell/decoder_cell/add/Enter_grad/b_acc': [attn_hidden_dim_symbol + dec_hidden_dim_symbol + 2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/mul/Enter_grad/b_acc': [attn_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [audio_features_symbol + enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/BiasAdd/Enter_grad/b_acc': [4 * enc_hidden_dim_symbol],
                  'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/basic_lstm_cell/MatMul/Enter_grad/b_acc': [3 * enc_hidden_dim_symbol, 4 * enc_hidden_dim_symbol],
                }

    # Update constant values
    const_dict = {
        'attn_model/AffineAttentionStateNN/Reshape/shape':
        [-1, 2 * enc_hidden_dim_symbol],
        'attn_model/AffineAttentionStateNN/Reshape_1/shape/2':
        attn_dim_symbol,
        'attn_model/AttentionEncoderDecoder/Reshape/shape/1':
        output_vocab_symbol,
        'attn_model/AttentionModel/gradients/attn_model/AffineAttentionStateNN/add_grad/Shape_1':
        [attn_dim_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/add_grad/Shape_1':
        [output_vocab_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/add_2_grad/Shape_1':
        [attn_dim_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/Conv2D_grad/Const':
        [1, conv_width_symbol, 1, num_conv_filters_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d/ExpandDims_1_grad/Shape':
        [conv_width_symbol, 1, num_conv_filters_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/Conv2D_grad/Const':
        [1, 1, num_conv_filters_symbol, attn_dim_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/conv1d_1/ExpandDims_1_grad/Shape':
        [1, num_conv_filters_symbol, attn_dim_symbol],
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/mul_grad/Shape_1':
        [attn_dim_symbol],
        'attn_model/Decoder/CustomLSTMCellZeroState/Const':
        [2 * attn_hidden_dim_symbol],
        'attn_model/Decoder/CustomLSTMCellZeroState/Const_1':
        [2 * attn_hidden_dim_symbol],
        'attn_model/Decoder/CustomLSTMCellZeroState_1/Const': [
            2 * dec_hidden_dim_symbol
        ],
        'attn_model/Decoder/CustomLSTMCellZeroState_1/Const_1': [
            2 * dec_hidden_dim_symbol
        ],
        'attn_model/Decoder/while/attn_model/attention_cell/attention_cell/Shape':
        [
            attn_hidden_dim_symbol + 2 * enc_hidden_dim_symbol +
            output_vocab_symbol, 4 * attn_hidden_dim_symbol
        ],
        'attn_model/Decoder/while/attn_model/decoder_cell/decoder_cell/Shape':
        [
            attn_hidden_dim_symbol + dec_hidden_dim_symbol +
            2 * enc_hidden_dim_symbol, 4 * dec_hidden_dim_symbol
        ],
        'attn_model/Decoder/while/attn_model/one_hot/depth':
        output_vocab_symbol,
        'attn_model/Decoder/zeros/shape/1':
        2 * enc_hidden_dim_symbol,
        'attn_model/Decoder/zeros_2/shape/1':
        output_vocab_symbol,
        'attn_model/Reshape/shape': [1, 1, audio_features_symbol],
        'attn_model/Reshape_1/shape': [1, 1, audio_features_symbol],
        'attn_model/Reshape_2/shape/2':
        2 * enc_hidden_dim_symbol,
        'attn_model/StackedEncoder/Layer0/RNNEncoder/Reshape/shape/2':
        audio_features_symbol,
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/Const_4':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/Const_4':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/Reshape/shape/2':
        2 * enc_hidden_dim_symbol,
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/Const_4':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/Const_4':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/Reshape/shape/2':
        2 * enc_hidden_dim_symbol,
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/Const_4':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/Const_1':
        [2 * enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/Const_1':
        [enc_hidden_dim_symbol],
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/Const_4':
        [enc_hidden_dim_symbol],
    }

    graph.bindConstantValues(const_dict)

    # TODO: Currently, Catamount doesn't automatically handle Tensorflow TensorArrays
    # or Stack ops. Here, manually set the dimensions of these ops' tensors.
    for op in graph._ops_by_name.values():
        op_name_suffix = op.name.split('/')[-1]
        if 'TensorArrayGather' in op_name_suffix:
            assert isinstance(op, UnknownOp)
            assert len(op._inputs) == 3
            assert len(op._outputs) == 1
            if op._outputs[0].shape.rank == 1 or op._outputs[0].shape.rank == 2:
                if len(op._outputs[0].consumers) > 0:
                    print(
                        'TODO: Unknown TensorArrayGather (rank {}): {}'.format(
                            op._outputs[0].shape.rank, op.debugString()))
            elif op._outputs[0].shape.isUnknown(
            ) or op._outputs[0].shape.rank == 3:
                if len(op._outputs[0].consumers) > 0:
                    # If output rank is 3, then appears to be:
                    # [seq_length, batch_size, enc_hid], where
                    # seq_length depends on layer
                    out_shape = None
                    if 'StackedEncoder/Layer0' in op.name:
                        out_shape = [
                            encoder_steps_symbol, subbatch_size_symbol,
                            enc_hidden_dim_symbol
                        ]
                    elif 'StackedEncoder/Layer2' in op.name:
                        if 'attn_model/AttentionModel/gradients' in op.name:
                            # Backprop stores concatenated state
                            out_shape = [
                                encoder_steps_symbol // 2,
                                subbatch_size_symbol, 2 * enc_hidden_dim_symbol
                            ]
                        else:
                            out_shape = [
                                encoder_steps_symbol // 2,
                                subbatch_size_symbol, enc_hidden_dim_symbol
                            ]
                    elif 'StackedEncoder/Layer4' in op.name:
                        if 'attn_model/AttentionModel/gradients' in op.name:
                            # Backprop stores concatenated state
                            out_shape = [(encoder_steps_symbol // 2) // 2,
                                         subbatch_size_symbol,
                                         2 * enc_hidden_dim_symbol]
                        else:
                            out_shape = [(encoder_steps_symbol // 2) // 2,
                                         subbatch_size_symbol,
                                         enc_hidden_dim_symbol]
                    elif 'Decoder' in op.name:
                        # HAXXXX: Manually specify a few
                        if op.name == 'attn_model/Decoder/TensorArrayStack/TensorArrayGatherV3':
                            out_shape = [
                                decoder_steps_symbol, subbatch_size_symbol,
                                output_vocab_symbol
                            ]
                        else:
                            out_shape = [
                                decoder_steps_symbol, subbatch_size_symbol,
                                dec_hidden_dim_symbol
                            ]
                    else:
                        print('TODO: Unknown TensorArrayGather {}'.format(
                            op.debugString()))
                    if out_shape is not None:
                        op._outputs[0].mergeShape(out_shape,
                                                  make_symbolic=True)
            else:
                print('TODO: Unknown TensorArrayGather {}'.format(
                    op.debugString()))
        elif 'TensorArraySize' in op_name_suffix:
            assert isinstance(op, UnknownOp)
            assert len(op._inputs) == 2
            assert len(op._outputs) == 1
            assert op._outputs[0].shape.rank == 0
            # NOTES:
            # StackedEncoder Layer0: enc_seq
            # StackedEncoder Layer2: enc_seq / 2 # Due to stride 2 in time
            # StackedEncoder Layer4: enc_seq / 4 # Due to stride 2 in time
            # Decoder: dec_seq
            if 'StackedEncoder/Layer0' in op.name:
                op._outputs[0].setValue(encoder_steps_symbol)
            elif 'StackedEncoder/Layer2' in op.name:
                op._outputs[0].setValue(encoder_steps_symbol // 2)
            elif 'StackedEncoder/Layer4' in op.name:
                op._outputs[0].setValue((encoder_steps_symbol // 2) // 2)
            elif 'Decoder' in op.name:
                op._outputs[0].setValue(decoder_steps_symbol)
            else:
                print('WARN: Unknown TensorArraySizeV3: {}'.format(
                    op.debugString()))
        elif 'TensorArrayRead' in op_name_suffix:
            assert isinstance(op, UnknownOp)
            assert len(op._inputs) == 3
            assert len(op._outputs) == 1
            assert op._outputs[0].shape.isUnknown() or \
                   op._outputs[0].shape.rank == 2, \
                   '{}'.format(op.name)
            if op._outputs[0].shape.isUnknown():
                if len(op._outputs[0].consumers) > 0:
                    out_shape = None
                    if 'attn_model/AttentionModel/gradients/attn_model/StackedEncoder/Layer' in op.name and \
                       ('/RNNEncoder/bidirectional_rnn/fw/fw/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' in op.name or \
                        '/RNNEncoder/bidirectional_rnn/bw/bw/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' in op.name):
                        out_shape = [
                            subbatch_size_symbol, enc_hidden_dim_symbol
                        ]
                    elif op.name == 'attn_model/AttentionModel/gradients/attn_model/Decoder/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3' or \
                         op.name == 'attn_model/AttentionModel/gradients/attn_model/Decoder/while/TensorArrayWrite_1/TensorArrayWriteV3_grad/TensorArrayReadV3' or \
                         op.name == 'attn_model_2/Decoder/while/cond/TensorArrayReadV3' or \
                         op.name == 'attn_model/Decoder/while/cond/TensorArrayReadV3':
                        out_shape = [subbatch_size_symbol, output_vocab_symbol]
                    else:
                        print('WARN: Unknown TensorArrayReadV3 out shape: {}'.
                              format(op.debugString()))
                    if out_shape is not None:
                        op._outputs[0].mergeShape(out_shape,
                                                  make_symbolic=True)
            else:
                # NOTES: Many are (?, 40 "features"), (?, 1051 "enc_hid"), or (?, 2102 "2*enc_hid")
                dim_1_val = op._outputs[0].shape.getDimension(1).value
                assert dim_1_val == base_audio_features or \
                       dim_1_val == base_enc_hidden_dim or \
                       dim_1_val == 2 * base_enc_hidden_dim, \
                       'Op: {}\n   Dim 1 value: {}'.format(op.debugString(), dim_1_val)
                out_shape = None
                if dim_1_val == base_audio_features:
                    out_shape = [subbatch_size_symbol, audio_features_symbol]
                elif dim_1_val > 0 and dim_1_val % base_enc_hidden_dim == 0:
                    mult = dim_1_val // base_enc_hidden_dim
                    out_shape = [
                        subbatch_size_symbol, mult * enc_hidden_dim_symbol
                    ]
                else:
                    print('Unhandled TensorArrayRead: {}'.format(
                        op.debugString()))
                if out_shape is not None:
                    op._outputs[0].mergeShape(out_shape, make_symbolic=True)

    # Manually set a couple shapes for max ops that can't yet resolve
    # maximums of 1 vs. positive symbols:
    max_op = graph._ops_by_name[
        'attn_model/AttentionModel/gradients/attn_model/AttentionEncoderDecoder/Sum_grad/Maximum']
    max_op._outputs[0].mergeShape([2])
    max_op._outputs[0].setValue([1, subbatch_size_symbol])

    max_op = graph._ops_by_name[
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/Sum_grad/Maximum']
    max_op._outputs[0].mergeShape([3])
    # [floor(floor(encoder_steps/2)/2) subbatch_size 1]
    max_op._outputs[0].setValue([(encoder_steps_symbol // 2) // 2,
                                 subbatch_size_symbol, 1])

    max_op = graph._ops_by_name[
        'attn_model/AttentionModel/gradients/attn_model/Decoder/while/attn_model/Sum_1_grad/Maximum']
    max_op._outputs[0].mergeShape([3])
    # [1 subbatch_size 2*enc_hidden_dim]
    max_op._outputs[0].setValue(
        [1, subbatch_size_symbol, 2 * enc_hidden_dim_symbol])

    print('Binding variables')

    graph.bindShapesAndPropagate(bind_dict,
                                 warn_if_ill_defined=(not is_pytest_run),
                                 make_symbolic=True)
    assert graph.isValid()

    print('\n\nCleaned Graph:\n{}'.format(graph))

    print('\n\nBound values')

    # Set base values to be subbed in:
    base_encoder_steps = 96
    base_decoder_steps = 24
    base_attn_dim = 128
    base_conv_width = 50
    base_attn_hidden_dim = 512
    base_dec_hidden_dim = 512
    base_enc_hidden_dim = 1024

    bind_subs = {
        audio_features_symbol: base_audio_features,
        encoder_steps_symbol: base_encoder_steps,
        decoder_steps_symbol: (encoder_steps_symbol // 2) // 2,
        subbatch_size_symbol: base_subbatch_size,
        attn_dim_symbol: base_attn_dim,
        attn_hidden_dim_symbol: enc_hidden_dim_symbol // 2,
        dec_hidden_dim_symbol: enc_hidden_dim_symbol // 2,
        output_vocab_symbol: base_output_vocab,
        conv_width_symbol: base_conv_width,
        enc_hidden_dim_symbol: base_enc_hidden_dim,
        num_conv_filters_symbol: 4,
        graph_iters_symbol: 1,
    }
    # Add loop iteration counts to bind_subs
    bind_str_subs = {
        'attn_model/AttentionModel/gradients/b_count_2_block::iters':
        decoder_steps_symbol,
        'attn_model/Decoder/while/LoopCond_block::iters':
        decoder_steps_symbol,
        'attn_model/AttentionModel/gradients/b_count_22_block::iters':
        encoder_steps_symbol,
        'attn_model/AttentionModel/gradients/b_count_26_block::iters':
        encoder_steps_symbol,
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters':
        encoder_steps_symbol,
        'attn_model/StackedEncoder/Layer0/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters':
        encoder_steps_symbol,
        'attn_model/AttentionModel/gradients/b_count_14_block::iters':
        encoder_steps_symbol // 2,
        'attn_model/AttentionModel/gradients/b_count_18_block::iters':
        encoder_steps_symbol // 2,
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters':
        encoder_steps_symbol // 2,
        'attn_model/StackedEncoder/Layer2/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters':
        encoder_steps_symbol // 2,
        'attn_model/AttentionModel/gradients/b_count_6_block::iters':
        (encoder_steps_symbol // 2) // 2,
        'attn_model/AttentionModel/gradients/b_count_10_block::iters':
        (encoder_steps_symbol // 2) // 2,
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/bw/bw/while/LoopCond_block::iters':
        (encoder_steps_symbol // 2) // 2,
        'attn_model/StackedEncoder/Layer4/RNNEncoder/bidirectional_rnn/fw/fw/while/LoopCond_block::iters':
        (encoder_steps_symbol // 2) // 2,
    }

    for var_name, sub_val in bind_str_subs.items():
        var_ref = utils.getIntSymbolFromString(var_name)
        assert var_name not in bind_subs.keys()
        bind_subs[var_ref] = sub_val

    # Calculate model parameter count
    parameters = graph.calcModelParameters()
    resolved_params = parameters.subs(bind_subs)
    try:
        resolved_params = int(resolved_params)
    except:
        print('ERROR: resolved_params should be int, but is {} = {}'.format(
            type(resolved_params), resolved_params))
    correct_params = 71084729
    assert resolved_params == correct_params, \
           'Incorrect model params: {}'.format(resolved_params)
    print('Parameters: {}\nWith specified dims: {}\n'.format(
        parameters, resolved_params))

    # Calculate algorithmic Flops
    alg_flops = graph.calcAlgFlops()
    resolved_flops = alg_flops.subs(bind_subs)
    try:
        resolved_flops = int(resolved_flops)
    except:
        print('ERROR: resolved_flops should be int, but is {} = {}'.format(
            type(resolved_flops), resolved_flops))
    correct_flops = 568878183032
    assert resolved_flops == correct_flops, \
           'Incorrect algorithmic flops: {}'.format(resolved_flops)
    print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format(
        alg_flops, resolved_flops))

    # Calculate algorthmic Bytes accessed
    alg_bytes = graph.calcAlgBytes()
    resolved_bytes = alg_bytes.subs(bind_subs)
    try:
        resolved_bytes = int(resolved_bytes)
    except:
        print('ERROR: resolved_bytes should be int, but is {} = {}'.format(
            type(resolved_bytes), resolved_bytes))
    correct_bytes = 92231419797
    assert resolved_bytes == correct_bytes, \
           'Incorrect algorithmic bytes: {}'.format(resolved_bytes)
    print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format(
        alg_bytes, resolved_bytes))

    # Calculate algorthmic Bytes accessed
    alg_footprint = graph.calcAlgFootprint()
    resolved_footprint = alg_footprint.subs(bind_subs)
    try:
        resolved_footprint = int(resolved_footprint)
    except:
        print('ERROR: resolved_footprint should be int, but is {} = {}'.format(
            type(resolved_footprint), resolved_footprint))
    correct_footprint = 32624988214
    assert resolved_footprint == correct_footprint, \
           'Incorrect algorithmic footprint: {}'.format(resolved_footprint)
    print('Alg mem footprint: {}\nWith specified dims: {}\n'.format(
        alg_footprint, resolved_footprint))

    # Calculate algorithmic IO per step
    total_io_footprint = 0
    for op in graph.getPlaceholders():
        total_io_footprint += op.calcAlgFootprint()
    resolved_io_footprint = total_io_footprint.subs(bind_subs)
    print('Alg IO footprint: {}\nWith specified dims: {}\n'.format(
        total_io_footprint, resolved_io_footprint))

    print('VERBOSE ALGORTHMIC FLOPS:')
    graph.calcAlgFlops(verbose=True)
    print('')

    print('VERBOSE ALGORTHMIC BYTES:')
    graph.calcAlgBytes(verbose=True)
    print('')

    print('VERBOSE ALGORTHMIC FOOTPRINT:')
    graph.calcAlgFootprint(verbose=True)
    print('')

    # HACKY WAY TO SAVE MODELS FOR NOW!
    pickle.dump(
        graph,
        open(
            'catamount/frameworks/example_graphs/tensorflow/full_models/speech_attention/graph_speech_attention.p',
            'wb'))

    if is_pytest_run:
        return

    print('\n\n======= Algorithmic graph-level analytics: =======')

    encoder_dims = [
        32, 64, 96, 128, 160, 192, 256, 320, 384, 448, 512, 640, 768, 892,
        1024, 1152, 1280, 1408, 1548, 1702, 1872, 2059, 2264, 2490, 2739, 3012,
        3289
    ]
    base_encoder_steps = 335
    base_subbatch_size = 32
    base_attn_dim = 128
    base_conv_width = 50
    base_attn_hidden_dim = 512
    base_dec_hidden_dim = 512
    base_enc_hidden_dim = 1024

    bind_subs[audio_features_symbol] = base_audio_features
    bind_subs[encoder_steps_symbol] = base_encoder_steps
    bind_subs[decoder_steps_symbol] = (encoder_steps_symbol // 2) // 2
    bind_subs[subbatch_size_symbol] = base_subbatch_size
    bind_subs[attn_dim_symbol] = base_attn_dim
    bind_subs[attn_hidden_dim_symbol] = enc_hidden_dim_symbol // 2
    bind_subs[dec_hidden_dim_symbol] = enc_hidden_dim_symbol // 2
    bind_subs[output_vocab_symbol] = base_output_vocab
    bind_subs[conv_width_symbol] = base_conv_width
    # bind_subs[enc_hidden_dim_symbol] = base_enc_hidden_dim
    bind_subs[num_conv_filters_symbol] = 4
    bind_subs[graph_iters_symbol] = 1

    bind_subs.pop(enc_hidden_dim_symbol)
    resolved_params = parameters.subs(bind_subs)

    print('Symbol associations: {}\n'.format(bind_subs))

    print(
        'Algorithmic Flops by hidden dimension, params, and per-batch-sample:')
    resolved_flops = alg_flops.subs(bind_subs)
    for enc_dim in encoder_dims:
        graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim})
        graph_flops = resolved_flops.subs({enc_hidden_dim_symbol: enc_dim})
        graph_flops_per_sample = float(graph_flops) / \
                                 bind_subs[subbatch_size_symbol]
        print('{}\t{}\t{}\t{}'.format(enc_dim, graph_params, graph_flops,
                                      int(graph_flops_per_sample)))

    print('\nAlgorithmic bytes accessed by hidden dimension, params:')
    resolved_bytes = alg_bytes.subs(bind_subs)
    for enc_dim in encoder_dims:
        graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim})
        graph_bytes = resolved_bytes.subs({enc_hidden_dim_symbol: enc_dim})
        print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_bytes))

    print('\nAlgorithmic memory footprint by hidden dimension, params:')
    resolved_footprint = alg_footprint.subs(bind_subs)
    for enc_dim in encoder_dims:
        graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim})
        graph_footprint = resolved_footprint.subs(
            {enc_hidden_dim_symbol: enc_dim})
        print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_footprint))

    print(
        '\nAlgorithmic minimal memory footprint by hidden dimension, params:')
    full_subs = dict(bind_subs)
    for enc_dim in encoder_dims:
        graph_params = resolved_params.subs({enc_hidden_dim_symbol: enc_dim})
        full_subs[enc_hidden_dim_symbol] = enc_dim
        graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs)
        print('{}\t{}\t{}'.format(enc_dim, graph_params, graph_min_foot))
Пример #3
0
def run_tf_w2v_model():
    global is_pytest_run

    graph_meta = 'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/word2vec_n200-latest_model.meta'

    graph = catamount.frameworks.tensorflow.import_graph(graph_meta)
    assert graph.isValid()

    # Next, remove ops that are not executed during a standard training step:
    graph_ops = list(graph._ops_by_name.values())
    for op in graph_ops:
        # Certain ops are only used for inference
        if 'Model/NceLoss_1_3/' in op.name or \
           'Model/Collapse_1/' in op.name or \
           'Model/Embedding_1_3/' in op.name or \
           'Model/Labels_1/' in op.name or \
           'Model/SkipGramSampler_1/' in op.name or \
           'Model/Mask_1/' in op.name:
            graph.removeOp(op)
        elif \
             op.name == 'Model/Cast_1' or \
             op.name == 'Model/Sum_1' or \
             op.name == 'Model/Size_1' or \
             op.name == 'Model/Exp_1' or \
             op.name == 'Model/truediv_2' or \
             op.name == 'Model/truediv_3':
            graph.removeOp(op)

    if not is_pytest_run:
        print('Initial graph:\n{}\n'.format(graph))
    init_params = graph.calcModelParameters()
    print('Initial parameters: {}'.format(init_params))
    print('Initial Flops: {}\n'.format(graph.calcAlgFlops()))

    print('Placeholders:')
    for op in graph.getPlaceholders():
        print(op.debugString())
    print('')

    # Set up symbols to name dimensions
    skip_window_symbol = utils.getPositiveIntSymbolFromString('skip_window')
    num_skips_symbol = utils.getPositiveIntSymbolFromString('num_skips')
    nce_samples_symbol = utils.getPositiveIntSymbolFromString('nce_samples')
    hidden_dim_symbol = utils.getIntSymbolFromString('hidden_dim')
    vocab_size_symbol = utils.getIntSymbolFromString('vocab_size')
    subbatch_size_symbol = utils.getIntSymbolFromString('subbatch_size')
    sequence_length_symbol = utils.getIntSymbolFromString('sequence_length')
    batch_times_seq_symbol = sequence_length_symbol * subbatch_size_symbol
    graph_iters_symbol = utils.getIntSymbolFromString('graph::iters')
    # For simplicity, assign samples symbol in the op
    nce_samp_op = graph.opsByName[
        'Model/NceLoss_1_1/nce_loss/LogUniformCandidateSampler']
    nce_samp_op._num_samples_symbol = nce_samples_symbol

    # Convert these constant dimensions to symbols
    base_skip_window = 8
    base_num_skips = 8
    base_nce_samples = 64
    base_hidden_dim = 400
    base_vocab_size = 40004
    base_sequence_length = 32
    base_subbatch_size = 1

    # Find and set constants that contain model hyperparameters
    const_dict = {
        'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/sub_1_grad/Shape_1':
        [nce_samples_symbol],
        'Model/SkipGramSampler/Const':
        2 * skip_window_symbol,
        'Model/SkipGramSampler/strided_slice/stack': [0, skip_window_symbol],
        'Model/SkipGramSampler/strided_slice/stack_1':
        [0, -skip_window_symbol],
        'Model/Collapse/Reshape/shape': [-1, hidden_dim_symbol],
        'Model/Gradient/Compute/gradients/Model/Embedding_1_1/Gather_grad/Shape':
        [vocab_size_symbol, hidden_dim_symbol],
        'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/embedding_lookup_1_grad/Shape':
        [vocab_size_symbol],
        'Model/Gradient/Compute/gradients/Model/NceLoss_1_1/nce_loss/embedding_lookup_grad/Shape':
        [vocab_size_symbol, hidden_dim_symbol],
        'Model/Mask/NotEqual/y':
        vocab_size_symbol - 3,
        'Model/SkipGramSampler/Const_2':
        num_skips_symbol,
        'Model/SkipGramSampler/Tile_1/multiples': [1, num_skips_symbol],
        'Model/SkipGramSampler/Tile/multiples': [1, num_skips_symbol],
    }
    graph.bindConstantValues(const_dict)

    # Next, bind the constant, placeholder, and variable shapes and propagate
    bind_dict = { # Constants
                  # Placeholders
                  'Input/Input': [subbatch_size_symbol, sequence_length_symbol],
                  'Labels/Labels': [subbatch_size_symbol, sequence_length_symbol],
                  # Variables
                  'Model/NceLoss_1/b_Softmax': [vocab_size_symbol],
                  'Model/NceLoss_1/W_Softmax': [vocab_size_symbol, hidden_dim_symbol],
                  'Model/Embedding_1/EmbeddingWeights': [vocab_size_symbol, hidden_dim_symbol],
                }

    print('Binding variables')

    # HACK: For now, manually set GatherNd op shapes. Later, implement GatherNd
    gnd_op = graph.opsByName['Model/SkipGramSampler/GatherNd']
    gnd_op.outputs[0].mergeShape([
        subbatch_size_symbol,
        num_skips_symbol * (sequence_length_symbol - 2 * skip_window_symbol)
    ])

    graph.bindShapesAndPropagate(bind_dict,
                                 warn_if_ill_defined=(not is_pytest_run),
                                 make_symbolic=True)
    assert graph.isValid()

    if not is_pytest_run:
        print('\n\nCleaned Graph:\n{}'.format(graph))

    print('\n\nBound values')

    bind_subs = {
        graph_iters_symbol: 1,
        hidden_dim_symbol: base_hidden_dim,
        sequence_length_symbol: base_sequence_length,
        subbatch_size_symbol: base_subbatch_size,
        vocab_size_symbol: base_vocab_size,
        skip_window_symbol: base_skip_window,
        num_skips_symbol: base_num_skips,
        nce_samples_symbol: base_nce_samples,
    }

    # Verify parameter counts first
    parameters = graph.calcModelParameters()

    correct_params = 32043205
    correct_flops = 21148823
    correct_bytes = 23762537
    correct_total_footprint = 137949925

    print('Symbol associations: {}\n'.format(bind_subs))

    # Calculate model parameter count
    resolved_params = parameters.subs(bind_subs)
    try:
        resolved_params = int(resolved_params)
    except:
        print('ERROR: resolved_params should be int, but is {} = {}'.format(
            type(resolved_params), resolved_params))
    assert resolved_params == correct_params, \
           'Incorrect model params: {}'.format(resolved_params)
    print('Parameters: {}\nWith specified dims: {}\n'.format(
        parameters, resolved_params))

    # Calculate algorithmic Flops
    alg_flops = graph.calcAlgFlops()
    resolved_flops = alg_flops.subs(bind_subs)
    try:
        resolved_flops = int(resolved_flops)
    except:
        print('ERROR: resolved_flops should be int, but is {} = {}'.format(
            type(resolved_flops), resolved_flops))
    assert resolved_flops == correct_flops, \
           'Incorrect algorithmic flops: {}'.format(resolved_flops)
    print('Algorithmic Flops: {}\nWith specified dims: {}\n'.format(
        alg_flops, resolved_flops))

    # Calculate algorthmic Bytes accessed
    alg_bytes = graph.calcAlgBytes()
    resolved_bytes = alg_bytes.subs(bind_subs)
    try:
        resolved_bytes = int(resolved_bytes)
    except:
        print('ERROR: resolved_bytes should be int, but is {} = {}'.format(
            type(resolved_bytes), resolved_bytes))
    assert resolved_bytes == correct_bytes, \
           'Incorrect algorithmic bytes: {}'.format(resolved_bytes)
    print('Alg bytes accessed: {}\nWith specified dims: {}\n'.format(
        alg_bytes, resolved_bytes))

    # Calculate total memory footprint
    alg_footprint = graph.calcAlgFootprint()
    resolved_footprint = alg_footprint.subs(bind_subs)
    try:
        resolved_footprint = int(resolved_footprint)
    except:
        print('ERROR: resolved_footprint should be int, but is {} = {}'.format(
            type(resolved_footprint), resolved_footprint))
    assert resolved_footprint == correct_total_footprint, \
           'Incorrect algorithmic footprint: {}'.format(resolved_footprint)
    print('Alg mem footprint: {}\nWith specified dims: {}\n'.format(
        alg_footprint, resolved_footprint))

    # Calculate minimal memory footprint
    alg_min_footprint = graph.calcMinimalFootprint(symbol_subs=bind_subs)
    print('Alg minimal footprint (With specified dims): {}\n'.format(
        alg_min_footprint))

    # Calculate algorithmic IO per step
    total_io_footprint = 0
    for op in graph.getPlaceholders():
        total_io_footprint += op.calcAlgFootprint()
    if isinstance(total_io_footprint, int):
        resolved_io_footprint = total_io_footprint
    else:
        resolved_io_footprint = total_io_footprint.subs(bind_subs)
    print('Alg IO footprint: {}\nWith specified dims: {}\n'.format(
        total_io_footprint, resolved_io_footprint))

    if not is_pytest_run:
        print('VERBOSE ALGORTHMIC FLOPS:')
        graph.calcAlgFlops(verbose=True)
        print('')

        print('VERBOSE ALGORTHMIC BYTES:')
        graph.calcAlgBytes(verbose=True)
        print('')

        print('VERBOSE ALGORTHMIC FOOTPRINT:')
        graph.calcAlgFootprint(verbose=True)
        print('')

    # HACKY WAY TO SAVE MODELS FOR NOW!
    pickle.dump(
        graph,
        open(
            'catamount/frameworks/example_graphs/tensorflow/full_models/language_models/graph_word2vec.p',
            'wb'))

    if is_pytest_run:
        return

    print('\n\n======= Algorithmic graph-level analytics: =======')

    hidden_dims = [
        1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 18, 20, 25, 28, 35, 40, 50, 56, 69,
        78, 86, 96, 108, 119, 123, 133, 148, 163, 182, 202, 221, 246, 273, 297,
        329, 330, 364, 396, 436, 437, 520, 572, 617, 676, 740, 796, 869, 948,
        1017, 1106, 1202, 1286, 1394, 1510, 1611, 1742, 1882, 2004, 2161, 2476,
        3040, 3714, 4520, 5478, 6628, 8019, 9702, 11739, 14204, 17186, 20795,
        25161, 30444, 36837, 38100
    ]

    bind_subs.pop(hidden_dim_symbol)
    resolved_params = parameters.subs(bind_subs)

    print('Symbol associations: {}\n'.format(bind_subs))

    print(
        'Algorithmic Flops by hidden dimension, params, and per-batch-sample:')
    resolved_flops = alg_flops.subs(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        graph_flops = resolved_flops.subs({hidden_dim_symbol: hid_dim})
        graph_flops_per_sample = float(graph_flops) / \
                                 bind_subs[subbatch_size_symbol]
        print('{}\t{}\t{}\t{}'.format(hid_dim, graph_params, graph_flops,
                                      int(graph_flops_per_sample)))

    print('\nAlgorithmic bytes accessed by hidden dimension, params:')
    resolved_bytes = alg_bytes.subs(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        graph_bytes = resolved_bytes.subs({hidden_dim_symbol: hid_dim})
        print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_bytes))

    print('\nAlgorithmic total memory footprint by hidden dimension, params:')
    resolved_footprint = alg_footprint.subs(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        graph_footprint = resolved_footprint.subs({hidden_dim_symbol: hid_dim})
        print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_footprint))

    print(
        '\nAlgorithmic minimal memory footprint by hidden dimension, params:')
    full_subs = dict(bind_subs)
    for hid_dim in hidden_dims:
        graph_params = resolved_params.subs({hidden_dim_symbol: hid_dim})
        full_subs[hidden_dim_symbol] = hid_dim
        graph_min_foot = graph.calcMinimalFootprint(symbol_subs=full_subs)
        print('{}\t{}\t{}'.format(hid_dim, graph_params, graph_min_foot))