class MacroGraph(NodeOpGraph): def __init__(self, config, primitives, ops_dict, *args, **kwargs): self.config = config self.primitives = primitives self.ops_dict = ops_dict self.nasbench_api = API('/home/siemsj/nasbench_201.pth') super(MacroGraph, self).__init__(*args, **kwargs) def _build_graph(self): num_cells_per_stack = self.config['num_cells_per_stack'] C = self.config['init_channels'] layer_channels = [C] * num_cells_per_stack + [C * 2] + [C * 2] * num_cells_per_stack + [C * 4] + [ C * 4] * num_cells_per_stack layer_reductions = [False] * num_cells_per_stack + [True] + [False] * num_cells_per_stack + [True] + [ False] * num_cells_per_stack stem = NASBENCH_201_Stem(C=C) self.add_node(0, type='input') self.add_node(1, op=stem, type='stem') C_prev = C self.cells = nn.ModuleList() for cell_num, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): if reduction: cell = ResNetBasicblock(C_prev, C_curr, 2, True) self.add_node(cell_num + 2, op=cell, primitives=self.primitives, transform=lambda x: x[0]) else: cell = Cell(primitives=self.primitives, stride=1, C_prev=C_prev, C=C_curr, ops_dict=self.ops_dict, cell_type='normal') self.add_node(cell_num + 2, op=cell, primitives=self.primitives) C_prev = C_curr lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) pooling = nn.AdaptiveAvgPool2d(1) classifier = nn.Linear(C_prev, self.config['num_classes']) self.add_node(cell_num + 3, op=lastact, transform=lambda x: x[0], type='postprocessing_nb201') self.add_node(cell_num + 4, op=pooling, transform=lambda x: x[0], type='pooling') self.add_node(cell_num + 5, op=classifier, transform=lambda x: x[0].view(x[0].size(0), -1), type='output') # Edges for i in range(1, cell_num + 6): self.add_edge(i - 1, i, type='input', desc='previous') def sample(self, same_cell_struct=True, n_ops_per_edge=1, n_input_edges=None, dist=None, seed=1): """ same_cell_struct: True; if the sampled cell topology is the same or not n_ops_per_edge: 1; number of sampled operations per edge in cell n_input_edges: None; list equal with length with number of intermediate nodes. Determines the number of predecesor nodes for each of them dist: None; distribution to sample operations in edges from seed: 1; random seed """ # create a new graph that we will discretize new_graph = MacroGraph(self.config, self.primitives, self.ops_dict) np.random.seed(seed) seeds = {'normal': seed + 1, 'reduction': seed + 2} for node in new_graph: cell = new_graph.get_node_op(node) if not isinstance(cell, Cell): continue if same_cell_struct: np.random.seed(seeds[new_graph.get_node_type(node)]) for edge in cell.edges: op_choices = cell.get_edge_op_choices(*edge) sampled_op = np.random.choice(op_choices, n_ops_per_edge, False, p=dist) cell[edge[0]][edge[1]]['op_choices'] = [*sampled_op] if n_input_edges is not None: for inter_node, k in zip(cell.inter_nodes(), n_input_edges): # in case the start node index is not 0 node_idx = list(cell.nodes).index(inter_node) prev_node_choices = list(cell.nodes)[:node_idx] assert k <= len(prev_node_choices), 'cannot sample more' ' than number of predecesor nodes' sampled_input_edges = np.random.choice(prev_node_choices, k, False) for i in set(prev_node_choices) - set(sampled_input_edges): cell.remove_edge(i, inter_node) return new_graph @classmethod def from_config(cls, config=None, filename=None): with open(filename, 'r') as f: graph_dict = yaml.safe_load(f) if config is None: raise ('No configuration provided') graph = cls(config, []) graph_type = graph_dict['type'] edges = [(*eval(e), attr) for e, attr in graph_dict['edges'].items()] graph.clear() graph.add_edges_from(edges) C = config['init_channels'] C_curr = config['stem_multiplier'] * C stem = Stem(C_curr=C_curr) C_prev_prev, C_prev, C_curr = C_curr, C_curr, C for node, attr in graph_dict['nodes'].items(): node_type = attr['type'] if node_type == 'input': graph.add_node(node, type='input') elif node_type == 'stem': graph.add_node(node, op=stem, type='stem') elif node_type in ['normal', 'reduction']: assert attr['op']['type'] == 'Cell' graph.add_node(node, op=Cell.from_config(attr['op'], primitives=attr['op']['primitives'], C_prev_prev=C_prev_prev, C_prev=C_prev, C=C_curr, reduction_prev=graph_dict['nodes'][node - 1]['type'] == 'reduction', cell_type=node_type), type=node_type) C_prev_prev, C_prev = C_prev, config['channel_multiplier'] * C_curr elif node_type == 'pooling': pooling = nn.AdaptiveAvgPool2d(1) graph.add_node(node, op=pooling, transform=lambda x: x[0], type='pooling') elif node_type == 'output': classifier = nn.Linear(C_prev, config['num_classes']) graph.add_node(node, op=classifier, transform=lambda x: x[0].view(x[0].size(0), -1), type='output') return graph @staticmethod def export_nasbench_201_results_to_dict(information): results_dict = {} dataset_names = information.get_dataset_names() results_dict['arch'] = information.arch_str results_dict['datasets'] = dataset_names for ida, dataset in enumerate(dataset_names): dataset_results = {} dataset_results['dataset'] = dataset metric = information.get_compute_costs(dataset) flop, param, latency, training_time = metric['flops'], metric['params'], metric['latency'], metric[ 'T-train@total'] dataset_results['flop'] = flop dataset_results['params (MB)'] = param dataset_results['latency (ms)'] = latency * 1000 if latency is not None and latency > 0 else None dataset_results['training_time'] = training_time train_info = information.get_metrics(dataset, 'train') if dataset == 'cifar10-valid': valid_info = information.get_metrics(dataset, 'x-valid') dataset_results['train_loss'] = train_info['loss'] dataset_results['train_accuracy'] = train_info['accuracy'] dataset_results['valid_loss'] = valid_info['loss'] dataset_results['valid_accuracy'] = valid_info['accuracy'] elif dataset == 'cifar10': test__info = information.get_metrics(dataset, 'ori-test') dataset_results['train_loss'] = train_info['loss'] dataset_results['train_accuracy'] = train_info['accuracy'] dataset_results['test_loss'] = test__info['loss'] dataset_results['test_accuracy'] = test__info['accuracy'] else: valid_info = information.get_metrics(dataset, 'x-valid') test__info = information.get_metrics(dataset, 'x-test') dataset_results['train_loss'] = train_info['loss'] dataset_results['train_accuracy'] = train_info['accuracy'] dataset_results['valid_loss'] = valid_info['loss'] dataset_results['valid_accuracy'] = valid_info['accuracy'] dataset_results['test_loss'] = test__info['loss'] dataset_results['test_accuracy'] = test__info['accuracy'] results_dict[dataset] = dataset_results return results_dict def query_architecture(self, arch_weights): arch_weight_idx_to_parent = {0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 2} arch_strs = { 'cell_normal_from_0_to_1': '', 'cell_normal_from_0_to_2': '', 'cell_normal_from_1_to_2': '', 'cell_normal_from_0_to_3': '', 'cell_normal_from_1_to_3': '', 'cell_normal_from_2_to_3': '', } for arch_weight_idx, (edge_key, edge_weights) in enumerate(arch_weights.items()): edge_weights_norm = torch.softmax(edge_weights, dim=-1) selected_op_str = PRIMITIVES[edge_weights_norm.argmax()] arch_strs[edge_key] = '{}~{}'.format(selected_op_str, arch_weight_idx_to_parent[arch_weight_idx]) arch_str = '|{}|+|{}|{}|+|{}|{}|{}|'.format(*arch_strs.values()) if not hasattr(self, 'nasbench_api'): self.nasbench_api = API('/home/siemsj/nasbench_201.pth') index = self.nasbench_api.query_index_by_arch(arch_str) self.nasbench_api.show(index) info = self.nasbench_api.query_by_index(index) return self.export_nasbench_201_results_to_dict(info)