def forward(self, x,segments):
        inp_shape = x.shape[2:]
        x = self.conv1(x)
        x = self.gn1(x)
        x = self.relu(x)


        x = self.layer1(x)
        x=self.full1(x)
        x1=self.full2(x)
        #print(x1.shape)
        x = self.layer2(x)
        #print(x.shape)
        #x = self.layer3(x)
        #print(x.shape)
        x = self.layer4(x)
        #print(x.shape)
        # H, W -> H/2, W/2 
        x = self.pyramid_pooling(x)

        #x = self.cbr_final(x)
        #x = self.dropout(x)
        x = self.deconv0(x)
     
        x = self.deconv1(x)
       
        x = self.deconv2(x)
       

        #128+128
        x=self.regress1(x)
        #print(x.shape)
        #print(x1.shape)
        x_f=torch.cat((x,x1),1)
        # x=self.regress2(x_f)
        # x=self.regress3(x)
        # x=self.regress4(x)
        # x=self.final(x)
        # x=self.final2(x)
        y=self.class1(x_f)
        y=self.class2(y)
        y=self.class3(y)
        y=self.class4(y)
        #y=self.class5(y)
        loss_var,loss_dis,loss_reg = cluster_loss(y,segments)
        loss_var=loss_var.reshape((y.shape[0],1))
        loss_dis=loss_dis.reshape((y.shape[0],1))
        loss_reg=loss_reg.reshape((y.shape[0],1))


        return y,loss_var,loss_dis,loss_reg
Exemplo n.º 2
0
    def forward(self, x, segments, labels, flag, task):
        #print(x.shape)
        location_map=torch.cat([(torch.arange(x.shape[-1])/x.shape[-1]).unsqueeze(0).expand(x.shape[-2],x.shape[-1]).unsqueeze(0), \
            (torch.arange(x.shape[-2])/x.shape[-2]).unsqueeze(0).transpose(1,0).expand(x.shape[-2],x.shape[-1]).unsqueeze(0)],0).unsqueeze(0).float().cuda()
        #x=torch.cat([x,location_map],1)
        zero = torch.zeros(1).cuda()
        one = torch.ones(1).cuda()
        x = self.conv1(x)
        x = self.conv2(x)
        x1 = self.layer1(x)
        #half resolution
        x2 = self.layer2(x1)
        #print(x.shape)
        x = self.layer3(x2)
        #print(x.shape)
        x = self.layer4(x)
        #print(x.shape)
        # H, W -> H/2, W/2
        x = self.pyramid_pooling(x)

        #x = self.cbr_final(x)
        #x = self.dropout(x)
        x = self.fuse0(x)
        x = self.fuse1(x)
        #print(x.shape)
        x = self.deconv1(x)
        #print(x.shape,x2.shape)
        x = self.fuse2(torch.cat((x, x2), 1))
        x = self.deconv2(x)
        #print(x.shape)
        x_share = self.fuse3(torch.cat((x, x1), 1))

        # x=self.regress1(x_share)
        # #print(x.shape)
        # x=self.regress2(x)
        # x=self.regress3(x)
        # depth=self.regress4(x)

        # accurate_depth=depth
        # return depth,accurate_depth
        #clustering feature

        #accurate_depth=depth*reliable

        if flag == 0:
            x_fuse = torch.cat([x_share, location_map], 1)
            y = self.class0(x_fuse)
            y = self.class1(y)
            y = self.class2(y)
            y = self.class3(y)
            y = self.class4(y)
            with torch.no_grad():
                #masks=fast_cluster(y).view(1,1,x_share.shape[-2],x_share.shape[-1])
                masks = segments.view(1, 1, x_share.shape[-2],
                                      x_share.shape[-1])
                mask_volume = torch.zeros(1, 80, x_share.shape[-2],
                                          x_share.shape[-1]).cuda()
                for i in range(1, torch.max(masks).int() + 1):
                    mask_volume[:, i - 1, :, :] = torch.where(
                        masks == i, one,
                        zero).view_as(mask_volume[:, i - 1, :, :])
            #x=self.regress1(x_share)
            #x=self.regress1(torch.cat([x_share,masks/torch.max(masks)],1))
            x = self.regress1(torch.cat([x_share, mask_volume], 1))
            #print(x.shape)
            x = self.regress2(x)
            x = self.regress3(x)
            depth = self.regress4(x)
            initial_depth = depth + 0
            with torch.no_grad():
                #masks=fast_cluster(y).view_as(depth)
                #masks=segments.view_as(depth)
                #coarse depth
                coarse_depth = depth + 0
                coarse_feature = x_share + 0
                mean_features = torch.zeros(1, x_share.shape[1],
                                            torch.max(masks).long() +
                                            1).cuda()
                mean_depth = torch.zeros(torch.max(masks).long() + 1).cuda()
                #print(torch.max(masks))
                for i in range(1, torch.max(masks).int() + 1):
                    index_r = torch.where(masks == i, one, zero)
                    mean_d = torch.sum(index_r * depth) / torch.sum(index_r)
                    mean_depth[i] = mean_d
                    coarse_depth = torch.where(masks == i, mean_d,
                                               coarse_depth)
                    mean_f = torch.sum((index_r * x_share).view(
                        x_share.shape[0], x_share.shape[1], -1),
                                       dim=-1) / torch.sum(index_r)
                    #print(mean_f.shape,mean_features[...,i].shape)
                    mean_features[..., i] = mean_f
                    coarse_feature = torch.where(
                        masks == i,
                        mean_f.view(x_share.shape[0], x_share.shape[1], 1, 1),
                        coarse_feature)

            #     #refine outer
            #     outer_feature=torch.zeros(1,2*x_share.shape[1]+2,torch.max(masks).long()+1,torch.max(masks).long()+1).cuda()
            #     for i in range(torch.min(masks).int(),torch.max(masks).int()+1):
            #         for j in range(torch.min(masks).int(),torch.max(masks).int()+1):
            #             if i!=j:
            #                 #print(outer_feature[...,i,j].shape,mean_depth[i].view(1,1).shape,mean_features[...,i].shape)
            #                 outer_feature[...,i,j]=torch.cat([mean_depth[i].view(1,1),mean_features[...,i],mean_depth[j].view(1,1),mean_features[...,j]],dim=-1)

            # outer=self.outrefine1(outer_feature)
            # outer=self.outrefine2(outer)
            # outer=self.outrefine3(outer)
            # outer_variance=self.outrefine4(outer)
            # outer_depth=torch.zeros(torch.max(masks).long()+1).cuda()
            # # #mean_depth_map=coarse_depth+0
            # #with torch.no_grad():
            # for i in range(torch.min(masks).int(),torch.max(masks).int()+1):
            #     outer_depth[i]=(torch.sum(mean_depth*outer_variance[...,i,:])+mean_depth[i])/torch.sum(outer_variance[...,i,:]+1)
            #     #outer_depth[i]=(torch.sum(mean_depth*outer_variance[...,i,:])+mean_depth[i])
            #     coarse_depth=torch.where(masks==i,outer_depth[i],coarse_depth)+0
            #refine inner
            #coarse_depth=self.output(coarse_depth)
            inner_feature = torch.cat([coarse_depth, x_share, coarse_feature],
                                      1)
            inner = self.inrefine1(inner_feature)
            inner = self.inrefine2(inner)
            inner = self.inrefine3(inner)
            inner = self.inrefine4(inner)
            inner_variance = self.inrefine5(inner)

            reliable_feature = torch.cat([depth, x_share, coarse_feature], 1)
            reliable = self.inrefine1(reliable_feature)
            reliable = self.inrefine2(reliable)
            reliable = self.inrefine3(reliable)
            reliable = self.inrefine4(reliable)
            reliable_variance = self.inrefine5(reliable)
            # #inner_variance[:,0,...]=inner_variance[:,0,...]/torch.max(inner_variance[:,0,...])
            # reliable_to_depth=(inner_variance[:,0,...]/torch.max(inner_variance[:,0,...])).unsqueeze(1)
            # variance_on_cosrse=inner_variance[:,1,...].unsqueeze(1)
            # #print(inner_variance.shape)
            # accurate_depth=depth*reliable_to_depth+(coarse_depth*variance_on_cosrse)*(1-reliable_to_depth)
            loss_var, loss_dis, loss_reg = cluster_loss(y,
                                                        segments.long(),
                                                        device_id=cuda_id)
            loss_var = loss_var.reshape((y.shape[0], 1))
            loss_dis = loss_dis.reshape((y.shape[0], 1))
            loss_reg = loss_reg.reshape((y.shape[0], 1))
            accurate_depth = self.output(inner_variance + coarse_depth)
            depth = self.output(reliable_variance + depth)
            accurate_depth = torch.where(masks > 0,
                                         (depth + accurate_depth) / 2,
                                         initial_depth)
            #print(torch.mean(depth).item(),torch.mean(coarse_depth).item())
            return initial_depth, accurate_depth, loss_var, loss_dis, loss_reg
        else:
            if task == 'train':
                with torch.no_grad():
                    masks = fast_cluster(y).view_as(depth)
                    print(torch.max(masks))

                loss_var, loss_dis, loss_reg = cluster_loss(y, segments.long())
                loss_var = loss_var.reshape((y.shape[0], 1))
                loss_dis = loss_dis.reshape((y.shape[0], 1))
                loss_reg = loss_reg.reshape((y.shape[0], 1))
                return depth, masks, loss_var, loss_dis, loss_reg
            elif task == 'test':

                loss_var, loss_dis, loss_reg = cluster_loss(y, segments.long())
                loss_var = loss_var.reshape((y.shape[0], 1))
                loss_dis = loss_dis.reshape((y.shape[0], 1))
                loss_reg = loss_reg.reshape((y.shape[0], 1))
                return depth, loss_var, loss_dis, loss_reg
            elif task == 'eval':

                x_fuse = torch.cat([x_share, location_map], 1)
                masks = segments.view_as(depth)
                #coarse depth
                coarse_depth = depth + 0
                coarse_feature = x_fuse + 0
                mean_features = torch.zeros(1, x_fuse.shape[1],
                                            torch.max(masks).long() +
                                            1).cuda()
                mean_depth = torch.zeros(torch.max(masks).long() + 1).cuda()

                for i in range(
                        torch.min(masks).int(),
                        torch.max(masks).int() + 1):
                    index_r = torch.where(masks == i, one, zero)
                    mean_d = torch.sum(index_r * depth) / torch.sum(index_r)
                    mean_depth[i] = mean_d + 0
                    coarse_depth = torch.where(masks == i, mean_depth[i],
                                               coarse_depth)
                    mean_f = torch.sum((index_r * x_fuse).view(
                        x_fuse.shape[0], x_fuse.shape[1], -1),
                                       dim=-1) / torch.sum(index_r)
                    #print(mean_f.shape,mean_features[...,i].shape)
                    mean_features[..., i] = mean_f
                    coarse_feature = torch.where(
                        masks == i,
                        mean_f.view(x_fuse.shape[0], x_fuse.shape[1], 1, 1),
                        coarse_feature)

                #refine outer
                # outer_feature=torch.zeros(1,2*x_fuse.shape[1]+2,torch.max(masks).long()-torch.min(masks).long()+1,torch.max(masks).long()-torch.min(masks).long()+1).cuda()
                # for i in range(torch.min(masks).int(),torch.max(masks).int()+1):
                #     for j in range(torch.min(masks).int(),torch.max(masks).int()+1):
                #         if i!=j:
                #             #print(outer_feature[...,i,j].shape,mean_depth[i].view(1,1).shape,mean_features[...,i].shape)
                #             outer_feature[...,i,j]=torch.cat([mean_depth[i].view(1,1),mean_features[...,i],mean_depth[j].view(1,1),mean_features[...,j]],dim=-1)

                # outer=self.outrefine1(outer_feature)
                # outer=self.outrefine2(outer)
                # outer=self.outrefine3(outer)
                # outer_variance=self.outrefine4(outer)
                # outer_depth=torch.zeros(torch.max(masks).long()-torch.min(masks).long()+1).cuda()
                # #mean_depth_map=coarse_depth+0
                # # print(torch.min(masks))
                # # print(torch.sum(torch.where(masks==0,torch.ones(1).cuda(),torch.zeros(1).cuda())))
                # for i in range(torch.min(masks).int(),torch.max(masks).int()+1):
                #     outer_depth[i]=(torch.sum(mean_depth*outer_variance[...,i,:])+mean_depth[i])/(torch.sum(outer_variance[...,i,:])+1)
                #     #outer_depth[i]=(torch.sum(mean_depth*outer_variance[...,i,:])+mean_depth[i])
                #     coarse_depth=torch.where(masks==i,outer_depth[i],coarse_depth)+0
                #print(torch.max(coarse_depth),torch.mean(mean_depth),torch.mean(outer_depth),torch.max(outer_variance))
                #mean_depth_map=coarse_depth+0
                #refine inner
                inner_feature = torch.cat(
                    [coarse_depth, x_fuse - coarse_feature], 1)

                #print('inner_feature',torch.max(inner_feature).item())
                inner = self.inrefine1(inner_feature)
                #print('inner_1',torch.max(inner).item())
                inner = self.inrefine2(inner)
                #print('inner_2',torch.max(inner).item())
                inner = self.inrefine3(inner)
                #print('inner_3',torch.max(inner).item())
                inner = self.inrefine4(inner)
                inner_variance = self.inrefine5(inner)
                accurate_depth = inner_variance

                # inner_feature= torch.cat([depth,x_share],1)
                # relialbe=self.reliable1(inner_feature)
                # relialbe=self.reliable2(relialbe)
                # relialbe=self.reliable3(relialbe)
                # relialbe=self.reliable4(relialbe)
                # relialbe=self.reliable5(relialbe)
                # accurate_depth=relialbe
                # print('inner_variance',torch.max(inner_variance).item())
                # inner_variance[:,0,...]=inner_variance[:,0,...]/torch.max(inner_variance[:,0,...])
                # reliable_to_depth=(torch.exp(-relialbe[:,0,...])).unsqueeze(1)
                # reliable_to_coarse=(torch.exp(-inner_variance[:,0,...])).unsqueeze(1)
                # variance_on_depth=relialbe[:,1,...].unsqueeze(1)
                # variance_on_cosrse=inner_variance[:,1,...].unsqueeze(1)
                # print('reliable_depth: %.2f reliable_coarse: %.2f variance_depth %.2f variance_coarse %.2f'%(torch.mean(reliable_to_depth).item(), \
                #                 torch.mean(reliable_to_coarse).item(),torch.mean(variance_on_depth).item(),torch.mean(variance_on_cosrse).item()))
                # #print('variance %.2f'%(torch.mean(inner_variance).item()))
                # relialbe_weights=reliable_to_coarse+reliable_to_depth
                # # #print(inner_variance.shape)
                # accurate_depth=(depth*variance_on_depth*reliable_to_coarse+coarse_depth*variance_on_cosrse*reliable_to_coarse)/ \
                #                                         (torch.where(relialbe_weights==0,torch.ones(1).cuda(),relialbe_weights))
                # refined_depth=depth*variance_on_depth
                # coarse_depth=coarse_depth*variance_on_cosrse
                # accurate_depth=(coarse_depth*reliable_to_coarse+refined_depth*(1-reliable_to_coarse))
                # accurate_depth=refined_depth*reliable_to_depth
                # print('depth',torch.max(depth).item())
                # print('coarse',torch.max(coarse_depth).item())
                # print('accurate',torch.max(accurate_depth).item())
                # loss_var,loss_dis,loss_reg = cluster_loss(y,segments.long())
                # loss_var=loss_var.reshape((y.shape[0],1))
                # loss_dis=loss_dis.reshape((y.shape[0],1))
                # loss_reg=loss_reg.reshape((y.shape[0],1))

                # accurate_depth=inner_variance
                # simple refinement
                # x_fuse=x_share+depth.expand_as(x_share)
                # inner=self.inrefine1(x_fuse)
                # inner=self.inrefine2(inner)
                # inner=self.inrefine3(inner)
                # inner=self.inrefine4(inner)
                # accurate_depth=self.inrefine5(inner)
                accurate_depth = depth
                return depth, accurate_depth