예제 #1
0
    def forward(self, inputs, dropout_path_rate):  # pylint: disable=arguments-differ
        states = [
            op(_input) for op, _input in zip(self.preprocess_ops, inputs)
        ]
        batch_size, _, height, width = states[0].shape
        o_height, o_width = height // self.stride, width // self.stride

        for to_ in range(self.num_init_nodes, self._num_nodes):
            state_to_ = torch.zeros(
                [batch_size, self.num_out_channels, o_height, o_width],
                device=states[0].device,
            )
            for from_ in range(to_):
                op_ = self.edges[from_][to_]
                if isinstance(op_, ops.Zero):
                    continue
                out = op_(states[from_])
                if self.training and dropout_path_rate > 0:
                    if not isinstance(op_, ops.Identity):
                        out = utils.drop_path(out, dropout_path_rate)
                state_to_ = state_to_ + out
            states.append(state_to_)

        # concat all internal nodes
        return torch.cat(states[self.num_init_nodes:], dim=1)
예제 #2
0
    def forward(self, inputs, dropout_path_rate):  #pylint: disable=arguments-differ
        assert self._num_init == len(inputs)
        states = [
            op(_input) for op, _input in zip(self.preprocess_ops, inputs)
        ]

        _num_conn = defaultdict(int)
        for to_, connections in self.conns_grouped:
            state_to_ = 0.
            for op_type, from_, _ in connections:
                conn_ind = 0 if not self.independent_conn else _num_conn[(
                    from_, to_, op_type)]
                op = self.edges[from_][to_][op_type][conn_ind]
                _num_conn[(from_, to_, op_type)] += 1
                out = op(states[from_])
                if self.training and dropout_path_rate > 0:
                    if not isinstance(op, ops.Identity):
                        out = utils.drop_path(out, dropout_path_rate)
                state_to_ = state_to_ + out
            states.append(state_to_)

        out = self.concat_op([states[ind] for ind in self.concat_nodes])
        if self.use_shortcut and self.layer_index != 0:
            out = out + self.shortcut_reduction_op(inputs[-1])

        return out
예제 #3
0
 def forward_one_step(self, context, dropout_path_rate):
     to_ = cur_step = context.next_step_index[1]
     if cur_step == 0:
         context._num_conn[self] = defaultdict(int)
     if cur_step < self._num_init:  # `self._num_init` preprocess steps
         ind = len(context.previous_cells) - (self._num_init - cur_step)
         ind = max(ind, 0)
         # state = self.preprocess_ops[cur_step](context.previous_cells[ind])
         # context.current_cell.append(state)
         # context.last_conv_module = self.preprocess_ops[cur_step].get_last_conv_module()
         current_op = context.next_op_index[1]
         state, context = self.preprocess_ops[cur_step].forward_one_step(
             context=context,
             inputs=context.previous_cells[ind]
             if current_op == 0 else None)
         if context.next_op_index[
                 1] == 0:  # this preprocess op finish, append to `current_cell`
             assert len(context.previous_op) == 1
             context.current_cell.append(context.previous_op[0])
             context.previous_op = []
             context.last_conv_module = self.preprocess_ops[
                 cur_step].get_last_conv_module()
     elif cur_step < self._num_init + self._steps:  # the following steps
         conns = self.conns_grouped[cur_step - self._num_init][1]
         op_ind, current_op = context.next_op_index
         if op_ind == len(conns):
             # all connections added to context.previous_ops, sum them up
             state = sum([st for st in context.previous_op])
             context.current_cell.append(state)
             context.previous_op = []
         else:
             op_type, from_, _ = conns[op_ind]
             conn_ind = 0 if not self.independent_conn else \
                        context._num_conn[self][(from_, to_, op_type)]
             op = self.edges[from_][to_][op_type][conn_ind]
             state, context = op.forward_one_step(
                 context=context,
                 inputs=context.current_cell[from_]
                 if current_op == 0 else None)
             if self.training and dropout_path_rate > 0:
                 if not isinstance(op, ops.Identity):
                     context.last_state = state = utils.drop_path(
                         state, dropout_path_rate)
             if context.next_op_index[0] != op_ind:
                 # this op finish
                 context._num_conn[self][(from_, to_, op_type)] += 1
     else:  # final concat
         state = self.concat_op(
             [context.current_cell[ind] for ind in self.concat_nodes])
         context.current_cell = []
         context.previous_cells.append(state)
     return state, context
예제 #4
0
    def forward(self, inputs, dropout_path_rate):

        if self.skip_cell:
            out = self.shortcut_reduction_op(inputs)
            return out

        if not self.postprocess:
            states = [self.process_op(inputs)]
        else:
            states = [inputs]

        batch_size, _, height, width = states[0].shape
        o_height, o_width = height // self.stride, width // self.stride

        for to_ in range(self._num_init_nodes, self._num_nodes):
            state_to_ = torch.zeros(
                [
                    batch_size,
                    # if preprocess & use_next_stage_width, op width
                    self.use_next_stage_width if
                    (not self.postprocess
                     and self.use_next_stage_width) else self.num_out_channels,
                    o_height,
                    o_width,
                ],
                device=states[0].device,
            )
            for from_ in range(to_):
                for op_ in self.edges[from_][to_].values():
                    if isinstance(op_, ops.Zero):
                        continue
                    out = op_(states[from_])
                    if self.training and dropout_path_rate > 0:
                        if not isinstance(op_, ops.Identity):
                            out = utils.drop_path(out, dropout_path_rate)
                    state_to_ = state_to_ + out
            states.append(state_to_)
        """
        (Maybe not so elegant)
        extra operations are applied here to align the width when 'use-next-stage-width'
        since the last cell is the normal cell, the cell-wise shortcut is the identity(),
        which could not change the dimension. but the conv inside should use-next-stage-width,
        which causes disalignment in width. for cell-wise shortcut, if ch_shortcut < ch_cell_out,
        only applying shortcut to former chs.
        """
        if self.use_concat:
            # concat all internal nodes
            out = torch.cat(states[self._num_init_nodes:], dim=1)
            if self.use_shortcut:
                assert (
                    self.is_last_cell and self.use_next_stage_width
                ) is not True, (
                    "is_last_cell and use_next_stage_width should not happen together"
                )
                if self.use_next_stage_width:
                    # if use-stage-width, cannot apply shortcut
                    shortcut = self.shortcut_reduction_op(inputs)
                    if self.postprocess:
                        out = self.process_op(out)
                    NO_SHORTCUT = False
                    if NO_SHORTCUT:
                        pass
                    else:
                        if shortcut.shape[1] > out.shape[1]:
                            out += shortcut[:, :out.shape[1], :, :]
                        else:
                            out[:, :shortcut.shape[1], :, :] += shortcut
                else:
                    if self.postprocess:
                        out = self.process_op(out)
                    out = out + self.shortcut_reduction_op(inputs)
        else:
            out = sum(states[self._num_init_nodes:])
            if self.use_shortcut:
                out = out + self.shortcut_reduction_op(inputs)

        return out