def forward(self, inputs, dropout_path_rate): # pylint: disable=arguments-differ states = [ op(_input) for op, _input in zip(self.preprocess_ops, inputs) ] batch_size, _, height, width = states[0].shape o_height, o_width = height // self.stride, width // self.stride for to_ in range(self.num_init_nodes, self._num_nodes): state_to_ = torch.zeros( [batch_size, self.num_out_channels, o_height, o_width], device=states[0].device, ) for from_ in range(to_): op_ = self.edges[from_][to_] if isinstance(op_, ops.Zero): continue out = op_(states[from_]) if self.training and dropout_path_rate > 0: if not isinstance(op_, ops.Identity): out = utils.drop_path(out, dropout_path_rate) state_to_ = state_to_ + out states.append(state_to_) # concat all internal nodes return torch.cat(states[self.num_init_nodes:], dim=1)
def forward(self, inputs, dropout_path_rate): #pylint: disable=arguments-differ assert self._num_init == len(inputs) states = [ op(_input) for op, _input in zip(self.preprocess_ops, inputs) ] _num_conn = defaultdict(int) for to_, connections in self.conns_grouped: state_to_ = 0. for op_type, from_, _ in connections: conn_ind = 0 if not self.independent_conn else _num_conn[( from_, to_, op_type)] op = self.edges[from_][to_][op_type][conn_ind] _num_conn[(from_, to_, op_type)] += 1 out = op(states[from_]) if self.training and dropout_path_rate > 0: if not isinstance(op, ops.Identity): out = utils.drop_path(out, dropout_path_rate) state_to_ = state_to_ + out states.append(state_to_) out = self.concat_op([states[ind] for ind in self.concat_nodes]) if self.use_shortcut and self.layer_index != 0: out = out + self.shortcut_reduction_op(inputs[-1]) return out
def forward_one_step(self, context, dropout_path_rate): to_ = cur_step = context.next_step_index[1] if cur_step == 0: context._num_conn[self] = defaultdict(int) if cur_step < self._num_init: # `self._num_init` preprocess steps ind = len(context.previous_cells) - (self._num_init - cur_step) ind = max(ind, 0) # state = self.preprocess_ops[cur_step](context.previous_cells[ind]) # context.current_cell.append(state) # context.last_conv_module = self.preprocess_ops[cur_step].get_last_conv_module() current_op = context.next_op_index[1] state, context = self.preprocess_ops[cur_step].forward_one_step( context=context, inputs=context.previous_cells[ind] if current_op == 0 else None) if context.next_op_index[ 1] == 0: # this preprocess op finish, append to `current_cell` assert len(context.previous_op) == 1 context.current_cell.append(context.previous_op[0]) context.previous_op = [] context.last_conv_module = self.preprocess_ops[ cur_step].get_last_conv_module() elif cur_step < self._num_init + self._steps: # the following steps conns = self.conns_grouped[cur_step - self._num_init][1] op_ind, current_op = context.next_op_index if op_ind == len(conns): # all connections added to context.previous_ops, sum them up state = sum([st for st in context.previous_op]) context.current_cell.append(state) context.previous_op = [] else: op_type, from_, _ = conns[op_ind] conn_ind = 0 if not self.independent_conn else \ context._num_conn[self][(from_, to_, op_type)] op = self.edges[from_][to_][op_type][conn_ind] state, context = op.forward_one_step( context=context, inputs=context.current_cell[from_] if current_op == 0 else None) if self.training and dropout_path_rate > 0: if not isinstance(op, ops.Identity): context.last_state = state = utils.drop_path( state, dropout_path_rate) if context.next_op_index[0] != op_ind: # this op finish context._num_conn[self][(from_, to_, op_type)] += 1 else: # final concat state = self.concat_op( [context.current_cell[ind] for ind in self.concat_nodes]) context.current_cell = [] context.previous_cells.append(state) return state, context
def forward(self, inputs, dropout_path_rate): if self.skip_cell: out = self.shortcut_reduction_op(inputs) return out if not self.postprocess: states = [self.process_op(inputs)] else: states = [inputs] batch_size, _, height, width = states[0].shape o_height, o_width = height // self.stride, width // self.stride for to_ in range(self._num_init_nodes, self._num_nodes): state_to_ = torch.zeros( [ batch_size, # if preprocess & use_next_stage_width, op width self.use_next_stage_width if (not self.postprocess and self.use_next_stage_width) else self.num_out_channels, o_height, o_width, ], device=states[0].device, ) for from_ in range(to_): for op_ in self.edges[from_][to_].values(): if isinstance(op_, ops.Zero): continue out = op_(states[from_]) if self.training and dropout_path_rate > 0: if not isinstance(op_, ops.Identity): out = utils.drop_path(out, dropout_path_rate) state_to_ = state_to_ + out states.append(state_to_) """ (Maybe not so elegant) extra operations are applied here to align the width when 'use-next-stage-width' since the last cell is the normal cell, the cell-wise shortcut is the identity(), which could not change the dimension. but the conv inside should use-next-stage-width, which causes disalignment in width. for cell-wise shortcut, if ch_shortcut < ch_cell_out, only applying shortcut to former chs. """ if self.use_concat: # concat all internal nodes out = torch.cat(states[self._num_init_nodes:], dim=1) if self.use_shortcut: assert ( self.is_last_cell and self.use_next_stage_width ) is not True, ( "is_last_cell and use_next_stage_width should not happen together" ) if self.use_next_stage_width: # if use-stage-width, cannot apply shortcut shortcut = self.shortcut_reduction_op(inputs) if self.postprocess: out = self.process_op(out) NO_SHORTCUT = False if NO_SHORTCUT: pass else: if shortcut.shape[1] > out.shape[1]: out += shortcut[:, :out.shape[1], :, :] else: out[:, :shortcut.shape[1], :, :] += shortcut else: if self.postprocess: out = self.process_op(out) out = out + self.shortcut_reduction_op(inputs) else: out = sum(states[self._num_init_nodes:]) if self.use_shortcut: out = out + self.shortcut_reduction_op(inputs) return out