Exemplo n.º 1
0
    def forward_one_step(self, context, inputs, genotypes, **kwargs):
        assert (context is None) + (inputs is None) == 1
        # stem
        if inputs is not None:
            if self.use_stem:
                stemed = self.stem(inputs)
            else:
                stemed = inputs
            context = Context(self._num_init,
                              self._num_layers,
                              previous_cells=[stemed],
                              current_cell=[])
            context.last_conv_module = self.stem.get_last_conv_module()
            return stemed, context

        cur_cell_ind, _ = context.next_step_index

        # final: pooling->dropout->classifier
        if cur_cell_ind == self._num_layers:
            out = self.global_pooling(context.previous_cells[-1])
            out = self.dropout(out)
            logits = self.classifier(out.view(out.size(0), -1))
            context.previous_cells.append(logits)
            return logits, context

        # cells
        cell_genotype = genotypes[self._cell_layout[cur_cell_ind]]
        return self.cells[cur_cell_ind].forward_one_step(
            context, cell_genotype, **kwargs)
Exemplo n.º 2
0
def test_context_last_conv_module():
    from aw_nas.ops import DilConv, SepConv, ResNetBlock
    from aw_nas.utils.common_utils import Context
    data = _cnn_data()
    dil_conv = DilConv(3, 16, 3, 1, 1, 1).to("cuda")
    context = Context(0, 1, use_stem=False)
    _, context = dil_conv.forward_one_step(context=context, inputs=data[0])
    assert context.last_conv_module is dil_conv.op[1]
    _, context = dil_conv.forward_one_step(context=context)
    assert context.last_conv_module is dil_conv.op[2]

    context = Context(0, 1, use_stem=False)
    sep_conv = SepConv(3, 16, 3, 1, 1).to("cuda")
    _, context = sep_conv.forward_one_step(context=context, inputs=data[0])
    assert context.last_conv_module is sep_conv.op[1]
    for expected_ind in [2, 5, 6]:
        _, context = sep_conv.forward_one_step(context=context)
        assert context.last_conv_module is sep_conv.op[expected_ind]

    context = Context(0, 1, use_stem=False)
    res_block = ResNetBlock(3, 3, 1, True).to("cuda")
    res_block.train()
    out_0 = res_block(data[0])
    for i, expected_mod in enumerate(
        [res_block.op_1.op[0], res_block.op_2.op[0], None, None]):
        state, context = res_block.forward_one_step(
            context=context, inputs=data[0] if i == 0 else None)
        assert context.last_conv_module is expected_mod
    assert context.is_end_of_op
    assert (state == out_0).all()

    context = Context(0, 1, use_stem=False)
    res_block_stride = ResNetBlock(3, 16, 2, True).to("cuda")
    out_0 = res_block_stride(data[0])
    for i, expected_mod in enumerate([
            res_block_stride.op_1.op[0], res_block_stride.op_2.op[0],
            res_block_stride.skip_op.op[0], None
    ]):
        state, context = res_block_stride.forward_one_step(
            context=context, inputs=data[0] if i == 0 else None)
        assert context.last_conv_module is expected_mod
    assert (state == out_0).all()