def flatten_if_needed(nodes, complete_dag, is_flattened, flatten_layer): """Ensure that the DAG outputs a two dimensional tensor To maintain a consistent prediction interface, all converted pytorch networks output a 2d tensor (rather than a 4d tensor). This function ensures that the DAG represented by the list of nodes meets this requirement. Args: nodes (List): A directed acyclic network, represented as a list of nodes, where each node is a dict containing layer attributes and pointers to its parents and children. complete_dag (bool): whehter of not the current list of nodes represents the complete DAG. is_flattened (bool): whether or not the DAG has already been "flattened" to produce tenors of the require dimensionality. flatten_layer (str): can be either the name of the layer after which flattening is performed, or 'last', in which case it is performed at the end of the network. Returns: (List): the updated DAG. """ if not is_flattened: prev = nodes[-1] flatten_condition = (flatten_layer == 'last' and complete_dag) \ or (flatten_layer == prev['name']) if flatten_condition: # TODO(samuel): Make network surgery more robust (it currently # relies on a unique variable naming modification to maintain a # consistent execution order) name = '{}_flatten'.format(prev['name']) outputs = prev['outputs'] prev['outputs'] = [ '{}_preflatten'.format(x) for x in prev['outputs'] ] node = { 'name': name, 'inputs': prev['outputs'], 'outputs': outputs, 'params': [], } node['mod'] = pmu.Flatten() nodes.append(node) is_flattened = True # if flatten_layer == 'last' and complete_dag: # name = '{}_flatten'.format(prev['name']) # outputs = [name] # node_pos = len(nodes) # elif flatten_layer == prev['name']: # import ipdb ; ipdb.set_trace() # node_names = [node['name'] for node in nodes] # idx = node_names.index(flatten_layer) # prev, follow = nodes[idx], nodes[idx + 1] # outputs = follow['inputs'] # node_pos = idx + 1 # else: return nodes, is_flattened
def simplify_dag(nodes): """Simplify unnecessary chains of operations Certain combinations of MCN operations can be simplified to a single PyTorch operation. For example, because matlab tensors are stored in column major order, the common `x.view(x.size(0),-1)` function maps to a combination of `Permute` and `Flatten` layers. Args: nodes (List): A directed acyclic network, represented as a list of nodes, where each node is a dict containing layer attributes and pointers to its parents and children. Returns: (List): a list of nodes, in which a subset of the operations have been simplified. """ simplified = [] skip = False for prev, node in zip(nodes, nodes[1:]): if (isinstance(node["mod"], pmu.Flatten) and isinstance(prev["mod"], pmu.Permute) and np.array_equal(prev["mod"].order, [1, 0, 2, 3])): new_node = { "name": node["name"], "inputs": prev["inputs"], "outputs": node["outputs"], "mod": pmu.Flatten(), "params": None } simplified.append(new_node) skip = True print("simplifying {}, {}".format(prev["name"], node["name"])) elif skip: skip = False else: simplified.append(prev) if not skip: simplified.append(node) # handle final node return simplified
def extract_dag(mcn_net, inplace, drop_prob_softmax=True, in_ch=3, flatten_layer='last', **kwargs): """Extract DAG nodes from stored matconvnet network Transform a stored mcn dagnn network structure into a Directed Acyclic Graph, represented as a list of nodes, each of which has pointers to its inputs and outputs. Since MatConvNet computes convolution groups online, rather than as a stored attribute (as done in PyTorch), the number of channels is tracked during DAG construction. Loss layers are not included in the imported model (they are skipped during dag construction). The number of channels associated with each named variable (feature activations) are tracked during construction in the `in_ch_store` dictionary. This makes it possible to determine the number of groups in each convolution (as well as enabling sanity checking for general convolution layers). Named variables which are not created by the network itself are assumed to be inputs, with the number of channels given by the `in_ch` argument. Args: mcn_net (dict): a native Python dictionary containg the network parameters, layers and meta information. drop_prob_softmax (bool) [True]: whether to remove the final softmax layer of a network, if present. in_ch (int) [3]: the number of channels expected in input data processed by the network. flatten (str) [last]: the layer after which a "flatten" operation should be inserted (if one is not present in the matconvnet network). inplace (bool): whether to convert ReLU modules to run "in place". Keyword Args: verbose (bool): whether to display more detailed information during the conversion process. Returns: nodes (List): A directed acyclic network, represented as a list of nodes, where each node is a dict containing layer attributes and pointers to its parents and children. uses_functional (bool): whether or not any of the mcn blocks have been mapped to members of the torch.functional module """ # TODO(samuel): improve state management nodes = [] is_flattened = False uses_functional = False num_layers = len(mcn_net['layers']['name']) in_ch_store = defaultdict(lambda: in_ch) # track the channels for ii in range(num_layers): params = mcn_net['layers']['params'][ii] if params == {'': []}: params = None node = { 'name': mcn_net['layers']['name'][ii], 'inputs': mcn_net['layers']['inputs'][ii], 'outputs': mcn_net['layers']['outputs'][ii], 'params': params, } bt = mcn_net['layers']['type'][ii] block = mcn_net['layers']['block'][ii] opts = {'block': block, 'block_type': bt} in_chs = [in_ch_store[x] for x in node['inputs']] out_chs = in_chs # by default, maintain the same number of channels if bt == 'dagnn.Conv': msg = 'conv layers should only take a single_input' if len(in_chs) != 1: import ipdb ipdb.set_trace() assert len(in_chs) == 1, msg mod, out_ch = pmu.conv2d_mod(block, in_chs[0], is_flattened, **kwargs) out_chs = [out_ch] elif bt == 'dagnn.BatchNorm': mod = pmu.batchnorm2d_mod(block, mcn_net, params) elif bt == 'dagnn.GlobalPooling': mod = pmu.globalpool_mod(block) elif bt == 'dagnn.ReLU': mod = nn.ReLU(inplace=inplace) elif bt == 'dagnn.Sigmoid': mod = nn.Sigmoid() elif bt == 'dagnn.Pooling': pad, ceil_mode = pmu.convert_padding(block['pad']) pool_opts = { 'kernel_size': pmu.int_list(block['poolSize']), 'stride': pmu.int_list(block['stride']), 'padding': pad, 'ceil_mode': ceil_mode } if block['method'] == 'avg': # mcn never includes padding in average pooling. # TODO(samuel): cleanup and add explanation pool_opts['count_include_pad'] = False mod = nn.AvgPool2d(**pool_opts) elif block['method'] == 'max': mod = nn.MaxPool2d(**pool_opts) else: msg = 'unknown pooling type: {}'.format(block['method']) raise ValueError(msg) elif bt == 'dagnn.DropOut': # both frameworks use p=prob(zeroed) mod = nn.Dropout(p=block['rate']) elif bt == 'dagnn.Permute': mod = pmu.Permute(**opts) elif bt == 'dagnn.Reshape': mod = pmu.Reshape(**opts) elif bt == 'dagnn.Axpy': mod = pmu.Axpy(**opts) elif bt == 'dagnn.Flatten': mod = pmu.Flatten(**opts) is_flattened = True out_chs = [1] elif bt == 'dagnn.Concat': mod = pmu.Concat(**opts) out_chs = [sum(in_chs)] elif bt == 'dagnn.Sum': mod = pmu.Sum(**opts) out_chs = [in_chs[0]] # (all input channels must be the same) elif bt == 'dagnn.AffineGridGenerator': mod = pmu.AffineGridGen(height=block['Ho'], width=block['Wo'], **opts) uses_functional = True elif bt == 'dagnn.BilinearSampler': mod = pmu.BilinearSampler(**opts) uses_functional = True elif bt in ['dagnn.Loss', 'dagnn.SoftmaxCELoss']: if kwargs['verbose']: print('skipping loss layer: {}'.format(node['name'])) continue elif (bt == 'dagnn.SoftMax' and (ii == num_layers - 1) and drop_prob_softmax): continue # remove softmax prediction layer else: import ipdb ipdb.set_trace() for output, out_ch in zip(node['outputs'], out_chs): in_ch_store[output] = out_ch node['mod'] = mod nodes += [node] complete_dag = (ii == num_layers - 1) nodes, is_flattened = flatten_if_needed(nodes, complete_dag, is_flattened, flatten_layer) return nodes, uses_functional
def extract_dag(mcn_net, drop_prob_softmax=True, in_ch=3, flatten_layer='last'): """Extract DAG nodes from stored matconvnet network Transform a stored mcn dagnn network structure into a Directed Acyclic Graph, represented as a list of nodes, each of which has pointers to its inputs and outputs. Since MatConvNet computes convolution groups online, rather than as a stored attribute (as done in PyTorch), the number of channels is tracked during DAG construction. Args: mcn_net (dict): a native Python dictionary containg the network parameters, layers and meta information. drop_prob_softmax (bool) [True]: whether to remove the final softmax layer of a network, if present. in_ch (int) [3]: the number of channels expected in input data processed by the network. flatten (str) [last]: the layer after which a "flatten" operation should be inserted (if one is not present in the matconvnet network). """ # TODO(samuel): improve state management nodes = [] is_flattened = False num_layers = len(mcn_net['layers']['name']) for ii in range(num_layers): params = mcn_net['layers']['params'][ii] if params == {'': []}: params = None node = { 'name': mcn_net['layers']['name'][ii], 'inputs': mcn_net['layers']['inputs'][ii], 'outputs': mcn_net['layers']['outputs'][ii], 'params': params, } bt = mcn_net['layers']['type'][ii] block = mcn_net['layers']['block'][ii] opts = {'block': block, 'block_type': bt} if bt == 'dagnn.Conv': mod, in_ch = pmu.conv2d_mod(block, in_ch, is_flattened) elif bt == 'dagnn.BatchNorm': mod = pmu.batchnorm2d_mod(block, mcn_net, params) elif bt == 'dagnn.ReLU': mod = nn.ReLU() elif bt == 'dagnn.Pooling': pad, ceil_mode = pmu.convert_padding(block['pad']) pool_opts = {'kernel_size': pmu.int_list(block['poolSize']), 'stride': pmu.int_list(block['stride']), 'padding': pad, 'ceil_mode': ceil_mode} if block['method'] == 'avg': # mcn never includes padding in average pooling. # TODO(samuel): cleanup and add explanation pool_opts['count_include_pad'] = False mod = nn.AvgPool2d(**pool_opts) elif block['method'] == 'max': mod = nn.MaxPool2d(**pool_opts) else: msg = 'unknown pooling type: {}'.format(block['method']) raise ValueError(msg) elif bt == 'dagnn.DropOut': # both frameworks use p=prob(zeroed) mod = nn.Dropout(p=block['rate']) elif bt == 'dagnn.Permute': mod = pmu.Permute(**opts) elif bt == 'dagnn.Flatten': mod = pmu.Flatten(**opts) is_flattened = True elif bt == 'dagnn.Concat': mod = pmu.Concat(**opts) elif bt == 'dagnn.Sum': mod = pmu.Sum(**opts) elif bt == 'dagnn.SoftMax' \ and (ii == num_layers -1) and drop_prob_softmax: continue # remove softmax prediction layer else: import ipdb ; ipdb.set_trace() node['mod'] = mod nodes += [node] complete_dag = (ii == num_layers -1) nodes, is_flattened = flatten_if_needed(nodes, complete_dag, is_flattened, flatten_layer) return nodes