def forward(self,ed,es,source):
        # print(ed.shape)
        # print(source.shape)
        x1 = torch.cat([ed, source], dim=1)
        x1 = self.unet1(x1)
        flow_field_x1 = self.flow(x1)
        depth = source.shape[2]
        # print(ed.shape)
        # print(flow_field_x1.shape)
        ed_source = layer.SpatialTransformer((depth, 128, 128))(ed, flow_field_x1)
        # ed_source_1 = layer.SpatialTransformer((depth, 128, 128))(ed, flow_field_x1[:,0:1,:,:,:])
        # ed_source_2 = layer.SpatialTransformer((depth, 128, 128))(ed, flow_field_x1[:, 1:2, :, :, :])
        # ed_source_3 = layer.SpatialTransformer((depth, 128, 128))(ed, flow_field_x1[:, 2:3, :, :, :])
        # ed_source = F.softmax(torch.cat((ed_source_1, ed_source_2, ed_source_3), 1), 1)
        # print(ed.shape)
        # print(flow_field_x1[:,0:1,:,:,:].shape)


        x2 = torch.cat([es, source], dim=1)
        x2 = self.unet2(x2)
        flow_field_x2 = self.flow(x2)
        es_source = layer.SpatialTransformer((depth, 128, 128))(es, flow_field_x2)
        # es_source_1 = layer.SpatialTransformer((depth, 128, 128))(es, flow_field_x2[:, 0:1, :, :, :])
        # es_source_2 = layer.SpatialTransformer((depth, 128, 128))(es, flow_field_x2[:, 1:2, :, :, :])
        # es_source_3 = layer.SpatialTransformer((depth, 128, 128))(es, flow_field_x2[:, 2:3, :, :, :])
        # es_source = F.softmax(torch.cat((es_source_1, es_source_2, es_source_3), 1), 1)


        return flow_field_x1,ed_source,flow_field_x2,es_source
    def forward(self, source, target):
        depth = source.shape[2]
        # y_source = layer.SpatialTransformer((depth, 192, 192))(source, source)

        x = torch.cat([source, target], dim=1)
        x = self.unet(x)
        flow_field = self.flow(x)
        y_source = layer.SpatialTransformer((depth, 192, 192))(source,
                                                               flow_field)

        return y_source, flow_field
                       criterion_ncc(img_pre[:, :, 1:-1, mi3:mx3, mi4:mx4], step3_flow[:, :, 1:-1, mi3:mx3, mi4:mx4])+ \
                       criterion_ncc(img_ed[:, :, 1:-1, mi3:mx3, mi4:mx4], step4_flow[:, :, 1:-1, mi3:mx3, mi4:mx4]))/4)/2
        ##############
        speed_combin = 0.25 * (speed_field[:, 0:3, :, :, :] + speed_field[:, 3:6, :, :, :] + \
                               speed_field[:, 6:9, :, :, :] + speed_field[:, 9:12, :, :, :])
        fm = flow_field - flow_field.mean()
        f = fm / fm.std()
        sm = speed_combin - speed_combin.mean()
        s = sm / sm.std()
        loss_com = criterion_MSE(f, s)
        #####################
        w0 = torch.unsqueeze(torch.unsqueeze(labeles[0, 0, :, :, :], 0), 0)
        w1 = torch.unsqueeze(torch.unsqueeze(labeles[0, 1, :, :, :], 0), 0)
        w2 = torch.unsqueeze(torch.unsqueeze(labeles[0, 2, :, :, :], 0), 0)
        w3 = torch.unsqueeze(torch.unsqueeze(labeles[0, 3, :, :, :], 0), 0)
        w0 = layer.SpatialTransformer((depth, 128, 128))(w0, flow_field)
        w1 = layer.SpatialTransformer((depth, 128, 128))(w1, flow_field)
        w2 = layer.SpatialTransformer((depth, 128, 128))(w2, flow_field)
        w3 = layer.SpatialTransformer((depth, 128, 128))(w3, flow_field)
        # print(flfinseg0.shape)
        ws = F.softmax(torch.cat([w0, w1, w2, w3], dim=1), dim=1)
        loss_fse = criterion_dice(ws[:, :, 1:-1, :, :], labeled[:, :,
                                                                1:-1, :, :])

        ###############
        loss_reg = loss_smooth + loss_com + loss_reg_ncc + 0.01 * loss_fse
        loss_continous_motion = loss_reg_ncc
        flow_field_reg = flow_field
        ################################
        loss_seg.backward()
        loss_reg.backward(retain_graph=True)
    def forward(self, start, final, pre_img, mid_img, aft_img, labeled,
                labeles):
        depth = start.shape[2]

        fsim, fsflow = self.reg(start, final)
        preimflow = layer.SpatialTransformer((depth, 192, 192))(start,
                                                                0.25 * fsflow)
        midimflow = layer.SpatialTransformer((depth, 192, 192))(start,
                                                                0.5 * fsflow)
        afimflow = layer.SpatialTransformer((depth, 192, 192))(start,
                                                               0.75 * fsflow)
        fake_es = layer.SpatialTransformer((depth, 192, 192))(final,
                                                              -1 * fsflow)

        startseg1 = self.unet(start)
        flfinseg0 = torch.unsqueeze(
            torch.unsqueeze(startseg1[0, 0, :, :, :], 0), 0)
        flfinseg1 = torch.unsqueeze(
            torch.unsqueeze(startseg1[0, 1, :, :, :], 0), 0)
        flfinseg2 = torch.unsqueeze(
            torch.unsqueeze(startseg1[0, 2, :, :, :], 0), 0)
        flfinseg3 = torch.unsqueeze(
            torch.unsqueeze(startseg1[0, 3, :, :, :], 0), 0)
        flfinseg0 = layer.SpatialTransformer((depth, 192, 192))(flfinseg0,
                                                                fsflow)
        flfinseg1 = layer.SpatialTransformer((depth, 192, 192))(flfinseg1,
                                                                fsflow)
        flfinseg2 = layer.SpatialTransformer((depth, 192, 192))(flfinseg2,
                                                                fsflow)
        flfinseg3 = layer.SpatialTransformer((depth, 192, 192))(flfinseg3,
                                                                fsflow)
        flfinseg = torch.cat([flfinseg0, flfinseg1, flfinseg2, flfinseg3],
                             dim=1)

        finalseg1 = self.unet(final)
        flstaseg0 = torch.unsqueeze(
            torch.unsqueeze(finalseg1[0, 0, :, :, :], 0), 0)
        flstaseg1 = torch.unsqueeze(
            torch.unsqueeze(finalseg1[0, 1, :, :, :], 0), 0)
        flstaseg2 = torch.unsqueeze(
            torch.unsqueeze(finalseg1[0, 2, :, :, :], 0), 0)
        flstaseg3 = torch.unsqueeze(
            torch.unsqueeze(finalseg1[0, 3, :, :, :], 0), 0)
        flstaseg0 = layer.SpatialTransformer((depth, 192, 192))(flstaseg0,
                                                                -1 * fsflow)
        flstaseg1 = layer.SpatialTransformer((depth, 192, 192))(flstaseg1,
                                                                -1 * fsflow)
        flstaseg2 = layer.SpatialTransformer((depth, 192, 192))(flstaseg2,
                                                                -1 * fsflow)
        flstaseg3 = layer.SpatialTransformer((depth, 192, 192))(flstaseg3,
                                                                -1 * fsflow)
        flstaseg = torch.cat([flstaseg0, flstaseg1, flstaseg2, flstaseg3],
                             dim=1)

        startseg = F.softmax(self.v2(
            self.v1(torch.cat([startseg1, flstaseg], dim=1))),
                             dim=1)
        finalseg = F.softmax(self.v2(
            self.v1(torch.cat([finalseg1, flfinseg], dim=1))),
                             dim=1)

        #####reg_label
        flowfn0 = torch.unsqueeze(torch.unsqueeze(labeles[0, 0, :, :, :], 0),
                                  0)
        flowfn1 = torch.unsqueeze(torch.unsqueeze(labeles[0, 1, :, :, :], 0),
                                  0)
        flowfn2 = torch.unsqueeze(torch.unsqueeze(labeles[0, 2, :, :, :], 0),
                                  0)
        flowfn3 = torch.unsqueeze(torch.unsqueeze(labeles[0, 3, :, :, :], 0),
                                  0)
        flowfn0 = layer.SpatialTransformer((depth, 192, 192))(flowfn0, fsflow)
        flowfn1 = layer.SpatialTransformer((depth, 192, 192))(flowfn1, fsflow)
        flowfn2 = layer.SpatialTransformer((depth, 192, 192))(flowfn2, fsflow)
        flowfn3 = layer.SpatialTransformer((depth, 192, 192))(flowfn3, fsflow)
        flowfn = torch.cat([flowfn0, flowfn1, flowfn2, flowfn3], dim=1)

        flowst0 = torch.unsqueeze(torch.unsqueeze(labeled[0, 0, :, :, :], 0),
                                  0)
        flowst1 = torch.unsqueeze(torch.unsqueeze(labeled[0, 1, :, :, :], 0),
                                  0)
        flowst2 = torch.unsqueeze(torch.unsqueeze(labeled[0, 2, :, :, :], 0),
                                  0)
        flowst3 = torch.unsqueeze(torch.unsqueeze(labeled[0, 3, :, :, :], 0),
                                  0)
        flowst0 = layer.SpatialTransformer((depth, 192, 192))(flowst0,
                                                              -1 * fsflow)
        flowst1 = layer.SpatialTransformer((depth, 192, 192))(flowst1,
                                                              -1 * fsflow)
        flowst2 = layer.SpatialTransformer((depth, 192, 192))(flowst2,
                                                              -1 * fsflow)
        flowst3 = layer.SpatialTransformer((depth, 192, 192))(flowst3,
                                                              -1 * fsflow)
        flowst = torch.cat([flowst0, flowst1, flowst2, flowst3], dim=1)

        return fsim, fake_es, preimflow, midimflow, afimflow, fsflow, startseg, finalseg, flowfn, flowst
        img_ed = img_ed.to(device).float()
        img_es = img_es.to(device).float()
        labeled = labeled.to(device).float()
        labeles = labeles.to(device).float()
        source = source.to(device).float()
        # mi3, mx3, mi4, mx4 = criterion.location(labeles, labeled, 5, 4)
        depth = img_ed.shape[2]
        ################################################################################
        ed_seg = Segnet(img_ed)
        es_seg = Segnet(img_es)
        seg = Segnet(source)

        with torch.no_grad():
            flow_field_x1, ed_source, flow_field_x2, es_source = Flownet(
                img_es, img_ed, source)
        ed_source_1 = layer.SpatialTransformer(
            (depth, 128, 128))(ed_seg[:, 0:1, :, :, :], flow_field_x1)
        ed_source_2 = layer.SpatialTransformer(
            (depth, 128, 128))(ed_seg[:, 1:2, :, :, :], flow_field_x1)
        ed_source_3 = layer.SpatialTransformer(
            (depth, 128, 128))(ed_seg[:, 2:3, :, :, :], flow_field_x1)
        ed_source_4 = layer.SpatialTransformer(
            (depth, 128, 128))(ed_seg[:, 3:4, :, :, :], flow_field_x1)
        ed_source = torch.cat(
            (ed_source_1, ed_source_2, ed_source_3, ed_source_4), 1)
        ####################################################################################################
        es_source_1 = layer.SpatialTransformer(
            (depth, 128, 128))(es_seg[:, 0:1, :, :, :], flow_field_x2)
        es_source_2 = layer.SpatialTransformer(
            (depth, 128, 128))(es_seg[:, 1:2, :, :, :], flow_field_x2)
        es_source_3 = layer.SpatialTransformer(
            (depth, 128, 128))(es_seg[:, 2:3, :, :, :], flow_field_x2)