예제 #1
0
    def InitializeNetwork(self):
        self.vgg = Network.vgg16_features(cuda_id=self.cuda_id[0])
        self.stn = Network.STN_Flow_relative(size=self.SIZE,
                                             cuda_id=self.cuda_id[0])
        self.stn_destroy = Network.STN_Flow_relative(size=self.SIZE,
                                                     cuda_id=self.cuda_id[0])

        self.models = {}
        self.models['seg_feat_src'] = Network.DRIU_novgg_siamese_feat(
            with_relu=False, cuda_id=self.cuda_id[0])
        self.models['seg_feat_tgt'] = Network.DRIU_novgg_siamese_feat(
            with_relu=False, cuda_id=self.cuda_id[0])
        self.models['seg_pred'] = Network.DRIU_novgg_siamese_seg(
            with_relu=True, with_sigmoid=True, cuda_id=self.cuda_id[0])
        self.models['flow'] = UNet.UNetFlow(
            down_scales=self.opt.flow_feat_scales,
            output_scale=self.opt.flow_scale_times,
            num_filters_base=self.opt.flow_base_filters,
            max_filters=self.opt.flow_max_filters,
            input_channels=128,
            output_channels=2,
            downsampling=self.opt.flow_downsample,
            cuda_id=self.cuda_id[1])

        for k in self.models.keys():
            print(self.models[k])
        for k in self.models.keys():
            num_params = 0
            for param in self.models[k].parameters():
                num_params += param.numel()
            print('[Network %s] Total number of parameters : %.3f M' %
                  (k, num_params / 1e6))