Пример #1
0
    def __init__(self, args, shared_damped_model, global_count,
                 global_writer_loss_count, global_writer_quality_count,
                 global_win_event_count, save_dir):
        super(AgentOffpac, self).__init__()
        self.args = args
        self.shared_damped_model = shared_damped_model
        self.global_count = global_count
        self.global_writer_loss_count = global_writer_loss_count
        self.global_writer_quality_count = global_writer_quality_count
        self.global_win_event_count = global_win_event_count
        self.writer_idx_warmup_loss = 0
        # self.eps = self.args.init_epsilon
        self.save_dir = save_dir
        if args.stop_qual_rule == 'naive':
            self.stop_qual_rule = NaiveDecay(initial_eps=args.init_stop_qual,
                                             episode_shrinkage=1,
                                             change_after_n_episodes=5)
        elif args.stop_qual_rule == 'gaussian':
            self.stop_qual_rule = GaussianDecay(args.stop_qual_final,
                                                args.stop_qual_scaling,
                                                args.stop_qual_offset,
                                                args.T_max)
        else:
            self.stop_qual_rule = NaiveDecay(args.init_stop_qual)

        if self.args.eps_rule == "treesearch":
            self.eps_rule = ActionPathTreeNodes()
        elif self.args.eps_rule == "sawtooth":
            self.eps_rule = ExpSawtoothEpsDecay()
        elif self.args.eps_rule == 'gaussian':
            self.eps_rule = GaussianDecay(args.eps_final, args.eps_scaling,
                                          args.eps_offset, args.T_max)
        else:
            self.eps_rule = NaiveDecay(self.eps, 0.00005, 1000, 1)
Пример #2
0
    def train_eps_greedy(self,
                         n_iterations=150,
                         batch_size=1,
                         showInterm=True):
        with torch.set_grad_enabled(True):
            state = self.env.state
            eps_rule = NaiveDecay(initial_eps=1,
                                  episode_shrinkage=1 / (n_iterations / 5),
                                  limiting_epsiode=n_iterations - 10,
                                  change_after_n_episodes=5)
            self.agent.q_eval.train()
            for step in tqdm(range(self.agent.mem.capacity)):
                action = self.agent.get_action(state)
                state_, reward, keep_old_state = self.env.execute_action(
                    action)
                self.agent.store_transit(state, action, reward, state_,
                                         int(self.env.done))
                if keep_old_state:
                    self.env.state = state.copy()
                else:
                    state = state_
                if self.env.done:
                    self.env.reset()
                    state = self.env.state

            print("----Fnished mem init----")
            eps_hist = []
            scores = []
            for episode in tqdm(range(n_iterations)):
                eps_hist.append(self.agent.eps)
                state = self.env.state.copy()
                steps = 0
                while not self.env.done:
                    self.agent.reset_eps(eps_rule.apply(episode, steps))
                    action = self.agent.get_action(state)
                    state_, reward, keep_old_state = self.env.execute_action(
                        action)
                    self.agent.store_transit(state.copy(), action, reward,
                                             state_.copy(), int(self.env.done))
                    if keep_old_state:
                        self.env.state = state.copy()
                    else:
                        state = state_
                    # print(f'reward:{reward[1]}')
                    self.agent.learn(batch_size)
                    steps += 1
                if showInterm:
                    self.env.show_current_soln()
                if self.dloader is not None:
                    raw, affinities, gt_affinities = next(iter(self.dloader))
                    affinities = affinities.squeeze().detach().cpu().numpy()
                    gt_affinities = gt_affinities.squeeze().detach().cpu(
                    ).numpy()
                    self.env.update_data(affinities, gt_affinities)
                scores.append(self.env.acc_reward)
                print("score: ", self.env.acc_reward, "; eps: ",
                      self.agent.eps, "; steps: ", steps)
                self.env.reset()
        return scores, eps_hist, self.env.get_current_soln()
Пример #3
0
    def train(self, n_iterations=150, showInterm=False):
        with torch.set_grad_enabled(True):
            numsteps = []
            avg_numsteps = []
            scores = []
            self.agent.policy.train()
            eps_rule = NaiveDecay(initial_eps=1,
                                  episode_shrinkage=1 / (n_iterations / 5),
                                  limiting_epsiode=n_iterations - 10,
                                  change_after_n_episodes=5)

            print("----Fnished mem init----")
            for episode in tqdm(range(n_iterations)):
                log_probs = []
                rewards = []
                values = []
                policy_entropy = 0
                steps = 0
                state = self.env.state

                while not self.env.done:
                    self.agent.reset_eps(eps_rule.apply(episode, steps))
                    action, log_prob, value, entropy = self.agent.get_action(
                        state)
                    state_, reward, _ = self.env.execute_action(action)
                    state = state_
                    log_probs.append(log_prob)
                    rewards.append(reward)
                    values.append(value)
                    policy_entropy += entropy
                    steps += 1

                _, _, q_val, _ = self.agent.get_action(state)
                self.agent.learn(rewards, log_probs, values, policy_entropy,
                                 q_val)

                numsteps.append(steps)
                avg_numsteps.append(np.mean(numsteps[-10:]))

                raw, affinities, gt_affinities = next(iter(self.dloader))
                affinities = affinities.squeeze().detach().cpu().numpy()
                gt_affinities = gt_affinities.squeeze().detach().cpu().numpy()
                scores.append(self.env.acc_reward)
                print("score: ", scores[-1], "; eps: ", self.agent.eps,
                      "; steps: ", self.env.ttl_cnt)
                self.env.update_data(affinities, gt_affinities)
                self.env.reset()

                if showInterm:
                    self.env.show_current_soln()

        return scores, steps, self.env.get_current_soln()
Пример #4
0
    def __init__(self, cfg, args, global_count, global_writer_loss_count,
                 global_writer_quality_count, global_win_event_count,
                 action_stats_count, save_dir):
        super(AgentSacTrainer_sg_lg, self).__init__()

        self.cfg = cfg
        self.args = args
        self.global_count = global_count
        self.global_writer_loss_count = global_writer_loss_count
        self.global_writer_quality_count = global_writer_quality_count
        self.global_win_event_count = global_win_event_count
        self.action_stats_count = action_stats_count
        # self.eps = self.args.init_epsilon
        self.save_dir = save_dir
        if args.stop_qual_rule == 'naive':
            self.stop_qual_rule = NaiveDecay(initial_eps=args.init_stop_qual,
                                             episode_shrinkage=1,
                                             change_after_n_episodes=5)
        elif args.stop_qual_rule == 'gaussian':
            self.stop_qual_rule = GaussianDecay(args.stop_qual_final,
                                                args.stop_qual_scaling,
                                                args.stop_qual_offset,
                                                args.T_max)
        elif args.stop_qual_rule == 'running_average':
            self.stop_qual_rule = RunningAverage(
                args.stop_qual_ra_bw,
                args.stop_qual_scaling + args.stop_qual_offset,
                args.stop_qual_ra_off)
        else:
            self.stop_qual_rule = Constant(args.stop_qual_final)

        if self.cfg.temperature_regulation == 'follow_quality':
            self.beta_rule = FollowLeadAvg(1, 80, 1)
        elif self.cfg.temperature_regulation == 'constant':
            self.eps_rule = Constant(cfg.init_temperature)
Пример #5
0
class AgentOffpac(object):
    def __init__(self, args, shared_damped_model, global_count,
                 global_writer_loss_count, global_writer_quality_count,
                 global_win_event_count, save_dir):
        super(AgentOffpac, self).__init__()
        self.args = args
        self.shared_damped_model = shared_damped_model
        self.global_count = global_count
        self.global_writer_loss_count = global_writer_loss_count
        self.global_writer_quality_count = global_writer_quality_count
        self.global_win_event_count = global_win_event_count
        self.writer_idx_warmup_loss = 0
        # self.eps = self.args.init_epsilon
        self.save_dir = save_dir
        if args.stop_qual_rule == 'naive':
            self.stop_qual_rule = NaiveDecay(initial_eps=args.init_stop_qual,
                                             episode_shrinkage=1,
                                             change_after_n_episodes=5)
        elif args.stop_qual_rule == 'gaussian':
            self.stop_qual_rule = GaussianDecay(args.stop_qual_final,
                                                args.stop_qual_scaling,
                                                args.stop_qual_offset,
                                                args.T_max)
        else:
            self.stop_qual_rule = NaiveDecay(args.init_stop_qual)

        if self.args.eps_rule == "treesearch":
            self.eps_rule = ActionPathTreeNodes()
        elif self.args.eps_rule == "sawtooth":
            self.eps_rule = ExpSawtoothEpsDecay()
        elif self.args.eps_rule == 'gaussian':
            self.eps_rule = GaussianDecay(args.eps_final, args.eps_scaling,
                                          args.eps_offset, args.T_max)
        else:
            self.eps_rule = NaiveDecay(self.eps, 0.00005, 1000, 1)

    def setup(self, rank, world_size):
        # BLAS setup
        os.environ['OMP_NUM_THREADS'] = '1'
        os.environ['MKL_NUM_THREADS'] = '1'

        # os.environ["CUDA_VISIBLE_DEVICES"] = "6"
        assert torch.cuda.device_count() == 1
        torch.set_default_tensor_type('torch.FloatTensor')
        # Detect if we have a GPU available
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)

        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        # os.environ['GLOO_SOCKET_IFNAME'] = 'eno1'

        # initialize the process group
        dist.init_process_group("gloo", rank=rank, world_size=world_size)

        # Explicitly setting seed to make sure that models created in two processes
        # start from same random weights and biases.
        torch.manual_seed(self.args.seed)

    def cleanup(self):
        dist.destroy_process_group()

    # Updates networks
    def _update_networks(self, loss, optimizer, shared_model, writer=None):
        # Zero shared and local grads
        optimizer.zero_grad()
        """
        Calculate gradients for gradient descent on loss functions
        Note that math comments follow the paper, which is formulated for gradient ascent
        """
        loss.backward()
        # Gradient L2 normalisation
        nn.utils.clip_grad_norm_(shared_model.parameters(),
                                 self.args.max_gradient_norm)
        optimizer.step()
        if self.args.min_lr != 0:
            # Linearly decay learning rate
            new_lr = self.args.lr - (
                (self.args.lr - self.args.min_lr) * (1 - max(
                    (self.args.T_max - self.global_count.value()) /
                    self.args.T_max, 1e-32)))
            adjust_learning_rate(optimizer, new_lr)

            if writer is not None:
                writer.add_scalar("loss/learning_rate", new_lr,
                                  self.global_writer_loss_count.value())

    def warm_start(self, transition_data):
        # warm-starting with data from initial behavior policy which is assumed to be uniform distribution
        qloss = self.get_qLoss(transition_data)
        ploss = 0
        for t in transition_data:
            ploss += self.policy.loss(self.policy(t.state),
                                      torch.ones(self.action_shape) / 2)
        ploss /= len(transition_data)
        self.q_val.optimizer.zero_grad()
        qloss.backward()
        for param in self.q_eval.parameters():
            param.grad.data.clamp_(-1, 1)
        self.q_val.optimizer.step()
        self.policy.optimizer.zero_grad()
        for param in self.q_eval.parameters():
            param.grad.data.clamp_(-1, 1)
        self.policy.optimizer.step()

    def agent_forward(self, env, model, state=None, grad=True):
        with torch.set_grad_enabled(grad):
            if state is None:
                state = env.state
            return model([
                obj.float().to(model.module.device)
                for obj in state + [env.raw]
            ],
                         sp_indices=env.sp_indices,
                         edge_index=env.edge_ids.to(model.module.device),
                         angles=env.edge_angles.to(model.module.device),
                         edge_features_1d=env.edge_features.to(
                             model.module.device))

    def get_action(self, action_probs, q, v, policy, device):
        if policy == 'off_sampled':
            behav_probs = action_probs.detach() + self.eps * (
                1 / self.args.n_actions - action_probs.detach())
            actions = torch.multinomial(behav_probs, 1).squeeze()
        elif policy == 'off_uniform':
            randm_draws = int(self.eps * len(action_probs))
            if randm_draws > 0:
                actions = action_probs.max(-1)[1].squeeze()
                randm_indices = torch.multinomial(
                    torch.ones(len(action_probs)) / len(action_probs),
                    randm_draws)
                actions[randm_indices] = torch.randint(
                    0, self.args.n_actions, (randm_draws, )).to(device)
                behav_probs = action_probs.detach()
                behav_probs[randm_indices] = torch.Tensor([
                    1 / self.args.n_actions for i in range(self.args.n_actions)
                ]).to(device)
            else:
                actions = action_probs.max(-1)[1].squeeze()
                behav_probs = action_probs.detach()
        elif policy == 'on':
            actions = action_probs.max(-1)[1].squeeze()
            behav_probs = action_probs.detach()
        elif policy == 'q_val':
            actions = q.max(-1)[1].squeeze()
            behav_probs = action_probs.detach()

        # log_probs = torch.log(sel_behav_probs)
        # entropy = - (behav_probs * torch.log(behav_probs)).sum()
        return actions, behav_probs

    def fe_extr_warm_start(self, sp_feature_ext, writer=None):
        dataloader = DataLoader(MultiDiscSpGraphDset(length=100),
                                batch_size=10,
                                shuffle=True,
                                pin_memory=True)
        criterion = ContrastiveLoss(delta_var=0.5, delta_dist=1.5)
        optimizer = torch.optim.Adam(sp_feature_ext.parameters())
        for i, (data, gt) in enumerate(dataloader):
            data, gt = data.to(sp_feature_ext.device), gt.to(
                sp_feature_ext.device)
            pred = sp_feature_ext(data)
            loss = criterion(pred, gt)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if writer is not None:
                writer.add_scalar("loss/fe_warm_start", loss.item(),
                                  self.writer_idx_warmup_loss)
                self.writer_idx_warmup_loss += 1

    def _step(self,
              memory,
              shared_model,
              env,
              optimizer,
              off_policy=True,
              writer=None):
        if self.args.qnext_replace_cnt is not None and self.global_count.value(
        ) % self.args.qnext_replace_cnt == 0:
            self.shared_damped_model.load_state_dict(shared_model.state_dict())
        # according to thm3 in https://arxiv.org/pdf/1606.02647.pdf
        c_loss, a_loss = 0, 0
        transition_data = memory.memory
        correction = 0
        current = transition_data[0].time
        importance_weight = 1
        m = 0
        l2_reg = None
        # self.train_a2c = not self.train_a2c
        # self.opt_fe_extr = not self.train_a2c
        # self.dist_correction.update_density(transition_data, self.gamma)
        for i, t in enumerate(transition_data):
            if not t.terminal:
                _, q_, v_ = self.agent_forward(env, self.shared_damped_model,
                                               t.state_, False)
            pvals, q, v = self.agent_forward(env, shared_model, t.state)
            # pvals = nn.functional.softmax(qvals, -1).detach()  # this alternatively
            q_t = q.gather(-1, t.action.unsqueeze(-1)).squeeze()
            behav_policy_proba_t = t.behav_policy_proba.gather(
                -1, t.action.unsqueeze(-1)).squeeze().detach()
            pvals_t = pvals.gather(-1,
                                   t.action.unsqueeze(-1)).squeeze().detach()

            m = m + self.args.discount**(t.time - current) * importance_weight
            importance_weight = importance_weight * self.args.lbd * \
                                torch.min(torch.ones(t.action.shape).to(shared_model.module.device),
                                          pvals_t / behav_policy_proba_t)

            if self.args.weight_l2_reg_params_weight != 0:
                for W in list(shared_model.parameters()):
                    if l2_reg is None:
                        l2_reg = W.norm(2)
                    else:
                        l2_reg = l2_reg + W.norm(2)
            if l2_reg is None:
                l2_reg = 0

            if t.terminal:
                c_loss = c_loss + nn.functional.mse_loss(t.reward * m, q_t * m)
            else:
                c_loss = c_loss + nn.functional.mse_loss(
                    (t.reward + self.args.discount * v_.detach()) * m, q_t * m)

        c_loss = c_loss / len(
            transition_data) + l2_reg * self.args.weight_l2_reg_params_weight

        # sample according to discounted state dis
        discount_distribution = [
            self.args.discount**i for i in range(len(transition_data))
        ]
        discount_distribution = np.exp(discount_distribution) / sum(
            np.exp(discount_distribution))  # softmax
        batch_ind = np.random.choice(len(transition_data),
                                     size=len(transition_data),
                                     p=discount_distribution)
        z = 0
        l2_reg = None
        for i in batch_ind:
            t = transition_data[i]
            # w = self.dist_correction.density_ratio(t.state.unsqueeze(0).unsqueeze(0).to(self.dist_correction.density_ratio.device)).detach().squeeze()
            w = 1
            z += w
            policy_proba, q, v = self.agent_forward(env, shared_model, t.state)

            policy_proba_t = policy_proba.gather(
                -1, t.action.unsqueeze(-1)).squeeze()
            q_t = q.gather(-1, t.action.unsqueeze(-1)).squeeze().detach()
            advantage_t = q_t - v.detach()
            behav_policy_proba_t = t.behav_policy_proba.gather(
                -1, t.action.unsqueeze(-1)).squeeze().detach()

            if self.args.weight_l2_reg_params_weight != 0:
                for W in list(shared_model.parameters()):
                    if l2_reg is None:
                        l2_reg = W.norm(2)
                    else:
                        l2_reg = l2_reg + W.norm(2)
            if l2_reg is None:
                l2_reg = 0

            a_loss = a_loss - (policy_proba_t.detach() / behav_policy_proba_t
                               ) * w * torch.log(policy_proba_t) * advantage_t
        z = z / len(batch_ind)
        a_loss = a_loss / z
        a_loss = a_loss / len(batch_ind)
        a_loss = a_loss / len(t.state)
        a_loss = torch.sum(
            a_loss) + l2_reg * self.args.weight_l2_reg_params_weight

        if writer is not None:
            writer.add_scalar("loss/critic", c_loss.item(),
                              self.global_writer_loss_count.value())
            writer.add_scalar("loss/actor", a_loss.item(),
                              self.global_writer_loss_count.value())
            print("c: ", c_loss.item())
            print("a: ", a_loss.item())
            self.global_writer_loss_count.increment()

        self._update_networks(a_loss + c_loss, optimizer, shared_model, writer)
        return

    # Acts and trains model
    def train(self, rank, start_time, return_dict):
        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        self.setup(rank, self.args.num_processes)

        transition = namedtuple('Transition',
                                ('state', 'action', 'reward', 'state_',
                                 'behav_policy_proba', 'time', 'terminal'))
        memory = TransitionData(capacity=self.args.t_max,
                                storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count)
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create shared network
        model = GcnEdgeAngle1dPQV(self.args.n_raw_channels,
                                  self.args.n_embedding_features,
                                  self.args.n_edge_features,
                                  self.args.n_actions, device)
        model.cuda(device)
        shared_model = DDP(model, device_ids=[model.device])
        # Create optimizer for shared network parameters with shared statistics
        optimizer = CstmAdam(shared_model.parameters(),
                             lr=self.args.lr,
                             betas=self.args.Adam_betas,
                             weight_decay=self.args.Adam_weight_decay)

        if self.args.fe_extr_warmup and rank == 0:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))
        dist.barrier()
        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        self.shared_damped_model.load_state_dict(shared_model.state_dict())
        env.done = True  # Start new episode
        while self.global_count.value() <= self.args.T_max:
            if env.done:
                edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \
                    next(iter(dloader))
                edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \
                    edges.squeeze().to(device), edge_feat.squeeze()[:, 0:self.args.n_edge_features].to(
                        device), diff_to_gt.squeeze().to(device), \
                    gt_edge_weights.squeeze().to(device), node_labeling.squeeze().to(device), raw.squeeze().to(
                        device), nodes.squeeze().to(device), \
                    angles.squeeze().to(device), affinities.squeeze().numpy(), gt.squeeze()
                env.update_data(edges, edge_feat, diff_to_gt, gt_edge_weights,
                                node_labeling, raw, nodes, angles, affinities,
                                gt)
                env.reset()
                state = [env.state[0].clone(), env.state[1].clone()]
                episode_length = 0

                self.eps = self.eps_rule.apply(self.global_count.value())
                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value())
                if writer is not None:
                    writer.add_scalar("step/epsilon", self.eps,
                                      env.writer_counter.value())

            while not env.done:
                # Calculate policy and values
                policy_proba, q, v = self.agent_forward(env,
                                                        shared_model,
                                                        grad=False)
                # average_policy_proba, _, _ = self.agent_forward(env, self.shared_average_model)
                # q_ret = v.detach()

                # Sample action
                # action = torch.multinomial(policy, 1)[0, 0]

                # Step
                action, behav_policy_proba = self.get_action(
                    policy_proba, q, v, policy='off_uniform', device=device)
                state_, reward = env.execute_action(action,
                                                    self.global_count.value())

                memory.push(state, action,
                            reward.to(shared_model.module.device), state_,
                            behav_policy_proba, episode_length, env.done)

                # Train the network
                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)

                # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                episode_length += 1  # Increase episode counter
                state = state_

            # Break graph for last values calculated (used for targets, not directly as model outputs)
            self.global_count.increment()
            # Qret = 0 for terminal s

            while len(memory) > 0:
                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)
                memory.pop(0)

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp:
                if self.args.model_name != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name))
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))
            else:
                test_score = 0
                env.writer = None
                for i in range(20):
                    self.update_env_data(env, dloader, device)
                    env.reset()
                    self.eps = 0
                    while not env.done:
                        # Calculate policy and values
                        policy_proba, q, v = self.agent_forward(env,
                                                                shared_model,
                                                                grad=False)
                        action, behav_policy_proba = self.get_action(
                            policy_proba,
                            q,
                            v,
                            policy='off_uniform',
                            device=device)
                        _, _ = env.execute_action(action,
                                                  self.global_count.value())
                    if env.win:
                        test_score += 1
                return_dict['test_score'] = test_score
                writer.add_text("time_needed", str((time.time() - start_time)))
        self.cleanup()
Пример #6
0
    def __init__(self, args, shared_average_model, global_count,
                 global_writer_loss_count, global_writer_quality_count,
                 global_win_event_count, save_dir):
        super(AgentAcerContinuousTrainer, self).__init__()

        self.args = args
        self.shared_average_model = shared_average_model
        self.global_count = global_count
        self.global_writer_loss_count = global_writer_loss_count
        self.global_writer_quality_count = global_writer_quality_count
        self.global_win_event_count = global_win_event_count
        self.writer_idx_warmup_loss = 0
        # self.eps = self.args.init_epsilon
        self.save_dir = save_dir
        if args.stop_qual_rule == 'naive':
            self.stop_qual_rule = NaiveDecay(initial_eps=args.init_stop_qual,
                                             episode_shrinkage=1,
                                             change_after_n_episodes=5)
        elif args.stop_qual_rule == 'gaussian':
            self.stop_qual_rule = GaussianDecay(args.stop_qual_final,
                                                args.stop_qual_scaling,
                                                args.stop_qual_offset,
                                                args.T_max)
        elif args.stop_qual_rule == 'running_average':
            self.stop_qual_rule = RunningAverage(
                args.stop_qual_ra_bw,
                args.stop_qual_scaling + args.stop_qual_offset,
                args.stop_qual_ra_off)
        else:
            self.stop_qual_rule = NaiveDecay(args.init_stop_qual)

        if self.args.eps_rule == "treesearch":
            self.b_sigma_rule = ActionPathTreeNodes()
        elif self.args.eps_rule == "sawtooth":
            self.b_sigma_rule = ExpSawtoothEpsDecay()
        elif self.args.eps_rule == 'gaussian':
            self.b_sigma_rule = GaussianDecay(args.b_sigma_final,
                                              args.b_sigma_scaling,
                                              args.p_sigma, args.T_max)
        elif self.args.eps_rule == "self_reg_min":
            self.args.T_max = np.inf
            self.b_sigma_rule = FollowLeadMin(
                (args.stop_qual_scaling + args.stop_qual_offset), 1)
        elif self.args.eps_rule == "self_reg_avg":
            self.args.T_max = np.inf
            self.b_sigma_rule = FollowLeadAvg(
                (args.stop_qual_scaling + args.stop_qual_offset) / 4, 2, 1)
        elif self.args.eps_rule == "self_reg_exp_avg":
            self.args.T_max = np.inf
            self.b_sigma_rule = ExponentialAverage(
                (args.stop_qual_scaling + args.stop_qual_offset) / 4, 0.9, 1)
        else:
            self.b_sigma_rule = NaiveDecay(self.eps, 0.00005, 1000, 1)
Пример #7
0
class AgentAcerContinuousTrainer(object):
    def __init__(self, args, shared_average_model, global_count,
                 global_writer_loss_count, global_writer_quality_count,
                 global_win_event_count, save_dir):
        super(AgentAcerContinuousTrainer, self).__init__()

        self.args = args
        self.shared_average_model = shared_average_model
        self.global_count = global_count
        self.global_writer_loss_count = global_writer_loss_count
        self.global_writer_quality_count = global_writer_quality_count
        self.global_win_event_count = global_win_event_count
        self.writer_idx_warmup_loss = 0
        # self.eps = self.args.init_epsilon
        self.save_dir = save_dir
        if args.stop_qual_rule == 'naive':
            self.stop_qual_rule = NaiveDecay(initial_eps=args.init_stop_qual,
                                             episode_shrinkage=1,
                                             change_after_n_episodes=5)
        elif args.stop_qual_rule == 'gaussian':
            self.stop_qual_rule = GaussianDecay(args.stop_qual_final,
                                                args.stop_qual_scaling,
                                                args.stop_qual_offset,
                                                args.T_max)
        elif args.stop_qual_rule == 'running_average':
            self.stop_qual_rule = RunningAverage(
                args.stop_qual_ra_bw,
                args.stop_qual_scaling + args.stop_qual_offset,
                args.stop_qual_ra_off)
        else:
            self.stop_qual_rule = NaiveDecay(args.init_stop_qual)

        if self.args.eps_rule == "treesearch":
            self.b_sigma_rule = ActionPathTreeNodes()
        elif self.args.eps_rule == "sawtooth":
            self.b_sigma_rule = ExpSawtoothEpsDecay()
        elif self.args.eps_rule == 'gaussian':
            self.b_sigma_rule = GaussianDecay(args.b_sigma_final,
                                              args.b_sigma_scaling,
                                              args.p_sigma, args.T_max)
        elif self.args.eps_rule == "self_reg_min":
            self.args.T_max = np.inf
            self.b_sigma_rule = FollowLeadMin(
                (args.stop_qual_scaling + args.stop_qual_offset), 1)
        elif self.args.eps_rule == "self_reg_avg":
            self.args.T_max = np.inf
            self.b_sigma_rule = FollowLeadAvg(
                (args.stop_qual_scaling + args.stop_qual_offset) / 4, 2, 1)
        elif self.args.eps_rule == "self_reg_exp_avg":
            self.args.T_max = np.inf
            self.b_sigma_rule = ExponentialAverage(
                (args.stop_qual_scaling + args.stop_qual_offset) / 4, 0.9, 1)
        else:
            self.b_sigma_rule = NaiveDecay(self.eps, 0.00005, 1000, 1)

    def setup(self, rank, world_size):
        # BLAS setup
        os.environ['OMP_NUM_THREADS'] = '1'
        os.environ['MKL_NUM_THREADS'] = '1'

        # os.environ["CUDA_VISIBLE_DEVICES"] = "6"
        assert torch.cuda.device_count() == 1
        torch.set_default_tensor_type('torch.FloatTensor')
        # Detect if we have a GPU available
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)

        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12356'
        # os.environ['GLOO_SOCKET_IFNAME'] = 'eno1'

        # initialize the process group
        dist.init_process_group("gloo", rank=rank, world_size=world_size)

        # Explicitly setting seed to make sure that models created in two processes
        # start from same random weights and biases.
        torch.manual_seed(self.args.seed)

    def cleanup(self):
        dist.destroy_process_group()

    def update_env_data(self, env, dloader, device):
        edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, affinities, gt = \
            next(iter(dloader))
        angles = None
        edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, affinities, gt = \
            edges.squeeze().to(device), edge_feat.squeeze()[:, 0:self.args.n_edge_features].to(
                device), diff_to_gt.squeeze().to(device), \
            gt_edge_weights.squeeze().to(device), node_labeling.squeeze().to(device), raw.squeeze().to(
                device), nodes.squeeze().to(device), \
            affinities.squeeze().numpy(), gt.squeeze()
        env.update_data(edges, edge_feat, diff_to_gt, gt_edge_weights,
                        node_labeling, raw, nodes, angles, affinities, gt)

    # Updates networks
    def _update_networks(self, loss, optimizer, shared_model, writer=None):
        # Zero shared and local grads
        optimizer.zero_grad()
        """
        Calculate gradients for gradient descent on loss functions
        Note that math comments follow the paper, which is formulated for gradient ascent
        """
        loss.backward()
        # Gradient L2 normalisation
        # nn.utils.clip_grad_norm_(shared_model.parameters(), self.args.max_gradient_norm)
        optimizer.step()

        with open(os.path.join(self.save_dir, 'config.yaml')) as info:
            args_dict = yaml.full_load(info)
            if args_dict is not None:
                if 'lr' in args_dict:
                    if self.args.lr != args_dict['lr']:
                        print("lr changed from ", self.args.lr, " to ",
                              args_dict['lr'], " at loss step ",
                              self.global_writer_loss_count.value())
                        self.args.lr = args_dict['lr']
        self.args.lr = args_dict['lr']
        new_lr = self.args.lr
        if self.args.min_lr != 0 and self.eps <= 0.6:
            # Linearly decay learning rate
            # new_lr = self.args.lr - ((self.args.lr - self.args.min_lr) * (1 - (self.eps * 2))) # (1 - max((self.args.T_max - self.global_count.value()) / self.args.T_max, 1e-32)))
            new_lr = self.args.lr * 10**(-(0.6 - self.eps))

        adjust_learning_rate(optimizer, new_lr)
        if writer is not None:
            writer.add_scalar("loss/learning_rate", new_lr,
                              self.global_writer_loss_count.value())

        # Update shared_average_model
        for shared_param, shared_average_param in zip(
                shared_model.parameters(),
                self.shared_average_model.parameters()):
            shared_average_param.data = self.args.trust_region_decay * shared_average_param.data + (
                1 - self.args.trust_region_decay) * shared_param.data

    # Computes an "efficient trust region" loss (policy head only) based on an existing loss and two distributions
    def _trust_region(self, g, k):
        # Compute dot products of gradients
        k_dot_g = (k * g).sum(0)
        k_dot_k = (k**2).sum(0)
        # Compute trust region update
        trust_factor = ((k_dot_g - self.args.trust_region_threshold) /
                        k_dot_k).clamp(min=0)
        z = g - trust_factor * k
        return z

    def get_action(self, policy_means, p_dis, device, policy='off'):
        if policy == 'off':
            # use a truncated normal dis here https://en.wikipedia.org/wiki/Truncated_normal_distribution
            b_dis = TruncNorm(policy_means, self.b_sigma, 0, 1,
                              self.args.density_eval_range)
            # rho is calculated as the distribution ration of the two normal distributions as described here:
            # https://www.researchgate.net/publication/257406150_On_the_existence_of_a_normal_approximation_to_the_distribution_of_the_ratio_of_two_independent_normal_random_variables
        elif policy == 'on':
            b_dis = p_dis
        else:
            assert False
        # sample actions alternatively consider unsampled approach by taking mean
        actions = b_dis.sample()
        # test = torch.stack([torch.from_numpy(actions).float().to(device),
        #              torch.from_numpy(policy_means).float().to(device)]).cpu().numpy()

        # print('sample sigma:', torch.sqrt(((actions - policy_means) ** 2).mean()).item())

        return actions, b_dis

    def fe_extr_warm_start(self, sp_feature_ext, writer=None):
        dataloader = DataLoader(
            MultiDiscSpGraphDset(length=10 * self.args.fe_warmup_iterations),
            batch_size=10,
            shuffle=True,
            pin_memory=True)
        criterion = ContrastiveLoss(delta_var=0.5, delta_dist=1.5)
        optimizer = torch.optim.Adam(sp_feature_ext.parameters())
        for i, (data, gt) in enumerate(dataloader):
            data, gt = data.to(sp_feature_ext.device), gt.to(
                sp_feature_ext.device)
            pred = sp_feature_ext(data)
            loss = criterion(pred, gt)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if writer is not None:
                writer.add_scalar("loss/fe_warm_start", loss.item(),
                                  self.writer_idx_warmup_loss)
                self.writer_idx_warmup_loss += 1

    def agent_forward(self,
                      env,
                      model,
                      action=None,
                      state=None,
                      grad=True,
                      stats_only=False,
                      post_input=False):
        with torch.set_grad_enabled(grad):
            if state is None:
                state = env.state
            inp = [
                obj.float().to(model.module.device)
                for obj in state + [env.raw, env.init_sp_seg]
            ]
            return model(inp,
                         action,
                         sp_indices=env.sp_indices,
                         edge_index=env.edge_ids.to(model.module.device),
                         angles=env.edge_angles,
                         edge_features_1d=env.edge_features.to(
                             model.module.device),
                         stats_only=stats_only,
                         round_n=env.counter,
                         post_input=post_input)

    # Trains model
    def _step(self,
              memory,
              shared_model,
              env,
              optimizer,
              off_policy=True,
              writer=None):
        torch.autograd.set_detect_anomaly(True)
        # starter code from https://github.com/Kaixhin/ACER/
        action_size = memory.memory[0].action.size(0)
        policy_loss, value_loss = 0, 0
        l2_reg = None

        # Calculate n-step returns in forward view, stepping backwards from the last state
        t = len(memory)
        if t <= 1:
            return
        for state, action, reward, b_dis, done in reversed(memory.memory):
            p, q, v, a, p_dis, sampled_action, q_prime = self.agent_forward(
                env, shared_model, action, state)
            average_p, average_action_rvs = self.agent_forward(
                env,
                self.shared_average_model,
                action,
                state,
                grad=False,
                stats_only=True)

            if done and t == len(memory):
                q_ret = torch.zeros_like(env.state[0]).to(
                    shared_model.module.device)
                q_opc = q_ret.clone()
            elif t == len(memory):
                q_ret = v.detach()
                q_opc = q_ret.clone()
                t -= 1
                continue  # here q_ret is for current step, need one more step for estimation

            if off_policy:
                # could also try relation of variances here
                rho = (p_dis.prob(action).detach()) \
                      / (b_dis.prob(action).detach())
                rho_prime = (p_dis.prob(sampled_action).detach()) \
                            / (b_dis.prob(sampled_action).detach())
                c = rho.pow(1 / action_size).clamp(max=1)
                # c = rho.clamp(max=1)
            else:
                rho = torch.ones(1, action_size).to(shared_model.module.device)
                rho_prime = torch.ones(1, action_size).to(
                    shared_model.module.device)

            # Qret ← r_i + γQret
            q_ret = reward + self.args.discount * q_ret
            q_opc = reward + self.args.discount * q_opc

            bias_weight = (1 - (self.args.trace_max / rho_prime)).clamp(min=0)

            # KL divergence k ← ∇θ0∙DKL[π(∙|s_i; θ_a) || π(∙|s_i; θ)]
            k = (p.detach() - average_p) / (self.args.p_sigma**4)
            g = rho.clamp(max=self.args.trace_max) * (
                q_opc - v.detach()) * p_dis.grad_pdf_mu(action).detach()
            if off_policy:
                g = g + bias_weight * (q_prime - v.detach(
                )) * p_dis.grad_pdf_mu(sampled_action).detach()
            # Policy update dθ ← dθ + ∂θ/∂θ∙z*
            z_star = self._trust_region(g, k)
            tr_loss = (z_star * p * self.args.trust_region_weight).mean()

            # policy_loss = policy_loss - tr_loss

            # vanilla policy gradient with importance sampling
            lp = p_dis.log_prob(action)
            policy_loss = policy_loss - (rho.clamp(max=self.args.trace_max) *
                                         lp *
                                         (q.detach() - v.detach())).mean()

            # Value update dθ ← dθ - ∇θ∙1/2∙(Qret - Q(s_i, a_i; θ))^2
            value_loss = value_loss + (-(q_ret - q.detach()) *
                                       q).mean()  # Least squares loss
            value_loss = value_loss + (-(rho.clamp(max=1) *
                                         (q_ret - q.detach()) * v)).mean()

            # Qret ← ρ¯_a_i∙(Qret - Q(s_i, a_i; θ)) + V(s_i; θ)
            q_ret = c * (q_ret - q.detach()) + v.detach()
            q_opc = (q_ret - q.detach()) + v.detach()
            t -= 1

            if self.args.l2_reg_params_weight != 0:
                for W in list(shared_model.parameters()):
                    if l2_reg is None:
                        l2_reg = W.norm(2)
                    else:
                        l2_reg = l2_reg + W.norm(2)
            if l2_reg is None:
                l2_reg = 0

        if writer is not None:
            writer.add_scalar("loss/critic", value_loss.item(),
                              self.global_writer_loss_count.value())
            writer.add_scalar("loss/actor", policy_loss.item(),
                              self.global_writer_loss_count.value())
            self.global_writer_loss_count.increment()

        # Update networks
        loss = (self.args.p_loss_weight * policy_loss +
                self.args.v_loss_weight * value_loss +
                l2_reg * self.args.l2_reg_params_weight) / len(memory)

        self._update_networks(loss, optimizer, shared_model, writer)

    # Acts and trains model
    def train(self, rank, start_time, return_dict):

        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        self.setup(rank, self.args.num_processes)
        transition = namedtuple(
            'Transition',
            ('state', 'action', 'reward', 'behav_policy_proba', 'done'))
        memory = TransitionData(capacity=self.args.t_max,
                                storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count,
                       discrete_action_space=False)
        # Create shared network
        model = GcnEdgeAngle1dPQA_dueling_1(self.args.n_raw_channels,
                                            self.args.n_embedding_features,
                                            self.args.n_edge_features,
                                            1,
                                            self.args.exp_steps,
                                            self.args.p_sigma,
                                            device,
                                            self.args.density_eval_range,
                                            writer=writer)
        if self.args.no_fe_extr_optim:
            for param in model.fe_ext.parameters():
                param.requires_grad = False

        model.cuda(device)
        shared_model = DDP(model, device_ids=[model.device])
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create optimizer for shared network parameters with shared statistics
        optimizer = torch.optim.Adam(shared_model.parameters(),
                                     lr=self.args.lr,
                                     betas=self.args.Adam_betas,
                                     weight_decay=self.args.Adam_weight_decay)

        if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        dist.barrier()

        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        self.shared_average_model.load_state_dict(shared_model.state_dict())

        if not self.args.test_score_only:
            quality = self.args.stop_qual_scaling + self.args.stop_qual_offset
            while self.global_count.value() <= self.args.T_max:
                if self.global_count.value() == 190:
                    a = 1
                self.update_env_data(env, dloader, device)
                env.reset()
                state = [env.state[0].clone(), env.state[1].clone()]

                self.b_sigma = self.b_sigma_rule.apply(
                    self.global_count.value(), quality)
                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value(), quality)

                with open(os.path.join(self.save_dir, 'config.yaml')) as info:
                    args_dict = yaml.full_load(info)
                    if args_dict is not None:
                        if 'eps' in args_dict:
                            if self.args.eps != args_dict['eps']:
                                self.eps = args_dict['eps']
                        if 'safe_model' in args_dict:
                            self.args.safe_model = args_dict['safe_model']
                        if 'add_noise' in args_dict:
                            self.args.add_noise = args_dict['add_noise']

                if writer is not None:
                    writer.add_scalar("step/b_variance", self.b_sigma,
                                      env.writer_counter.value())

                if self.args.safe_model:
                    if rank == 0:
                        if self.args.model_name_dest != "":
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir,
                                             self.args.model_name_dest))
                        else:
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir, 'agent_model'))

                while not env.done:
                    post_input = True if self.global_count.value(
                    ) % 50 and env.counter == 0 else False
                    # Calculate policy and values
                    policy_means, p_dis = self.agent_forward(
                        env,
                        shared_model,
                        grad=False,
                        stats_only=True,
                        post_input=post_input)

                    # Step
                    action, b_rvs = self.get_action(policy_means, p_dis,
                                                    device)
                    state_, reward, quality = env.execute_action(action)

                    if self.args.add_noise:
                        if self.global_count.value(
                        ) > 110 and self.global_count.value() % 5:
                            noise = torch.randn_like(reward) * 0.8
                            reward = reward + noise

                    memory.push(state, action, reward, b_rvs, env.done)

                    # Train the network
                    # self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer)

                    # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                    # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                    state = state_

                # Break graph for last values calculated (used for targets, not directly as model outputs)
                self.global_count.increment()

                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)
                memory.clear()
                # while len(memory) > 0:
                #     self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer)
                #     memory.pop(0)

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save:
                # pass
                if self.args.model_name != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name))
                    print('saved')
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))
            if self.args.cross_validate_hp or self.args.test_score_only:
                test_score = 0
                env.writer = None
                for i in range(20):
                    self.update_env_data(env, dloader, device)
                    env.reset()
                    self.b_sigma = self.args.p_sigma
                    env.stop_quality = 40
                    while not env.done:
                        # Calculate policy and values
                        policy_means, p_dis = self.agent_forward(
                            env, shared_model, grad=False, stats_only=True)
                        action, b_rvs = self.get_action(
                            policy_means, p_dis, device)
                        _, _ = env.execute_action(action,
                                                  self.global_count.value())

                    # import matplotlib.pyplot as plt;
                    # plt.imshow(env.get_current_soln());
                    # plt.show()
                    if env.win:
                        test_score += 1
                return_dict['test_score'] = test_score
                writer.add_text("time_needed", str((time.time() - start_time)))
        self.cleanup()
Пример #8
0
class AgentAcerTrainer(object):
    def __init__(self, args, shared_average_model, global_count,
                 global_writer_loss_count, global_writer_quality_count,
                 global_win_event_count, save_dir):
        super(AgentAcerTrainer, self).__init__()

        self.args = args
        self.shared_average_model = shared_average_model
        self.global_count = global_count
        self.global_writer_loss_count = global_writer_loss_count
        self.global_writer_quality_count = global_writer_quality_count
        self.global_win_event_count = global_win_event_count
        self.writer_idx_warmup_loss = 0
        # self.eps = self.args.init_epsilon
        self.save_dir = save_dir
        if args.stop_qual_rule == 'naive':
            self.stop_qual_rule = NaiveDecay(initial_eps=args.init_stop_qual,
                                             episode_shrinkage=1,
                                             change_after_n_episodes=5)
        elif args.stop_qual_rule == 'gaussian':
            self.stop_qual_rule = GaussianDecay(args.stop_qual_final,
                                                args.stop_qual_scaling,
                                                args.stop_qual_offset,
                                                args.T_max)
        elif args.stop_qual_rule == 'running_average':
            self.stop_qual_rule = RunningAverage(
                args.stop_qual_ra_bw,
                args.stop_qual_scaling + args.stop_qual_offset,
                args.stop_qual_ra_off)
        else:
            self.stop_qual_rule = NaiveDecay(args.init_stop_qual)

        if self.args.eps_rule == "treesearch":
            self.eps_rule = ActionPathTreeNodes()
        elif self.args.eps_rule == "sawtooth":
            self.eps_rule = ExpSawtoothEpsDecay()
        elif self.args.eps_rule == 'gaussian':
            self.eps_rule = GaussianDecay(args.eps_final, args.eps_scaling,
                                          args.eps_offset, args.T_max)
        elif self.args.eps_rule == "self_reg_min":
            self.args.T_max = np.inf
            self.eps_rule = FollowLeadMin(
                (args.stop_qual_scaling + args.stop_qual_offset), 1)
        elif self.args.eps_rule == "self_reg_avg":
            self.args.T_max = np.inf
            self.eps_rule = FollowLeadAvg(
                1.5 * (args.stop_qual_scaling + args.stop_qual_offset), 2, 1)
        elif self.args.eps_rule == "self_reg_exp_avg":
            self.args.T_max = np.inf
            self.eps_rule = ExponentialAverage(
                1.5 * (args.stop_qual_scaling + args.stop_qual_offset), 0.9, 1)
        else:
            self.eps_rule = NaiveDecay(self.eps, 0.00005, 1000, 1)

    def setup(self, rank, world_size):
        # BLAS setup
        os.environ['OMP_NUM_THREADS'] = '1'
        os.environ['MKL_NUM_THREADS'] = '1'

        # os.environ["CUDA_VISIBLE_DEVICES"] = "6"
        # assert torch.cuda.device_count() == 1
        torch.set_default_tensor_type('torch.FloatTensor')
        # Detect if we have a GPU available
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)

        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12354'
        # os.environ['GLOO_SOCKET_IFNAME'] = 'eno1'

        # initialize the process group
        dist.init_process_group("gloo", rank=rank, world_size=world_size)

        # Explicitly setting seed to make sure that models created in two processes
        # start from same random weights and biases.
        torch.manual_seed(self.args.seed)

    def cleanup(self):
        dist.destroy_process_group()

    def update_env_data(self, env, dloader, device):
        edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, affinities, gt = \
            next(iter(dloader))
        angles = None
        edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, affinities, gt = \
            edges.squeeze().to(device), edge_feat.squeeze()[:, 0:self.args.n_edge_features].to(
                device), diff_to_gt.squeeze().to(device), \
            gt_edge_weights.squeeze().to(device), node_labeling.squeeze().to(device), raw.squeeze().to(
                device), nodes.squeeze().to(device), \
            affinities.squeeze().numpy(), gt.squeeze()
        env.update_data(edges, edge_feat, diff_to_gt, gt_edge_weights,
                        node_labeling, raw, nodes, angles, affinities, gt)

    # Updates networks
    def _update_networks(self, loss, optimizer, shared_model, writer=None):
        # Zero shared and local grads
        optimizer.zero_grad()
        """
        Calculate gradients for gradient descent on loss functions
        Note that math comments follow the paper, which is formulated for gradient ascent
        """
        loss.backward()
        # Gradient L2 normalisation
        nn.utils.clip_grad_norm_(shared_model.parameters(),
                                 self.args.max_gradient_norm)
        optimizer.step()
        with open(os.path.join(self.save_dir, 'config.yaml')) as info:
            args_dict = yaml.full_load(info)
            if args_dict is not None:
                if 'lr' in args_dict:
                    if self.args.lr != args_dict['lr']:
                        print("lr changed from ", self.args.lr, " to ",
                              args_dict['lr'], " at loss step ",
                              self.global_writer_loss_count.value())
                        self.args.lr = args_dict['lr']
        self.args.lr = args_dict['lr']
        new_lr = self.args.lr
        if self.args.min_lr != 0 and self.eps <= 0.6:
            # Linearly decay learning rate
            # new_lr = self.args.lr - ((self.args.lr - self.args.min_lr) * (1 - (self.eps * 2))) # (1 - max((self.args.T_max - self.global_count.value()) / self.args.T_max, 1e-32)))
            new_lr = self.args.lr * 10**(-(0.6 - self.eps))

        adjust_learning_rate(optimizer, new_lr)
        if writer is not None:
            writer.add_scalar("loss/learning_rate", new_lr,
                              self.global_writer_loss_count.value())

        # Update shared_average_model
        for shared_param, shared_average_param in zip(
                shared_model.parameters(),
                self.shared_average_model.parameters()):
            shared_average_param.data = self.args.trust_region_decay * shared_average_param.data + (
                1 - self.args.trust_region_decay) * shared_param.data

    # Computes an "efficient trust region" loss (policy head only) based on an existing loss and two distributions
    def _trust_region(self, g, k):
        # Compute dot products of gradients
        k_dot_g = (k * g).sum(0)
        k_dot_k = (k**2).sum(0)
        # Compute trust region update
        trust_factor = ((k_dot_g - self.args.trust_region_threshold) /
                        k_dot_k).clamp(min=0)
        z = g - trust_factor * k
        return z

    def get_action(self, action_probs, q, v, policy, waff_dis, device):
        if policy == 'off_sampled':
            behav_probs = action_probs.detach() + self.eps * (
                1 / self.args.n_actions - action_probs.detach())
            actions = torch.multinomial(behav_probs, 1).squeeze()
        elif policy == 'off_uniform':
            randm_draws = int(self.eps * len(action_probs))
            if randm_draws > 0:
                actions = action_probs.max(-1)[1].squeeze()
                # randm_indices = torch.multinomial(torch.ones(len(action_probs)) / len(action_probs), randm_draws)
                randm_indices = torch.multinomial(waff_dis, randm_draws)
                actions[randm_indices] = torch.randint(
                    0, self.args.n_actions, (randm_draws, )).to(device)
                behav_probs = action_probs.detach()
                behav_probs[randm_indices] = torch.Tensor([
                    1 / self.args.n_actions for i in range(self.args.n_actions)
                ]).to(device)
            else:
                actions = action_probs.max(-1)[1].squeeze()
                behav_probs = action_probs.detach()
        elif policy == 'on':
            actions = action_probs.max(-1)[1].squeeze()
            behav_probs = action_probs.detach()
        elif policy == 'q_val':
            actions = q.max(-1)[1].squeeze()
            behav_probs = action_probs.detach()

        # log_probs = torch.log(sel_behav_probs)
        # entropy = - (behav_probs * torch.log(behav_probs)).sum()
        return actions, behav_probs

    def fe_extr_warm_start(self, sp_feature_ext, writer=None):
        dataloader = DataLoader(
            MultiDiscSpGraphDset(length=self.args.fe_warmup_iterations * 10),
            batch_size=10,
            shuffle=True,
            pin_memory=True)
        criterion = ContrastiveLoss(delta_var=0.5, delta_dist=1.5)
        optimizer = torch.optim.Adam(sp_feature_ext.parameters())
        for i, (data, gt) in enumerate(dataloader):
            data, gt = data.to(sp_feature_ext.device), gt.to(
                sp_feature_ext.device)
            pred = sp_feature_ext(data)
            loss = criterion(pred, gt)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if writer is not None:
                writer.add_scalar("loss/fe_warm_start", loss.item(),
                                  self.writer_idx_warmup_loss)
                self.writer_idx_warmup_loss += 1

    def agent_forward(self,
                      env,
                      model,
                      state=None,
                      grad=True,
                      post_input=False):
        with torch.set_grad_enabled(grad):
            if state is None:
                state = env.state
            inp = [
                obj.float().to(model.module.device)
                for obj in state + [env.raw, env.init_sp_seg]
            ]
            return model(inp,
                         sp_indices=env.sp_indices,
                         edge_index=env.edge_ids.to(model.module.device),
                         angles=env.edge_angles,
                         edge_features_1d=env.edge_features.to(
                             model.module.device),
                         round_n=env.counter,
                         post_input=post_input)

    # Trains model
    def _step(self,
              memory,
              shared_model,
              env,
              optimizer,
              loss_weight,
              off_policy=True,
              writer=None):
        # starter code from https://github.com/Kaixhin/ACER/
        action_size = memory.memory[0].action.size(0)
        policy_loss, value_loss, cum_side_loss = 0, 0, 0
        l2_reg = None

        # Calculate n-step returns in forward view, stepping backwards from the last state
        t = len(memory)
        if t <= 1:
            return
        for state, action, reward, behav_policy_proba, done in reversed(
                memory.memory):
            policy_proba, q, v, side_loss = self.agent_forward(
                env, shared_model, state)
            average_policy_proba, _, _, _ = self.agent_forward(
                env, self.shared_average_model, state, grad=False)
            tr_loss, p_loss = 0, 0
            if done and t == len(memory):
                q_ret_t = torch.zeros_like(env.state[0]).to(
                    shared_model.module.device)
            elif t == len(memory):
                q_ret_t = v.detach()
                t -= 1
                continue  # here q_ret is for current step, need one more step for estimation
            # Importance sampling weights ρ ← π(∙|s_i) / µ(∙|s_i); 1 for on-policy
            if off_policy:
                rho = policy_proba.detach() / behav_policy_proba.detach()
            else:
                rho = torch.ones(1, action_size)

            # Qret ← r_i + γQret
            q_ret_t = reward + self.args.discount * q_ret_t
            # Advantage A ← Qret - V(s_i; θ)
            adv = q.detach() - v.unsqueeze(-1).detach()
            adv_ret_t = q_ret_t.detach() - v.detach()

            policy_proba_t = policy_proba.gather(
                -1, action.unsqueeze(-1)).squeeze()
            rho_t = rho.gather(-1, action.unsqueeze(-1)).squeeze()
            bias_weight = (1 - self.args.trace_max / rho).clamp(min=0)

            # if not self.args.trust_region:
            # g ← min(c, ρ_a_i)∙∇θ∙log(π(a_i|s_i; θ))∙A
            p_loss = rho_t.clamp(
                max=self.args.trace_max) * policy_proba_t.log() * adv_ret_t
            # Off-policy bias correction
            if off_policy:
                # g ← g + Σ_a [1 - c/ρ_a]_+∙π(a|s_i; θ)∙∇θ∙log(π(a|s_i; θ))∙(Q(s_i, a; θ) - V(s_i; θ)
                p_loss = p_loss + ((bias_weight * policy_proba.log() * adv) *
                                   policy_proba.detach()).sum(-1)
            # if self.args.trust_region:
            #     # KL divergence k ← ∇θ0∙DKL[π(∙|s_i; θ_a) || π(∙|s_i; θ)]
            #     k = (- average_policy_proba / (policy_proba.detach() + 1e-10)).sum(-1)
            #     g = rho_t.clamp(max=self.args.trace_max) * adv_ret_t / policy_proba_t.detach()
            #     if off_policy:
            #         g = g + (bias_weight * adv).sum(-1)
            #     # Policy update dθ ← dθ + ∂θ/∂θ∙z*
            #     z_star = self._trust_region(g, k)
            #     tr_loss = (z_star * policy_proba_t * self.args.trust_region_weight).mean()
            #
            # # Entropy regularisation dθ ← dθ + β∙∇θH(π(s_i; θ))
            # entropy_loss = (-(self.args.entropy_weight * (policy_proba.log() * policy_proba).sum(-1))).mean()
            # policy_loss = policy_loss - tr_loss - p_loss - entropy_loss

            policy_loss = policy_loss - (loss_weight * p_loss).sum()

            # Value update dθ ← dθ - ∇θ∙1/2∙(Qret - Q(s_i, a_i; θ))^2
            q_t = q.gather(-1, action.unsqueeze(-1)).squeeze()
            value_loss = value_loss + (loss_weight * (
                (q_ret_t - q_t)**2 / 2)).sum()  # Least squares loss

            cum_side_loss = cum_side_loss + side_loss

            # Qret ← ρ¯_a_i∙(Qret - Q(s_i, a_i; θ)) + V(s_i; θ)
            q_ret_t = rho_t.clamp(max=self.args.trace_max) * (
                q_ret_t - q_t.detach()) + v.detach()
            t -= 1

            if self.args.l2_reg_params_weight != 0:
                for W in list(shared_model.parameters()):
                    if l2_reg is None:
                        l2_reg = W.norm(2)
                    else:
                        l2_reg = l2_reg + W.norm(2)
            if l2_reg is None:
                l2_reg = 0

        if writer is not None:
            writer.add_scalar("loss/critic", value_loss.item(),
                              self.global_writer_loss_count.value())
            writer.add_scalar("loss/actor", policy_loss.item(),
                              self.global_writer_loss_count.value())
            self.global_writer_loss_count.increment()

        # Update networks
        loss = (self.args.p_loss_weight * policy_loss
                + self.args.v_loss_weight * value_loss
                + l2_reg * self.args.l2_reg_params_weight) \
               / len(memory) + cum_side_loss * 5
        torch.autograd.set_detect_anomaly(True)
        self._update_networks(loss, optimizer, shared_model, writer)

    # Acts and trains model
    def train(self, rank, start_time, return_dict):

        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        self.setup(rank, self.args.num_processes)
        transition = namedtuple(
            'Transition',
            ('state', 'action', 'reward', 'behav_policy_proba', 'done'))
        memory = TransitionData(capacity=self.args.t_max,
                                storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count)
        # Create shared network
        model = GcnEdgeAngle1dPQV(self.args.n_raw_channels,
                                  self.args.n_embedding_features,
                                  self.args.n_edge_features,
                                  self.args.n_actions,
                                  device,
                                  writer=writer)

        if self.args.no_fe_extr_optim:
            for param in model.fe_ext.parameters():
                param.requires_grad = False

        model.cuda(device)
        shared_model = DDP(model, device_ids=[model.device])
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create optimizer for shared network parameters with shared statistics
        # optimizer = CstmAdam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas,
        #                      weight_decay=self.args.Adam_weight_decay)
        optimizer = Adam(shared_model.parameters(), lr=self.args.lr)

        if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        dist.barrier()

        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        self.shared_average_model.load_state_dict(shared_model.state_dict())

        if not self.args.test_score_only:
            quality = self.args.stop_qual_scaling + self.args.stop_qual_offset
            while self.global_count.value() <= self.args.T_max:
                if self.global_count.value() == 78:
                    a = 1
                self.update_env_data(env, dloader, device)
                # waff_dis = torch.softmax(env.edge_features[:, 0].squeeze() + 1e-30, dim=0)
                # waff_dis = torch.softmax(env.gt_edge_weights + 0.5, dim=0)
                waff_dis = torch.softmax(torch.ones_like(env.gt_edge_weights),
                                         dim=0)
                loss_weight = torch.softmax(env.gt_edge_weights + 1, dim=0)
                env.reset()
                state = [env.state[0].clone(), env.state[1].clone()]

                self.eps = self.eps_rule.apply(self.global_count.value(),
                                               quality)
                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value(), quality)

                with open(os.path.join(self.save_dir, 'config.yaml')) as info:
                    args_dict = yaml.full_load(info)
                    if args_dict is not None:
                        if 'eps' in args_dict:
                            if self.args.eps != args_dict['eps']:
                                self.eps = args_dict['eps']
                        if 'safe_model' in args_dict:
                            self.args.safe_model = args_dict['safe_model']
                        if 'add_noise' in args_dict:
                            self.args.add_noise = args_dict['add_noise']

                if writer is not None:
                    writer.add_scalar("step/epsilon", self.eps,
                                      env.writer_counter.value())

                if self.args.safe_model:
                    if rank == 0:
                        if self.args.model_name_dest != "":
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir,
                                             self.args.model_name_dest))
                        else:
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir, 'agent_model'))

                while not env.done:
                    post_input = True if (
                        self.global_count.value() +
                        1) % 1000 == 0 and env.counter == 0 else False
                    # Calculate policy and values
                    policy_proba, q, v, _ = self.agent_forward(
                        env, shared_model, grad=False, post_input=post_input)

                    # Step
                    action, behav_policy_proba = self.get_action(
                        policy_proba, q, v, 'off_uniform', waff_dis, device)
                    state_, reward, quality = env.execute_action(action)

                    if self.args.add_noise:
                        if self.global_count.value(
                        ) > 110 and self.global_count.value() % 5:
                            noise = torch.randn_like(reward) * 0.8
                            reward = reward + noise

                    memory.push(state, action, reward, behav_policy_proba,
                                env.done)

                    # Train the network
                    # self._step(memory, shared_model, env, optimizer, loss_weight, off_policy=True, writer=writer)

                    # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                    # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                    state = state_

                self.global_count.increment()
                if self.args.eps_rule == "self_regulating" and quality <= 0:
                    break

                while len(memory) > 0:
                    self._step(memory,
                               shared_model,
                               env,
                               optimizer,
                               loss_weight,
                               off_policy=True,
                               writer=writer)
                    memory.clear()

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save:
                # pass
                if self.args.model_name != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name))
                    print('saved')
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))
            if self.args.cross_validate_hp or self.args.test_score_only:
                test_score = 0
                env.writer = None
                for i in range(20):
                    self.update_env_data(env, dloader, device)
                    env.reset()
                    self.eps = 0
                    env.stop_quality = 40
                    while not env.done:
                        # Calculate policy and values
                        policy_proba, q, v, _ = self.agent_forward(
                            env, shared_model, grad=False)
                        action, behav_policy_proba = self.get_action(
                            policy_proba,
                            q,
                            v,
                            policy='off_uniform',
                            device=device)
                        _, _ = env.execute_action(action,
                                                  self.global_count.value())

                    # import matplotlib.pyplot as plt;
                    # plt.imshow(env.get_current_soln());
                    # plt.show()
                    if env.win:
                        test_score += 1
                return_dict['test_score'] = test_score
                writer.add_text("time_needed", str((time.time() - start_time)))

        self.cleanup()