Пример #1
0
    def __init__(self, cfg, global_count, global_writer_loss_count,
                 global_writer_quality_count, action_stats_count,
                 global_writer_count, save_dir):
        super(AgentSacTrainer_test_sg_global, self).__init__()

        self.cfg = cfg
        self.global_count = global_count
        self.global_writer_loss_count = global_writer_loss_count
        self.global_writer_quality_count = global_writer_quality_count
        self.action_stats_count = action_stats_count
        self.global_writer_count = global_writer_count
        self.contr_trpl_loss = ContrastiveTripletLoss(delta_var=0.5)
        self.contr_loss = ContrastiveLoss(delta_var=0.5, delta_dist=1.5)

        self.memory = TransitionData_ts(capacity=self.cfg.trainer.t_max)
        # self.eps = self.args.init_epsilon
        self.save_dir = save_dir
Пример #2
0
    def train_step(self, rank, start_time, return_dict, writer):
        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        self.setup(rank, self.args.num_processes)
        if self.cfg.MC_DQL:
            transition = namedtuple('Transition', ('episode'))
        else:
            transition = namedtuple(
                'Transition',
                ('state', 'action', 'reward', 'next_state', 'done'))
        memory = TransitionData_ts(capacity=self.args.t_max,
                                   storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count)
        # Create shared network

        # model = GcnEdgeAC_1(self.cfg, self.args.n_raw_channels, self.args.n_embedding_features, 1, device, writer=writer)
        model = GcnEdgeAC(self.cfg, self.args, device, writer=writer)
        # model = GcnEdgeAC(self.cfg, self.args.n_raw_channels, self.args.n_embedding_features, 1, device, writer=writer)

        model.cuda(device)
        shared_model = DDP(model,
                           device_ids=[model.device],
                           find_unused_parameters=True)

        # dloader = DataLoader(MultiDiscSpGraphDsetBalanced(no_suppix=False, create=False), batch_size=1, shuffle=True, pin_memory=True,
        #                      num_workers=0)
        dloader = DataLoader(SpgDset(),
                             batch_size=self.cfg.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create optimizer for shared network parameters with shared statistics
        # optimizer = CstmAdam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas,
        #                      weight_decay=self.args.Adam_weight_decay)
        ######################
        self.action_range = 1
        self.device = torch.device(device)
        self.discount = 0.5
        self.critic_tau = self.cfg.critic_tau
        self.actor_update_frequency = self.cfg.actor_update_frequency
        self.critic_target_update_frequency = self.cfg.critic_target_update_frequency
        self.batch_size = self.cfg.batch_size

        self.log_alpha = torch.tensor(np.log(self.cfg.init_temperature)).to(
            self.device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        ######################
        # optimizers
        OptimizerContainer = namedtuple('OptimizerContainer',
                                        ('actor', 'critic', 'temperature'))
        actor_optimizer = torch.optim.Adam(
            shared_model.module.actor.parameters(),
            lr=self.cfg.actor_lr,
            betas=self.cfg.actor_betas)

        critic_optimizer = torch.optim.Adam(
            shared_model.module.critic.parameters(),
            lr=self.cfg.critic_lr,
            betas=self.cfg.critic_betas)

        temp_optimizer = torch.optim.Adam([self.log_alpha],
                                          lr=self.cfg.alpha_lr,
                                          betas=self.cfg.alpha_betas)

        optimizers = OptimizerContainer(actor_optimizer, critic_optimizer,
                                        temp_optimizer)

        if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only:
            fe_extr = shared_model.module.fe_ext
            fe_extr.cuda(device)
            self.fe_extr_warm_start_1(fe_extr, writer=writer)
            # self.fe_extr_warm_start(fe_extr, writer=writer)
            if self.args.model_name == "" and not self.args.no_save:
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, 'agent_model_fe_extr'))
            elif not self.args.no_save:
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        dist.barrier()
        for param in model.fe_ext.parameters():
            param.requires_grad = False

        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.model_fe_name != "":
            shared_model.module.fe_ext.load_state_dict(
                torch.load(os.path.join(self.save_dir,
                                        self.args.model_fe_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            shared_model.module.fe_ext.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model_fe_extr')))

        if not self.args.test_score_only:
            quality = self.args.stop_qual_scaling + self.args.stop_qual_offset
            best_quality = np.inf
            last_quals = []
            while self.global_count.value() <= self.args.T_max:
                if self.global_count.value() == 78:
                    a = 1
                self.update_env_data(env, dloader, device)
                # waff_dis = torch.softmax(env.edge_features[:, 0].squeeze() + 1e-30, dim=0)
                # waff_dis = torch.softmax(env.gt_edge_weights + 0.5, dim=0)
                waff_dis = torch.softmax(torch.ones_like(
                    env.b_gt_edge_weights),
                                         dim=0)
                loss_weight = torch.softmax(env.b_gt_edge_weights + 1, dim=0)
                env.reset()
                # self.target_entropy = - float(env.gt_edge_weights.shape[0])
                self.target_entropy = -8.0

                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value(), quality)
                if self.cfg.temperature_regulation == 'follow_quality':
                    self.alpha = self.eps_rule.apply(self.global_count.value(),
                                                     quality)
                    print(self.alpha.item())

                with open(os.path.join(self.save_dir,
                                       'runtime_cfg.yaml')) as info:
                    args_dict = yaml.full_load(info)
                    if args_dict is not None:
                        if 'safe_model' in args_dict:
                            self.args.safe_model = args_dict['safe_model']
                            args_dict['safe_model'] = False
                        if 'add_noise' in args_dict:
                            self.args.add_noise = args_dict['add_noise']
                        if 'critic_lr' in args_dict and args_dict[
                                'critic_lr'] != self.cfg.critic_lr:
                            self.cfg.critic_lr = args_dict['critic_lr']
                            adjust_learning_rate(critic_optimizer,
                                                 self.cfg.critic_lr)
                        if 'actor_lr' in args_dict and args_dict[
                                'actor_lr'] != self.cfg.actor_lr:
                            self.cfg.actor_lr = args_dict['actor_lr']
                            adjust_learning_rate(actor_optimizer,
                                                 self.cfg.actor_lr)
                        if 'alpha_lr' in args_dict and args_dict[
                                'alpha_lr'] != self.cfg.alpha_lr:
                            self.cfg.alpha_lr = args_dict['alpha_lr']
                            adjust_learning_rate(temp_optimizer,
                                                 self.cfg.alpha_lr)
                with open(os.path.join(self.save_dir, 'runtime_cfg.yaml'),
                          "w") as info:
                    yaml.dump(args_dict, info)

                if self.args.safe_model:
                    best_quality = quality
                    if rank == 0:
                        if self.args.model_name_dest != "":
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir,
                                             self.args.model_name_dest))
                        else:
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir, 'agent_model'))

                state = env.get_state()
                while not env.done:
                    # Calculate policy and values
                    post_input = True if (
                        self.global_count.value() +
                        1) % 15 == 0 and env.counter == 0 else False
                    round_n = env.counter
                    # sample action for data collection
                    distr = None
                    if self.global_count.value() < self.cfg.num_seed_steps:
                        action = torch.rand_like(env.b_current_edge_weights)
                    else:
                        distr, _, _, action = self.agent_forward(
                            env,
                            shared_model,
                            state=state,
                            grad=False,
                            post_input=post_input)

                    logg_dict = {'temperature': self.alpha.item()}
                    if distr is not None:
                        logg_dict['mean_loc'] = distr.loc.mean().item()
                        logg_dict['mean_scale'] = distr.scale.mean().item()

                    if self.global_count.value(
                    ) >= self.cfg.num_seed_steps and memory.is_full():
                        self._step(memory,
                                   optimizers,
                                   env,
                                   shared_model,
                                   self.global_count.value(),
                                   writer=writer)
                        self.global_writer_loss_count.increment()

                    next_state, reward, quality = env.execute_action(
                        action, logg_dict)

                    last_quals.append(quality)
                    if len(last_quals) > 10:
                        last_quals.pop(0)

                    if self.args.add_noise:
                        noise = torch.randn_like(reward) * self.alpha.item()
                        reward = reward + noise

                    memory.push(self.state_to_cpu(state), action, reward,
                                self.state_to_cpu(next_state), env.done)

                    # Train the network
                    # self._step(memory, shared_model, env, optimizer, loss_weight, off_policy=True, writer=writer)

                    # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                    # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                    state = next_state

                self.global_count.increment()

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save:
                # pass
                if self.args.model_name_dest != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name_dest))
                    print('saved')
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))

        self.cleanup()
        return sum(last_quals) / 10
Пример #3
0
class AgentSacTrainer_test_sg_global(object):
    def __init__(self, cfg, global_count, global_writer_loss_count,
                 global_writer_quality_count, action_stats_count,
                 global_writer_count, save_dir):
        super(AgentSacTrainer_test_sg_global, self).__init__()

        self.cfg = cfg
        self.global_count = global_count
        self.global_writer_loss_count = global_writer_loss_count
        self.global_writer_quality_count = global_writer_quality_count
        self.action_stats_count = action_stats_count
        self.global_writer_count = global_writer_count
        self.contr_trpl_loss = ContrastiveTripletLoss(delta_var=0.5)
        self.contr_loss = ContrastiveLoss(delta_var=0.5, delta_dist=1.5)

        self.memory = TransitionData_ts(capacity=self.cfg.trainer.t_max)
        # self.eps = self.args.init_epsilon
        self.save_dir = save_dir

    def setup(self, rank, world_size):
        # BLAS setup
        os.environ['OMP_NUM_THREADS'] = '10'
        os.environ['MKL_NUM_THREADS'] = '10'

        # assert torch.cuda.device_count() == 1
        torch.set_default_tensor_type('torch.FloatTensor')

        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = self.cfg.gen.master_port

        # initialize the process group
        dist.init_process_group("gloo", rank=rank, world_size=world_size)

    def pretrain_embeddings_gt(self, model, device, writer=None):
        dset = SpgDset(root_dir=self.cfg.gen.data_dir)
        dloader = DataLoader(dset,
                             batch_size=self.cfg.fe.warmup.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        max_p = torch.nn.MaxPool2d(3, padding=1, stride=1)
        sheduler = ReduceLROnPlateau(optimizer)
        acc_loss = 0
        iteration = 0

        while iteration <= self.cfg.fe.warmup.n_iterations:
            for it, (raw, gt, sp_seg, indices) in enumerate(dloader):
                raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(
                    device)
                sp_seg_edge = torch.cat([(-max_p(-sp_seg) != sp_seg).float(),
                                         (max_p(sp_seg) != sp_seg).float()], 1)
                embeddings = model(torch.cat([raw, sp_seg_edge], 1))
                loss = self.contr_loss(embeddings, gt.long().squeeze(1))

                optimizer.zero_grad()
                loss.backward(retain_graph=False)
                optimizer.step()
                acc_loss += loss.item()

                if writer is not None:
                    writer.add_scalar("fe_warm_start/loss", loss.item(),
                                      iteration)
                    writer.add_scalar("fe_warm_start/lr",
                                      optimizer.param_groups[0]['lr'],
                                      iteration)
                    if it % 50 == 0:
                        plt.clf()
                        fig = plt.figure(frameon=False)
                        plt.imshow(sp_seg[0].detach().squeeze().cpu().numpy())
                        plt.colorbar()
                        writer.add_figure("image/sp_seg", fig, iteration // 50)
                if it % 10 == 0:
                    sheduler.step(acc_loss / 10)
                    acc_loss = 0
                iteration += 1
                if iteration > self.cfg.fe.warmup.n_iterations:
                    break

                del loss
                del embeddings
        return

    def pretrain_embeddings_sp(self, model, device, writer=None):
        dset = SpgDset(self.args.data_dir, self.cfg.fe.patch_manager,
                       self.cfg.fe.patch_stride, self.cfg.fe.patch_shape)
        dloader = DataLoader(dset,
                             batch_size=self.cfg.fe.warmup.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        max_p = torch.nn.MaxPool2d(3, padding=1, stride=1)
        sheduler = ReduceLROnPlateau(optimizer)
        acc_loss = 0

        for i in range(self.cfg.fe.warmup.n_iterations):
            print(f"fe ext wu iter: {i}")
            for it, (raw, gt, sp_seg, indices) in enumerate(dloader):
                raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(
                    device)
                sp_seg_edge = torch.cat([(-max_p(-sp_seg) != sp_seg).float(),
                                         (max_p(sp_seg) != sp_seg).float()], 1)
                embeddings = model(torch.cat([raw, sp_seg_edge], 1),
                                   True if it % 500 == 0 else False)

                loss = self.contr_loss(embeddings, sp_seg.long().squeeze(1))

                optimizer.zero_grad()
                loss.backward(retain_graph=False)
                optimizer.step()
                acc_loss += loss.item()

                if writer is not None:
                    writer.add_scalar("fe_warm_start/loss", loss.item(),
                                      (len(dloader) * i) + it)
                    writer.add_scalar("fe_warm_start/lr",
                                      optimizer.param_groups[0]['lr'],
                                      (len(dloader) * i) + it)
                    if it % 500 == 0:
                        plt.clf()
                        fig = plt.figure(frameon=False)
                        plt.imshow(sp_seg[0].detach().squeeze().cpu().numpy())
                        plt.colorbar()
                        writer.add_figure("image/sp_seg", fig,
                                          ((len(dloader) * i) + it) // 500)

                if it % 10 == 0:
                    sheduler.step(acc_loss / 10)
                    acc_loss = 0

    def cleanup(self):
        dist.destroy_process_group()

    def update_env_data(self, env, dloader, device):
        raw, gt, sp_seg, indices = next(iter(dloader))
        raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(device)
        edges, edge_feat, diff_to_gt, gt_edge_weights = dloader.dataset.get_graphs(
            indices, device)
        env.update_data(edges, edge_feat, diff_to_gt, gt_edge_weights, sp_seg,
                        raw, gt)

    def agent_forward(self,
                      env,
                      model,
                      state,
                      actions=None,
                      grad=True,
                      post_input=False,
                      post_model=False,
                      policy_opt=False,
                      embeddings_opt=False):
        with torch.set_grad_enabled(grad):
            raw, sp_seg, edge_ids, sp_indices, edge_angles, sub_graphs, sep_subgraphs, counter, b_gt_edge_weights, edge_offsets = self.state_to_cuda(
                state, env.device)
            if actions is not None:
                actions = actions.to(model.module.device)
            counter /= self.cfg.trainer.max_episode_length
            # model.module.writer.add_graph(model.module, (raw, b_gt_edge_weights, sp_indices, edge_ids, edge_angles, counter, sub_graphs, sep_subgraphs, actions, post_input, policy_opt), verbose=False)

            ret = model(raw,
                        sp_seg,
                        edge_index=edge_ids,
                        angles=edge_angles,
                        round_n=counter,
                        actions=actions,
                        sp_indices=sp_indices,
                        gt_edges=b_gt_edge_weights,
                        sub_graphs=sub_graphs,
                        sep_subgraphs=sep_subgraphs,
                        post_input=post_input,
                        policy_opt=policy_opt and grad,
                        embeddings_opt=embeddings_opt)

            if post_model and grad:
                for name, value in model.module.actor.named_parameters():
                    model.writer.add_histogram(name,
                                               value.data.cpu().numpy(),
                                               self.global_count.value())
                    model.writer.add_histogram(name + '/grad',
                                               value.grad.data.cpu().numpy(),
                                               self.global_count.value())
                for name, value in model.module.critic_tgt.named_parameters():
                    model.writer.add_histogram(name,
                                               value.data.cpu().numpy(),
                                               self.global_count.value())
                    model.writer.add_histogram(name + '/grad',
                                               value.grad.data.cpu().numpy(),
                                               self.global_count.value())

        return ret

    def _get_connected_paths(self, edges, weights, size, get_repulsive=False):
        graph = np.ones((size, size)) * np.inf
        graph[edges[0], edges[1]] = weights
        graph[edges[1], edges[0]] = weights
        dists = dijkstra(graph, directed=False)
        if get_repulsive:
            tril = np.tril(np.ones_like(dists), 0).astype(np.bool)
            dists[tril] = 0
            repulsive_edges = np.nonzero(dists == np.inf)
            if repulsive_edges[0].size > 0:
                return torch.stack([
                    torch.from_numpy(repulsive_edges[0]),
                    torch.from_numpy(repulsive_edges[1])
                ], 0)
            return None
        tril = np.tril(np.ones_like(dists) * np.inf, 0)
        dists += tril
        attr_edges = np.nonzero(dists < np.inf)
        if attr_edges[0].size > 0:
            return torch.stack([
                torch.from_numpy(attr_edges[0]),
                torch.from_numpy(attr_edges[1])
            ], 0)
        return None

    def get_embed_loss_contr_trpl(self, weights, obs, embeddings):
        b_attr_edges = []
        b_rep_edges = []
        for i in range(obs[1].shape[0]):
            edges = obs[2][:, obs[-1][i]:obs[-1][i + 1]]
            edges = edges - edges.min()

            attr_weight_del = weights[
                obs[-1][i]:obs[-1][i + 1]] < self.cfg.sac.weight_tolerance_attr
            attr_weights = weights[obs[-1][i]:obs[-1][i + 1]][attr_weight_del]
            if len(attr_weights) != 0:
                attr_weights = attr_weights - attr_weights.min()
                max = attr_weights.max()
                max = max if max != 0 else 1e-16
                attr_weights = attr_weights / max
                attr_weights += 1e-16  # make sure all edges exist in graph
                direct_attr = edges[:, attr_weight_del].numpy()
                b_attr_edges.append(
                    self._get_connected_paths(direct_attr, attr_weights,
                                              edges.max() + 1))
            else:
                b_attr_edges.append(None)

            rep_weight_del = weights[
                obs[-1][i]:obs[-1][i + 1]] <= self.cfg.sac.weight_tolerance_rep
            rep_weights = weights[obs[-1][i]:obs[-1][i + 1]][rep_weight_del]
            if len(rep_weights) != 0:
                rep_weights = rep_weights - rep_weights.min()
                max = rep_weights.max()
                max = max if max != 0 else 1e-16
                rep_weights = rep_weights / max
                rep_weights += 1e-16  # make sure all edges exist in graph
                direct_rep = edges[:, rep_weight_del].numpy()
                b_rep_edges.append(
                    self._get_connected_paths(direct_rep,
                                              rep_weights,
                                              edges.max() + 1,
                                              get_repulsive=True))
            else:
                b_rep_edges.append(None)

        return self.contr_trpl_loss(embeddings,
                                    obs[1].long().to(embeddings.device),
                                    (b_attr_edges, b_rep_edges))

    def get_embed_loss_contr(self, weights, env, embeddings):
        segs = env.get_current_soln(weights)
        return self.contr_loss(embeddings, segs.long().to(embeddings.device))

    def update_embeddings(self, obs, env, model, optimizers):
        distribution, actor_Q1, actor_Q2, action, embeddings, side_loss = \
            self.agent_forward(env, model, grad=False, state=obs, policy_opt=False, embeddings_opt=True)
        weights = distribution.loc.detach()
        loss = self.get_embed_loss_contr(weights, env, embeddings)

        optimizers.embeddings.zero_grad()
        loss.backward()
        optimizers.embeddings.step()
        return loss.item()

    def update_critic(self, obs, action, reward, next_obs, not_done, env,
                      model, optimizers):
        distribution, target_Q1, target_Q2, next_action, _, side_loss = self.agent_forward(
            env, model, state=next_obs)
        current_Q1, current_Q2, side_loss = self.agent_forward(env,
                                                               model,
                                                               state=obs,
                                                               actions=action)

        log_prob = distribution.log_prob(next_action)
        critic_loss = torch.tensor([0.0], device=target_Q1[0].device)
        mean_reward = 0

        for i, sz in enumerate(self.cfg.sac.s_subgraph):
            _log_prob = log_prob[next_obs[5][i].view(-1, sz)].sum(-1)

            target_V = torch.min(
                target_Q1[i],
                target_Q2[i]) - model.module.alpha[i].detach() * _log_prob

            target_Q = reward[i] + (not_done * self.cfg.sac.discount *
                                    target_V)
            target_Q = target_Q.detach()

            critic_loss = critic_loss + (
                F.mse_loss(current_Q1[i], target_Q) +
                F.mse_loss(current_Q2[i], target_Q)
            )  # / 2) + self.cfg.sac.sl_beta * side_loss
            mean_reward += reward[i].mean()
        # critic_loss = critic_loss / len(self.cfg.sac.s_subgraph)
        optimizers.critic.zero_grad()
        critic_loss.backward()
        optimizers.critic.step()

        return critic_loss.item(), mean_reward / len(self.cfg.sac.s_subgraph)

    def update_actor_and_alpha(self,
                               obs,
                               env,
                               model,
                               optimizers,
                               embeddings_opt=False):
        distribution, actor_Q1, actor_Q2, action, side_loss = \
            self.agent_forward(env, model, state=obs, policy_opt=True, embeddings_opt=embeddings_opt)

        log_prob = distribution.log_prob(action)
        actor_loss = torch.tensor([0.0], device=actor_Q1[0].device)
        alpha_loss = torch.tensor([0.0], device=actor_Q1[0].device)

        for i, sz in enumerate(self.cfg.sac.s_subgraph):
            actor_Q = torch.min(actor_Q1[i], actor_Q2[i])

            _log_prob = log_prob[obs[5][i].view(-1, sz)].sum(-1)
            loss = (model.module.alpha[i].detach() * _log_prob -
                    actor_Q[i]).mean()

            actor_loss = actor_loss + loss  # + self.cfg.sac.sl_beta * side_loss

        actor_loss = actor_loss / len(self.cfg.sac.s_subgraph)
        alpha_loss = alpha_loss / len(self.cfg.sac.s_subgraph)

        optimizers.actor.zero_grad()
        actor_loss.backward()
        optimizers.actor.step()

        for i, sz in enumerate(self.cfg.sac.s_subgraph):
            _log_prob = log_prob[obs[5][i].view(-1, sz)].sum(-1)
            alpha_loss = alpha_loss + (
                model.module.alpha[i] *
                (-_log_prob + self.cfg.sac.s_subgraph[i]).detach()).mean()

        optimizers.temperature.zero_grad()
        alpha_loss.backward()
        optimizers.temperature.step()

        return actor_loss.item(), alpha_loss.item()

    def _step(self,
              replay_buffer,
              optimizers,
              mov_sum_loss,
              env,
              model,
              step,
              writer=None):

        (obs, action, reward, next_obs,
         done), sample_idx = replay_buffer.sample()
        not_done = int(not done)
        n_prep_steps = self.cfg.trainer.t_max - self.cfg.fe.update_after_steps
        embeddings_opt = step - n_prep_steps > 0 and (
            step - n_prep_steps) % self.cfg.fe.update_frequency == 0

        if "extra" in self.cfg.fe.optim:
            if embeddings_opt:
                embedd_loss = self.update_embeddings(obs, env, model,
                                                     optimizers)
                mov_sum_loss.embeddings.apply(embedd_loss)
                optimizers.embed_shed.step(mov_sum_loss.embeddings.avg)
                if writer is not None:
                    writer.add_scalar("loss/embedd", embedd_loss,
                                      self.global_writer_loss_count.value())
                return

        critic_loss, mean_reward = self.update_critic(obs, action, reward,
                                                      next_obs, not_done, env,
                                                      model, optimizers)
        mov_sum_loss.critic.apply(critic_loss)
        # optimizers.critic_shed.step(mov_sum_loss.critic.avg)
        replay_buffer.report_sample_loss(critic_loss + mean_reward, sample_idx)

        if step % self.cfg.sac.actor_update_frequency == 0:
            actor_loss, alpha_loss = self.update_actor_and_alpha(
                obs, env, model, optimizers, embeddings_opt)
            mov_sum_loss.actor.apply(actor_loss)
            mov_sum_loss.temperature.apply(alpha_loss)
            # optimizers.temp_shed.step(mov_sum_loss.actor.avg)
            # optimizers.temp_shed.step(mov_sum_loss.temperature.avg)
            if writer is not None:
                writer.add_scalar("loss/actor", actor_loss,
                                  self.global_writer_loss_count.value())
                writer.add_scalar("loss/temperature", alpha_loss,
                                  self.global_writer_loss_count.value())

        if step % self.cfg.sac.critic_target_update_frequency == 0:
            soft_update_params(model.module.critic, model.module.critic_tgt,
                               self.cfg.sac.critic_tau)

        if writer is not None:
            writer.add_scalar("loss/critic", critic_loss,
                              self.global_writer_loss_count.value())

    # Acts and trains model
    def train(self, rank, start_time, return_dict, rn):

        self.log_dir = os.path.join(self.save_dir, 'logs', '_' + str(rn))
        writer = None
        if rank == 0:
            writer = SummaryWriter(logdir=self.log_dir)
            writer.add_text("config", self.cfg.pretty(), 0)
            copyfile(os.path.join(self.save_dir, 'runtime_cfg.yaml'),
                     os.path.join(self.log_dir, 'runtime_cfg.yaml'))

            self.global_count.reset()
            self.global_writer_loss_count.reset()
            self.global_writer_quality_count.reset()
            self.action_stats_count.reset()
            self.global_writer_count.reset()

        set_seed_everywhere(rn)
        if rank == 0:
            print('training with seed: ' + str(rn))
        score = self.train_step(rank, writer)
        if rank == 0:
            return_dict['score'] = score
            del self.memory
        return

    def train_step(self, rank, writer):
        device = torch.device("cuda:" +
                              str(rank // self.cfg.gen.n_processes_per_gpu))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)
        self.setup(rank, self.cfg.gen.n_processes_per_gpu * self.cfg.gen.n_gpu)

        env = SpGcnEnv(self.cfg,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count)
        # Create shared network

        model = GcnEdgeAC(self.cfg, device, writer=writer)
        model.cuda(device)
        shared_model = DDP(model,
                           device_ids=[device],
                           find_unused_parameters=True)
        if 'extra' in self.cfg.fe.optim:
            # optimizers
            MovSumLosses = namedtuple(
                'mov_avg_losses',
                ('actor', 'embeddings', 'critic', 'temperature'))
            OptimizerContainer = namedtuple(
                'OptimizerContainer',
                ('actor', 'embeddings', 'critic', 'temperature', 'actor_shed',
                 'embed_shed', 'critic_shed', 'temp_shed'))
        else:
            MovSumLosses = namedtuple('mov_avg_losses',
                                      ('actor', 'critic', 'temperature'))
            OptimizerContainer = namedtuple(
                'OptimizerContainer',
                ('actor', 'critic', 'temperature', 'actor_shed', 'critic_shed',
                 'temp_shed'))
        if "rl_loss" == self.cfg.fe.optim:
            actor_optimizer = torch.optim.Adam(
                list(shared_model.module.actor.parameters()) +
                list(shared_model.module.fe_ext.parameters()),
                lr=self.cfg.sac.actor_lr,
                betas=self.cfg.sac.actor_betas)
        else:
            actor_optimizer = torch.optim.Adam(
                shared_model.module.actor.parameters(),
                lr=self.cfg.sac.actor_lr,
                betas=self.cfg.sac.actor_betas)
        if "extra" in self.cfg.fe.optim:
            embeddings_optimizer = torch.optim.Adam(
                shared_model.module.fe_ext.parameters(),
                lr=self.cfg.fe.lr,
                betas=self.cfg.fe.betas)
        critic_optimizer = torch.optim.Adam(
            shared_model.module.critic.parameters(),
            lr=self.cfg.sac.critic_lr,
            betas=self.cfg.sac.critic_betas)
        temp_optimizer = torch.optim.Adam([shared_model.module.log_alpha],
                                          lr=self.cfg.sac.alpha_lr,
                                          betas=self.cfg.sac.alpha_betas)

        if "extra" in self.cfg.fe.optim:
            mov_sum_losses = MovSumLosses(RunningAverage(), RunningAverage(),
                                          RunningAverage(), RunningAverage())
            optimizers = OptimizerContainer(
                actor_optimizer, embeddings_optimizer, critic_optimizer,
                temp_optimizer, ReduceLROnPlateau(actor_optimizer),
                ReduceLROnPlateau(embeddings_optimizer),
                ReduceLROnPlateau(critic_optimizer),
                ReduceLROnPlateau(temp_optimizer))
        else:
            mov_sum_losses = MovSumLosses(RunningAverage(), RunningAverage(),
                                          RunningAverage())
            optimizers = OptimizerContainer(
                actor_optimizer, critic_optimizer, temp_optimizer,
                ReduceLROnPlateau(actor_optimizer),
                ReduceLROnPlateau(critic_optimizer),
                ReduceLROnPlateau(temp_optimizer))

        dist.barrier()

        if self.cfg.gen.resume:
            shared_model.module.load_state_dict(
                torch.load(os.path.join(self.log_dir,
                                        self.cfg.gen.model_name)))
        elif self.cfg.fe.load_pretrained:
            shared_model.module.fe_ext.load_state_dict(
                torch.load(os.path.join(self.save_dir,
                                        self.cfg.fe.model_name)))
        elif 'warmup' in self.cfg.fe and rank == 0:
            print('pretrain fe extractor')
            self.pretrain_embeddings_gt(shared_model.module.fe_ext, device,
                                        writer)
            torch.save(shared_model.module.fe_ext.state_dict(),
                       os.path.join(self.save_dir, self.cfg.fe.model_name))
        dist.barrier()

        if "none" == self.cfg.fe.optim:
            for param in shared_model.module.fe_ext.parameters():
                param.requires_grad = False

        dset = SpgDset(self.cfg.gen.data_dir)
        step = 0
        while self.global_count.value() <= self.cfg.trainer.T_max:
            dloader = DataLoader(dset,
                                 batch_size=self.cfg.trainer.batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=0)
            for iteration in range(
                    len(dset) * self.cfg.trainer.data_update_frequency):
                # if self.global_count.value() > self.args.T_max:
                #     a=1
                if iteration % self.cfg.trainer.data_update_frequency == 0:
                    self.update_env_data(env, dloader, device)
                # waff_dis = torch.softmax(env.edge_features[:, 0].squeeze() + 1e-30, dim=0)
                # waff_dis = torch.softmax(env.gt_edge_weights + 0.5, dim=0)
                # waff_dis = torch.softmax(torch.ones_like(env.b_gt_edge_weights), dim=0)
                # loss_weight = torch.softmax(env.b_gt_edge_weights + 1, dim=0)
                env.reset()
                self.update_rt_vars(critic_optimizer, actor_optimizer)
                if rank == 0 and self.cfg.rt_vars.safe_model:
                    if self.cfg.gen.model_name != "":
                        torch.save(
                            shared_model.module.state_dict(),
                            os.path.join(self.log_dir,
                                         self.cfg.gen.model_name))
                    else:
                        torch.save(shared_model.module.state_dict(),
                                   os.path.join(self.log_dir, 'agent_model'))

                state = env.get_state()
                while not env.done:
                    # Calculate policy and values
                    post_stats = True if (self.global_writer_count.value() + 1) % self.cfg.trainer.post_stats_frequency == 0 \
                        else False
                    post_model = True if (self.global_writer_count.value() + 1) % self.cfg.trainer.post_model_frequency == 0 \
                        else False
                    post_stats &= self.memory.is_full()
                    post_model &= self.memory.is_full()
                    distr = None
                    if not self.memory.is_full():
                        action = torch.rand_like(env.current_edge_weights)
                    else:
                        distr, _, _, action, _, _ = self.agent_forward(
                            env,
                            shared_model,
                            state=state,
                            grad=False,
                            post_input=post_stats,
                            post_model=post_model)

                    logg_dict = {}
                    if post_stats:
                        for i in range(len(self.cfg.sac.s_subgraph)):
                            logg_dict[
                                'alpha_' +
                                str(i)] = shared_model.module.alpha[i].item()
                        if distr is not None:
                            logg_dict['mean_loc'] = distr.loc.mean().item()
                            logg_dict['mean_scale'] = distr.scale.mean().item()

                    if self.memory.is_full():
                        for i in range(self.cfg.trainer.n_updates_per_step):
                            self._step(self.memory,
                                       optimizers,
                                       mov_sum_losses,
                                       env,
                                       shared_model,
                                       step,
                                       writer=writer)
                            self.global_writer_loss_count.increment()

                    next_state, reward = env.execute_action(
                        action, logg_dict, post_stats=post_stats)
                    # next_state, reward, quality = env.execute_action(torch.sigmoid(distr.loc), logg_dict, post_stats=post_stats)

                    if self.cfg.rt_vars.add_noise:
                        noise = torch.randn_like(reward) * 0.2
                        reward = reward + noise

                    self.memory.push(self.state_to_cpu(state), action, reward,
                                     self.state_to_cpu(next_state), env.done)
                    state = next_state

                self.global_count.increment()
                step += 1
                if rank == 0:
                    self.global_writer_count.increment()
                if step > self.cfg.trainer.T_max:
                    break

        dist.barrier()
        if rank == 0:
            self.memory.clear()
            if not self.cfg.gen.cross_validate_hp and not self.cfg.gen.test_score_only and not self.cfg.gen.no_save:
                # pass
                if self.cfg.gen.model_name != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.log_dir, self.cfg.gen.model_name))
                    print('saved')
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.log_dir, 'agent_model'))

        self.cleanup()
        return sum(env.acc_reward) / len(env.acc_reward)

    def state_to_cpu(self, state):
        state = list(state)
        for i in range(len(state)):
            if torch.is_tensor(state[i]):
                state[i] = state[i].cpu()
            elif isinstance(state[i], list) or isinstance(state[i], tuple):
                state[i] = self.state_to_cpu(state[i])
        return state

    def state_to_cuda(self, state, device):
        state = list(state)
        for i in range(len(state)):
            if torch.is_tensor(state[i]):
                state[i] = state[i].to(device)
            elif isinstance(state[i], list) or isinstance(state[i], tuple):
                state[i] = self.state_to_cuda(state[i], device)
        return state

    def update_rt_vars(self, critic_optimizer, actor_optimizer):
        with portalocker.Lock(os.path.join(self.log_dir, 'runtime_cfg.yaml'),
                              'rb+',
                              timeout=60) as fh:
            with open(os.path.join(self.log_dir, 'runtime_cfg.yaml')) as info:
                args_dict = yaml.full_load(info)
                if args_dict is not None:
                    if 'safe_model' in args_dict:
                        self.cfg.rt_vars.safe_model = args_dict['safe_model']
                        args_dict['safe_model'] = False
                    if 'add_noise' in args_dict:
                        self.cfg.rt_vars.add_noise = args_dict['add_noise']
                    if 'critic_lr' in args_dict and args_dict[
                            'critic_lr'] != self.cfg.sac.critic_lr:
                        self.cfg.sac.critic_lr = args_dict['critic_lr']
                        adjust_learning_rate(critic_optimizer,
                                             self.cfg.sac.critic_lr)
                    if 'actor_lr' in args_dict and args_dict[
                            'actor_lr'] != self.cfg.sac.actor_lr:
                        self.cfg.sac.actor_lr = args_dict['actor_lr']
                        adjust_learning_rate(actor_optimizer,
                                             self.cfg.sac.actor_lr)
            with open(os.path.join(self.log_dir, 'runtime_cfg.yaml'),
                      "w") as info:
                yaml.dump(args_dict, info)

            # flush and sync to filesystem
            fh.flush()
            os.fsync(fh.fileno())