コード例 #1
0
ファイル: unroll_methods.py プロジェクト: phate09/SafeDRL
def compute_remaining_intervals4_multi(current_intervals: List[HyperRectangle], tree: index.Index, rounding: int, debug=True) -> Tuple[
    List[HyperRectangle], List[Tuple[HyperRectangle, List[HyperRectangle]]]]:
    intervals_with_relevants = []  #: List[Tuple[HyperRectangle, HyperRectangle_action]]
    for i, interval in enumerate(current_intervals):
        relevant_intervals = filter_relevant_intervals3(tree, interval)
        intervals_with_relevants.append((interval, relevant_intervals))
    remain_list = []  #: List[HyperRectangle]
    proc_ids = []
    chunk_size = 200
    intersection_list = []  # list with intervals and associated intervals with action assigned: List[Tuple[HyperRectangle, List[HyperRectangle]]]
    for i, chunk in enumerate(utils.chunks(intervals_with_relevants, chunk_size)):
        proc_ids.append(compute_remaining_intervals_remote.remote(chunk, False))  # if debug:  #     bar.update(i)
    with StandardProgressBar(prefix="Computing remaining intervals ", max_value=len(proc_ids)) if debug else nullcontext() as bar:
        while len(proc_ids) != 0:
            ready_ids, proc_ids = ray.wait(proc_ids)
            results = ray.get(ready_ids[0])
            for result in results:
                if result is not None:
                    (remain, intersection), previous_interval = result
                    remain_list.extend(remain)
                    assigned = []  #: List[HyperRectangle]
                    assigned.extend(intersection)
                    intersection_list.append((previous_interval, assigned))
            if debug:
                bar.update(bar.value + 1)

    return remain_list, intersection_list
コード例 #2
0
ファイル: unroll_methods.py プロジェクト: phate09/SafeDRL
def abstract_step(abstract_states_normalised: List[HyperRectangle_action], env_class, n_workers: int, rounding: int, probabilistic=False):
    """
    Given some abstract states, compute the next abstract states taking the action passed as parameter
    :param env:
    :param abstract_states_normalised: the abstract states from which to start, list of tuples of intervals
    :return: the next abstract states after taking the action (array)
    """
    next_states = []
    terminal_states = defaultdict(bool)
    half_terminal_states = defaultdict(bool)
    chunk_size = 1000
    n_chunks = ceil(len(abstract_states_normalised) / chunk_size)
    workers = cycle([AbstractStepWorker.remote(rounding, env_class, probabilistic) for _ in range(min(n_workers, n_chunks))])
    proc_ids = []
    with StandardProgressBar(prefix="Preparing AbstractStepWorkers ", max_value=n_chunks) as bar:
        for i, intervals in enumerate(utils.chunks(abstract_states_normalised, chunk_size)):
            proc_ids.append(next(workers).work.remote(intervals))
            bar.update(i)
    with StandardProgressBar(prefix="Performing abstract step ", max_value=len(proc_ids)) as bar:
        while len(proc_ids) != 0:
            ready_ids, proc_ids = ray.wait(proc_ids, num_returns=min(10, len(proc_ids)), timeout=0.5)
            results = ray.get(ready_ids)  #: Tuple[List[Tuple[HyperRectangle_action, List[HyperRectangle]]], dict, dict]
            bar.update(bar.value + len(results))
            for next_states_local, half_terminal_states_local, terminal_states_local in results:
                for next_state_key in next_states_local:
                    next_states.append((next_state_key, next_states_local[next_state_key]))
                terminal_states.update(terminal_states_local)
                half_terminal_states.update(half_terminal_states_local)
    return next_states, half_terminal_states, terminal_states
コード例 #3
0
ファイル: unroll_methods.py プロジェクト: phate09/SafeDRL
def merge_supremum3(starting_intervals: List[HyperRectangle], n_workers: int, precision: int, positional_method=False, show_bar=True) -> List[HyperRectangle]:
    if len(starting_intervals) <= 1:
        return starting_intervals
    dimensions = len(starting_intervals[0])
    # generate tree
    intervals_dummy_action = [(x, True) for x in starting_intervals]
    tree = utils.create_tree(intervals_dummy_action)
    # find bounds
    boundaries = compute_boundaries(intervals_dummy_action)
    if positional_method:
        # split
        split_list = [boundaries]
        n_splits = 10
        for i in range(n_splits):
            domain = split_list.pop(0)
            splitted_domains = DomainExplorer.box_split_tuple(domain, precision)
            split_list.extend(splitted_domains)

        # find relevant intervals
        working_list = []
        for domain in split_list:
            relevant_list = list(tree.intersection(utils.flatten_interval(domain), objects='raw'))
            local_working_list = []
            # resize intervals
            for relevant, action in relevant_list:
                resized = utils.shrink(relevant, domain)
                local_working_list.append((resized, action))
            working_list.append(local_working_list)
    else:
        working_list = list(utils.chunks(starting_intervals, min(1000, max(int(len(starting_intervals) / n_workers), 1))))
    # intervals = starting_intervals
    merged_list = []  #: List[HyperRectangle]
    proc_ids = []
    with StandardProgressBar(prefix="Merging intervals", max_value=len(working_list)) if show_bar else nullcontext() as bar:
        while len(working_list) != 0 or len(proc_ids) != 0:
            while len(proc_ids) < n_workers and len(working_list) != 0:
                intervals = working_list.pop()
                proc_ids.append(merge_supremum2_remote.remote(intervals))
            ready_ids, proc_ids = ray.wait(proc_ids, num_returns=len(proc_ids), timeout=0.5)
            results = ray.get(ready_ids)
            for result in results:
                if result is not None:
                    merged_list.extend(result)
            if show_bar:
                bar.update(bar.value + len(ready_ids))
    new_merged_list = merge_supremum2(merged_list)
    # show_plot(merged_list, new_merged_list)
    return new_merged_list
コード例 #4
0
ファイル: unroll_methods.py プロジェクト: phate09/SafeDRL
def premerge(intersected_intervals, n_workers, rounding: int, show_bar=True):
    proc_ids = []
    merged_list = [(interval_noaction, successors) for interval_noaction, successors in intersected_intervals if len(successors) <= 1]
    working_list = list(utils.chunks([(interval_noaction, successors) for interval_noaction, successors in intersected_intervals if len(successors) > 1], 200))
    with StandardProgressBar(prefix="Premerging intervals ", max_value=len(working_list)) if show_bar else nullcontext() as bar:
        while len(working_list) != 0 or len(proc_ids) != 0:
            while len(proc_ids) < n_workers and len(working_list) != 0:
                intervals = working_list.pop(0)
                proc_ids.append(merge_successors.remote(intervals, rounding))
            ready_ids, proc_ids = ray.wait(proc_ids, num_returns=len(proc_ids), timeout=0.5)
            results = ray.get(ready_ids)
            for result in results:
                if result is not None:
                    merged_list.extend(result)
            if show_bar:
                bar.update(bar.value + len(ready_ids))

    return merged_list
コード例 #5
0
def dqn(n_episodes=2000, max_t=1000, eps_start=1.0, eps_end=MIN_EPS):
    scores = []  # list containing scores from each episode
    timesteps = []  # list containing number of timesteps from each episode
    scores_window = deque(maxlen=100)  # last 100 scores
    timesteps_window = deque(maxlen=100)  # last 100 timesteps
    n_traces = 100
    betas = Scheduler(STARTING_BETA, 1.0, n_episodes)
    eps = Scheduler(eps_start, eps_end, int(n_episodes * EPS_DECAY))
    # val_data = GridSearchDataset(shuffle=True)
    for i_episode in range(n_episodes):
        invariant_dataset = []
        for trace in range(n_traces):
            # state = env.reset()  # reset the environment
            # state = val_data[i_episode%len(val_data)]
            state = np.array(
                [np.random.uniform(-10, 30),
                 np.random.uniform(-10, 40)])
            t = 0
            temp_dataset = []
            action = None
            agent_error = None
            done = False
            done_once = False
            for t in range(max_t):
                action = agent.act(state)

                # next_state, reward, done, _ = env.step(action)  # send the action to the environment
                next_state, reward, done, _ = env.compute_successor(
                    state, action)
                temp_dataset.append((state, action, next_state))
                # agent_error = agent.step(state, action, reward, next_state, done, beta=betas.get(i_episode))
                # if agent_error is not None:
                #     writer.add_scalar('loss/agent_loss', agent_error.item(), i_episode)
                state = next_state
                if done:
                    break
            if done:
                for state_np, action, next_state_np in temp_dataset:
                    invariant_dataset.append(
                        (state_np.astype(dtype=np.float32), action,
                         next_state_np.astype(dtype=np.float32), -1))
            else:
                for state_np, action, next_state_np in temp_dataset:
                    invariant_dataset.append(
                        (state_np.astype(dtype=np.float32), action,
                         next_state_np.astype(dtype=np.float32), 1))

        agent2.ioptimizer.zero_grad()  # resets the gradient
        losses = []
        for chunk in chunks(invariant_dataset, 256):
            states, actions, successors, flags = zip(*chunk)
            loss = SafetyLoss(agent2.inetwork_local, agent2.inetwork_target)
            loss_value = loss(
                torch.tensor(states).to(device),
                torch.tensor(actions).to(device),
                torch.tensor(successors).to(device),
                torch.tensor(flags, dtype=torch.float32).to(device))
            loss_value.backward()
            agent2.ioptimizer.step()  # updates the invariant
            agent2.soft_update(agent2.inetwork_local, agent2.inetwork_target,
                               TAU)
            losses.append(loss_value.item())
        invariant_loss = np.array(losses).mean()
        scores_window.append(invariant_loss)  # save most recent score
        scores.append(invariant_loss)  # save most recent score
        writer.add_scalar('data/score', invariant_loss, i_episode)
        writer.add_scalar('data/score_average', np.mean(scores_window),
                          i_episode)
        writer.add_scalar('data/epsilon', eps.get(i_episode), i_episode)
        writer.add_scalar('data/beta', betas.get(i_episode), i_episode)
        if invariant_loss is not None:
            writer.add_scalar('loss/invariant_loss', invariant_loss, i_episode)

        # eps = max(eps_end, eps_decay * eps)  # decrease epsilon
        print(
            f'\rEpisode {i_episode + 1}\tAverage Score: {np.mean(scores_window):.4f}\t eps={eps.get(i_episode):.3f} beta={betas.get(i_episode):.3f}',
            end="")
        if (i_episode + 1) % 100 == 0:
            print(
                f'\rEpisode {i_episode + 1}\tAverage Score: {np.mean(scores_window):.4f} eps={eps.get(i_episode):.3f} beta={betas.get(i_episode):.3f}'
            )
            agent2.save(
                os.path.join(log_dir, f"checkpoint_{i_episode + 1}.pth"),
                i_episode)
    agent2.save(os.path.join(log_dir, f"checkpoint_final.pth"), i_episode)
    return scores