コード例 #1
0
ファイル: a2c.py プロジェクト: paulhfu/RLForSeg
    def __init__(self, cfg, global_count):
        super(AgentA2CTrainer, self).__init__()
        assert torch.cuda.device_count() == 1
        self.device = torch.device("cuda:0")
        torch.cuda.set_device(self.device)
        torch.set_default_tensor_type(torch.FloatTensor)

        self.cfg = cfg
        self.global_count = global_count
        self.memory = TransitionData_ts(capacity=self.cfg.mem_size)
        self.best_val_reward = -np.inf
        if self.cfg.distance == 'cosine':
            self.distance = CosineDistance()
        else:
            self.distance = L2Distance()

        self.model = Agent(self.cfg, State, self.distance, self.device, with_temp=False)
        wandb.watch(self.model)
        self.model.cuda(self.device)
        self.model_mtx = Lock()

        MovSumLosses = namedtuple('mov_avg_losses', ('actor', 'critic'))
        Scalers = namedtuple('Scalers', ('critic', 'actor'))
        OptimizerContainer = namedtuple('OptimizerContainer',
                                        ('actor', 'critic', 'actor_shed', 'critic_shed'))
        actor_optimizer = torch.optim.Adam(self.model.actor.parameters(), lr=self.cfg.actor_lr)
        critic_optimizer = torch.optim.Adam(self.model.critic.parameters(), lr=self.cfg.critic_lr)

        lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched)
        bw = lr_sched_cfg.mov_avg_bandwidth
        off = lr_sched_cfg.mov_avg_offset
        weights = np.linspace(lr_sched_cfg.weight_range[0], lr_sched_cfg.weight_range[1], bw)
        weights = weights / weights.sum()  # make them sum up to one
        shed = lr_sched_cfg.torch_sched

        self.mov_sum_losses = MovSumLosses(RunningAverage(weights, band_width=bw, offset=off),
                                           RunningAverage(weights, band_width=bw, offset=off))
        self.optimizers = OptimizerContainer(actor_optimizer, critic_optimizer,
                                             *[ReduceLROnPlateau(opt, patience=shed.patience,
                                                                 threshold=shed.threshold, min_lr=shed.min_lr,
                                                                 factor=shed.factor) for opt in
                                               (actor_optimizer, critic_optimizer)])
        self.scalers = Scalers(torch.cuda.amp.GradScaler(), torch.cuda.amp.GradScaler())
        self.forwarder = Forwarder()

        if self.cfg.agent_model_name != "":
            self.model.load_state_dict(torch.load(self.cfg.agent_model_name))
        # finished with prepping

        self.train_dset = SpgDset(self.cfg.data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys), max(self.cfg.s_subgraph))
        self.val_dset = SpgDset(self.cfg.val_data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys), max(self.cfg.s_subgraph))

        self.segm_metric = AveragePrecision()
        self.clst_metric = ClusterMetrics()
        self.global_counter = 0
コード例 #2
0
    def pretrain_embeddings_gt(self, model, device, writer=None):
        dset = SpgDset(root_dir=self.cfg.gen.data_dir)
        dloader = DataLoader(dset,
                             batch_size=self.cfg.fe.warmup.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        max_p = torch.nn.MaxPool2d(3, padding=1, stride=1)
        sheduler = ReduceLROnPlateau(optimizer)
        acc_loss = 0
        iteration = 0

        while iteration <= self.cfg.fe.warmup.n_iterations:
            for it, (raw, gt, sp_seg, indices) in enumerate(dloader):
                raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(
                    device)
                sp_seg_edge = torch.cat([(-max_p(-sp_seg) != sp_seg).float(),
                                         (max_p(sp_seg) != sp_seg).float()], 1)
                embeddings = model(torch.cat([raw, sp_seg_edge], 1))
                loss = self.contr_loss(embeddings, gt.long().squeeze(1))

                optimizer.zero_grad()
                loss.backward(retain_graph=False)
                optimizer.step()
                acc_loss += loss.item()

                if writer is not None:
                    writer.add_scalar("fe_warm_start/loss", loss.item(),
                                      iteration)
                    writer.add_scalar("fe_warm_start/lr",
                                      optimizer.param_groups[0]['lr'],
                                      iteration)
                    if it % 50 == 0:
                        plt.clf()
                        fig = plt.figure(frameon=False)
                        plt.imshow(sp_seg[0].detach().squeeze().cpu().numpy())
                        plt.colorbar()
                        writer.add_figure("image/sp_seg", fig, iteration // 50)
                if it % 10 == 0:
                    sheduler.step(acc_loss / 10)
                    acc_loss = 0
                iteration += 1
                if iteration > self.cfg.fe.warmup.n_iterations:
                    break

                del loss
                del embeddings
        return
コード例 #3
0
    def pretrain_embeddings_sp(self, model, device, writer=None):
        dset = SpgDset(self.args.data_dir, self.cfg.fe.patch_manager,
                       self.cfg.fe.patch_stride, self.cfg.fe.patch_shape)
        dloader = DataLoader(dset,
                             batch_size=self.cfg.fe.warmup.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        max_p = torch.nn.MaxPool2d(3, padding=1, stride=1)
        sheduler = ReduceLROnPlateau(optimizer)
        acc_loss = 0

        for i in range(self.cfg.fe.warmup.n_iterations):
            print(f"fe ext wu iter: {i}")
            for it, (raw, gt, sp_seg, indices) in enumerate(dloader):
                raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(
                    device)
                sp_seg_edge = torch.cat([(-max_p(-sp_seg) != sp_seg).float(),
                                         (max_p(sp_seg) != sp_seg).float()], 1)
                embeddings = model(torch.cat([raw, sp_seg_edge], 1),
                                   True if it % 500 == 0 else False)

                loss = self.contr_loss(embeddings, sp_seg.long().squeeze(1))

                optimizer.zero_grad()
                loss.backward(retain_graph=False)
                optimizer.step()
                acc_loss += loss.item()

                if writer is not None:
                    writer.add_scalar("fe_warm_start/loss", loss.item(),
                                      (len(dloader) * i) + it)
                    writer.add_scalar("fe_warm_start/lr",
                                      optimizer.param_groups[0]['lr'],
                                      (len(dloader) * i) + it)
                    if it % 500 == 0:
                        plt.clf()
                        fig = plt.figure(frameon=False)
                        plt.imshow(sp_seg[0].detach().squeeze().cpu().numpy())
                        plt.colorbar()
                        writer.add_figure("image/sp_seg", fig,
                                          ((len(dloader) * i) + it) // 500)

                if it % 10 == 0:
                    sheduler.step(acc_loss / 10)
                    acc_loss = 0
コード例 #4
0
    def validate(self):
        self.device = torch.device("cuda:0")
        model = GcnEdgeAC(self.cfg, self.args, self.device)
        thresh = 0.5

        assert self.args.model_name != ""
        model.load_state_dict(
            torch.load(os.path.join(self.save_dir, self.args.model_name)))

        model.cuda(self.device)
        for param in model.parameters():
            param.requires_grad = False
        dloader = DataLoader(SpgDset(root_dir=self.args.data_dir),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        env = SpGcnEnv(self.args, self.device)
        abs_diffs, rel_diffs, sizes, n_larger_thresh = [], [], [], []

        for i in range(len(dloader)):
            self.update_env_data(env, dloader, self.device)
            env.reset()
            state = env.get_state()

            distr, _, _, _ = self.agent_forward(env,
                                                model,
                                                state=state,
                                                grad=False)
            actions = torch.sigmoid(distr.loc)

            diff = (actions - env.b_gt_edge_weights).squeeze().abs()

            abs_diffs.append(diff.sum().item())
            rel_diffs.append(diff.mean().item())
            sizes.append(len(diff))
            n_larger_thresh.append((diff > thresh).float().sum().item())

        mean_size = sum(sizes) / len(sizes)
        mean_n_larger_thresh = sum(n_larger_thresh) / len(n_larger_thresh)
        return abs_diffs, rel_diffs, mean_size, mean_n_larger_thresh
コード例 #5
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        writer.add_text("conf", self.cfg.pretty())
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(self.cfg.fe.n_raw_channels,
                       self.cfg.fe.n_embedding_features,
                       final_sigmoid=False,
                       num_levels=5)
        momentum_model = UNet2D(self.cfg.fe.n_raw_channels,
                                self.cfg.fe.n_embedding_features,
                                final_sigmoid=False,
                                num_levels=5)
        if wu_cfg.identical_initialization:
            soft_update_params(model, momentum_model, 1)
        momentum_model.cuda(device)
        for param in momentum_model.parameters():
            param.requires_grad = False
        model.cuda(device)
        dset = SpgDset(self.cfg.gen.data_dir, wu_cfg.patch_manager,
                       wu_cfg.patch_stride, wu_cfg.patch_shape,
                       wu_cfg.reorder_sp)
        dloader = DataLoader(dset,
                             batch_size=wu_cfg.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=100,
                                     threshold=1e-3,
                                     min_lr=1e-6,
                                     factor=0.1)
        criterion = EntrInfoNCE(alpha=self.cfg.fe.alpha,
                                beta=self.cfg.fe.beta,
                                lbd=self.cfg.fe.lbd,
                                tau=self.cfg.fe.tau,
                                gamma=self.cfg.fe.gamma,
                                num_neg=self.cfg.fe.num_neg,
                                subs_size=self.cfg.fe.subs_size)
        tfs = RndAugmentationTfs(wu_cfg.patch_shape)
        acc_loss = 0
        iteration = 0
        k_step = math.ceil((wu_cfg.n_iterations - wu_cfg.n_k_stop_it) /
                           (wu_cfg.k_start - wu_cfg.k_stop))
        k = wu_cfg.k_start
        psi_step = (wu_cfg.psi_start - wu_cfg.psi_stop) / (
            wu_cfg.n_iterations - wu_cfg.n_k_stop_it)
        psi = wu_cfg.psi_start

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, indices) in enumerate(dloader):
                inp, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(
                    device)
                mask = torch.ones((
                    inp.shape[0],
                    1,
                ) + inp.shape[2:],
                                  device=device).float()
                # get transforms
                spat_tf, int_tf = tfs.sample(1, 1)
                _, _int_tf = tfs.sample(1, 1)
                # add noise to intensity tf of input for momentum network
                mom_inp = add_sp_gauss_noise(_int_tf(inp), 0.2, 0.1, 0.3)
                # get momentum prediction
                embeddings_mom = momentum_model(
                    mom_inp.unsqueeze(2)).squeeze(2)
                # do the same spatial tf for input, mask and momentum prediction
                paired = spat_tf(torch.cat((mask, inp, embeddings_mom), -3))
                embeddings_mom, mask = paired[..., inp.shape[1] +
                                              1:, :, :], paired[...,
                                                                0, :, :][:,
                                                                         None]
                # do intensity transform for spatial transformed input
                aug_inp = int_tf(paired[..., 1:inp.shape[1] + 1, :, :])
                # and add some noise
                aug_inp = add_sp_gauss_noise(aug_inp, 0.2, 0.1, 0.3)
                # get prediction of the augmented input
                embeddings = model(aug_inp.unsqueeze(2)).squeeze(2)

                # put embeddings on unit sphere so we can use cosine distance
                embeddings = embeddings / torch.norm(
                    embeddings, dim=1, keepdim=True)
                embeddings_mom = embeddings_mom + (
                    mask == 0)  # set the void of the image to the 1-vector
                embeddings_mom = embeddings_mom / torch.norm(
                    embeddings_mom, dim=1, keepdim=True)

                loss = criterion(embeddings.squeeze(0),
                                 embeddings_mom.squeeze(0),
                                 k,
                                 mask.squeeze(0),
                                 whiten=wu_cfg.whitened_embeddings,
                                 warmup=iteration < wu_cfg.n_warmup_it,
                                 psi=psi)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                acc_loss += loss.item()

                print(loss.item())
                writer.add_scalar("fe_warm_start/loss", loss.item(), iteration)
                writer.add_scalar("fe_warm_start/lr",
                                  optimizer.param_groups[0]['lr'], iteration)
                if (iteration) % 50 == 0:
                    sheduler.step(acc_loss / 10)
                    acc_loss = 0
                    fig, (a1, a2, a3, a4) = plt.subplots(1,
                                                         4,
                                                         sharex='col',
                                                         sharey='row',
                                                         gridspec_kw={
                                                             'hspace': 0,
                                                             'wspace': 0
                                                         })
                    a1.imshow(inp[0].cpu().permute(1, 2, 0).squeeze())
                    a1.set_title('raw')
                    a2.imshow(aug_inp[0].cpu().permute(1, 2, 0))
                    a2.set_title('augment')
                    a3.imshow(
                        pca_project(
                            get_angles(embeddings).squeeze(0).detach().cpu()))
                    a3.set_title('embed')
                    a4.imshow(
                        pca_project(
                            get_angles(embeddings_mom).squeeze(
                                0).detach().cpu()))
                    a4.set_title('mom_embed')
                    writer.add_figure("examples", fig, iteration // 100)
                iteration += 1
                psi = max(psi - psi_step, wu_cfg.psi_stop)
                if iteration % k_step == 0:
                    k = max(k - 1, wu_cfg.k_stop)

                if iteration > wu_cfg.n_iterations:
                    break
                if iteration % wu_cfg.momentum == 0:
                    soft_update_params(model, momentum_model,
                                       wu_cfg.momentum_tau)
        return
コード例 #6
0
    def train_step(self, rank, start_time, return_dict, writer):
        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        self.setup(rank, self.args.num_processes)
        if self.cfg.MC_DQL:
            transition = namedtuple('Transition', ('episode'))
        else:
            transition = namedtuple(
                'Transition',
                ('state', 'action', 'reward', 'next_state', 'done'))
        memory = TransitionData_ts(capacity=self.args.t_max,
                                   storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count)
        # Create shared network

        # model = GcnEdgeAC_1(self.cfg, self.args.n_raw_channels, self.args.n_embedding_features, 1, device, writer=writer)
        model = GcnEdgeAC(self.cfg, self.args, device, writer=writer)
        # model = GcnEdgeAC(self.cfg, self.args.n_raw_channels, self.args.n_embedding_features, 1, device, writer=writer)

        model.cuda(device)
        shared_model = DDP(model,
                           device_ids=[model.device],
                           find_unused_parameters=True)

        # dloader = DataLoader(MultiDiscSpGraphDsetBalanced(no_suppix=False, create=False), batch_size=1, shuffle=True, pin_memory=True,
        #                      num_workers=0)
        dloader = DataLoader(SpgDset(),
                             batch_size=self.cfg.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create optimizer for shared network parameters with shared statistics
        # optimizer = CstmAdam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas,
        #                      weight_decay=self.args.Adam_weight_decay)
        ######################
        self.action_range = 1
        self.device = torch.device(device)
        self.discount = 0.5
        self.critic_tau = self.cfg.critic_tau
        self.actor_update_frequency = self.cfg.actor_update_frequency
        self.critic_target_update_frequency = self.cfg.critic_target_update_frequency
        self.batch_size = self.cfg.batch_size

        self.log_alpha = torch.tensor(np.log(self.cfg.init_temperature)).to(
            self.device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        ######################
        # optimizers
        OptimizerContainer = namedtuple('OptimizerContainer',
                                        ('actor', 'critic', 'temperature'))
        actor_optimizer = torch.optim.Adam(
            shared_model.module.actor.parameters(),
            lr=self.cfg.actor_lr,
            betas=self.cfg.actor_betas)

        critic_optimizer = torch.optim.Adam(
            shared_model.module.critic.parameters(),
            lr=self.cfg.critic_lr,
            betas=self.cfg.critic_betas)

        temp_optimizer = torch.optim.Adam([self.log_alpha],
                                          lr=self.cfg.alpha_lr,
                                          betas=self.cfg.alpha_betas)

        optimizers = OptimizerContainer(actor_optimizer, critic_optimizer,
                                        temp_optimizer)

        if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only:
            fe_extr = shared_model.module.fe_ext
            fe_extr.cuda(device)
            self.fe_extr_warm_start_1(fe_extr, writer=writer)
            # self.fe_extr_warm_start(fe_extr, writer=writer)
            if self.args.model_name == "" and not self.args.no_save:
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, 'agent_model_fe_extr'))
            elif not self.args.no_save:
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        dist.barrier()
        for param in model.fe_ext.parameters():
            param.requires_grad = False

        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.model_fe_name != "":
            shared_model.module.fe_ext.load_state_dict(
                torch.load(os.path.join(self.save_dir,
                                        self.args.model_fe_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            shared_model.module.fe_ext.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model_fe_extr')))

        if not self.args.test_score_only:
            quality = self.args.stop_qual_scaling + self.args.stop_qual_offset
            best_quality = np.inf
            last_quals = []
            while self.global_count.value() <= self.args.T_max:
                if self.global_count.value() == 78:
                    a = 1
                self.update_env_data(env, dloader, device)
                # waff_dis = torch.softmax(env.edge_features[:, 0].squeeze() + 1e-30, dim=0)
                # waff_dis = torch.softmax(env.gt_edge_weights + 0.5, dim=0)
                waff_dis = torch.softmax(torch.ones_like(
                    env.b_gt_edge_weights),
                                         dim=0)
                loss_weight = torch.softmax(env.b_gt_edge_weights + 1, dim=0)
                env.reset()
                # self.target_entropy = - float(env.gt_edge_weights.shape[0])
                self.target_entropy = -8.0

                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value(), quality)
                if self.cfg.temperature_regulation == 'follow_quality':
                    self.alpha = self.eps_rule.apply(self.global_count.value(),
                                                     quality)
                    print(self.alpha.item())

                with open(os.path.join(self.save_dir,
                                       'runtime_cfg.yaml')) as info:
                    args_dict = yaml.full_load(info)
                    if args_dict is not None:
                        if 'safe_model' in args_dict:
                            self.args.safe_model = args_dict['safe_model']
                            args_dict['safe_model'] = False
                        if 'add_noise' in args_dict:
                            self.args.add_noise = args_dict['add_noise']
                        if 'critic_lr' in args_dict and args_dict[
                                'critic_lr'] != self.cfg.critic_lr:
                            self.cfg.critic_lr = args_dict['critic_lr']
                            adjust_learning_rate(critic_optimizer,
                                                 self.cfg.critic_lr)
                        if 'actor_lr' in args_dict and args_dict[
                                'actor_lr'] != self.cfg.actor_lr:
                            self.cfg.actor_lr = args_dict['actor_lr']
                            adjust_learning_rate(actor_optimizer,
                                                 self.cfg.actor_lr)
                        if 'alpha_lr' in args_dict and args_dict[
                                'alpha_lr'] != self.cfg.alpha_lr:
                            self.cfg.alpha_lr = args_dict['alpha_lr']
                            adjust_learning_rate(temp_optimizer,
                                                 self.cfg.alpha_lr)
                with open(os.path.join(self.save_dir, 'runtime_cfg.yaml'),
                          "w") as info:
                    yaml.dump(args_dict, info)

                if self.args.safe_model:
                    best_quality = quality
                    if rank == 0:
                        if self.args.model_name_dest != "":
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir,
                                             self.args.model_name_dest))
                        else:
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir, 'agent_model'))

                state = env.get_state()
                while not env.done:
                    # Calculate policy and values
                    post_input = True if (
                        self.global_count.value() +
                        1) % 15 == 0 and env.counter == 0 else False
                    round_n = env.counter
                    # sample action for data collection
                    distr = None
                    if self.global_count.value() < self.cfg.num_seed_steps:
                        action = torch.rand_like(env.b_current_edge_weights)
                    else:
                        distr, _, _, action = self.agent_forward(
                            env,
                            shared_model,
                            state=state,
                            grad=False,
                            post_input=post_input)

                    logg_dict = {'temperature': self.alpha.item()}
                    if distr is not None:
                        logg_dict['mean_loc'] = distr.loc.mean().item()
                        logg_dict['mean_scale'] = distr.scale.mean().item()

                    if self.global_count.value(
                    ) >= self.cfg.num_seed_steps and memory.is_full():
                        self._step(memory,
                                   optimizers,
                                   env,
                                   shared_model,
                                   self.global_count.value(),
                                   writer=writer)
                        self.global_writer_loss_count.increment()

                    next_state, reward, quality = env.execute_action(
                        action, logg_dict)

                    last_quals.append(quality)
                    if len(last_quals) > 10:
                        last_quals.pop(0)

                    if self.args.add_noise:
                        noise = torch.randn_like(reward) * self.alpha.item()
                        reward = reward + noise

                    memory.push(self.state_to_cpu(state), action, reward,
                                self.state_to_cpu(next_state), env.done)

                    # Train the network
                    # self._step(memory, shared_model, env, optimizer, loss_weight, off_policy=True, writer=writer)

                    # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                    # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                    state = next_state

                self.global_count.increment()

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save:
                # pass
                if self.args.model_name_dest != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name_dest))
                    print('saved')
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))

        self.cleanup()
        return sum(last_quals) / 10
コード例 #7
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        writer.add_text("conf", self.cfg.pretty())
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(self.cfg.fe.n_raw_channels,
                       self.cfg.fe.n_embedding_features,
                       final_sigmoid=False,
                       num_levels=5)
        model.cuda(device)
        dset = SpgDset(self.cfg.gen.data_dir, wu_cfg.patch_manager,
                       wu_cfg.patch_stride, wu_cfg.patch_shape,
                       wu_cfg.reorder_sp)
        dloader = DataLoader(dset,
                             batch_size=wu_cfg.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=100,
                                     threshold=1e-3,
                                     min_lr=1e-6,
                                     factor=0.1)
        criterion = RagInfoNCE(tau=self.cfg.fe.tau)
        acc_loss = 0
        iteration = 0

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, indices) in enumerate(dloader):
                inp, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(
                    device)
                edges = dloader.dataset.get_graphs(indices, sp_seg, device)[0]

                off = 0
                for i in range(len(edges)):
                    sp_seg[i] += off
                    edges[i] += off
                    off = sp_seg[i].max() + 1
                edges = torch.cat(edges, 1)
                embeddings = model(inp.unsqueeze(2)).squeeze(2)

                # put embeddings on unit sphere so we can use cosine distance
                embeddings = embeddings / torch.norm(
                    embeddings, dim=1, keepdim=True)

                loss = criterion(embeddings, sp_seg, edges)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                acc_loss += loss.item()

                print(loss.item())
                writer.add_scalar("fe_warm_start/loss", loss.item(), iteration)
                writer.add_scalar("fe_warm_start/lr",
                                  optimizer.param_groups[0]['lr'], iteration)
                if (iteration) % 50 == 0:
                    sheduler.step(acc_loss / 10)
                    acc_loss = 0
                    fig, (a1, a2) = plt.subplots(1,
                                                 2,
                                                 sharex='col',
                                                 sharey='row',
                                                 gridspec_kw={
                                                     'hspace': 0,
                                                     'wspace': 0
                                                 })
                    a1.imshow(inp[0].cpu().permute(1, 2, 0).squeeze())
                    a1.set_title('raw')
                    a2.imshow(
                        pca_project(
                            get_angles(embeddings).squeeze(0).detach().cpu()))
                    a2.set_title('embed')
                    writer.add_figure("examples", fig, iteration // 100)
                iteration += 1
                if iteration > wu_cfg.n_iterations:
                    break
        return
コード例 #8
0
    def __init__(self, cfg, global_count):
        super(AgentSacTrainerObjLvlReward, self).__init__()
        assert torch.cuda.device_count() == 1
        self.device = torch.device("cuda:0")
        torch.cuda.set_device(self.device)
        torch.set_default_tensor_type(torch.FloatTensor)

        self.cfg = cfg
        self.global_count = global_count
        self.memory = TransitionData_ts(capacity=self.cfg.mem_size)
        self.best_val_reward = -np.inf
        if self.cfg.distance == 'cosine':
            self.distance = CosineDistance()
        else:
            self.distance = L2Distance()

        self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone),
                                  self.distance, cfg.fe_delta_dist,
                                  self.device)
        self.fe_ext.embed_model.load_state_dict(
            torch.load(self.cfg.fe_model_name))
        self.fe_ext.cuda(self.device)

        self.model = Agent(self.cfg, State, self.distance, self.device)
        wandb.watch(self.model)
        self.model.cuda(self.device)
        self.model_mtx = Lock()

        MovSumLosses = namedtuple('mov_avg_losses',
                                  ('actor', 'critic', 'temperature'))
        Scalers = namedtuple('Scalers', ('critic', 'actor'))
        OptimizerContainer = namedtuple(
            'OptimizerContainer', ('actor', 'critic', 'temperature',
                                   'actor_shed', 'critic_shed', 'temp_shed'))
        actor_optimizer = torch.optim.Adam(self.model.actor.parameters(),
                                           lr=self.cfg.actor_lr)
        critic_optimizer = torch.optim.Adam(self.model.critic.parameters(),
                                            lr=self.cfg.critic_lr)
        temp_optimizer = torch.optim.Adam([self.model.log_alpha],
                                          lr=self.cfg.alpha_lr)

        lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched)
        bw = lr_sched_cfg.mov_avg_bandwidth
        off = lr_sched_cfg.mov_avg_offset
        weights = np.linspace(lr_sched_cfg.weight_range[0],
                              lr_sched_cfg.weight_range[1], bw)
        weights = weights / weights.sum()  # make them sum up to one
        shed = lr_sched_cfg.torch_sched

        self.mov_sum_losses = MovSumLosses(
            RunningAverage(weights, band_width=bw, offset=off),
            RunningAverage(weights, band_width=bw, offset=off),
            RunningAverage(weights, band_width=bw, offset=off))
        self.optimizers = OptimizerContainer(
            actor_optimizer, critic_optimizer, temp_optimizer, *[
                ReduceLROnPlateau(opt,
                                  patience=shed.patience,
                                  threshold=shed.threshold,
                                  min_lr=shed.min_lr,
                                  factor=shed.factor)
                for opt in (actor_optimizer, critic_optimizer, temp_optimizer)
            ])
        self.scalers = Scalers(torch.cuda.amp.GradScaler(),
                               torch.cuda.amp.GradScaler())
        self.forwarder = Forwarder()

        if self.cfg.agent_model_name != "":
            self.model.load_state_dict(torch.load(self.cfg.agent_model_name))
        # if "policy_warmup" in self.cfg and self.cfg.agent_model_name == "":
        #     supervised_policy_pretraining(self.model, self.env, self.cfg, device=self.device)
        #     torch.save(self.model.state_dict(), os.path.join(wandb.run.dir, "sv_pretrained_policy_agent.pth"))

        # finished with prepping
        for param in self.fe_ext.parameters():
            param.requires_grad = False

        self.train_dset = SpgDset(self.cfg.data_dir,
                                  dict_to_attrdict(self.cfg.patch_manager),
                                  dict_to_attrdict(self.cfg.data_keys))
        self.val_dset = SpgDset(self.cfg.val_data_dir,
                                dict_to_attrdict(self.cfg.patch_manager),
                                dict_to_attrdict(self.cfg.data_keys))
コード例 #9
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(**self.cfg.fe.backbone)
        model.cuda(device)
        train_set = SpgDset(
            "/g/kreshuk/hilt/projects/data/leptin_fused_tp1_ch_0/true_val",
            reorder_sp=True)
        val_set = SpgDset(
            "/g/kreshuk/hilt/projects/data/leptin_fused_tp1_ch_0/train",
            reorder_sp=True)
        # pm = StridedPatches2D(wu_cfg.patch_stride, wu_cfg.patch_shape, train_set.image_shape)
        train_loader = DataLoader(train_set,
                                  batch_size=wu_cfg.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=0)
        val_loader = DataLoader(val_set,
                                batch_size=wu_cfg.batch_size,
                                shuffle=True,
                                pin_memory=True,
                                num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=40,
                                     threshold=1e-4,
                                     min_lr=1e-5,
                                     factor=0.1)
        criterion = RagContrastiveWeights(delta_var=0.1, delta_dist=0.3)
        acc_loss = 0
        valit = 0
        iteration = 0
        best_loss = np.inf

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, affinities, offs,
                     indices) in enumerate(train_loader):
                raw, gt = raw.to(device), gt.to(device)

                loss_embeds = model(raw[:, :, None]).squeeze(2)
                loss_embeds = loss_embeds / (
                    torch.norm(loss_embeds, dim=1, keepdim=True) + 1e-9)

                edges = [
                    feats.compute_rag(seg.cpu().numpy()).uvIds() for seg in gt
                ]
                edges = [
                    torch.from_numpy(e.astype(np.long)).to(device).T
                    for e in edges
                ]

                loss = criterion(loss_embeds, gt.long(), edges, None, 30)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                print(loss.item())
                # writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration)
                # writer.add_scalar("fe_train/loss", loss.item(), iteration)
                # if (iteration) % 100 == 0:
                #
                #     fig, (a1, a2, a3) = plt.subplots(3, 1, sharex='col', sharey='row',
                #                                  gridspec_kw={'hspace': 0, 'wspace': 0})
                #     a1.imshow(raw[0, 0].cpu().squeeze())
                #     a1.set_title('train raw')
                #     a2.imshow(pca_project(loss_embeds[0].detach().cpu()))
                #     a2.set_title('train embed')
                #     a3.imshow(gt[0, 0].cpu().squeeze())
                #     a3.set_title('train gt')
                #     plt.show()
                #
                #     with torch.set_grad_enabled(False):
                #         for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(val_loader):
                #             raw = raw.to(device)
                #             embeds = model(raw[:, :, None]).squeeze(2)
                #             embeds = embeds / (torch.norm(embeds, dim=1, keepdim=True) + 1e-9)
                #
                #             print(loss.item())
                #             writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration)
                #             writer.add_scalar("fe_train/loss", loss.item(), iteration)
                #             fig, (a1, a2) = plt.subplots(2, 1, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0})
                #             a1.imshow(raw[0, 0].cpu().squeeze())
                #             a1.set_title('raw')
                #             a2.imshow(pca_project(embeds[0].detach().cpu()))
                #             a2.set_title('embed')
                #             plt.show()
                #             if it > 2:
                #                 break
                iteration += 1
                print(iteration)
                if iteration > wu_cfg.n_iterations:
                    print(self.save_dir)
                    torch.save(model.state_dict(),
                               os.path.join(self.save_dir, "last_model.pth"))
                    break
        return
コード例 #10
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(**self.cfg.fe.backbone)
        model.cuda(device)
        # train_set = SpgDset(self.cfg.gen.data_dir_raw_train, patch_manager="no_cross", patch_stride=(10,10), patch_shape=(300,300), reorder_sp=True)
        # val_set = SpgDset(self.cfg.gen.data_dir_raw_val, patch_manager="no_cross", patch_stride=(10,10), patch_shape=(300,300), reorder_sp=True)
        train_set = SpgDset(self.cfg.gen.data_dir_raw_train, reorder_sp=True)
        val_set = SpgDset(self.cfg.gen.data_dir_raw_val, reorder_sp=True)
        # pm = StridedPatches2D(wu_cfg.patch_stride, wu_cfg.patch_shape, train_set.image_shape)
        pm = NoPatches2D()
        train_set.length = len(train_set.graph_file_names) * np.prod(
            pm.n_patch_per_dim)
        train_set.n_patch_per_dim = pm.n_patch_per_dim
        val_set.length = len(val_set.graph_file_names)
        gauss_kernel = GaussianSmoothing(1, 5, 3, device=device)
        # dset = LeptinDset(self.cfg.gen.data_dir_raw, self.cfg.gen.data_dir_affs, wu_cfg.patch_manager, wu_cfg.patch_stride, wu_cfg.patch_shape, wu_cfg.reorder_sp)
        train_loader = DataLoader(train_set,
                                  batch_size=wu_cfg.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=0)
        val_loader = DataLoader(val_set,
                                batch_size=wu_cfg.batch_size,
                                shuffle=True,
                                pin_memory=True,
                                num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=80,
                                     threshold=1e-4,
                                     min_lr=1e-8,
                                     factor=0.1)
        slcs = [
            slice(None, self.cfg.fe.embeddings_separator),
            slice(self.cfg.fe.embeddings_separator, None)
        ]
        criterion = RegRagContrastiveWeights(delta_var=0.1,
                                             delta_dist=0.3,
                                             slices=slcs)
        acc_loss = 0
        valit = 0
        iteration = 0
        best_loss = np.inf

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, affinities, offs,
                     indices) in enumerate(train_loader):
                raw, gt, sp_seg, affinities = raw.to(device), gt.to(
                    device), sp_seg.to(device), affinities.to(device)
                sp_seg = sp_seg + 1
                edge_img = F.pad(get_contour_from_2d_binary(sp_seg),
                                 (2, 2, 2, 2),
                                 mode='constant')
                edge_img = gauss_kernel(edge_img.float())
                all = torch.cat([raw, gt, sp_seg, edge_img], dim=1)

                angle = float(torch.randint(-180, 180, (1, )).item())
                rot_all = tvF.rotate(all, angle, PIL.Image.NEAREST)
                rot_raw = rot_all[:, :1]
                rot_gt = rot_all[:, 1:2]
                rot_sp = rot_all[:, 2:3]
                rot_edge_img = rot_all[:, 3:]
                angle = abs(angle / 180)
                valid_sp = []
                for i in range(len(rot_sp)):
                    _valid_sp = torch.unique(rot_sp[i], sorted=True)
                    _valid_sp = _valid_sp[1:] if _valid_sp[
                        0] == 0 else _valid_sp
                    if len(_valid_sp) > self.cfg.gen.sp_samples_per_step:
                        inds = torch.multinomial(
                            torch.ones_like(_valid_sp),
                            self.cfg.gen.sp_samples_per_step,
                            replacement=False)
                        _valid_sp = _valid_sp[inds]
                    valid_sp.append(_valid_sp)

                _rot_sp, _sp_seg = [], []
                for val_sp, rsp, sp in zip(valid_sp, rot_sp, sp_seg):
                    mask = rsp == val_sp[:, None, None]
                    _rot_sp.append((mask * (torch.arange(
                        len(val_sp), device=rsp.device)[:, None, None] + 1)
                                    ).sum(0))
                    mask = sp == val_sp[:, None, None]
                    _sp_seg.append((mask * (torch.arange(
                        len(val_sp), device=sp.device)[:, None, None] + 1)
                                    ).sum(0))

                rot_sp = torch.stack(_rot_sp)
                sp_seg = torch.stack(_sp_seg)
                valid_sp = [
                    torch.unique(_rot_sp, sorted=True) for _rot_sp in rot_sp
                ]
                valid_sp = [
                    _valid_sp[1:] if _valid_sp[0] == 0 else _valid_sp
                    for _valid_sp in valid_sp
                ]

                inp = torch.cat([
                    torch.cat([raw, edge_img], 1),
                    torch.cat([rot_raw, rot_edge_img], 1)
                ], 0)
                offs = offs.numpy().tolist()
                edge_feat, edges = tuple(
                    zip(*[
                        get_edge_features_1d(seg.squeeze().cpu().numpy(), os,
                                             affs.squeeze().cpu().numpy())
                        for seg, os, affs in zip(sp_seg, offs, affinities)
                    ]))
                edges = [
                    torch.from_numpy(e.astype(np.long)).to(device).T
                    for e in edges
                ]
                edge_weights = [
                    torch.from_numpy(ew.astype(np.float32)).to(device)[:,
                                                                       0][None]
                    for ew in edge_feat
                ]
                valid_edges_masks = [
                    (_edges[None] == _valid_sp[:, None,
                                               None]).sum(0).sum(0) == 2
                    for _valid_sp, _edges in zip(valid_sp, edges)
                ]
                edges = [
                    _edges[:, valid_edges_mask] - 1
                    for _edges, valid_edges_mask in zip(
                        edges, valid_edges_masks)
                ]
                edge_weights = [
                    _edge_weights[:, valid_edges_mask]
                    for _edge_weights, valid_edges_mask in zip(
                        edge_weights, valid_edges_masks)
                ]

                # put embeddings on unit sphere so we can use cosine distance
                loss_embeds = model(inp[:, :, None]).squeeze(2)
                loss_embeds = criterion.norm_each_space(loss_embeds, 1)

                loss = criterion(loss_embeds,
                                 sp_seg.long(),
                                 rot_sp.long(),
                                 edges,
                                 edge_weights,
                                 valid_sp,
                                 angle,
                                 chunks=int(sp_seg.max().item() //
                                            self.cfg.gen.train_chunk_size))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                print(f"step {iteration}: {loss.item()}")
                writer.add_scalar("fe_train/lr",
                                  optimizer.param_groups[0]['lr'], iteration)
                writer.add_scalar("fe_train/loss", loss.item(), iteration)
                if (iteration) % 100 == 0:
                    with torch.set_grad_enabled(False):
                        model.eval()
                        print("####start validation####")
                        for it, (raw, gt, sp_seg, affinities, offs,
                                 indices) in enumerate(val_loader):
                            raw, gt, sp_seg, affinities = raw.to(
                                device), gt.to(device), sp_seg.to(
                                    device), affinities.to(device)
                            sp_seg = sp_seg + 1
                            edge_img = F.pad(
                                get_contour_from_2d_binary(sp_seg),
                                (2, 2, 2, 2),
                                mode='constant')
                            edge_img = gauss_kernel(edge_img.float())
                            all = torch.cat([raw, gt, sp_seg, edge_img], dim=1)

                            angle = float(
                                torch.randint(-180, 180, (1, )).item())
                            rot_all = tvF.rotate(all, angle, PIL.Image.NEAREST)
                            rot_raw = rot_all[:, :1]
                            rot_gt = rot_all[:, 1:2]
                            rot_sp = rot_all[:, 2:3]
                            rot_edge_img = rot_all[:, 3:]
                            angle = abs(angle / 180)
                            valid_sp = [
                                torch.unique(_rot_sp, sorted=True)
                                for _rot_sp in rot_sp
                            ]
                            valid_sp = [
                                _valid_sp[1:]
                                if _valid_sp[0] == 0 else _valid_sp
                                for _valid_sp in valid_sp
                            ]

                            _rot_sp, _sp_seg = [], []
                            for val_sp, rsp, sp in zip(valid_sp, rot_sp,
                                                       sp_seg):
                                mask = rsp == val_sp[:, None, None]
                                _rot_sp.append((mask * (torch.arange(
                                    len(val_sp), device=rsp.device)[:, None,
                                                                    None] + 1)
                                                ).sum(0))
                                mask = sp == val_sp[:, None, None]
                                _sp_seg.append((mask * (torch.arange(
                                    len(val_sp), device=sp.device)[:, None,
                                                                   None] + 1)
                                                ).sum(0))

                            rot_sp = torch.stack(_rot_sp)
                            sp_seg = torch.stack(_sp_seg)
                            valid_sp = [
                                torch.unique(_rot_sp, sorted=True)
                                for _rot_sp in rot_sp
                            ]
                            valid_sp = [
                                _valid_sp[1:]
                                if _valid_sp[0] == 0 else _valid_sp
                                for _valid_sp in valid_sp
                            ]

                            inp = torch.cat([
                                torch.cat([raw, edge_img], 1),
                                torch.cat([rot_raw, rot_edge_img], 1)
                            ], 0)
                            offs = offs.numpy().tolist()
                            edge_feat, edges = tuple(
                                zip(*[
                                    get_edge_features_1d(
                                        seg.squeeze().cpu().numpy(), os,
                                        affs.squeeze().cpu().numpy())
                                    for seg, os, affs in zip(
                                        sp_seg, offs, affinities)
                                ]))
                            edges = [
                                torch.from_numpy(e.astype(
                                    np.long)).to(device).T for e in edges
                            ]
                            edge_weights = [
                                torch.from_numpy(ew.astype(
                                    np.float32)).to(device)[:, 0][None]
                                for ew in edge_feat
                            ]
                            valid_edges_masks = [
                                (_edges[None] == _valid_sp[:, None, None]
                                 ).sum(0).sum(0) == 2
                                for _valid_sp, _edges in zip(valid_sp, edges)
                            ]
                            edges = [
                                _edges[:, valid_edges_mask] - 1
                                for _edges, valid_edges_mask in zip(
                                    edges, valid_edges_masks)
                            ]
                            edge_weights = [
                                _edge_weights[:, valid_edges_mask]
                                for _edge_weights, valid_edges_mask in zip(
                                    edge_weights, valid_edges_masks)
                            ]

                            # put embeddings on unit sphere so we can use cosine distance
                            embeds = model(inp[:, :, None]).squeeze(2)
                            embeds = criterion.norm_each_space(embeds, 1)

                            ls = criterion(
                                embeds,
                                sp_seg.long(),
                                rot_sp.long(),
                                edges,
                                edge_weights,
                                valid_sp,
                                angle,
                                chunks=int(sp_seg.max().item() //
                                           self.cfg.gen.train_chunk_size))

                            acc_loss += ls
                            writer.add_scalar("fe_val/loss", ls, valit)
                            print(f"step {it}: {ls.item()}")
                            valit += 1

                    acc_loss = acc_loss / len(val_loader)
                    if acc_loss < best_loss:
                        print(self.save_dir)
                        torch.save(
                            model.state_dict(),
                            os.path.join(self.save_dir, "best_val_model.pth"))
                        best_loss = acc_loss
                    sheduler.step(acc_loss)
                    acc_loss = 0
                    fig, ((a1, a2), (a3, a4)) = plt.subplots(2,
                                                             2,
                                                             sharex='col',
                                                             sharey='row',
                                                             gridspec_kw={
                                                                 'hspace': 0,
                                                                 'wspace': 0
                                                             })
                    a1.imshow(raw[0].cpu().permute(1, 2, 0).squeeze())
                    a1.set_title('raw')
                    a2.imshow(
                        cm.prism(sp_seg[0].cpu().squeeze() /
                                 sp_seg[0].cpu().squeeze().max()))
                    a2.set_title('sp')
                    a3.imshow(pca_project(embeds[0, slcs[0]].detach().cpu()))
                    a3.set_title('embed', y=-0.01)
                    a4.imshow(pca_project(embeds[0, slcs[1]].detach().cpu()))
                    a4.set_title('embed rot', y=-0.01)
                    plt.show()
                    writer.add_figure("examples", fig, iteration // 100)
                    # model.train()
                    print("####end validation####")
                iteration += 1
                if iteration > wu_cfg.n_iterations:
                    print(self.save_dir)
                    torch.save(model.state_dict(),
                               os.path.join(self.save_dir, "last_model.pth"))
                    break
        return
コード例 #11
0
    def train_step(self, rank, writer):
        device = torch.device("cuda:" +
                              str(rank // self.cfg.gen.n_processes_per_gpu))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)
        self.setup(rank, self.cfg.gen.n_processes_per_gpu * self.cfg.gen.n_gpu)

        env = SpGcnEnv(self.cfg,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count)
        # Create shared network

        model = GcnEdgeAC(self.cfg, device, writer=writer)
        model.cuda(device)
        shared_model = DDP(model,
                           device_ids=[device],
                           find_unused_parameters=True)
        if 'extra' in self.cfg.fe.optim:
            # optimizers
            MovSumLosses = namedtuple(
                'mov_avg_losses',
                ('actor', 'embeddings', 'critic', 'temperature'))
            OptimizerContainer = namedtuple(
                'OptimizerContainer',
                ('actor', 'embeddings', 'critic', 'temperature', 'actor_shed',
                 'embed_shed', 'critic_shed', 'temp_shed'))
        else:
            MovSumLosses = namedtuple('mov_avg_losses',
                                      ('actor', 'critic', 'temperature'))
            OptimizerContainer = namedtuple(
                'OptimizerContainer',
                ('actor', 'critic', 'temperature', 'actor_shed', 'critic_shed',
                 'temp_shed'))
        if "rl_loss" == self.cfg.fe.optim:
            actor_optimizer = torch.optim.Adam(
                list(shared_model.module.actor.parameters()) +
                list(shared_model.module.fe_ext.parameters()),
                lr=self.cfg.sac.actor_lr,
                betas=self.cfg.sac.actor_betas)
        else:
            actor_optimizer = torch.optim.Adam(
                shared_model.module.actor.parameters(),
                lr=self.cfg.sac.actor_lr,
                betas=self.cfg.sac.actor_betas)
        if "extra" in self.cfg.fe.optim:
            embeddings_optimizer = torch.optim.Adam(
                shared_model.module.fe_ext.parameters(),
                lr=self.cfg.fe.lr,
                betas=self.cfg.fe.betas)
        critic_optimizer = torch.optim.Adam(
            shared_model.module.critic.parameters(),
            lr=self.cfg.sac.critic_lr,
            betas=self.cfg.sac.critic_betas)
        temp_optimizer = torch.optim.Adam([shared_model.module.log_alpha],
                                          lr=self.cfg.sac.alpha_lr,
                                          betas=self.cfg.sac.alpha_betas)

        if "extra" in self.cfg.fe.optim:
            mov_sum_losses = MovSumLosses(RunningAverage(), RunningAverage(),
                                          RunningAverage(), RunningAverage())
            optimizers = OptimizerContainer(
                actor_optimizer, embeddings_optimizer, critic_optimizer,
                temp_optimizer, ReduceLROnPlateau(actor_optimizer),
                ReduceLROnPlateau(embeddings_optimizer),
                ReduceLROnPlateau(critic_optimizer),
                ReduceLROnPlateau(temp_optimizer))
        else:
            mov_sum_losses = MovSumLosses(RunningAverage(), RunningAverage(),
                                          RunningAverage())
            optimizers = OptimizerContainer(
                actor_optimizer, critic_optimizer, temp_optimizer,
                ReduceLROnPlateau(actor_optimizer),
                ReduceLROnPlateau(critic_optimizer),
                ReduceLROnPlateau(temp_optimizer))

        dist.barrier()

        if self.cfg.gen.resume:
            shared_model.module.load_state_dict(
                torch.load(os.path.join(self.log_dir,
                                        self.cfg.gen.model_name)))
        elif self.cfg.fe.load_pretrained:
            shared_model.module.fe_ext.load_state_dict(
                torch.load(os.path.join(self.save_dir,
                                        self.cfg.fe.model_name)))
        elif 'warmup' in self.cfg.fe and rank == 0:
            print('pretrain fe extractor')
            self.pretrain_embeddings_gt(shared_model.module.fe_ext, device,
                                        writer)
            torch.save(shared_model.module.fe_ext.state_dict(),
                       os.path.join(self.save_dir, self.cfg.fe.model_name))
        dist.barrier()

        if "none" == self.cfg.fe.optim:
            for param in shared_model.module.fe_ext.parameters():
                param.requires_grad = False

        dset = SpgDset(self.cfg.gen.data_dir)
        step = 0
        while self.global_count.value() <= self.cfg.trainer.T_max:
            dloader = DataLoader(dset,
                                 batch_size=self.cfg.trainer.batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=0)
            for iteration in range(
                    len(dset) * self.cfg.trainer.data_update_frequency):
                # if self.global_count.value() > self.args.T_max:
                #     a=1
                if iteration % self.cfg.trainer.data_update_frequency == 0:
                    self.update_env_data(env, dloader, device)
                # waff_dis = torch.softmax(env.edge_features[:, 0].squeeze() + 1e-30, dim=0)
                # waff_dis = torch.softmax(env.gt_edge_weights + 0.5, dim=0)
                # waff_dis = torch.softmax(torch.ones_like(env.b_gt_edge_weights), dim=0)
                # loss_weight = torch.softmax(env.b_gt_edge_weights + 1, dim=0)
                env.reset()
                self.update_rt_vars(critic_optimizer, actor_optimizer)
                if rank == 0 and self.cfg.rt_vars.safe_model:
                    if self.cfg.gen.model_name != "":
                        torch.save(
                            shared_model.module.state_dict(),
                            os.path.join(self.log_dir,
                                         self.cfg.gen.model_name))
                    else:
                        torch.save(shared_model.module.state_dict(),
                                   os.path.join(self.log_dir, 'agent_model'))

                state = env.get_state()
                while not env.done:
                    # Calculate policy and values
                    post_stats = True if (self.global_writer_count.value() + 1) % self.cfg.trainer.post_stats_frequency == 0 \
                        else False
                    post_model = True if (self.global_writer_count.value() + 1) % self.cfg.trainer.post_model_frequency == 0 \
                        else False
                    post_stats &= self.memory.is_full()
                    post_model &= self.memory.is_full()
                    distr = None
                    if not self.memory.is_full():
                        action = torch.rand_like(env.current_edge_weights)
                    else:
                        distr, _, _, action, _, _ = self.agent_forward(
                            env,
                            shared_model,
                            state=state,
                            grad=False,
                            post_input=post_stats,
                            post_model=post_model)

                    logg_dict = {}
                    if post_stats:
                        for i in range(len(self.cfg.sac.s_subgraph)):
                            logg_dict[
                                'alpha_' +
                                str(i)] = shared_model.module.alpha[i].item()
                        if distr is not None:
                            logg_dict['mean_loc'] = distr.loc.mean().item()
                            logg_dict['mean_scale'] = distr.scale.mean().item()

                    if self.memory.is_full():
                        for i in range(self.cfg.trainer.n_updates_per_step):
                            self._step(self.memory,
                                       optimizers,
                                       mov_sum_losses,
                                       env,
                                       shared_model,
                                       step,
                                       writer=writer)
                            self.global_writer_loss_count.increment()

                    next_state, reward = env.execute_action(
                        action, logg_dict, post_stats=post_stats)
                    # next_state, reward, quality = env.execute_action(torch.sigmoid(distr.loc), logg_dict, post_stats=post_stats)

                    if self.cfg.rt_vars.add_noise:
                        noise = torch.randn_like(reward) * 0.2
                        reward = reward + noise

                    self.memory.push(self.state_to_cpu(state), action, reward,
                                     self.state_to_cpu(next_state), env.done)
                    state = next_state

                self.global_count.increment()
                step += 1
                if rank == 0:
                    self.global_writer_count.increment()
                if step > self.cfg.trainer.T_max:
                    break

        dist.barrier()
        if rank == 0:
            self.memory.clear()
            if not self.cfg.gen.cross_validate_hp and not self.cfg.gen.test_score_only and not self.cfg.gen.no_save:
                # pass
                if self.cfg.gen.model_name != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.log_dir, self.cfg.gen.model_name))
                    print('saved')
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.log_dir, 'agent_model'))

        self.cleanup()
        return sum(env.acc_reward) / len(env.acc_reward)
コード例 #12
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(**self.cfg.fe.backbone)
        model.cuda(device)
        train_set = SpgDset(self.cfg.gen.data_dir_raw_train, reorder_sp=False)
        val_set = SpgDset(self.cfg.gen.data_dir_raw_val, reorder_sp=False)
        train_loader = DataLoader(train_set,
                                  batch_size=wu_cfg.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=0)
        val_loader = DataLoader(val_set,
                                batch_size=wu_cfg.batch_size,
                                shuffle=True,
                                pin_memory=True,
                                num_workers=0)

        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        criterion = AffinityContrastive(delta_var=0.1, delta_dist=0.3)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=5,
                                     threshold=1e-4,
                                     min_lr=1e-5,
                                     factor=0.1)
        valit = 0
        iteration = 0
        best_loss = np.inf

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, affinities, offs,
                     indices) in enumerate(train_loader):
                raw, gt, sp_seg, affinities, offs = raw.to(device), gt.to(
                    device), sp_seg.to(device), affinities.to(
                        device), offs[0].to(device)

                input = torch.cat([raw, affinities], dim=1)

                embeddings = model(input.unsqueeze(2)).squeeze(2)

                embeddings = embeddings / torch.norm(
                    embeddings, dim=1, keepdim=True)

                loss = criterion(embeddings, affinities, offs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                lr = optimizer.param_groups[0]['lr']
                print(f"step {it}; lr({lr}); loss({loss.item()})")
                writer.add_scalar("fe_warm_start/loss", loss.item(), iteration)
                writer.add_scalar("fe_warm_start/lr", lr, iteration)
                if (iteration) % 100 == 0:
                    acc_loss = 0
                    with torch.set_grad_enabled(False):
                        for val_it, (raw, gt, sp_seg, affinities, offs,
                                     indices) in enumerate(val_loader):
                            raw, gt, sp_seg, affinities, offs = raw.to(
                                device), gt.to(device), sp_seg.to(
                                    device), affinities.to(device), offs[0].to(
                                        device)

                            input = torch.cat([raw, affinities], dim=1)

                            embeddings = model(input.unsqueeze(2)).squeeze(2)

                            embeddings = embeddings / torch.norm(
                                embeddings, dim=1, keepdim=True)

                            loss = criterion(embeddings, affinities, offs)
                            acc_loss += loss
                            writer.add_scalar("fe_val/loss", loss, valit)
                            valit += 1
                    acc_loss = acc_loss / len(val_loader)
                    if acc_loss < best_loss:
                        torch.save(
                            model.state_dict(),
                            os.path.join(self.save_dir, "best_val_model.pth"))
                        best_loss = acc_loss
                    sheduler.step(acc_loss)
                    fig, (a1, a2) = plt.subplots(1,
                                                 2,
                                                 sharex='col',
                                                 sharey='row',
                                                 gridspec_kw={
                                                     'hspace': 0,
                                                     'wspace': 0
                                                 })
                    a1.imshow(raw[0].cpu().permute(1, 2, 0).squeeze())
                    a1.set_title('raw')
                    a2.imshow(pca_project(embeddings[0].detach().cpu()))
                    a2.set_title('embed')
                    plt.show()
                    # writer.add_figure("examples", fig, iteration // 50)
                iteration += 1
                if iteration > wu_cfg.n_iterations:
                    break
        return
コード例 #13
0
def validate_and_compare_to_clustering(model, env, distance, device, cfg):
    """validates the prediction against the method of clustering the embedding space"""

    model.eval()
    offs = [[1, 0], [0, 1], [2, 0], [0, 2], [4, 0], [0, 4], [16, 0], [0, 16]]
    ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_embeds, ex_clst, ex_clst_sp, ex_mcaff, ex_mc_embed, ex_rl, \
    ex_clst_graph_agglo= [], [], [], [], [], [], [], [], [], [], []
    dset = SpgDset(cfg.val_data_dir, dict_to_attrdict(cfg.patch_manager), dict_to_attrdict(cfg.val_data_keys), max(cfg.s_subgraph))
    dloader = iter(DataLoader(dset))
    acc_reward = 0
    forwarder = Forwarder()
    delta_dist = 0.4

    # segm_metric = AveragePrecision()
    clst_metric_rl = ClusterMetrics()
    # clst_metric = ClusterMetrics()
    metric_sp_gt = ClusterMetrics()
    # clst_metric_mcaff = ClusterMetrics()
    # clst_metric_mcembed = ClusterMetrics()
    # clst_metric_graphagglo = ClusterMetrics()
    sbd = SBD()

    # map_rl, map_embed, map_sp_gt, map_mcaff, map_mcembed, map_graphagglo = [], [], [], [], [], []
    sbd_rl, sbd_embed, sbd_sp_gt, sbd_mcaff, sbd_mcembed, sbd_graphagglo = [], [], [], [], [], []

    n_examples = len(dset)
    for it in range(n_examples):
        update_env_data(env, dloader, dset, device, with_gt_edges=False)
        env.reset()
        state = env.get_state()
        distr, _, _, _, _, node_features, embeddings = forwarder.forward(model, state, State,
                                                                              device,
                                                                              grad=False, post_data=False,
                                                                              get_node_feats=True,
                                                                              get_embeddings=True)
        action = torch.sigmoid(distr.loc)
        reward = env.execute_action(action, tau=0.0, train=False)
        acc_reward += reward[-2].item()

        embeds = embeddings[0].cpu()
        # node_features = node_features.cpu().numpy()
        rag = env.rags[0]
        edge_ids = rag.uvIds()
        gt_seg = env.gt_seg[0].cpu().numpy()
        # l2_embeddings = get_angles(embeds[None])[0]
        # l2_node_feats = get_angles(torch.from_numpy(node_features.T[None, ..., None])).squeeze().T.numpy()
        # clst_labels_kmeans = cluster_embeddings(l2_embeddings.permute((1, 2, 0)), len(np.unique(gt_seg)))
        # node_labels = cluster_embeddings(l2_node_feats, len(np.unique(gt_seg)))
        # clst_labels_sp_kmeans = elf.segmentation.features.project_node_labels_to_pixels(rag, node_labels).squeeze()

        # clst_labels_sp_graph_agglo = get_soln_graph_clustering(env.init_sp_seg, torch.from_numpy(edge_ids.astype(np.int)), torch.from_numpy(l2_node_feats), len(np.unique(gt_seg)))[0][0].numpy()
        # mc_labels_aff = env.get_current_soln(edge_weights=env.edge_features[:, 0]).cpu().numpy()[0]
        # ew_embedaffs = 1 - get_edge_features_1d(env.init_sp_seg[0].cpu().numpy(), offs, get_affinities_from_embeddings_2d(embeddings, offs, delta_dist, distance)[0].cpu().numpy())[0][:, 0]
        # mc_labels_embedding_aff = env.get_current_soln(edge_weights=torch.from_numpy(ew_embedaffs).to(device)).cpu().numpy()[0]
        rl_labels = env.current_soln.cpu().numpy()[0]

        ex_embeds.append(pca_project(embeds, n_comps=3))
        ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze())
        # ex_sps.append(cm.prism(env.init_sp_seg[0].cpu() / env.init_sp_seg[0].max().item()))
        ex_sps.append(env.init_sp_seg[0].cpu())
        ex_mc_gts.append(project_overseg_to_seg(env.init_sp_seg[0], torch.from_numpy(gt_seg).to(device)).cpu().numpy())

        ex_gts.append(gt_seg)
        ex_rl.append(rl_labels)
        # ex_clst.append(clst_labels_kmeans)
        # ex_clst_sp.append(clst_labels_sp_kmeans)
        # ex_clst_graph_agglo.append(clst_labels_sp_graph_agglo)
        # ex_mcaff.append(mc_labels_aff)
        # ex_mc_embed.append(mc_labels_embedding_aff)

        # map_rl.append(segm_metric(rl_labels, gt_seg))
        sbd_rl.append(sbd(gt_seg, rl_labels))
        clst_metric_rl(rl_labels, gt_seg)

        # map_sp_gt.append(segm_metric(ex_mc_gts[-1], gt_seg))
        sbd_sp_gt.append(sbd(gt_seg, ex_mc_gts[-1]))
        metric_sp_gt(ex_mc_gts[-1], gt_seg)

        # map_embed.append(segm_metric(clst_labels_kmeans, gt_seg))
        # clst_metric(clst_labels_kmeans, gt_seg)

        # map_mcaff.append(segm_metric(mc_labels_aff, gt_seg))
        # sbd_mcaff.append(sbd(gt_seg, mc_labels_aff))
        # clst_metric_mcaff(mc_labels_aff, gt_seg)
        #
        # map_mcembed.append(segm_metric(mc_labels_embedding_aff, gt_seg))
        # sbd_mcembed.append(sbd(gt_seg, mc_labels_embedding_aff))
        # clst_metric_mcembed(mc_labels_embedding_aff, gt_seg)
        #
        # map_graphagglo.append(segm_metric(clst_labels_sp_graph_agglo, gt_seg))
        # sbd_graphagglo.append(sbd(gt_seg, clst_labels_sp_graph_agglo.astype(np.int)))
        # clst_metric_graphagglo(clst_labels_sp_graph_agglo.astype(np.int), gt_seg)

    print("\nSBD: ")
    print(f"sp gt       : {round(np.array(sbd_sp_gt).mean(), 4)}; {round(np.array(sbd_sp_gt).std(), 4)}")
    print(f"ours        : {round(np.array(sbd_rl).mean(), 4)}; {round(np.array(sbd_rl).std(), 4)}")
    # print(f"mc node     : {np.array(sbd_mcembed).mean()}")
    # print(f"mc embed    : {np.array(sbd_mcaff).mean()}")
    # print(f"graph agglo : {np.array(sbd_graphagglo).mean()}")

    # print("\nmAP: ")
    # print(f"sp gt       : {np.array(map_sp_gt).mean()}")
    # print(f"ours        : {np.array(map_rl).mean()}")
    # print(f"mc node     : {np.array(map_mcembed).mean()}")
    # print(f"mc embed    : {np.array(map_mcaff).mean()}")
    # print(f"graph agglo : {np.array(map_graphagglo).mean()}")
    #
    vi_rl_s, vi_rl_m, are_rl, arp_rl, arr_rl = clst_metric_rl.dump()
    vi_spgt_s, vi_spgt_m, are_spgt, arp_spgt, arr_spgt = metric_sp_gt.dump()
    # vi_mcaff_s, vi_mcaff_m, are_mcaff, arp_mcaff, arr_mcaff = clst_metric_mcaff.dump()
    # vi_mcembed_s, vi_mcembed_m, are_mcembed, arp_embed, arr_mcembed = clst_metric_mcembed.dump()
    # vi_graphagglo_s, vi_graphagglo_m, are_graphagglo, arp_graphagglo, arr_graphagglo = clst_metric_graphagglo.dump()
    #
    vi_rl_s_std, vi_rl_m_std, are_rl_std, arp_rl_std, arr_rl_std = clst_metric_rl.dump_std()
    vi_spgt_s_std, vi_spgt_m_std, are_spgt_std, arp_spgt_std, arr_spgt_std = metric_sp_gt.dump_std()

    print("\nVI merge: ")
    print(f"sp gt       : {round(vi_spgt_m, 4)}; {round(vi_spgt_m_std, 4)}")
    print(f"ours        : {round(vi_rl_m, 4)}; {round(vi_rl_m_std, 4)}")
    # print(f"mc affnties : {vi_mcaff_m}")
    # print(f"mc embed    : {vi_mcembed_m}")
    # print(f"graph agglo : {vi_graphagglo_m}")
    #
    print("\nVI split: ")
    print(f"sp gt       : {round(vi_spgt_s, 4)}; {round(vi_spgt_s_std, 4)}")
    print(f"ours        : {round(vi_rl_s, 4)}; {round(vi_rl_s_std, 4)}")
    # print(f"mc affnties : {vi_mcaff_s}")
    # print(f"mc embed    : {vi_mcembed_s}")
    # print(f"graph agglo : {vi_graphagglo_s}")
    #
    print("\nARE: ")
    print(f"sp gt       : {round(are_spgt, 4)}; {round(are_spgt_std, 4)}")
    print(f"ours        : {round(are_rl, 4)}; {round(are_rl_std, 4)}")
    # print(f"mc affnties : {are_mcaff}")
    # print(f"mc embed    : {are_mcembed}")
    # print(f"graph agglo : {are_graphagglo}")
    #
    print("\nARP: ")
    print(f"sp gt       : {round(arp_spgt, 4)}; {round(arp_spgt_std, 4)}")
    print(f"ours        : {round(arp_rl, 4)}; {round(arp_rl_std, 4)}")
    # print(f"mc affnties : {arp_mcaff}")
    # print(f"mc embed    : {arp_embed}")
    # print(f"graph agglo : {arp_graphagglo}")
    #
    print("\nARR: ")
    print(f"sp gt       : {round(arr_spgt, 4)}; {round(arr_spgt_std, 4)}")
    print(f"ours        : {round(arr_rl, 4)}; {round(arr_rl_std, 4)}")
    # print(f"mc affnties : {arr_mcaff}")
    # print(f"mc embed    : {arr_mcembed}")
    # print(f"graph agglo : {arr_graphagglo}")

    exit()
    for i in range(len(ex_gts)):
        fig, axs = plt.subplots(2, 4, figsize=(20, 13), sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0})
        axs[0, 0].imshow(ex_gts[i], cmap=random_label_cmap(), interpolation="none")
        axs[0, 0].set_title('gt')
        axs[0, 0].axis('off')
        axs[0, 1].imshow(ex_embeds[i])
        axs[0, 1].set_title('pc proj')
        axs[0, 1].axis('off')
        # axs[0, 2].imshow(ex_clst[i], cmap=random_label_cmap(), interpolation="none")
        # axs[0, 2].set_title('pix clst')
        # axs[0, 2].axis('off')
        axs[0, 2].imshow(ex_clst_graph_agglo[i], cmap=random_label_cmap(), interpolation="none")
        axs[0, 2].set_title('nagglo')
        axs[0, 2].axis('off')
        axs[0, 3].imshow(ex_mc_embed[i], cmap=random_label_cmap(), interpolation="none")
        axs[0, 3].set_title('mc embed')
        axs[0, 3].axis('off')
        axs[1, 0].imshow(ex_mc_gts[i], cmap=random_label_cmap(), interpolation="none")
        axs[1, 0].set_title('sp gt')
        axs[1, 0].axis('off')
        axs[1, 1].imshow(ex_sps[i], cmap=random_label_cmap(), interpolation="none")
        axs[1, 1].set_title('sp')
        axs[1, 1].axis('off')
        # axs[1, 2].imshow(ex_clst_sp[i], cmap=random_label_cmap(), interpolation="none")
        # axs[1, 2].set_title('sp clst')
        # axs[1, 2].axis('off')
        axs[1, 2].imshow(ex_rl[i], cmap=random_label_cmap(), interpolation="none")
        axs[1, 2].set_title('ours')
        axs[1, 2].axis('off')
        axs[1, 3].imshow(ex_mcaff[i], cmap=random_label_cmap(), interpolation="none")
        axs[1, 3].set_title('mc aff')
        axs[1, 3].axis('off')
        plt.show()
        # wandb.log({"validation/samples": [wandb.Image(fig, caption="sample images")]})
        plt.close('all')
コード例 #14
0
def supervised_policy_pretraining(model,
                                  env,
                                  cfg,
                                  device="cuda:0",
                                  fe_opt=False):
    wu_cfg = AttrDict()
    add_dict(cfg.policy_warmup, wu_cfg)
    dset = SpgDset(cfg.data_dir, wu_cfg.patch_manager, max(cfg.trn.s_subgraph))
    dloader = DataLoader(dset,
                         batch_size=wu_cfg.batch_size,
                         shuffle=True,
                         pin_memory=True,
                         num_workers=0)
    if fe_opt:
        actor_fe_opt = torch.optim.Adam(list(model.actor.parameters()) +
                                        list(env.embedding_net.parameters()),
                                        lr=wu_cfg.lr)
    else:
        actor_fe_opt = torch.optim.Adam(model.actor.parameters(), lr=wu_cfg.lr)

    dummy_opt = torch.optim.Adam([model.log_alpha], lr=wu_cfg.lr)
    sheduler = ReduceLROnPlateau(actor_fe_opt, threshold=0.001, min_lr=1e-6)
    criterion = torch.nn.BCELoss()
    acc_loss = 0
    iteration = 0
    best_score = -np.inf
    # be careful with this, it assumes a one step episode environment
    while iteration <= wu_cfg.n_iterations:
        update_env_data(env,
                        dloader,
                        device,
                        with_gt_edges=True,
                        fe_grad=fe_opt)
        state = env.get_state()
        # Calculate policy and values
        distr, q1, q2, _, _ = agent_forward(env, model, state, policy_opt=True)
        action = distr.transforms[0](distr.loc)
        loss = criterion(action.squeeze(1), env.gt_edge_weights)

        dummy_loss = (model.alpha * 0).sum(
        )  # not using all parameters in backprop gives error, so add dummy loss
        for sq1, sq2 in zip(q1, q2):
            loss = loss + (sq1.sum() * sq2.sum() * 0)

        actor_fe_opt.zero_grad()
        loss.backward(retain_graph=False)
        actor_fe_opt.step()

        dummy_opt.zero_grad()
        dummy_loss.backward(retain_graph=False)
        dummy_opt.step()
        acc_loss += loss.item()

        if iteration % 10 == 0:
            _, reward = env.execute_action(action.detach(),
                                           None,
                                           post_images=True,
                                           tau=0.0)
            sheduler.step(acc_loss / 10)
            total_reward = 0
            for _rew in reward:
                total_reward += _rew.mean().item()
            total_reward /= len(reward)
            wandb.log({"policy_warm_start/acc_loss": acc_loss})
            wandb.log({"policy_warm_start/rewards": total_reward})
            acc_loss = 0
            if total_reward > best_score:
                best_model = copy.deepcopy(model.state_dict())
                best_score = total_reward
        wandb.log({"policy_warm_start/loss": loss.item()})
        wandb.log({"policy_warm_start/lr": actor_fe_opt.param_groups[0]['lr']})
        iteration += 1
    model.load_state_dict(best_model)
    return
コード例 #15
0
    def __init__(self, cfg, global_count):
        super(AgentSaTrainerObjLvlReward, self).__init__()
        assert torch.cuda.device_count() == 1
        self.device = torch.device("cuda:0")
        torch.cuda.set_device(self.device)
        torch.set_default_tensor_type(torch.FloatTensor)

        self.cfg = cfg
        self.global_count = global_count
        self.memory = TransitionData_ts(capacity=self.cfg.mem_size)
        self.best_val_reward = -np.inf
        if self.cfg.distance == 'cosine':
            self.distance = CosineDistance()
        else:
            self.distance = L2Distance()

        self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone),
                                  self.distance, cfg.fe_delta_dist,
                                  self.device)
        self.fe_ext.embed_model.load_state_dict(
            torch.load(self.cfg.fe_model_name))
        self.fe_ext.cuda(self.device)

        self.model = Agent(self.cfg, State, self.distance, self.device)
        wandb.watch(self.model)
        self.model.cuda(self.device)
        self.model_mtx = Lock()

        self.optimizer = torch.optim.Adam(self.model.actor.parameters(),
                                          lr=self.cfg.actor_lr)

        lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched)
        bw = lr_sched_cfg.mov_avg_bandwidth
        off = lr_sched_cfg.mov_avg_offset
        weights = np.linspace(lr_sched_cfg.weight_range[0],
                              lr_sched_cfg.weight_range[1], bw)
        weights = weights / weights.sum()  # make them sum up to one
        shed = lr_sched_cfg.torch_sched
        self.shed = ReduceLROnPlateau(self.optimizer,
                                      patience=shed.patience,
                                      threshold=shed.threshold,
                                      min_lr=shed.min_lr,
                                      factor=shed.factor)

        self.mov_sum_loss = RunningAverage(weights, band_width=bw, offset=off)
        self.scaler = torch.cuda.amp.GradScaler()
        self.forwarder = Forwarder()

        if self.cfg.agent_model_name != "":
            self.model.load_state_dict(torch.load(self.cfg.agent_model_name))

        # finished with prepping
        for param in self.fe_ext.parameters():
            param.requires_grad = False

        self.train_dset = SpgDset(self.cfg.data_dir,
                                  dict_to_attrdict(self.cfg.patch_manager),
                                  dict_to_attrdict(self.cfg.data_keys))
        self.val_dset = SpgDset(self.cfg.val_data_dir,
                                dict_to_attrdict(self.cfg.patch_manager),
                                dict_to_attrdict(self.cfg.data_keys))
コード例 #16
0
        #
        # graph_file.create_dataset("edges", data=edges, chunks=True)
        # graph_file.create_dataset("edge_feat", data=edge_feat, chunks=True)
        # graph_file.create_dataset("diff_to_gt", data=diff_to_gt)
        # graph_file.create_dataset("gt_edge_weights", data=gt_edge_weights, chunks=True)
        # graph_file.create_dataset("node_labeling", data=node_labeling, chunks=True)
        # graph_file.create_dataset("affinities", data=affinities, chunks=True)
        #
        # graph_file.close()
        # pix_file.close()


if __name__ == "__main__":
    dir = "/g/kreshuk/hilt/projects/fewShotLearning/mutexWtsd/data/storage/sqrs_crclspn/pix_and_graphs"
    # store_all(dir)

    dset = SpgDset(dir)
    raw, gt, sp_seg, idx = dset.__getitem__(20)
    edges, edge_feat, diff_to_gt, gt_edge_weights = dset.get_graphs(idx)
    gt_seg = get_current_soln(gt_edge_weights[0].numpy().astype(np.float64),
                              sp_seg[0].numpy().astype(np.uint64),
                              edges[0].numpy().transpose().astype(np.int64))
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
    ax1.imshow(gt[0])
    ax1.set_title('gt')
    ax2.imshow(cm.prism(sp_seg[0] / sp_seg[0].max()))
    ax2.set_title('sp')
    ax3.imshow(gt_seg)
    ax3.set_title('mc')
    plt.show()
    a = 1
コード例 #17
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(**self.cfg.fe.backbone)
        model.cuda(device)
        train_set = SpgDset(self.cfg.gen.data_dir_raw_train, reorder_sp=False)
        val_set = SpgDset(self.cfg.gen.data_dir_raw_val, reorder_sp=False)
        # pm = StridedPatches2D(wu_cfg.patch_stride, wu_cfg.patch_shape, train_set.image_shape)
        pm = NoPatches2D()
        train_set.length = len(train_set.graph_file_names) * np.prod(pm.n_patch_per_dim)
        train_set.n_patch_per_dim = pm.n_patch_per_dim
        val_set.length = len(val_set.graph_file_names)
        # dset = LeptinDset(self.cfg.gen.data_dir_raw, self.cfg.gen.data_dir_affs, wu_cfg.patch_manager, wu_cfg.patch_stride, wu_cfg.patch_shape, wu_cfg.reorder_sp)
        train_loader = DataLoader(train_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True,
                             num_workers=0)
        val_loader = DataLoader(val_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True,
                             num_workers=0)
        gauss_kernel = GaussianSmoothing(1, 5, 3, device=device)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=20,
                                     threshold=1e-4,
                                     min_lr=1e-5,
                                     factor=0.1)
        criterion = RagContrastiveWeights(delta_var=0.1, delta_dist=0.4)
        acc_loss = 0
        valit = 0
        iteration = 0
        best_loss = np.inf

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(train_loader):
                raw, gt, sp_seg, affinities = raw.to(device), gt.to(device), sp_seg.to(device), affinities.to(device)

                # edge_img = F.pad(get_contour_from_2d_binary(sp_seg), (2, 2, 2, 2), mode='constant')
                # edge_img = gauss_kernel(edge_img.float())
                # input = torch.cat([raw, edge_img], dim=1)

                offs = offs.numpy().tolist()
                loss_embeds = model(raw[:, :, None]).squeeze(2)

                edge_feat, edges = tuple(zip(*[get_edge_features_1d(seg.squeeze().cpu().numpy(), os, affs.squeeze().cpu().numpy()) for seg, os, affs in zip(sp_seg, offs, affinities)]))
                edges = [torch.from_numpy(e.astype(np.long)).to(device).T for e in edges]
                edge_weights = [torch.from_numpy(ew.astype(np.float32)).to(device)[:, 0][None] for ew in edge_feat]

                # put embeddings on unit sphere so we can use cosine distance
                loss_embeds = loss_embeds / (torch.norm(loss_embeds, dim=1, keepdim=True) + 1e-9)

                loss = criterion(loss_embeds, sp_seg.long(), edges, edge_weights,
                                 chunks=int(sp_seg.max().item()//self.cfg.gen.train_chunk_size),
                                 sigm_factor=self.cfg.gen.sigm_factor, pull_factor=self.cfg.gen.pull_factor)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                print(loss.item())
                writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration)
                writer.add_scalar("fe_train/loss", loss.item(), iteration)
                if (iteration) % 100 == 0:
                    with torch.set_grad_enabled(False):
                        for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(val_loader):
                            raw, gt, sp_seg, affinities = raw.to(device), gt.to(device), sp_seg.to(device), affinities.to(device)

                            offs = offs.numpy().tolist()
                            embeddings = model(raw[:, :, None]).squeeze(2)

                            # relabel to consecutive ints starting at 0
                            edge_feat, edges = tuple(zip(
                                *[get_edge_features_1d(seg.squeeze().cpu().numpy(), os, affs.squeeze().cpu().numpy())
                                  for seg, os, affs in zip(sp_seg, offs, affinities)]))
                            edges = [torch.from_numpy(e.astype(np.long)).to(device).T for e in edges]
                            edge_weights = [torch.from_numpy(ew.astype(np.float32)).to(device)[:, 0][None] for ew in edge_feat]

                            # put embeddings on unit sphere so we can use cosine distance
                            embeddings = embeddings / (torch.norm(embeddings, dim=1, keepdim=True) + 1e-9)

                            ls = criterion(embeddings, sp_seg.long(), edges, edge_weights,
                                           chunks=int(sp_seg.max().item()//self.cfg.gen.train_chunk_size),
                                           sigm_factor=self.cfg.gen.sigm_factor, pull_factor=self.cfg.gen.pull_factor)
                            # ls = 0
                            acc_loss += ls
                            writer.add_scalar("fe_val/loss", ls, valit)
                            valit += 1
                    acc_loss = acc_loss / len(val_loader)
                    if acc_loss < best_loss:
                        print(self.save_dir)
                        torch.save(model.state_dict(), os.path.join(self.save_dir, "best_val_model.pth"))
                        best_loss = acc_loss
                    sheduler.step(acc_loss)
                    acc_loss = 0
                    fig, ((a1, a2), (a3, a4)) = plt.subplots(2, 2, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0})
                    a1.imshow(raw[0].cpu().permute(1, 2, 0)[..., 0].squeeze())
                    a1.set_title('raw')
                    a2.imshow(cm.prism(sp_seg[0, 0].cpu().squeeze() / sp_seg[0, 0].cpu().squeeze().max()))
                    a2.set_title('sp')
                    a3.imshow(pca_project(get_angles(embeddings)[0].detach().cpu()))
                    a3.set_title('angle_embed')
                    a4.imshow(pca_project(embeddings[0].detach().cpu()))
                    a4.set_title('embed')
                    # plt.show()
                    writer.add_figure("examples", fig, iteration//100)
                iteration += 1
                print(iteration)
                if iteration > wu_cfg.n_iterations:
                    print(self.save_dir)
                    torch.save(model.state_dict(), os.path.join(self.save_dir, "last_model.pth"))
                    break
        return
コード例 #18
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        writer.add_text("conf", self.cfg.pretty())
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(self.cfg.fe.n_raw_channels,
                       self.cfg.fe.n_embedding_features,
                       final_sigmoid=False,
                       num_levels=5)
        model.cuda(device)
        dset = SpgDset(self.cfg.gen.data_dir, wu_cfg.patch_manager,
                       wu_cfg.patch_stride, wu_cfg.patch_shape,
                       wu_cfg.reorder_sp)
        dloader = DataLoader(dset,
                             batch_size=wu_cfg.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        tfs = RndAugmentationTfs(wu_cfg.patch_shape)
        criterion = AugmentedAffinityContrastive(delta_var=0.1, delta_dist=0.3)
        acc_loss = 0
        iteration = 0

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, indices) in enumerate(dloader):
                raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(
                    device)
                # this is still not the correct mask calculation as the affinity offsets go in no tf offset direction
                mask = torch.from_numpy(
                    get_valid_edges([len(criterion.offs)] +
                                    list(raw.shape[-2:]),
                                    criterion.offs)).to(device)[None]
                # _, _, _, _, affs = dset.get_graphs(indices, sp_seg, device)
                spat_tf, int_tf = tfs.sample(1, 1)
                _, _int_tf = tfs.sample(1, 1)
                inp = add_sp_gauss_noise(_int_tf(raw), 0.2, 0.1, 0.3)
                embeddings = model(inp.unsqueeze(2)).squeeze(2)

                paired = spat_tf(torch.cat((mask, raw, embeddings), -3))
                embeddings_0, mask = paired[
                    ..., inp.shape[1] + len(criterion.offs):, :, :], paired[
                        ..., :len(criterion.offs), :, :].detach()
                # do intensity transform for spatial transformed input
                aug_inp = int_tf(paired[...,
                                        len(criterion.offs):inp.shape[1] +
                                        len(criterion.offs), :, :]).detach()
                # get prediction of the augmented input
                embeddings_1 = model(
                    add_sp_gauss_noise(aug_inp, 0.2, 0.1,
                                       0.3).unsqueeze(2)).squeeze(2)

                # put embeddings on unit sphere so we can use cosine distance
                embeddings_0 = embeddings_0 / (
                    torch.norm(embeddings_0, dim=1, keepdim=True) + 1e-6)
                embeddings_1 = embeddings_1 / (
                    torch.norm(embeddings_1, dim=1, keepdim=True) + 1e-6)

                loss = criterion(embeddings_0, embeddings_1, aug_inp, mask)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                acc_loss += loss.item()

                print(loss.item())
                writer.add_scalar("fe_warm_start/loss", loss.item(), iteration)
                writer.add_scalar("fe_warm_start/lr",
                                  optimizer.param_groups[0]['lr'], iteration)
                if (iteration) % 50 == 0:
                    acc_loss = 0
                    fig, ((a1, a2), (a3, a4)) = plt.subplots(2, 2)
                    a1.imshow(aug_inp[0].cpu().permute(1, 2, 0).squeeze())
                    a1.set_title('tf_raw')
                    a3.imshow(
                        pca_project(
                            get_angles(embeddings_0).squeeze(
                                0).detach().cpu()))
                    a3.set_title('tf_embed')
                    a4.imshow(
                        pca_project(
                            get_angles(embeddings_1).squeeze(
                                0).detach().cpu()))
                    a4.set_title('embed')
                    a2.imshow(raw[0].cpu().permute(1, 2, 0).squeeze())
                    a2.set_title('raw')
                    plt.show()
                    # writer.add_figure("examples", fig, iteration//100)
                iteration += 1
                if iteration > wu_cfg.n_iterations:
                    break
        return