Exemplo n.º 1
0
class TUM(chainer.Chain):
    def __init__(self, inplanes, scales=6, mingrid=1):
        super(TUM, self).__init__()

        self.scales = scales

        with self.init_scope():
            ecs = []
            for s in range(scales - 1):
                if s == 0:
                    conv = Conv2DBNActiv(inplanes,
                                         256,
                                         3,
                                         2,
                                         pad=1,
                                         nobias=True)
                elif s == scales - 2 and mingrid == 1:
                    conv = Conv2DBNActiv(256, 256, 3, 2, nobias=True)
                else:
                    conv = Conv2DBNActiv(256, 256, 3, 2, pad=1, nobias=True)
                ecs.append(conv)
            self.ecs = ChainList(*ecs)

            dcs = []
            for s in range(scales):
                if s == scales - 1:
                    conv = Conv2DBNActiv(inplanes, 256, 3, pad=1, nobias=True)
                else:
                    conv = Conv2DBNActiv(256, 256, 3, pad=1, nobias=True)
                dcs.append(conv)
            self.dcs = ChainList(*dcs)

            self.scs = ChainList(*[
                Conv2DBNActiv(256, 128, 1, nobias=True) for _ in range(scales)
            ])

    def __call__(self, x):
        e = x
        es = [e]
        for conv in self.ecs.children():
            e = conv(e)
            es.append(e)

        d = es[-1]
        ds = [d]
        for s in range(self.scales - 2):
            d = F.resize_images(
                self.dcs[s](d),
                (es[-(s + 2)].shape[2], es[-(s + 2)].shape[3])) + es[-(s + 2)]
            ds.append(d)
        d = F.resize_images(
            self.dcs[self.scales - 2](d),
            (x.shape[2], x.shape[3])) + self.dcs[self.scales - 1](x)
        ds.append(d)

        ys = []
        for s in range(self.scales):
            ys.append(self.scs[s](ds[s]))

        return ys[::-1]
 def __call__(self, X):
     Y = X
     for layer in ChainList.children(self):
         Y = layer(Y)
     return Y