예제 #1
0
    def __init__(self,
                 gamma=1,
                 eps=0.9,
                 eps_min=0.000001,
                 replace_cnt=2,
                 mem_size=10):
        super(QlAgent1, self).__init__()
        self.gamma = gamma
        self.init_eps = eps
        self.eps = self.init_eps
        self.eps_min = eps_min

        self.mem = TransitionData(capacity=mem_size)
        self.learn_steps = 0
        self.steps = 0
        self.replace_cnt = replace_cnt
예제 #2
0
class QlAgent1(object):
    # different mem management
    def __init__(self,
                 gamma=1,
                 eps=0.9,
                 eps_min=0.000001,
                 replace_cnt=2,
                 mem_size=10):
        super(QlAgent1, self).__init__()
        self.gamma = gamma
        self.init_eps = eps
        self.eps = self.init_eps
        self.eps_min = eps_min

        self.mem = TransitionData(capacity=mem_size)
        self.learn_steps = 0
        self.steps = 0
        self.replace_cnt = replace_cnt

    def reset_eps(self, eps):
        self.eps = eps

    def safe_models(self, directory):
        return

    def load_model(self, directory):
        return

    def store_transit(self, state, action, reward, next_state_, time,
                      behav_probs, terminal):
        self.mem.push(state, action, reward, next_state_, time, behav_probs,
                      terminal)

    def get_action(self, state):
        return

    def learn(self, batch_size):
        pass
예제 #3
0
    def train(self, rank, start_time, return_dict):
        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        self.setup(rank, self.args.num_processes)

        transition = namedtuple('Transition',
                                ('state', 'action', 'reward', 'state_',
                                 'behav_policy_proba', 'time', 'terminal'))
        memory = TransitionData(capacity=self.args.t_max,
                                storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count)
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create shared network
        model = GcnEdgeAngle1dPQV(self.args.n_raw_channels,
                                  self.args.n_embedding_features,
                                  self.args.n_edge_features,
                                  self.args.n_actions, device)
        model.cuda(device)
        shared_model = DDP(model, device_ids=[model.device])
        # Create optimizer for shared network parameters with shared statistics
        optimizer = CstmAdam(shared_model.parameters(),
                             lr=self.args.lr,
                             betas=self.args.Adam_betas,
                             weight_decay=self.args.Adam_weight_decay)

        if self.args.fe_extr_warmup and rank == 0:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))
        dist.barrier()
        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        self.shared_damped_model.load_state_dict(shared_model.state_dict())
        env.done = True  # Start new episode
        while self.global_count.value() <= self.args.T_max:
            if env.done:
                edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \
                    next(iter(dloader))
                edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \
                    edges.squeeze().to(device), edge_feat.squeeze()[:, 0:self.args.n_edge_features].to(
                        device), diff_to_gt.squeeze().to(device), \
                    gt_edge_weights.squeeze().to(device), node_labeling.squeeze().to(device), raw.squeeze().to(
                        device), nodes.squeeze().to(device), \
                    angles.squeeze().to(device), affinities.squeeze().numpy(), gt.squeeze()
                env.update_data(edges, edge_feat, diff_to_gt, gt_edge_weights,
                                node_labeling, raw, nodes, angles, affinities,
                                gt)
                env.reset()
                state = [env.state[0].clone(), env.state[1].clone()]
                episode_length = 0

                self.eps = self.eps_rule.apply(self.global_count.value())
                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value())
                if writer is not None:
                    writer.add_scalar("step/epsilon", self.eps,
                                      env.writer_counter.value())

            while not env.done:
                # Calculate policy and values
                policy_proba, q, v = self.agent_forward(env,
                                                        shared_model,
                                                        grad=False)
                # average_policy_proba, _, _ = self.agent_forward(env, self.shared_average_model)
                # q_ret = v.detach()

                # Sample action
                # action = torch.multinomial(policy, 1)[0, 0]

                # Step
                action, behav_policy_proba = self.get_action(
                    policy_proba, q, v, policy='off_uniform', device=device)
                state_, reward = env.execute_action(action,
                                                    self.global_count.value())

                memory.push(state, action,
                            reward.to(shared_model.module.device), state_,
                            behav_policy_proba, episode_length, env.done)

                # Train the network
                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)

                # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                episode_length += 1  # Increase episode counter
                state = state_

            # Break graph for last values calculated (used for targets, not directly as model outputs)
            self.global_count.increment()
            # Qret = 0 for terminal s

            while len(memory) > 0:
                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)
                memory.pop(0)

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp:
                if self.args.model_name != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name))
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))
            else:
                test_score = 0
                env.writer = None
                for i in range(20):
                    self.update_env_data(env, dloader, device)
                    env.reset()
                    self.eps = 0
                    while not env.done:
                        # Calculate policy and values
                        policy_proba, q, v = self.agent_forward(env,
                                                                shared_model,
                                                                grad=False)
                        action, behav_policy_proba = self.get_action(
                            policy_proba,
                            q,
                            v,
                            policy='off_uniform',
                            device=device)
                        _, _ = env.execute_action(action,
                                                  self.global_count.value())
                    if env.win:
                        test_score += 1
                return_dict['test_score'] = test_score
                writer.add_text("time_needed", str((time.time() - start_time)))
        self.cleanup()
예제 #4
0
    def train(self, rank, start_time, return_dict):

        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        self.setup(rank, self.args.num_processes)
        transition = namedtuple(
            'Transition',
            ('state', 'action', 'reward', 'behav_policy_proba', 'done'))
        memory = TransitionData(capacity=self.args.t_max,
                                storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count,
                       discrete_action_space=False)
        # Create shared network
        model = GcnEdgeAngle1dPQA_dueling_1(self.args.n_raw_channels,
                                            self.args.n_embedding_features,
                                            self.args.n_edge_features,
                                            1,
                                            self.args.exp_steps,
                                            self.args.p_sigma,
                                            device,
                                            self.args.density_eval_range,
                                            writer=writer)
        if self.args.no_fe_extr_optim:
            for param in model.fe_ext.parameters():
                param.requires_grad = False

        model.cuda(device)
        shared_model = DDP(model, device_ids=[model.device])
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create optimizer for shared network parameters with shared statistics
        optimizer = torch.optim.Adam(shared_model.parameters(),
                                     lr=self.args.lr,
                                     betas=self.args.Adam_betas,
                                     weight_decay=self.args.Adam_weight_decay)

        if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        dist.barrier()

        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        self.shared_average_model.load_state_dict(shared_model.state_dict())

        if not self.args.test_score_only:
            quality = self.args.stop_qual_scaling + self.args.stop_qual_offset
            while self.global_count.value() <= self.args.T_max:
                if self.global_count.value() == 190:
                    a = 1
                self.update_env_data(env, dloader, device)
                env.reset()
                state = [env.state[0].clone(), env.state[1].clone()]

                self.b_sigma = self.b_sigma_rule.apply(
                    self.global_count.value(), quality)
                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value(), quality)

                with open(os.path.join(self.save_dir, 'config.yaml')) as info:
                    args_dict = yaml.full_load(info)
                    if args_dict is not None:
                        if 'eps' in args_dict:
                            if self.args.eps != args_dict['eps']:
                                self.eps = args_dict['eps']
                        if 'safe_model' in args_dict:
                            self.args.safe_model = args_dict['safe_model']
                        if 'add_noise' in args_dict:
                            self.args.add_noise = args_dict['add_noise']

                if writer is not None:
                    writer.add_scalar("step/b_variance", self.b_sigma,
                                      env.writer_counter.value())

                if self.args.safe_model:
                    if rank == 0:
                        if self.args.model_name_dest != "":
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir,
                                             self.args.model_name_dest))
                        else:
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir, 'agent_model'))

                while not env.done:
                    post_input = True if self.global_count.value(
                    ) % 50 and env.counter == 0 else False
                    # Calculate policy and values
                    policy_means, p_dis = self.agent_forward(
                        env,
                        shared_model,
                        grad=False,
                        stats_only=True,
                        post_input=post_input)

                    # Step
                    action, b_rvs = self.get_action(policy_means, p_dis,
                                                    device)
                    state_, reward, quality = env.execute_action(action)

                    if self.args.add_noise:
                        if self.global_count.value(
                        ) > 110 and self.global_count.value() % 5:
                            noise = torch.randn_like(reward) * 0.8
                            reward = reward + noise

                    memory.push(state, action, reward, b_rvs, env.done)

                    # Train the network
                    # self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer)

                    # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                    # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                    state = state_

                # Break graph for last values calculated (used for targets, not directly as model outputs)
                self.global_count.increment()

                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)
                memory.clear()
                # while len(memory) > 0:
                #     self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer)
                #     memory.pop(0)

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save:
                # pass
                if self.args.model_name != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name))
                    print('saved')
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))
            if self.args.cross_validate_hp or self.args.test_score_only:
                test_score = 0
                env.writer = None
                for i in range(20):
                    self.update_env_data(env, dloader, device)
                    env.reset()
                    self.b_sigma = self.args.p_sigma
                    env.stop_quality = 40
                    while not env.done:
                        # Calculate policy and values
                        policy_means, p_dis = self.agent_forward(
                            env, shared_model, grad=False, stats_only=True)
                        action, b_rvs = self.get_action(
                            policy_means, p_dis, device)
                        _, _ = env.execute_action(action,
                                                  self.global_count.value())

                    # import matplotlib.pyplot as plt;
                    # plt.imshow(env.get_current_soln());
                    # plt.show()
                    if env.win:
                        test_score += 1
                return_dict['test_score'] = test_score
                writer.add_text("time_needed", str((time.time() - start_time)))
        self.cleanup()