示例#1
0
class Counter(object):
    '''
    A counter used for multiprocessing, simple wrapper around multiprocessing.Value
    '''
    def __init__(self):
        from torch.multiprocessing import Value
        self.val = Value('i', 0)

    def increment(self, n=1):
        with self.val.get_lock():
            self.val.value += n

    def reset(self):
        with self.val.get_lock():
            self.val.value = 0

    @property
    def value(self):
        return self.val.value
def _worker(
    reader: DatasetReader,
    input_queue: Queue,
    output_queue: Queue,
    num_active_workers: Value,
    num_inflight_items: Value,
    worker_id: int,
) -> None:
    """
    A worker that pulls filenames off the input queue, uses the dataset reader
    to read them, and places the generated instances on the output queue.  When
    there are no filenames left on the input queue, it decrements
    num_active_workers to signal completion.
    """
    logger.info(f"Reader worker: {worker_id} PID: {os.getpid()}")
    # Keep going until you get a file_path that's None.
    while True:
        file_path = input_queue.get()
        if file_path is None:
            # It's important that we close and join the queue here before
            # decrementing num_active_workers. Otherwise our parent may join us
            # before the queue's feeder thread has passed all buffered items to
            # the underlying pipe resulting in a deadlock.
            #
            # See:
            # https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#pipes-and-queues
            # https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#programming-guidelines
            output_queue.close()
            output_queue.join_thread()
            # Decrementing is not atomic.
            # See https://docs.python.org/2/library/multiprocessing.html#multiprocessing.Value.
            with num_active_workers.get_lock():
                num_active_workers.value -= 1
            logger.info(f"Reader worker {worker_id} finished")
            break

        logger.info(f"reading instances from {file_path}")
        for instance in reader.read(file_path):
            with num_inflight_items.get_lock():
                num_inflight_items.value += 1
            output_queue.put(instance)
示例#3
0
class Signal(object):
    '''
    a signal used for mutliprocessing, simple wrapper around multiprocessing.Value
    '''
    def __init__(self):
        from torch.multiprocessing import Value
        self.val = Value('i', False)

    def set_signal(self, boolean):
        with self.val.get_lock():
            self.val.value = boolean

    @property
    def value(self):
        return bool(self.val.value)
示例#4
0
class Agent_sync(Agent):
    """
    An agent class will maintain multiple policy net and environments, each worker will have one environment and one policy
    useful for most of single agent RL/IL settings
    """
    def __init__(self, config: ParamDict, environment: Environment,
                 policy: Policy, filter_op: Filter):
        threads, gpu = config.require("threads", "gpu")
        super(Agent_sync, self).__init__(config, environment, policy,
                                         filter_op)

        # sync signal, -1: terminate, 0: normal running, >0 restart and waiting for parameter update
        self._sync_signal = Value('i', 0)

        # sampler sub-process list
        self._sampler_proc = []

        # used for synchronize commands
        self._cmd_pipe = None
        self._param_pipe = None
        self._cmd_lock = Lock()

        cmd_pipe_child, cmd_pipe_parent = Pipe(duplex=True)
        param_pipe_child, param_pipe_parent = Pipe(duplex=False)
        self._cmd_pipe = cmd_pipe_parent
        self._param_pipe = param_pipe_parent
        for i_thread in range(threads):
            child_name = f"sampler_{i_thread}"
            worker_cfg = ParamDict({
                "seed": self.seed + 1024 + i_thread,
                "gpu": gpu
            })
            child = Process(target=Agent_sync._sampler_worker,
                            name=child_name,
                            args=(worker_cfg, cmd_pipe_child, param_pipe_child,
                                  self._cmd_lock, self._sync_signal,
                                  deepcopy(policy), deepcopy(environment),
                                  deepcopy(filter_op)))
            self._sampler_proc.append(child)
            child.start()

    def __del__(self):
        """
        We should terminate all child-process here
        """
        self._sync_signal.value = -1
        sleep(1)
        for _proc in self._sampler_proc:
            _proc.join(2)
            if _proc.is_alive():
                _proc.terminate()

        self._cmd_pipe.close()
        self._param_pipe.close()

    def broadcast(self, config: ParamDict):
        policy_state, filter_state, max_step, self._batch_size, fixed_env, fixed_policy, fixed_filter = \
            config.require("policy state dict", "filter state dict", "trajectory max step", "batch size",
                           "fixed environment", "fixed policy", "fixed filter")

        self._replay_buffer = []
        policy_state["fixed policy"] = fixed_policy
        filter_state["fixed filter"] = fixed_filter
        cmd = ParamDict({
            "trajectory max step": max_step,
            "fixed environment": fixed_env,
            "filter state dict": filter_state
        })

        assert self._sync_signal.value < 1, "Last sync event not finished due to some error, some sub-proc maybe died, abort"
        # tell sub-process to reset
        with self._sync_signal.get_lock():
            self._sync_signal.value = len(self._sampler_proc)

        # sync net parameters
        with self._cmd_lock:
            for _ in range(len(self._sampler_proc)):
                self._param_pipe.send(policy_state)

        # wait for all agents' ready feedback
        while self._sync_signal.value > 0:
            sleep(0.01)

        # sync commands
        for _ in range(self._batch_size):
            self._cmd_pipe.send(cmd)

    def collect(self):
        if self._cmd_pipe.poll(0.1):
            self._replay_buffer.append(self._cmd_pipe.recv())
        if len(self._replay_buffer) < self._batch_size:
            return None
        else:
            batch = self._filter.operate_trajectoryList(self._replay_buffer)
            return batch

    @staticmethod
    def _sampler_worker(setups: ParamDict, pipe_cmd, pipe_param, read_lock,
                        sync_signal, policy, environment, filter_op):
        gpu, seed = setups.require("gpu", "seed")

        device = decide_device(gpu)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        environment.init(display=False)
        filter_op.init()
        filter_op.to_device(device)
        policy.init()
        policy.to_device(device)

        # -1: syncing, 0: waiting for new command, 1: sampling
        local_state = 0
        current_step = None
        step_buffer = []
        cmd = None

        def _get_piped_data(pipe):
            with read_lock:
                if pipe.poll(0.001):
                    return pipe.recv()
                else:
                    return None

        while sync_signal.value >= 0:
            # check sync counter for sync event, and waiting for new parameters
            if sync_signal.value > 0 and local_state >= 0:
                # receive sync signal, reset all workspace settings, decrease sync counter,
                # and set state machine to -1 for not init again
                while _get_piped_data(pipe_cmd) is not None:
                    pass
                step_buffer.clear()
                _policy_state = _get_piped_data(pipe_param)
                if _policy_state is not None:
                    # set new parameters
                    policy.reset(_policy_state)
                    with sync_signal.get_lock():
                        sync_signal.value -= 1
                    local_state = -1

            # if sync ends, tell state machine to recover from syncing state, and reset environment
            elif sync_signal.value == 0 and local_state == -1:
                local_state = 0

            # waiting for states (states are list of dicts)
            elif sync_signal.value == 0 and local_state == 0:
                cmd = _get_piped_data(pipe_cmd)
                if cmd is not None:
                    step_buffer.clear()
                    cmd.require("filter state dict", "fixed environment",
                                "trajectory max step")
                    current_step = environment.reset(
                        random=not cmd["fixed environment"])
                    filter_op.reset(cmd["filter state dict"])
                    local_state = 1

            # sampling
            elif sync_signal.value == 0 and local_state == 1:
                with torch.no_grad():
                    policy_step = filter_op.operate_currentStep(current_step)
                    last_step = policy.step([policy_step])[0]
                last_step, current_step, done = environment.step(last_step)
                record_step = filter_op.operate_recordStep(last_step)
                step_buffer.append(record_step)

                if len(step_buffer) >= cmd["trajectory max step"] or done:
                    traj = filter_op.operate_stepList(step_buffer, done=done)
                    with read_lock:
                        pipe_cmd.send(traj)
                    local_state = 0

        # finalization
        filter_op.finalize()
        policy.finalize()
        environment.finalize()
        pipe_cmd.close()
        pipe_param.close()
        print("Sampler sub-process exited")
示例#5
0
        server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        server.bind((address, port))
        server.listen()

        param_queue = Queue()
        param_queue.put(net.state_dict())

        shutdown_val = Value('b', 0)

        receiver_proc = Process(target=HandleWorkers,
                                args=(server, replay_memory, mem_lock,
                                      param_queue, shutdown_val))
        receiver_proc.start()

    while True:
        try:
            Train(net, replay_memory, mem_lock, args.output_file)
            if param_queue is not None:
                param_queue.put(net.state_dict)
            torch.save(net.state_dict(), args.output_file)
        except KeyboardInterrupt:
            if server is not None:
                assert (shutdown_val is not None and receiver_proc is not None)
                print("Shutting down...")

                with shutdown_val.get_lock():
                    shutdown_val.value = 1
                receiver_proc.join()
                server.close()
            break
示例#6
0
class ProgressiveResize(ResizeNative):
    """Resize data to sizes specified by scheduler"""
    def __init__(self,
                 scheduler: scheduler_type,
                 mode: str = 'nearest',
                 align_corners: bool = None,
                 preserve_range: bool = False,
                 keys: Sequence = ('data', ),
                 grad: bool = False,
                 **kwargs):
        """
        Args:
            scheduler: scheduler which determined the current size.
                The scheduler is called with the current iteration of the
                transform
            mode: one of ``nearest``, ``linear``, ``bilinear``, ``bicubic``,
                    ``trilinear``, ``area`` (for more inforamtion see
                    :func:`torch.nn.functional.interpolate`)
            align_corners: input and output tensors are aligned by the center
                points of their corners pixels, preserving the values at the
                corner pixels.
            preserve_range: output tensor has same range as input tensor
            keys: keys which should be augmented
            grad: enable gradient computation inside transformation
            **kwargs: keyword arguments passed to augment_fn

        Warnings:
            When this transformations is used in combination with
            multiprocessing, the step counter is not perfectly synchronized
            between multiple processes.
            As a result the step count my jump between values
            in a range of the number of processes used.
        """
        super().__init__(size=0,
                         mode=mode,
                         align_corners=align_corners,
                         preserve_range=preserve_range,
                         keys=keys,
                         grad=grad,
                         **kwargs)
        self.scheduler = scheduler
        self._step = Value('i', 0)

    def reset_step(self) -> ResizeNative:
        """
        Reset step to 0

        Returns:
            ResizeNative: returns self to allow chaining
        """
        with self._step.get_lock():
            self._step.value = 0
        return self

    def increment(self) -> ResizeNative:
        """
        Increment step by 1

        Returns:
            ResizeNative: returns self to allow chaining
        """
        with self._step.get_lock():
            self._step.value += 1
        return self

    @property
    def step(self) -> int:
        """
        Current step

        Returns:
            int: number of steps
        """
        return self._step.value

    def forward(self, **data) -> dict:
        """
        Resize data

        Args:
            **data: input batch

        Returns:
            dict: augmented batch
        """
        self.kwargs["size"] = self.scheduler(self.step)
        self.increment()
        return super().forward(**data)
示例#7
0
def train():
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    writer = SummaryWriter()
    ac = AC(latent_num, cnn_chanel_num, stat_dim)
    writer.add_graph(ac, (torch.zeros([1, 1, img_shape[0], img_shape[1]
                                       ]), torch.zeros([1, stat_dim])))
    optim = GlobalAdam([{
        'params': ac.encode_img.parameters(),
        'lr': 2.5e-5
    }, {
        'params': ac.encode_stat.parameters(),
        'lr': 2.5e-5
    }, {
        'params': ac.pi.parameters(),
        'lr': 2.5e-5
    }, {
        'params': ac.actor.parameters(),
        'lr': 2.5e-5
    }, {
        'params': ac.f.parameters()
    }, {
        'params': ac.V.parameters()
    }],
                       lr=5e-3,
                       weight_decay=weight_decay)

    if os.path.exists('S3_state_dict.pt'):
        ac.load_state_dict(torch.load('S3_state_dict.pt'))
        optim.load_state_dict(torch.load('S3_Optim_state_dict.pt'))
    else:
        ac.load_state_dict(torch.load('../stage2/S2_state_dict.pt'),
                           strict=False)

    result_queue = Queue()
    validate_queue = Queue()
    gradient_queue = Queue()
    loss_queue = Queue()
    ep_cnt = Value('i', 0)
    optimizer_lock = Lock()
    processes = []
    ac.share_memory()

    optimizer_worker = Process(target=update_shared_model,
                               args=(gradient_queue, optimizer_lock, optim,
                                     ac))
    optimizer_worker.start()

    for no in range(mp.cpu_count() - 3):
        worker = Worker(no, ac, ep_cnt, optimizer_lock, result_queue,
                        gradient_queue, loss_queue)
        worker.start()
        processes.append(worker)
    validater = Validate(ac, ep_cnt, optimizer_lock, validate_queue)
    validater.start()

    best_reward = 0
    while True:
        with ep_cnt.get_lock():
            if not result_queue.empty():
                ep_cnt.value += 1
                reward, money, win_rate = result_queue.get()
                objective_actor, loss_critic, loss_f = loss_queue.get()

                writer.add_scalar('Interaction/Reward', reward, ep_cnt.value)
                writer.add_scalar('Interaction/Money', money, ep_cnt.value)
                writer.add_scalar('Interaction/win_rate', win_rate,
                                  ep_cnt.value)

                writer.add_scalar('Update/objective_actor', objective_actor,
                                  ep_cnt.value)
                writer.add_scalar('Update/loss_critic', loss_critic,
                                  ep_cnt.value)
                writer.add_scalar('Update/loss_f', loss_f, ep_cnt.value)

                with optimizer_lock:
                    if reward > best_reward:
                        best_reward = reward
                        torch.save(ac.state_dict(), 'S3_BEST_state_dict.pt')
                    if ep_cnt.value % save_every == 0:
                        torch.save(ac.state_dict(), 'S3_state_dict.pt')
                        torch.save(optim.state_dict(),
                                   'S3_Optim_state_dict.pt')

            if not validate_queue.empty():
                val_reward, val_money, val_win_rate = validate_queue.get()

                writer.add_scalar('Validation/reward', val_reward,
                                  ep_cnt.value)
                writer.add_scalar('Validation/money', val_money, ep_cnt.value)
                writer.add_scalar('Validation/win_rate', val_win_rate,
                                  ep_cnt.value)

    for worker in processes:
        worker.join()
    optimizer_worker.kill()
示例#8
0
def create_worker(gnet_actor: Actor, gnet_critic: Critic, opt: SharedAdam,
                  global_episode: mp.Value, global_results_queue: mp.Queue,
                  name: int) -> None:
    """
    This is our main function.

    It is in a function so that it can be spread over multiple processes.

    :param gnet_actor: Our global Actor network.
    :param gnet_critic: Our global Critic network.
    :param opt: Our shared Adam optimizer.
    :param global_episode: A shared value that tells us what episode we are on over all workers.
    :param global_results_queue: A shared queue that workers can put rewards onto.
    :param int: A number for this worker.
    :return: None
    """
    lnet_actor, lnet_critic = Actor(), Critic()
    lnet_critic.load_state_dict(gnet_critic.state_dict())
    lnet_actor.load_state_dict(gnet_actor.state_dict())
    lenv = gym.make('CartPole-v0')
    if name == 0:
        print("Creating Video Recorder")
        video_recorder = VideoRecorder(lenv,
                                       './output/06_Cartpole_A3C_Q_Critic.mp4',
                                       enabled=True)
    else:
        video_recorder = None

    print(f"Worker {name} starting run...")

    total_step = 1
    while global_episode.value < N_ITERS:
        buffer_state, buffer_log_probs, buffer_rewards, buffer_policy_dist, buffer_action_one_hot = [], [], [], [], []
        episode_reward = 0
        state = lenv.reset()

        for _ in count():
            # Render the environment if you are the zeroth worker every 100 steps
            if (total_step + 1) % 10 == 0 and video_recorder is not None:
                video_recorder.capture_frame()

            state = torch.FloatTensor(state).to(device)
            policy_dist = lnet_actor(state)
            policy = Categorical(policy_dist)
            action = policy.sample()
            action_one_hot = to_onehot(action, ACTION_DIM)
            next_state, reward, done, _ = lenv.step(action.cpu().numpy())

            log_prob = policy.log_prob(action).unsqueeze(0)

            episode_reward += reward
            buffer_policy_dist.append(policy_dist[None, :])
            buffer_action_one_hot.append(action_one_hot[None, :])
            buffer_log_probs.append(log_prob[None, :])
            buffer_state.append(state[None, :])
            buffer_rewards.append(
                torch.FloatTensor([reward])[None, :].to(device))

            state = next_state

            if total_step % UPDATE_GLOBAL_ITER == 0 or done:
                # sync
                next_state = torch.FloatTensor(next_state).to(device)
                next_policy_dist = lnet_actor(next_state)
                next_policy = Categorical(next_policy_dist)
                next_action = next_policy.sample()
                next_action_one_hot = to_onehot(next_action, ACTION_DIM)
                final_value = lnet_critic(next_state,
                                          next_policy_dist) if not done else 0

                # Concatenate buffers
                buffer_state = torch.cat(buffer_state, dim=0)
                buffer_policy_dist = torch.cat(buffer_policy_dist, dim=0)
                buffer_action_one_hot = torch.cat(buffer_action_one_hot, dim=0)
                buffer_log_probs = torch.cat(buffer_log_probs, dim=0)
                buffer_rewards = torch.cat(buffer_rewards, dim=0)

                # Calculate the cumulative rewards using the final predicted value as the terminal value
                cum_reward = final_value
                discounted_future_rewards = torch.FloatTensor(
                    len(buffer_rewards)).to(device)
                for i in range(len(buffer_rewards)):
                    cum_reward = buffer_rewards[-i] + GAMMA * cum_reward
                    discounted_future_rewards[-i] = cum_reward

                # Calculate the local losses for the states in the buffer
                values = lnet_critic(buffer_state, buffer_policy_dist)

                # Now we calculate the advantage function
                advantage = discounted_future_rewards - values

                # And the loss for both the actor and the critic
                actor_loss = -(buffer_log_probs * advantage.detach())
                critic_loss = advantage.pow(2)

                # calculate local gradients and push local parameters to global
                # We are going to couple these losses so that on each episode they are related together
                opt.zero_grad()
                (actor_loss + critic_loss).mean().backward()
                for lp, gp in zip(lnet_actor.parameters(),
                                  gnet_actor.parameters()):
                    gp._grad = lp.grad
                for lp, gp in zip(lnet_critic.parameters(),
                                  gnet_critic.parameters()):
                    gp._grad = lp.grad
                opt.step()

                # pull global parameters
                lnet_critic.load_state_dict(gnet_critic.state_dict())
                lnet_actor.load_state_dict(gnet_actor.state_dict())
                buffer_state, buffer_log_probs, buffer_rewards, buffer_policy_dist, buffer_action_one_hot = [], [], [], [], []

                if done:
                    # Increment the global episode
                    with global_episode.get_lock():
                        global_episode.value += 1

                    # Update the results queue
                    # print(episode_reward)
                    global_results_queue.put(episode_reward)

                    # End this batch
                    break
        total_step += 1

    # This indicates its time to join all workers
    global_results_queue.put(None)
    print("DONE!")

    if video_recorder is not None:
        video_recorder.close()
    lenv.close()
示例#9
0
class Pipeline():
    def __init__(self,
                 config,
                 share_batches=True,
                 manager=None,
                 new_process=True):
        if new_process == True and manager is None:
            manager = Manager()
        self.knows = Semaphore(0)  # > 0 if we know if any are coming
        # == 0 if DatasetReader is processing a command
        self.working = Semaphore(1 if new_process else 100)
        self.finished_reading = Lock(
        )  # locked if we're still reading from file
        # number of molecules that have been sent to the pipe:
        self.in_pipe = Value('i', 0)

        # Tracking what's already been sent through the pipe:
        self._example_number = Value('i', 0)

        # The final kill switch:
        self._close = Value('i', 0)

        self.command_queue = manager.Queue(10)
        self.molecule_pipeline = None
        self.batch_queue = Queue(config.data.batch_queue_cap
                                 )  #manager.Queue(config.data.batch_queue_cap)
        self.share_batches = share_batches

        self.dataset_reader = DatasetReader("dataset_reader",
                                            self,
                                            config,
                                            new_process=new_process)
        if new_process:
            self.dataset_reader.start()

    def __getstate__(self):
        self_dict = self.__dict__.copy()
        self_dict['dataset_reader'] = None
        return self_dict

    # methods for pipeline user/consumer:
    def start_reading(self,
                      examples_to_read,
                      make_molecules=True,
                      batch_size=None,
                      wait=False):
        #print("Start reading...")
        assert check_semaphore(
            self.finished_reading
        ), "Tried to start reading file, but already reading!"
        with self.in_pipe.get_lock():
            assert self.in_pipe.value == 0, "Tried to start reading, but examples already in pipe!"
        set_semaphore(self.finished_reading, False)
        set_semaphore(self.knows, False)
        self.working.acquire()
        self.command_queue.put(
            StartReading(examples_to_read, make_molecules, batch_size))
        if wait:
            self.wait_till_done()

    def wait_till_done(self):
        # wait_semaphore(self.knows)
        # wait_semaphore(self.finished_reading)
        self.working.acquire()
        self.working.release()
        if self.any_coming():
            with self.in_pipe.get_lock():
                ip = self.in_pipe.value
            raise Exception(f"Waiting with {ip} examples in pipe!")

    def scan_to(self, index):
        assert check_semaphore(
            self.knows), "Tried to scan to index, but don't know if finished!"
        assert check_semaphore(
            self.finished_reading
        ), "Tried to scan to index, but not finished reading!"
        assert not self.any_coming(
        ), "Tried to scan to index, but pipeline not empty!"
        self.working.acquire()
        self.command_queue.put(ScanTo(index))
        with self._example_number.get_lock():
            self._example_number.value = index
        # What to do if things are still in the pipe???

    def set_indices(self, test_set_indices):
        self.working.acquire()
        self.command_queue.put(SetIndices(torch.tensor(test_set_indices)))
        self.working.acquire()
        self.command_queue.put(ScanTo(0))

    def set_shuffle(self, shuffle):
        self.command_queue.put(SetShuffle(shuffle))

    def any_coming(self):  # returns True if at least one example is coming
        wait_semaphore(self.knows)
        with self.in_pipe.get_lock():
            return self.in_pipe.value > 0

    def get_batch(self, timeout=None):
        #assert self.any_coming(verbose=verbose), "Tried to get data from an empty pipeline!"
        x = self.batch_queue.get(True, timeout)
        #print(f"{type(x)} : {x}")
        #for b in x:
        #    print(f" --{type(b)} : {b}")

        with self.in_pipe.get_lock():
            self.in_pipe.value -= x.n_examples
            if self.in_pipe.value == 0 and not check_semaphore(
                    self.finished_reading):
                set_semaphore(self.knows, False)
        with self._example_number.get_lock():
            self._example_number.value += x.n_examples
        return x

    @property
    def example_number(self):
        with self._example_number.get_lock():
            return self._example_number.value

    def close(self):
        self.command_queue.put(CloseReader())
        with self._close.get_lock():
            self._close.value = True
        self.dataset_reader.join(4)
        self.dataset_reader.kill()

    # methods for DatasetReader:
    def get_command(self):
        return self.command_queue.get()

    def put_molecule_to_ext(self, m, block=True):
        r = self.molecule_pipeline.put_molecule(m, block)
        if not r:
            return False
        with self.in_pipe.get_lock():
            if self.in_pipe.value == 0:
                set_semaphore(self.knows, True)
            self.in_pipe.value += 1
        return True

    def put_molecule_data(self, data, atomic_numbers, weights, ID, block=True):
        r = self.molecule_pipeline.put_molecule_data(data, atomic_numbers,
                                                     weights, ID, block)
        if not r:
            return False
        with self.in_pipe.get_lock():
            if self.in_pipe.value == 0:
                set_semaphore(self.knows, True)
            if data.ndim == 3:
                self.in_pipe.value += data.shape[0]
            else:
                self.in_pipe.value += 1
        return True

    def get_batch_from_ext(self, block=True):
        return self.molecule_pipeline.get_next_batch(block)

    def ext_batch_ready(self):
        return self.molecule_pipeline.batch_ready()

    # !!! Call only after you've put the molecules !!!
    def set_finished_reading(self):
        set_semaphore(self.finished_reading, True)
        set_semaphore(self.knows, True)
        self.molecule_pipeline.notify_finished()

    def put_batch(self, x):
        if False:  #self.share_batches:
            print("[P] Sharing memory... ")
            try:
                x.share_memory_()
            except Exception as e:
                print("[P] Failed when moving tensor to shared memory")
                print(e)
            print("[P] Done sharing memory")
        self.batch_queue.put(x)

    def time_to_close(self):
        with self._close.get_lock():
            return self._close.value
示例#10
0
            receiver_proc = Process(target=HandleWorkers,
                                    args=(server, out_queue, param_queue,
                                          shutdown_server))
            receiver_proc.start()

            state_dict = resnet.state_dict()
            print(type(state_dict))
            param_queue.put(state_dict)

            tensor = out_queue.get()
            net_output = resnet(tensor)
            print(net_output)

            print("Shutting down server...")
            with shutdown_server.get_lock():
                shutdown_server.value = 1
        else:
            param_queue = Queue()
            server = communication.WorkerSocket(address, port)
            print("Connected to server")

            receiver_proc = Process(target=ReceiveParams,
                                    args=(server, param_queue),
                                    daemon=True)
            receiver_proc.start()

            # for _ in range(3):
            #     SendPlayout(server)
            #     time.sleep(1)
示例#11
0
class CustomDataset(Dataset):
    def __init__(self,
                 dataset,
                 debug_mode=False,
                 build_fn=None,
                 num_workers=3,
                 pca=None,
                 params=None,
                 cache_path=None,
                 logger=None,
                 debug=False):
        """
        Converts the KDD Dataset files into a PyTorch-Dataset object. The Dataset object is to be
        passeed to a PyTorch-DataLoader. The DataLoader then uses __getitem__ to retrieve batches
        in the form {'data': input_vector, 'label': attack_type}. Textual fields are one hot encoded.

        Args:
            dataset (DataFrame): A Pandas dataset that will be used
            debug_mode: If True only a small sample for debugging will be returned

            build_fn: A function that takes a single row as input and returns a dict with keys 'data' and 'label', both of which hold a list of numpy objects
            
        """
        self.dataset = dataset

        self.params = params
        self.preprocessed = []

        self.build_fn = build_fn

        enter_build_fn_block = True
        if cache_path is not None:
            try:
                pickle_in = open(cache_path, "rb")
            except FileNotFoundError:
                pass
            else:
                enter_build_fn_block = False
                self.preprocessed = pickle.load(pickle_in)
                pickle_in.close()

        if self.build_fn and enter_build_fn_block:
            self.num_workers = num_workers
            self.messagePipes = {
                'proc_' + str(id): Pipe()
                for id in range(self.num_workers)
            }
            self.finished = Value('i', 0)
            self._build(dataset)

            if cache_path is not None:
                pickle_out = open(cache_path, "wb")
                pickle.dump(self.preprocessed, pickle_out)
                pickle_out.close()

        # Creating the logger before the multiprocessing step
        # causes errors as spawned processes try to pickle it
        self.logger = logger
        if not self.build_fn:
            if self.logger:
                self.logger.warning(
                    'No build function specified. __getitem__() will return the raw dataset.'
                )
        else:
            if self.logger:
                self.logger.debug('Build complete with {} failure(s)'.format(
                    abs(len(self.preprocessed) - len(dataset))))

        if pca is not None: self._do_pca(pca)

    def _do_pca(self, n_components):
        msg = 'Performing PCA with {} components.'.format(n_components)
        if self.logger: self.logger.info(msg)
        else: print(msg)

        data_only = [entry['data'].tolist() for entry in self.preprocessed]
        pca = PCA(n_components=n_components)
        transformed = pca.fit_transform(data_only)
        for i, elem in enumerate(transformed):
            self.preprocessed[i]['data'] = torch.tensor(elem)

        msg = 'Retained variance: {:.4f}'.format(
            pca.explained_variance_ratio_.cumsum()[-1])
        if self.logger: self.logger.info(msg)
        else: print(msg)

    def __call__(self,
                 batch_size,
                 shuffle=False,
                 num_workers=0,
                 pin_memory=False):
        return DataLoader(self,
                          batch_size=batch_size,
                          shuffle=shuffle,
                          num_workers=num_workers,
                          pin_memory=pin_memory)

    def __len__(self):
        if self.build_fn: return len(self.preprocessed)
        else: return len(self.dataset)

    def _helper(self, dataset, sender_end, position):
        '''
          Args: 
            dataset (DataFrame) - part of the whole dataset to be preprocessed by a worker process
        '''
        every_n = 10000
        batch = []
        for count, (_, row) in enumerate(dataset.iterrows()):

            batch.append(self.build_fn(row, self.params))

            if count % every_n == 0:
                sender_end.send(batch)
                batch = []

            count += 1

        sender_end.send(batch)
        sender_end.send(-1)

        with self.finished.get_lock():
            self.finished.value += 1

    def _receiver(self, receiver_end, id):
        finished = 0
        batch_2 = []
        while True:
            batch = receiver_end.recv()
            if batch == -1:
                break

            for item in batch:

                label_tensor = torch.tensor(item['label'])
                data_tensor = torch.tensor(item['data'])

                batch_2.append({'data': data_tensor, 'label': label_tensor})

                if len(batch_2) >= 200:
                    with lock:
                        self.preprocessed.extend(batch_2)
                    batch_2 = []

        with lock:
            self.preprocessed.extend(batch_2)

    def _single_thread_build(self, dataset):
        for _, row in dataset.iterrows():
            item = self.build_fn(row, self.params)

            label_tensor = torch.tensor(item['label'])
            data_tensor = torch.tensor(item['data'])

            self.preprocessed.append({
                'data': data_tensor,
                'label': label_tensor
            })

    def _build(self, dataset):
        if self.num_workers == 1:
            self._single_thread_build(dataset)
        else:
            # Divide the dataset into equal parts and send to the worker processes
            processes = []
            threads = []
            ds_size = len(dataset)
            chunk_size = ds_size // self.num_workers

            for i in range(self.num_workers):
                if i == self.num_workers - 1:
                    ds_chunk = dataset.iloc[i * chunk_size:]
                else:
                    ds_chunk = dataset.iloc[i * chunk_size:i * chunk_size +
                                            chunk_size]
                # thread = Thread(target=self._receiver, args=(self.messagePipes['proc_' + str(i)][0],i))
                process = Process(target=self._helper,
                                  name='helper_' + str(i),
                                  args=(ds_chunk,
                                        self.messagePipes['proc_' + str(i)][1],
                                        i))
                processes.append(process)
                # thread.start()
                process.start()

            for i in range(self.num_workers):
                thread = Thread(target=self._receiver,
                                args=(self.messagePipes['proc_' + str(i)][0],
                                      i))
                threads.append(thread)
                thread.start()

            for process in processes:
                process.join()

            for thread in threads:
                thread.join()

    def output_size(self):
        return len(set([elem['label'][0].item()
                        for elem in self.preprocessed]))

    def input_size(self):
        return len(self.preprocessed[0]['data'])

    def __getitem__(self, idx):
        if self.build_fn: return self.preprocessed[idx]
        elif isinstance(self.dataset, pd.DataFrame):
            return self.dataset.iloc[idx, :]
        else:
            return self.dataset[idx]
示例#12
0
class HogwildWorld(World):
    """Creates a separate world for each thread (process).

    Maintains a few shared objects to keep track of state:

    - A Semaphore which represents queued examples to be processed. Every call
      of parley increments this counter; every time a Process claims an
      example, it decrements this counter.

    - A Condition variable which notifies when there are no more queued
      examples.

    - A boolean Value which represents whether the inner worlds should shutdown.

    - An integer Value which contains the number of unprocessed examples queued
      (acquiring the semaphore only claims them--this counter is decremented
      once the processing is complete).
    """
    def __init__(self, world_class, opt, agents):
        super().__init__(opt)
        self.inner_world = world_class(opt, agents)

        self.queued_items = Semaphore(0)  # counts num exs to be processed
        self.epochDone = Condition()  # notifies when exs are finished
        self.terminate = Value('b', False)  # tells threads when to shut down
        self.cnt = Value('i', 0)  # number of exs that remain to be processed

        self.threads = []
        for i in range(opt['numthreads']):
            self.threads.append(
                HogwildProcess(i, world_class, opt, agents, self.queued_items,
                               self.epochDone, self.terminate, self.cnt))
        for t in self.threads:
            t.start()

    def display(self):
        self.shutdown()
        raise NotImplementedError('Hogwild does not support displaying in-run'
                                  ' task data. Use `--numthreads 1`.')

    def episode_done(self):
        return False

    def parley(self):
        """Queue one item to be processed."""
        with self.cnt.get_lock():
            self.cnt.value += 1
        self.queued_items.release()
        self.total_parleys += 1

    def getID(self):
        return self.inner_world.getID()

    def report(self, compute_time=False):
        return self.inner_world.report(compute_time)

    def save_agents(self):
        self.inner_world.save_agents()

    def synchronize(self):
        """Sync barrier: will wait until all queued examples are processed."""
        with self.epochDone:
            self.epochDone.wait_for(lambda: self.cnt.value == 0)

    def shutdown(self):
        """Set shutdown flag and wake threads up to close themselves"""
        # set shutdown flag
        with self.terminate.get_lock():
            self.terminate.value = True
        # wake up each thread by queueing fake examples
        for _ in self.threads:
            self.queued_items.release()
        # wait for threads to close
        for t in self.threads:
            t.join()
示例#13
0
def train(args,
          worker_id: int,
          global_model: Union[ActorNetwork, ActorCriticNetwork],
          T: Value,
          global_reward: Value,
          optimizer: torch.optim.Optimizer = None,
          global_model_critic: CriticNetwork = None,
          optimizer_critic: torch.optim.Optimizer = None,
          lr_scheduler: torch.optim.lr_scheduler = None,
          lr_scheduler_critic: torch.optim.lr_scheduler = None):
    """
    Start worker in training mode, i.e. training the shared model with backprop
    loosely based on https://github.com/ikostrikov/pytorch-a3c/blob/master/train.py
    :param args: console arguments
    :param worker_id: id of worker to differentiatethem and init different seeds
    :param global_model: global model, which is optimized/ for split models: actor
    :param T: global counter of steps
    :param global_reward: global running reward value
    :param optimizer: optimizer for shared model/ for split models: actor model
    :param global_model_critic: optional global critic model for split networks
    :param optimizer_critic: optional critic optimizer for split networks
    :param lr_scheduler: optional learning rate scheduler instance for shared model
    / for fixed model: actor learning rate scheduler
    :param lr_scheduler_critic: optional learning rate scheduler instance for critic model
    :return: None
    """
    torch.manual_seed(args.seed + worker_id)

    if args.worker == 1:
        logging.info(f"Running A2C with {args.n_envs} environments.")
        if "RR" not in args.env_name:
            env = SubprocVecEnv([
                make_env(args.env_name, args.seed, i, args.log_dir)
                for i in range(args.n_envs)
            ])
        else:
            env = DummyVecEnv(
                [make_env(args.env_name, args.seed, worker_id, args.log_dir)])
    else:
        logging.info(f"Running A3C: training worker {worker_id} started.")
        env = DummyVecEnv(
            [make_env(args.env_name, args.seed, worker_id, args.log_dir)])
        # avoid any issues if this is not 1
        args.n_envs = 1

    normalizer = get_normalizer(args.normalizer, env)

    # init local NN instance for worker thread
    model = copy.deepcopy(global_model)
    model.train()

    model_critic = None

    if global_model_critic:
        model_critic = copy.deepcopy(global_model_critic)
        model_critic.train()

    # if no shared optimizer is provided use individual one
    if not optimizer:
        optimizer, optimizer_critic = get_optimizer(
            args.optimizer,
            global_model,
            args.lr,
            model_critic=global_model_critic,
            lr_critic=args.lr_critic)
        if args.lr_scheduler == "exponential":
            lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                                  gamma=0.99)
            if optimizer_critic:
                lr_scheduler_critic = torch.optim.lr_scheduler.ExponentialLR(
                    optimizer_critic, gamma=0.99)

    state = torch.Tensor(env.reset())

    t = np.zeros(args.n_envs)
    global_iter = 0
    episode_reward = np.zeros(args.n_envs)

    if worker_id == 0:
        writer = SummaryWriter(log_dir='experiments/runs/')

    while True:
        # Get state of the global model
        model.load_state_dict(global_model.state_dict())
        if not args.shared_model:
            model_critic.load_state_dict(global_model_critic.state_dict())

        # containers for computing loss
        values = []
        log_probs = []
        rewards = []
        entropies = []
        # container to check whether a terminal state was reached from one of the envs
        terminals = []

        # reward_sum = 0
        for step in range(args.rollout_steps):
            t += 1

            if args.shared_model:
                value, mu, std = model(normalizer(state))
            else:
                mu, std = model(normalizer(state))
                value = model_critic(normalizer(state))

            dist = torch.distributions.Normal(mu, std)

            # ------------------------------------------
            # # select action
            action = dist.sample()

            # ------------------------------------------
            # Compute statistics for loss
            entropy = dist.entropy().sum(-1).unsqueeze(-1)
            log_prob = dist.log_prob(action).sum(-1).unsqueeze(-1)

            # make selected move
            action = np.clip(action.detach().numpy(), -args.max_action,
                             args.max_action)
            state, reward, dones, _ = env.step(
                action[0]
                if not args.worker == 1 or "RR" in args.env_name else action)

            reward = shape_reward(args, reward)

            episode_reward += reward

            # probably don't set terminal state if max_episode length
            dones = np.logical_or(dones, t >= args.max_episode_length)

            values.append(value)
            log_probs.append(log_prob)
            rewards.append(torch.Tensor(reward).unsqueeze(-1))
            entropies.append(entropy)
            terminals.append(torch.Tensor(1 - dones).unsqueeze(-1))

            for i, done in enumerate(dones):
                if done:
                    # keep track of the avg overall global reward
                    with global_reward.get_lock():
                        if global_reward.value == -np.inf:
                            global_reward.value = episode_reward[i]
                        else:
                            global_reward.value = .99 * global_reward.value + .01 * episode_reward[
                                i]
                    if worker_id == 0 and T.value % args.log_frequency == 0:
                        writer.add_scalar("reward/global", global_reward.value,
                                          T.value)

                    episode_reward[i] = 0
                    t[i] = 0
                    if args.worker != 1 or "RR" in args.env_name:
                        env.reset()

            with T.get_lock():
                # this is one for a3c and n for A2C (actually the lock is not needed for A2C)
                T.value += args.n_envs

            if lr_scheduler and worker_id == 0 and T.value % args.lr_scheduler_step and global_iter != 0:
                lr_scheduler.step(T.value / args.lr_scheduler_step)

                if lr_scheduler_critic:
                    lr_scheduler_critic.step(T.value / args.lr_scheduler_step)

            state = torch.Tensor(state)

        if args.shared_model:
            v, _, _ = model(normalizer(state))
            G = v.detach()
        else:
            G = model_critic(normalizer(state)).detach()

        values.append(G)

        # compute loss and backprop
        advantages = torch.zeros((args.n_envs, 1))

        ret = torch.zeros((args.rollout_steps, args.n_envs, 1))
        adv = torch.zeros((args.rollout_steps, args.n_envs, 1))

        # iterate over all time steps from most recent to the starting one
        for i in reversed(range(args.rollout_steps)):
            # G can be seen essentially as the return over the course of the rollout
            G = rewards[i] + args.discount * terminals[i] * G
            if not args.no_gae:
                # Generalized Advantage Estimation
                td_error = rewards[i] + args.discount * terminals[i] * values[
                    i + 1] - values[i]
                # terminals here to "reset" advantages to 0, because reset ist called internally in the env
                # and new trajectory started
                advantages = advantages * args.discount * args.tau * terminals[
                    i] + td_error
            else:
                advantages = G - values[i].detach()

            adv[i] = advantages.detach()
            ret[i] = G.detach()

        policy_loss = -(torch.stack(log_probs) * adv).mean()
        # minus 1 in order to remove the last element, which is only necessary for next timestep value
        value_loss = .5 * (ret - torch.stack(values[:-1])).pow(2).mean()
        entropy_loss = torch.stack(entropies).mean()

        # zero grads to reset the gradients
        optimizer.zero_grad()

        if args.shared_model:
            # combined loss for shared architecture
            total_loss = policy_loss + args.value_loss_weight * value_loss - args.entropy_loss_weight * entropy_loss
            total_loss.backward()
        else:
            optimizer_critic.zero_grad()

            value_loss.backward()
            (policy_loss - args.entropy_loss_weight * entropy_loss).backward()

            # this is just used for plotting in tensorboard
            total_loss = policy_loss + args.value_loss_weight * value_loss - args.entropy_loss_weight * entropy_loss

        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        sync_grads(model, global_model)
        optimizer.step()

        if not args.shared_model:
            torch.nn.utils.clip_grad_norm_(model_critic.parameters(),
                                           args.max_grad_norm)
            sync_grads(model_critic, global_model_critic)
            optimizer_critic.step()

        global_iter += 1

        if worker_id == 0 and T.value % args.log_frequency == 0:
            log_to_tensorboard(writer,
                               model,
                               optimizer,
                               rewards,
                               values,
                               total_loss,
                               policy_loss,
                               value_loss,
                               entropy_loss,
                               T.value,
                               model_critic=model_critic,
                               optimizer_critic=optimizer_critic)