Ejemplo n.º 1
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(**self.cfg.fe.backbone)
        model.cuda(device)
        train_set = SpgDset(self.cfg.gen.data_dir_raw_train, reorder_sp=False)
        val_set = SpgDset(self.cfg.gen.data_dir_raw_val, reorder_sp=False)
        # pm = StridedPatches2D(wu_cfg.patch_stride, wu_cfg.patch_shape, train_set.image_shape)
        pm = NoPatches2D()
        train_set.length = len(train_set.graph_file_names) * np.prod(pm.n_patch_per_dim)
        train_set.n_patch_per_dim = pm.n_patch_per_dim
        val_set.length = len(val_set.graph_file_names)
        # dset = LeptinDset(self.cfg.gen.data_dir_raw, self.cfg.gen.data_dir_affs, wu_cfg.patch_manager, wu_cfg.patch_stride, wu_cfg.patch_shape, wu_cfg.reorder_sp)
        train_loader = DataLoader(train_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True,
                             num_workers=0)
        val_loader = DataLoader(val_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True,
                             num_workers=0)
        gauss_kernel = GaussianSmoothing(1, 5, 3, device=device)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=20,
                                     threshold=1e-4,
                                     min_lr=1e-5,
                                     factor=0.1)
        criterion = RagContrastiveWeights(delta_var=0.1, delta_dist=0.4)
        acc_loss = 0
        valit = 0
        iteration = 0
        best_loss = np.inf

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(train_loader):
                raw, gt, sp_seg, affinities = raw.to(device), gt.to(device), sp_seg.to(device), affinities.to(device)

                # edge_img = F.pad(get_contour_from_2d_binary(sp_seg), (2, 2, 2, 2), mode='constant')
                # edge_img = gauss_kernel(edge_img.float())
                # input = torch.cat([raw, edge_img], dim=1)

                offs = offs.numpy().tolist()
                loss_embeds = model(raw[:, :, None]).squeeze(2)

                edge_feat, edges = tuple(zip(*[get_edge_features_1d(seg.squeeze().cpu().numpy(), os, affs.squeeze().cpu().numpy()) for seg, os, affs in zip(sp_seg, offs, affinities)]))
                edges = [torch.from_numpy(e.astype(np.long)).to(device).T for e in edges]
                edge_weights = [torch.from_numpy(ew.astype(np.float32)).to(device)[:, 0][None] for ew in edge_feat]

                # put embeddings on unit sphere so we can use cosine distance
                loss_embeds = loss_embeds / (torch.norm(loss_embeds, dim=1, keepdim=True) + 1e-9)

                loss = criterion(loss_embeds, sp_seg.long(), edges, edge_weights,
                                 chunks=int(sp_seg.max().item()//self.cfg.gen.train_chunk_size),
                                 sigm_factor=self.cfg.gen.sigm_factor, pull_factor=self.cfg.gen.pull_factor)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                print(loss.item())
                writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration)
                writer.add_scalar("fe_train/loss", loss.item(), iteration)
                if (iteration) % 100 == 0:
                    with torch.set_grad_enabled(False):
                        for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(val_loader):
                            raw, gt, sp_seg, affinities = raw.to(device), gt.to(device), sp_seg.to(device), affinities.to(device)

                            offs = offs.numpy().tolist()
                            embeddings = model(raw[:, :, None]).squeeze(2)

                            # relabel to consecutive ints starting at 0
                            edge_feat, edges = tuple(zip(
                                *[get_edge_features_1d(seg.squeeze().cpu().numpy(), os, affs.squeeze().cpu().numpy())
                                  for seg, os, affs in zip(sp_seg, offs, affinities)]))
                            edges = [torch.from_numpy(e.astype(np.long)).to(device).T for e in edges]
                            edge_weights = [torch.from_numpy(ew.astype(np.float32)).to(device)[:, 0][None] for ew in edge_feat]

                            # put embeddings on unit sphere so we can use cosine distance
                            embeddings = embeddings / (torch.norm(embeddings, dim=1, keepdim=True) + 1e-9)

                            ls = criterion(embeddings, sp_seg.long(), edges, edge_weights,
                                           chunks=int(sp_seg.max().item()//self.cfg.gen.train_chunk_size),
                                           sigm_factor=self.cfg.gen.sigm_factor, pull_factor=self.cfg.gen.pull_factor)
                            # ls = 0
                            acc_loss += ls
                            writer.add_scalar("fe_val/loss", ls, valit)
                            valit += 1
                    acc_loss = acc_loss / len(val_loader)
                    if acc_loss < best_loss:
                        print(self.save_dir)
                        torch.save(model.state_dict(), os.path.join(self.save_dir, "best_val_model.pth"))
                        best_loss = acc_loss
                    sheduler.step(acc_loss)
                    acc_loss = 0
                    fig, ((a1, a2), (a3, a4)) = plt.subplots(2, 2, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0})
                    a1.imshow(raw[0].cpu().permute(1, 2, 0)[..., 0].squeeze())
                    a1.set_title('raw')
                    a2.imshow(cm.prism(sp_seg[0, 0].cpu().squeeze() / sp_seg[0, 0].cpu().squeeze().max()))
                    a2.set_title('sp')
                    a3.imshow(pca_project(get_angles(embeddings)[0].detach().cpu()))
                    a3.set_title('angle_embed')
                    a4.imshow(pca_project(embeddings[0].detach().cpu()))
                    a4.set_title('embed')
                    # plt.show()
                    writer.add_figure("examples", fig, iteration//100)
                iteration += 1
                print(iteration)
                if iteration > wu_cfg.n_iterations:
                    print(self.save_dir)
                    torch.save(model.state_dict(), os.path.join(self.save_dir, "last_model.pth"))
                    break
        return
Ejemplo n.º 2
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        writer.add_text("conf", self.cfg.pretty())
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(self.cfg.fe.n_raw_channels,
                       self.cfg.fe.n_embedding_features,
                       final_sigmoid=False,
                       num_levels=5)
        momentum_model = UNet2D(self.cfg.fe.n_raw_channels,
                                self.cfg.fe.n_embedding_features,
                                final_sigmoid=False,
                                num_levels=5)
        if wu_cfg.identical_initialization:
            soft_update_params(model, momentum_model, 1)
        momentum_model.cuda(device)
        for param in momentum_model.parameters():
            param.requires_grad = False
        model.cuda(device)
        dset = SpgDset(self.cfg.gen.data_dir, wu_cfg.patch_manager,
                       wu_cfg.patch_stride, wu_cfg.patch_shape,
                       wu_cfg.reorder_sp)
        dloader = DataLoader(dset,
                             batch_size=wu_cfg.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=100,
                                     threshold=1e-3,
                                     min_lr=1e-6,
                                     factor=0.1)
        criterion = EntrInfoNCE(alpha=self.cfg.fe.alpha,
                                beta=self.cfg.fe.beta,
                                lbd=self.cfg.fe.lbd,
                                tau=self.cfg.fe.tau,
                                gamma=self.cfg.fe.gamma,
                                num_neg=self.cfg.fe.num_neg,
                                subs_size=self.cfg.fe.subs_size)
        tfs = RndAugmentationTfs(wu_cfg.patch_shape)
        acc_loss = 0
        iteration = 0
        k_step = math.ceil((wu_cfg.n_iterations - wu_cfg.n_k_stop_it) /
                           (wu_cfg.k_start - wu_cfg.k_stop))
        k = wu_cfg.k_start
        psi_step = (wu_cfg.psi_start - wu_cfg.psi_stop) / (
            wu_cfg.n_iterations - wu_cfg.n_k_stop_it)
        psi = wu_cfg.psi_start

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, indices) in enumerate(dloader):
                inp, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(
                    device)
                mask = torch.ones((
                    inp.shape[0],
                    1,
                ) + inp.shape[2:],
                                  device=device).float()
                # get transforms
                spat_tf, int_tf = tfs.sample(1, 1)
                _, _int_tf = tfs.sample(1, 1)
                # add noise to intensity tf of input for momentum network
                mom_inp = add_sp_gauss_noise(_int_tf(inp), 0.2, 0.1, 0.3)
                # get momentum prediction
                embeddings_mom = momentum_model(
                    mom_inp.unsqueeze(2)).squeeze(2)
                # do the same spatial tf for input, mask and momentum prediction
                paired = spat_tf(torch.cat((mask, inp, embeddings_mom), -3))
                embeddings_mom, mask = paired[..., inp.shape[1] +
                                              1:, :, :], paired[...,
                                                                0, :, :][:,
                                                                         None]
                # do intensity transform for spatial transformed input
                aug_inp = int_tf(paired[..., 1:inp.shape[1] + 1, :, :])
                # and add some noise
                aug_inp = add_sp_gauss_noise(aug_inp, 0.2, 0.1, 0.3)
                # get prediction of the augmented input
                embeddings = model(aug_inp.unsqueeze(2)).squeeze(2)

                # put embeddings on unit sphere so we can use cosine distance
                embeddings = embeddings / torch.norm(
                    embeddings, dim=1, keepdim=True)
                embeddings_mom = embeddings_mom + (
                    mask == 0)  # set the void of the image to the 1-vector
                embeddings_mom = embeddings_mom / torch.norm(
                    embeddings_mom, dim=1, keepdim=True)

                loss = criterion(embeddings.squeeze(0),
                                 embeddings_mom.squeeze(0),
                                 k,
                                 mask.squeeze(0),
                                 whiten=wu_cfg.whitened_embeddings,
                                 warmup=iteration < wu_cfg.n_warmup_it,
                                 psi=psi)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                acc_loss += loss.item()

                print(loss.item())
                writer.add_scalar("fe_warm_start/loss", loss.item(), iteration)
                writer.add_scalar("fe_warm_start/lr",
                                  optimizer.param_groups[0]['lr'], iteration)
                if (iteration) % 50 == 0:
                    sheduler.step(acc_loss / 10)
                    acc_loss = 0
                    fig, (a1, a2, a3, a4) = plt.subplots(1,
                                                         4,
                                                         sharex='col',
                                                         sharey='row',
                                                         gridspec_kw={
                                                             'hspace': 0,
                                                             'wspace': 0
                                                         })
                    a1.imshow(inp[0].cpu().permute(1, 2, 0).squeeze())
                    a1.set_title('raw')
                    a2.imshow(aug_inp[0].cpu().permute(1, 2, 0))
                    a2.set_title('augment')
                    a3.imshow(
                        pca_project(
                            get_angles(embeddings).squeeze(0).detach().cpu()))
                    a3.set_title('embed')
                    a4.imshow(
                        pca_project(
                            get_angles(embeddings_mom).squeeze(
                                0).detach().cpu()))
                    a4.set_title('mom_embed')
                    writer.add_figure("examples", fig, iteration // 100)
                iteration += 1
                psi = max(psi - psi_step, wu_cfg.psi_stop)
                if iteration % k_step == 0:
                    k = max(k - 1, wu_cfg.k_stop)

                if iteration > wu_cfg.n_iterations:
                    break
                if iteration % wu_cfg.momentum == 0:
                    soft_update_params(model, momentum_model,
                                       wu_cfg.momentum_tau)
        return
Ejemplo n.º 3
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        writer.add_text("conf", self.cfg.pretty())
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(self.cfg.fe.n_raw_channels,
                       self.cfg.fe.n_embedding_features,
                       final_sigmoid=False,
                       num_levels=5)
        model.cuda(device)
        dset = SpgDset(self.cfg.gen.data_dir, wu_cfg.patch_manager,
                       wu_cfg.patch_stride, wu_cfg.patch_shape,
                       wu_cfg.reorder_sp)
        dloader = DataLoader(dset,
                             batch_size=wu_cfg.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=100,
                                     threshold=1e-3,
                                     min_lr=1e-6,
                                     factor=0.1)
        criterion = RagInfoNCE(tau=self.cfg.fe.tau)
        acc_loss = 0
        iteration = 0

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, indices) in enumerate(dloader):
                inp, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(
                    device)
                edges = dloader.dataset.get_graphs(indices, sp_seg, device)[0]

                off = 0
                for i in range(len(edges)):
                    sp_seg[i] += off
                    edges[i] += off
                    off = sp_seg[i].max() + 1
                edges = torch.cat(edges, 1)
                embeddings = model(inp.unsqueeze(2)).squeeze(2)

                # put embeddings on unit sphere so we can use cosine distance
                embeddings = embeddings / torch.norm(
                    embeddings, dim=1, keepdim=True)

                loss = criterion(embeddings, sp_seg, edges)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                acc_loss += loss.item()

                print(loss.item())
                writer.add_scalar("fe_warm_start/loss", loss.item(), iteration)
                writer.add_scalar("fe_warm_start/lr",
                                  optimizer.param_groups[0]['lr'], iteration)
                if (iteration) % 50 == 0:
                    sheduler.step(acc_loss / 10)
                    acc_loss = 0
                    fig, (a1, a2) = plt.subplots(1,
                                                 2,
                                                 sharex='col',
                                                 sharey='row',
                                                 gridspec_kw={
                                                     'hspace': 0,
                                                     'wspace': 0
                                                 })
                    a1.imshow(inp[0].cpu().permute(1, 2, 0).squeeze())
                    a1.set_title('raw')
                    a2.imshow(
                        pca_project(
                            get_angles(embeddings).squeeze(0).detach().cpu()))
                    a2.set_title('embed')
                    writer.add_figure("examples", fig, iteration // 100)
                iteration += 1
                if iteration > wu_cfg.n_iterations:
                    break
        return
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(**self.cfg.fe.backbone)
        model.cuda(device)
        train_set = SpgDset(self.cfg.gen.data_dir_raw_train, reorder_sp=False)
        val_set = SpgDset(self.cfg.gen.data_dir_raw_val, reorder_sp=False)
        train_loader = DataLoader(train_set,
                                  batch_size=wu_cfg.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=0)
        val_loader = DataLoader(val_set,
                                batch_size=wu_cfg.batch_size,
                                shuffle=True,
                                pin_memory=True,
                                num_workers=0)

        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        criterion = AffinityContrastive(delta_var=0.1, delta_dist=0.3)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=5,
                                     threshold=1e-4,
                                     min_lr=1e-5,
                                     factor=0.1)
        valit = 0
        iteration = 0
        best_loss = np.inf

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, affinities, offs,
                     indices) in enumerate(train_loader):
                raw, gt, sp_seg, affinities, offs = raw.to(device), gt.to(
                    device), sp_seg.to(device), affinities.to(
                        device), offs[0].to(device)

                input = torch.cat([raw, affinities], dim=1)

                embeddings = model(input.unsqueeze(2)).squeeze(2)

                embeddings = embeddings / torch.norm(
                    embeddings, dim=1, keepdim=True)

                loss = criterion(embeddings, affinities, offs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                lr = optimizer.param_groups[0]['lr']
                print(f"step {it}; lr({lr}); loss({loss.item()})")
                writer.add_scalar("fe_warm_start/loss", loss.item(), iteration)
                writer.add_scalar("fe_warm_start/lr", lr, iteration)
                if (iteration) % 100 == 0:
                    acc_loss = 0
                    with torch.set_grad_enabled(False):
                        for val_it, (raw, gt, sp_seg, affinities, offs,
                                     indices) in enumerate(val_loader):
                            raw, gt, sp_seg, affinities, offs = raw.to(
                                device), gt.to(device), sp_seg.to(
                                    device), affinities.to(device), offs[0].to(
                                        device)

                            input = torch.cat([raw, affinities], dim=1)

                            embeddings = model(input.unsqueeze(2)).squeeze(2)

                            embeddings = embeddings / torch.norm(
                                embeddings, dim=1, keepdim=True)

                            loss = criterion(embeddings, affinities, offs)
                            acc_loss += loss
                            writer.add_scalar("fe_val/loss", loss, valit)
                            valit += 1
                    acc_loss = acc_loss / len(val_loader)
                    if acc_loss < best_loss:
                        torch.save(
                            model.state_dict(),
                            os.path.join(self.save_dir, "best_val_model.pth"))
                        best_loss = acc_loss
                    sheduler.step(acc_loss)
                    fig, (a1, a2) = plt.subplots(1,
                                                 2,
                                                 sharex='col',
                                                 sharey='row',
                                                 gridspec_kw={
                                                     'hspace': 0,
                                                     'wspace': 0
                                                 })
                    a1.imshow(raw[0].cpu().permute(1, 2, 0).squeeze())
                    a1.set_title('raw')
                    a2.imshow(pca_project(embeddings[0].detach().cpu()))
                    a2.set_title('embed')
                    plt.show()
                    # writer.add_figure("examples", fig, iteration // 50)
                iteration += 1
                if iteration > wu_cfg.n_iterations:
                    break
        return
Ejemplo n.º 5
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(**self.cfg.fe.backbone)
        model.cuda(device)
        # train_set = SpgDset(self.cfg.gen.data_dir_raw_train, patch_manager="no_cross", patch_stride=(10,10), patch_shape=(300,300), reorder_sp=True)
        # val_set = SpgDset(self.cfg.gen.data_dir_raw_val, patch_manager="no_cross", patch_stride=(10,10), patch_shape=(300,300), reorder_sp=True)
        train_set = SpgDset(self.cfg.gen.data_dir_raw_train, reorder_sp=True)
        val_set = SpgDset(self.cfg.gen.data_dir_raw_val, reorder_sp=True)
        # pm = StridedPatches2D(wu_cfg.patch_stride, wu_cfg.patch_shape, train_set.image_shape)
        pm = NoPatches2D()
        train_set.length = len(train_set.graph_file_names) * np.prod(
            pm.n_patch_per_dim)
        train_set.n_patch_per_dim = pm.n_patch_per_dim
        val_set.length = len(val_set.graph_file_names)
        gauss_kernel = GaussianSmoothing(1, 5, 3, device=device)
        # dset = LeptinDset(self.cfg.gen.data_dir_raw, self.cfg.gen.data_dir_affs, wu_cfg.patch_manager, wu_cfg.patch_stride, wu_cfg.patch_shape, wu_cfg.reorder_sp)
        train_loader = DataLoader(train_set,
                                  batch_size=wu_cfg.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=0)
        val_loader = DataLoader(val_set,
                                batch_size=wu_cfg.batch_size,
                                shuffle=True,
                                pin_memory=True,
                                num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=80,
                                     threshold=1e-4,
                                     min_lr=1e-8,
                                     factor=0.1)
        slcs = [
            slice(None, self.cfg.fe.embeddings_separator),
            slice(self.cfg.fe.embeddings_separator, None)
        ]
        criterion = RegRagContrastiveWeights(delta_var=0.1,
                                             delta_dist=0.3,
                                             slices=slcs)
        acc_loss = 0
        valit = 0
        iteration = 0
        best_loss = np.inf

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, affinities, offs,
                     indices) in enumerate(train_loader):
                raw, gt, sp_seg, affinities = raw.to(device), gt.to(
                    device), sp_seg.to(device), affinities.to(device)
                sp_seg = sp_seg + 1
                edge_img = F.pad(get_contour_from_2d_binary(sp_seg),
                                 (2, 2, 2, 2),
                                 mode='constant')
                edge_img = gauss_kernel(edge_img.float())
                all = torch.cat([raw, gt, sp_seg, edge_img], dim=1)

                angle = float(torch.randint(-180, 180, (1, )).item())
                rot_all = tvF.rotate(all, angle, PIL.Image.NEAREST)
                rot_raw = rot_all[:, :1]
                rot_gt = rot_all[:, 1:2]
                rot_sp = rot_all[:, 2:3]
                rot_edge_img = rot_all[:, 3:]
                angle = abs(angle / 180)
                valid_sp = []
                for i in range(len(rot_sp)):
                    _valid_sp = torch.unique(rot_sp[i], sorted=True)
                    _valid_sp = _valid_sp[1:] if _valid_sp[
                        0] == 0 else _valid_sp
                    if len(_valid_sp) > self.cfg.gen.sp_samples_per_step:
                        inds = torch.multinomial(
                            torch.ones_like(_valid_sp),
                            self.cfg.gen.sp_samples_per_step,
                            replacement=False)
                        _valid_sp = _valid_sp[inds]
                    valid_sp.append(_valid_sp)

                _rot_sp, _sp_seg = [], []
                for val_sp, rsp, sp in zip(valid_sp, rot_sp, sp_seg):
                    mask = rsp == val_sp[:, None, None]
                    _rot_sp.append((mask * (torch.arange(
                        len(val_sp), device=rsp.device)[:, None, None] + 1)
                                    ).sum(0))
                    mask = sp == val_sp[:, None, None]
                    _sp_seg.append((mask * (torch.arange(
                        len(val_sp), device=sp.device)[:, None, None] + 1)
                                    ).sum(0))

                rot_sp = torch.stack(_rot_sp)
                sp_seg = torch.stack(_sp_seg)
                valid_sp = [
                    torch.unique(_rot_sp, sorted=True) for _rot_sp in rot_sp
                ]
                valid_sp = [
                    _valid_sp[1:] if _valid_sp[0] == 0 else _valid_sp
                    for _valid_sp in valid_sp
                ]

                inp = torch.cat([
                    torch.cat([raw, edge_img], 1),
                    torch.cat([rot_raw, rot_edge_img], 1)
                ], 0)
                offs = offs.numpy().tolist()
                edge_feat, edges = tuple(
                    zip(*[
                        get_edge_features_1d(seg.squeeze().cpu().numpy(), os,
                                             affs.squeeze().cpu().numpy())
                        for seg, os, affs in zip(sp_seg, offs, affinities)
                    ]))
                edges = [
                    torch.from_numpy(e.astype(np.long)).to(device).T
                    for e in edges
                ]
                edge_weights = [
                    torch.from_numpy(ew.astype(np.float32)).to(device)[:,
                                                                       0][None]
                    for ew in edge_feat
                ]
                valid_edges_masks = [
                    (_edges[None] == _valid_sp[:, None,
                                               None]).sum(0).sum(0) == 2
                    for _valid_sp, _edges in zip(valid_sp, edges)
                ]
                edges = [
                    _edges[:, valid_edges_mask] - 1
                    for _edges, valid_edges_mask in zip(
                        edges, valid_edges_masks)
                ]
                edge_weights = [
                    _edge_weights[:, valid_edges_mask]
                    for _edge_weights, valid_edges_mask in zip(
                        edge_weights, valid_edges_masks)
                ]

                # put embeddings on unit sphere so we can use cosine distance
                loss_embeds = model(inp[:, :, None]).squeeze(2)
                loss_embeds = criterion.norm_each_space(loss_embeds, 1)

                loss = criterion(loss_embeds,
                                 sp_seg.long(),
                                 rot_sp.long(),
                                 edges,
                                 edge_weights,
                                 valid_sp,
                                 angle,
                                 chunks=int(sp_seg.max().item() //
                                            self.cfg.gen.train_chunk_size))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                print(f"step {iteration}: {loss.item()}")
                writer.add_scalar("fe_train/lr",
                                  optimizer.param_groups[0]['lr'], iteration)
                writer.add_scalar("fe_train/loss", loss.item(), iteration)
                if (iteration) % 100 == 0:
                    with torch.set_grad_enabled(False):
                        model.eval()
                        print("####start validation####")
                        for it, (raw, gt, sp_seg, affinities, offs,
                                 indices) in enumerate(val_loader):
                            raw, gt, sp_seg, affinities = raw.to(
                                device), gt.to(device), sp_seg.to(
                                    device), affinities.to(device)
                            sp_seg = sp_seg + 1
                            edge_img = F.pad(
                                get_contour_from_2d_binary(sp_seg),
                                (2, 2, 2, 2),
                                mode='constant')
                            edge_img = gauss_kernel(edge_img.float())
                            all = torch.cat([raw, gt, sp_seg, edge_img], dim=1)

                            angle = float(
                                torch.randint(-180, 180, (1, )).item())
                            rot_all = tvF.rotate(all, angle, PIL.Image.NEAREST)
                            rot_raw = rot_all[:, :1]
                            rot_gt = rot_all[:, 1:2]
                            rot_sp = rot_all[:, 2:3]
                            rot_edge_img = rot_all[:, 3:]
                            angle = abs(angle / 180)
                            valid_sp = [
                                torch.unique(_rot_sp, sorted=True)
                                for _rot_sp in rot_sp
                            ]
                            valid_sp = [
                                _valid_sp[1:]
                                if _valid_sp[0] == 0 else _valid_sp
                                for _valid_sp in valid_sp
                            ]

                            _rot_sp, _sp_seg = [], []
                            for val_sp, rsp, sp in zip(valid_sp, rot_sp,
                                                       sp_seg):
                                mask = rsp == val_sp[:, None, None]
                                _rot_sp.append((mask * (torch.arange(
                                    len(val_sp), device=rsp.device)[:, None,
                                                                    None] + 1)
                                                ).sum(0))
                                mask = sp == val_sp[:, None, None]
                                _sp_seg.append((mask * (torch.arange(
                                    len(val_sp), device=sp.device)[:, None,
                                                                   None] + 1)
                                                ).sum(0))

                            rot_sp = torch.stack(_rot_sp)
                            sp_seg = torch.stack(_sp_seg)
                            valid_sp = [
                                torch.unique(_rot_sp, sorted=True)
                                for _rot_sp in rot_sp
                            ]
                            valid_sp = [
                                _valid_sp[1:]
                                if _valid_sp[0] == 0 else _valid_sp
                                for _valid_sp in valid_sp
                            ]

                            inp = torch.cat([
                                torch.cat([raw, edge_img], 1),
                                torch.cat([rot_raw, rot_edge_img], 1)
                            ], 0)
                            offs = offs.numpy().tolist()
                            edge_feat, edges = tuple(
                                zip(*[
                                    get_edge_features_1d(
                                        seg.squeeze().cpu().numpy(), os,
                                        affs.squeeze().cpu().numpy())
                                    for seg, os, affs in zip(
                                        sp_seg, offs, affinities)
                                ]))
                            edges = [
                                torch.from_numpy(e.astype(
                                    np.long)).to(device).T for e in edges
                            ]
                            edge_weights = [
                                torch.from_numpy(ew.astype(
                                    np.float32)).to(device)[:, 0][None]
                                for ew in edge_feat
                            ]
                            valid_edges_masks = [
                                (_edges[None] == _valid_sp[:, None, None]
                                 ).sum(0).sum(0) == 2
                                for _valid_sp, _edges in zip(valid_sp, edges)
                            ]
                            edges = [
                                _edges[:, valid_edges_mask] - 1
                                for _edges, valid_edges_mask in zip(
                                    edges, valid_edges_masks)
                            ]
                            edge_weights = [
                                _edge_weights[:, valid_edges_mask]
                                for _edge_weights, valid_edges_mask in zip(
                                    edge_weights, valid_edges_masks)
                            ]

                            # put embeddings on unit sphere so we can use cosine distance
                            embeds = model(inp[:, :, None]).squeeze(2)
                            embeds = criterion.norm_each_space(embeds, 1)

                            ls = criterion(
                                embeds,
                                sp_seg.long(),
                                rot_sp.long(),
                                edges,
                                edge_weights,
                                valid_sp,
                                angle,
                                chunks=int(sp_seg.max().item() //
                                           self.cfg.gen.train_chunk_size))

                            acc_loss += ls
                            writer.add_scalar("fe_val/loss", ls, valit)
                            print(f"step {it}: {ls.item()}")
                            valit += 1

                    acc_loss = acc_loss / len(val_loader)
                    if acc_loss < best_loss:
                        print(self.save_dir)
                        torch.save(
                            model.state_dict(),
                            os.path.join(self.save_dir, "best_val_model.pth"))
                        best_loss = acc_loss
                    sheduler.step(acc_loss)
                    acc_loss = 0
                    fig, ((a1, a2), (a3, a4)) = plt.subplots(2,
                                                             2,
                                                             sharex='col',
                                                             sharey='row',
                                                             gridspec_kw={
                                                                 'hspace': 0,
                                                                 'wspace': 0
                                                             })
                    a1.imshow(raw[0].cpu().permute(1, 2, 0).squeeze())
                    a1.set_title('raw')
                    a2.imshow(
                        cm.prism(sp_seg[0].cpu().squeeze() /
                                 sp_seg[0].cpu().squeeze().max()))
                    a2.set_title('sp')
                    a3.imshow(pca_project(embeds[0, slcs[0]].detach().cpu()))
                    a3.set_title('embed', y=-0.01)
                    a4.imshow(pca_project(embeds[0, slcs[1]].detach().cpu()))
                    a4.set_title('embed rot', y=-0.01)
                    plt.show()
                    writer.add_figure("examples", fig, iteration // 100)
                    # model.train()
                    print("####end validation####")
                iteration += 1
                if iteration > wu_cfg.n_iterations:
                    print(self.save_dir)
                    torch.save(model.state_dict(),
                               os.path.join(self.save_dir, "last_model.pth"))
                    break
        return
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        writer.add_text("conf", self.cfg.pretty())
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(self.cfg.fe.n_raw_channels,
                       self.cfg.fe.n_embedding_features,
                       final_sigmoid=False,
                       num_levels=5)
        model.cuda(device)
        dset = SpgDset(self.cfg.gen.data_dir, wu_cfg.patch_manager,
                       wu_cfg.patch_stride, wu_cfg.patch_shape,
                       wu_cfg.reorder_sp)
        dloader = DataLoader(dset,
                             batch_size=wu_cfg.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        tfs = RndAugmentationTfs(wu_cfg.patch_shape)
        criterion = AugmentedAffinityContrastive(delta_var=0.1, delta_dist=0.3)
        acc_loss = 0
        iteration = 0

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, indices) in enumerate(dloader):
                raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(
                    device)
                # this is still not the correct mask calculation as the affinity offsets go in no tf offset direction
                mask = torch.from_numpy(
                    get_valid_edges([len(criterion.offs)] +
                                    list(raw.shape[-2:]),
                                    criterion.offs)).to(device)[None]
                # _, _, _, _, affs = dset.get_graphs(indices, sp_seg, device)
                spat_tf, int_tf = tfs.sample(1, 1)
                _, _int_tf = tfs.sample(1, 1)
                inp = add_sp_gauss_noise(_int_tf(raw), 0.2, 0.1, 0.3)
                embeddings = model(inp.unsqueeze(2)).squeeze(2)

                paired = spat_tf(torch.cat((mask, raw, embeddings), -3))
                embeddings_0, mask = paired[
                    ..., inp.shape[1] + len(criterion.offs):, :, :], paired[
                        ..., :len(criterion.offs), :, :].detach()
                # do intensity transform for spatial transformed input
                aug_inp = int_tf(paired[...,
                                        len(criterion.offs):inp.shape[1] +
                                        len(criterion.offs), :, :]).detach()
                # get prediction of the augmented input
                embeddings_1 = model(
                    add_sp_gauss_noise(aug_inp, 0.2, 0.1,
                                       0.3).unsqueeze(2)).squeeze(2)

                # put embeddings on unit sphere so we can use cosine distance
                embeddings_0 = embeddings_0 / (
                    torch.norm(embeddings_0, dim=1, keepdim=True) + 1e-6)
                embeddings_1 = embeddings_1 / (
                    torch.norm(embeddings_1, dim=1, keepdim=True) + 1e-6)

                loss = criterion(embeddings_0, embeddings_1, aug_inp, mask)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                acc_loss += loss.item()

                print(loss.item())
                writer.add_scalar("fe_warm_start/loss", loss.item(), iteration)
                writer.add_scalar("fe_warm_start/lr",
                                  optimizer.param_groups[0]['lr'], iteration)
                if (iteration) % 50 == 0:
                    acc_loss = 0
                    fig, ((a1, a2), (a3, a4)) = plt.subplots(2, 2)
                    a1.imshow(aug_inp[0].cpu().permute(1, 2, 0).squeeze())
                    a1.set_title('tf_raw')
                    a3.imshow(
                        pca_project(
                            get_angles(embeddings_0).squeeze(
                                0).detach().cpu()))
                    a3.set_title('tf_embed')
                    a4.imshow(
                        pca_project(
                            get_angles(embeddings_1).squeeze(
                                0).detach().cpu()))
                    a4.set_title('embed')
                    a2.imshow(raw[0].cpu().permute(1, 2, 0).squeeze())
                    a2.set_title('raw')
                    plt.show()
                    # writer.add_figure("examples", fig, iteration//100)
                iteration += 1
                if iteration > wu_cfg.n_iterations:
                    break
        return