Beispiel #1
0
def run_workers_in_parallel(task_queue: mp.Queue, worker):
    NUMBER_OF_PROCESSES = min(int(mp.cpu_count() * 1.1), task_queue.qsize())

    # TODO: We've noticed that on certain 2 core machine parallelizing the tests
    # makes the llvm backend legacy pass manager 20x slower than using a
    # single process. Need to investigate the root cause eventually. This is a
    # hack to work around this issue.
    if mp.cpu_count() == 2:
        NUMBER_OF_PROCESSES = 1

    processes = []
    for i in range(NUMBER_OF_PROCESSES):
        p = mp.get_context("fork").Process(target=worker, args=(task_queue, ))
        p.start()
        processes.append(p)
    for i in range(NUMBER_OF_PROCESSES):
        task_queue.put(queue_sentinel)
    for p in processes:
        p.join()
Beispiel #2
0
def evaluate_blueprints(blueprint_q: mp.Queue,
                        input_size: List[int]) -> List[BlueprintGenome]:
    """
    Consumes blueprints off the blueprints queue, evaluates them and adds them back to the queue if all of their
    evaluations have not been completed for the current generation. If all their evaluations have been completed, add
    them to the completed_blueprints list.

    :param blueprint_q: A thread safe queue of blueprints
    :param input_size: The shape of the input to each network
    :param num_epochs: the number of epochs to train each model for
    :return: A list of evaluated blueprints
    """
    completed_blueprints: List[BlueprintGenome] = []
    print(f'Process {mp.current_process().name} - epochs: {config.epochs_in_evolution}')
    while blueprint_q.qsize() != 0:
        blueprint = blueprint_q.get()
        blueprint = evaluate_blueprint(blueprint, input_size)
        if blueprint.n_evaluations == config.n_evals_per_bp:
            completed_blueprints.append(blueprint)
        else:
            blueprint_q.put(blueprint)

    return completed_blueprints
Beispiel #3
0
        for r in train_rewards:
            outputs.put(r)


if __name__ == '__main__':
    args, env, agent, opt = get_setup()
    num_processes = args.n_proc
    processes = []

    # Share parameters of the policy (and opt)
    agent.share_memory()

    exp = ro.Experiment(args.env + '-dev-async', params={})
    train_rewards = Queue()
    for rank in range(num_processes):
        sleep(1.0)
        p = mp.Process(target=async_update, args=(agent, opt, rank, train_rewards))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    test_rewards = test(args, env, agent)
    data = {p: getattr(args, p) for p in vars(args)}
    data['train_rewards'] = [train_rewards.get() for _ in range(train_rewards.qsize())]
    data['test_rewards'] = test_rewards
    data['timestamp'] = time()
    exp.add_result(result=sum(test_rewards) / len(test_rewards),
                   data=data)
def generate_kernel_parallel(kernel_cfg,
                             x,
                             y,
                             batch_size=32,
                             num_gpus=4,
                             symmetric=False,
                             model_uuid=None,
                             checkpoint_K=None,
                             checkpoint_rows_done=None,
                             cache_path="tc_cache",
                             float32=False,
                             extra_info={},
                             verbose=False,
                             use_tqdm=True):
    ''' Takes in two numpy arrays x and y that are N x H x W x C and M x H x W x C
        and spits out a kernel matrix K that is N x M
    '''

    #TODO fixme
    print("Batch Size ", batch_size)
    assert num_gpus <= torch.cuda.device_count()
    N = x.shape[0]
    M = y.shape[0]
    if float32:
        K = np.memmap("/dev/shm/kernel",
                      mode="w+",
                      dtype="float32",
                      shape=(N, M))
    else:
        K = np.memmap("/dev/shm/kernel",
                      mode="w+",
                      dtype="float64",
                      shape=(N, M))

    K.fill(np.inf)
    rows_done = np.memmap("/dev/shm/rowsdone",
                          mode="w+",
                          dtype="uint16",
                          shape=(1, ))
    if checkpoint_rows_done is not None:
        rows_done[:] = np.copy(utils.bytes_to_numpy(checkpoint_rows_done))
        K[:rows_done[0], :] = np.copy(utils.bytes_to_numpy(checkpoint_K))

    n = 0
    done_q = Queue()
    data_q = Queue()
    done = Value('i', 0)
    num_column_blocks = int(N / batch_size)

    x_idxs = torch.arange(x.shape[0])
    y_idxs = torch.arange(y.shape[0])

    x_data = TensorDataset(x_idxs, torch.from_numpy(x))
    x_loader = DataLoader(x_data, batch_size=batch_size)

    y_data = TensorDataset(y_idxs, torch.from_numpy(y))
    y_loader = DataLoader(y_data, batch_size=batch_size)

    processes = []

    x_data = [x for x in x_loader]
    y_data = [y for y in y_loader]
    count = 0
    start_time = time.time()
    for x_idxs, x_b in x_data:
        for y_idxs, y_b in y_data:
            count += 1
            start_x = int(min(x_idxs))
            end_x = int(max(x_idxs) + 1)
            start_y = int(min(y_idxs))
            end_y = int(max(y_idxs) + 1)
            if end_x > rows_done[0]:
                data_q.put(((x_idxs, x_b), (y_idxs, y_b)))
            #print(start_x, start_y)
            if count % 1000 == 0:
                print("Current Count Is: ", count)
    os.environ["OMP_NUM_THREADS"] = str(1)

    for gpu_idx in range(num_gpus):
        p = Process(target=_kernel_gen_help,
                    args=(done_q, data_q, kernel_cfg, batch_size, symmetric,
                          gpu_idx, K.shape, cache_path, float32, done,
                          verbose))
        processes.append(p)

    for i, p in enumerate(processes):
        p.start()
    if symmetric:
        done_work = rows_done[0] * M + (N - rows_done[0]) * (rows_done[0])
    else:
        done_work = rows_done[0] * M
    work_left = N * M - done_work
    last_checkpoint = work_left
    print("Data_q size start", data_q.qsize())
    if use_tqdm:
        pbar = tqdm(total=N * M)
    else:
        pbar = None
    total_progress = 0
    while work_left > 0:
        progress = done_q.get()
        total_progress += progress
        work_left -= progress
        elapsed = time.time() - start_time
        avg_speed = total_progress / elapsed
        time_left = utils.pretty_time_delta(work_left / avg_speed)
        if pbar is not None:
            pbar.update(progress)
        else:
            print(
                f"Work Left: {work_left}, Work done so far: {total_progress}, Time Left: {time_left}"
            )
    if pbar is not None:
        pbar.close()
    print("Data_q size end", data_q.qsize())
    done.value = 1
    for i, p in enumerate(processes):
        p.join()
    np.save("/tmp/K_train_full.npy", K)
    if symmetric:
        _symmetric_fill(K, x, y, batch_size)
    K_copy = np.zeros(K.shape)
    np.copyto(K_copy, K)
    assert np.all(np.isfinite(K_copy))
    return K_copy
Beispiel #5
0
class Synthetic(Process):
    def __init__(self, agent, dataloader, settings):
        super().__init__()

        self.agent = agent
        self.dataloader = dataloader
        self.settings = settings

        self.queue = Queue(maxsize=settings.QUEUE_LEN)
        self.put_flag = Queue(maxsize=1)
        self.get_flag = Queue(maxsize=1)
        self.done = False

    def update_settings(self, settings):
        self.settings = settings

    def update_agent(self, target_agent):
        self.agent.load_state_dict(target_agent.state_dict())

    def fetch_data(self):
        num_batch = self.settings.NUM_BATCH_WHILE_SYNTHETIC
        while self.put_flag.empty():
            out = []
            for _ in range(num_batch):
                d = self.queue.get()
                if self.queue.qsize() < num_batch:
                    self.queue.put(d)
                out.append(d)
            yield utils.cat_namedtuple_list(out, dim=0)

        # Put a single to flag
        self.get_flag.put(True)

    def run(self):
        """ Generate Data Queue
        """
        settings = self.settings
        for d in self.dataloader:
            episode_data, episode_interpolate_ratio, episode_source_pose = [], [], []

            mesh = d["mesh"].to(settings.SYNTHETIC_DEVICE)
            raw_data = utils.variable_namedtuple(d["data"],
                                                 settings.SYNTHETIC_DEVICE)

            source_pose = raw_data.init_pose
            target_pose = raw_data.target_pose
            intrinsic = raw_data.Intrinsic
            settings.set_intrinsic(intrinsic)

            for _ in range(settings.SYNTHETIC_EPISODE_LEN):
                episode_source_pose.append(source_pose)
                center_points, center_depth = utils.translation_to_voxel_and_depth(
                    source_pose.Translation.translation, intrinsic,
                    self.settings)
                try:
                    syn_data, interpolate_ratio = self.agent.synthetic(
                        observed_image=raw_data.image,
                        observed_depth=raw_data.depth,
                        observed_mask=raw_data.mask,
                        init_pose=source_pose,
                        mesh=mesh,
                        center_points=center_points,
                        center_depth=center_depth,
                        settings=settings)
                    if settings.SYNTHETIC_EPISODE_LEN > 1:
                        state_feature, mask, flow = self.agent.state_encoding(
                            syn_data)
                        action = self.agent.action_encoding(
                            state_feature, interpolate_ratio)
                        source_pose = utils.apply_action_to_pose(
                            action, source_pose, settings)
                        source_pose = utils.detach_namedtuple(source_pose)
                    episode_data.append(syn_data)
                    episode_interpolate_ratio.append(interpolate_ratio)
                except Exception as e:
                    print(e)
            if len(episode_data) != settings.SYNTHETIC_EPISODE_LEN or len(
                    episode_interpolate_ratio
            ) != settings.SYNTHETIC_EPISODE_LEN:
                # Something may be wrong while generating data
                continue
            # append data to queue
            for i in range(settings.SYNTHETIC_EPISODE_LEN):
                syn_raw_data = utils.SynRawData(
                    data=episode_data[i],
                    Intrinsic=intrinsic,
                    target_pose=target_pose,
                    init_pose=episode_source_pose[i],
                    model_points=raw_data.model_points,
                    interpolate_ratio=episode_interpolate_ratio[i])
                syn_raw_data = utils.variable_namedtuple(syn_raw_data,
                                                         device="cpu")
                self.queue.put(syn_raw_data)
        # Put a single to flag
        self.put_flag.put(True)
        # Waiting for main thread finish last data fetch
        while self.get_flag.empty():
            time.sleep(2)
class DataQueue(object):
    '''Queue for data prefetching
       DataQueue launch a subprocess to avoid python's GIL
       # Arguments
            generator: instance of generator which feeds data infinitely
            max_queue_size: maximum queue size
            nb_worker: control concurrency,
                       only take effect when do preprocessing
    '''
    def __init__(self, generator, max_queue_size=5, nb_worker=1):
        self.generator = generator
        self.nb_worker = nb_worker
        self.max_queue_size = max_queue_size

        self._queue = Queue()
        self._signal = Event()
        self._available_cv = Condition()
        self._full_cv = Condition()

        args = (generator, self._queue, self._signal, self._available_cv,
                self._full_cv, self.nb_worker, self.max_queue_size)
        self.working_process = Process(target=self.generator_process,
                                       args=args)
        self.working_process.daemon = True
        self.working_process.start()

    def get(self, timeout=None):
        with self._available_cv:
            if not self._signal.is_set() and self._queue.qsize() == 0:
                self._available_cv.wait()

        if self._signal.is_set():
            raise Exception("prefetch process terminated!")

        try:
            data = self._queue.get()
            with self._full_cv:
                self._full_cv.notify()
        except Exception as e:
            with self._full_cv:
                self._signal.set()
                self._full_cv.notify_all()
                raise e

        return data

    def qsize(self):
        return self._queue.qsize()

    def __del__(self):
        with self._full_cv:
            self._signal.set()
            self._full_cv.notify_all()
        #self.working_process.terminate()
        self.working_process.join()

    @staticmethod
    def generator_process(generator, queue, signal, available_cv, full_cv,
                          nb_worker, max_qsize):
        preprocess = generator.preprocess
        generator = BackgroundGenerator(generator())  # invoke call()

        # put data in the queue
        def enqueue_fn(generator, preprocess, queue, signal, available_cv,
                       full_cv, lock, max_qsize):
            while True:
                try:
                    with lock:
                        data = next(generator)
                    data = preprocess(data)

                    if not isinstance(data, types.GeneratorType):
                        data = [data]

                    for ele in data:
                        ele = np2tensor(ele)  # numpy array to pytorch's tensor
                        with full_cv:
                            while not signal.is_set(
                            ) and queue.qsize() >= max_qsize:
                                full_cv.wait()

                        if signal.is_set(): return

                        queue.put(ele)

                        with available_cv:
                            available_cv.notify()
                except Exception as e:
                    print("Error Message", e, file=sys.stderr)
                    with full_cv:
                        signal.set()
                        full_cv.notify_all()
                    with available_cv:
                        signal.set()
                        available_cv.notify_all()
                    raise Exception("generator thread went wrong!")

        # start threads
        lock = threading.Lock()
        args = (generator, preprocess, queue, signal, available_cv, full_cv,
                lock, max_qsize)
        generator_threads = [
            threading.Thread(target=enqueue_fn, args=args)
            for _ in range(nb_worker)
        ]

        for thread in generator_threads:
            thread.daemon = True
            thread.start()

        for thread in generator_threads:
            thread.join()
class ModelInferenceServer(Process):
    """
    Model Inference Server
    Checks for latest model
    Pull from backlog
    Process batched states
    Return values via return dict
    """
    def __init__(self,
                 model_shared_dict,
                 buffer_queue,
                 device,
                 batch_size=16,
                 max_workers=140,
                 episode_length=1000):
        super(ModelInferenceServer, self).__init__()
        self.model_shared_dict = model_shared_dict
        self.workers = []
        self.pipes = []

        # Communication channels with workers
        self.backlog = Queue()
        self.return_dict = Manager().dict()
        self.buffer_queue = buffer_queue

        self.model = None
        self.device = device
        self.batch_size = batch_size
        self.max_workers = max_workers  # 140 seems to be the optimum
        self.episode_length = episode_length

        self.shutdown = False
        self.empty_count = 0

        self.avg_queue_size = 0
        self.avg_batch_size = 0
        self.avg_total_time = 0
        self.avg_batch_pull_time = 0
        self.avg_batch_infr_time = 0
        self.avg_batch_retu_time = 0

    def update_model(self):
        print("Updated model")
        if self.model is None:
            self.model = create_model(self.model_shared_dict,
                                      device=self.device)
        else:
            self.model.load_state_dict(self.model_shared_dict)
        self.model.eval()

    def get_batch(self):
        start_batch_pull_time = time.time()
        batch = []
        return_ids = []
        try:
            self.avg_queue_size = 0.9 * self.avg_queue_size + 0.1 * self.backlog.qsize(
            )
            return_id, element = self.backlog.get(True, 60)
            return_ids.append(return_id)
            batch.append(element)
            while len(batch) < self.batch_size:
                return_id, element = self.backlog.get(True, 0.1)
                return_ids.append(return_id)
                batch.append(element)
                del return_id
                del element

        except Empty:
            self.empty_count += 1
            if len(self.workers) < self.max_workers:
                self.add_worker(1)
        except TimeoutError:
            print(
                "60 seconds without anything being put in the backlog, cleaning up"
            )
            self.cleanup()
        finally:
            self.avg_batch_size = 0.9 * self.avg_batch_size + 0.1 * len(batch)
            self.avg_batch_pull_time = 0.9 * self.avg_batch_pull_time + 0.1 * (
                time.time() - start_batch_pull_time)
            batch = torch.cat(batch, 0)

        return return_ids, batch

    def threaded_return_batch(self, return_ids, batch: torch.Tensor):
        x = Thread(target=self.return_batch, args=(
            return_ids,
            batch,
        ))
        x.start()

    def return_batch(self, return_ids, batch: torch.Tensor):
        self.return_dict.update({
            return_id: element.cpu()
            for return_id, element in zip(return_ids, batch.unbind())
        })

    def add_worker(self, n):
        for i in range(n):
            worker = EnvWorker(self.backlog,
                               self.return_dict, self.buffer_queue,
                               len(self.workers), self.episode_length)
            worker.start()
            self.workers.append(worker)

    def cleanup(self):
        for worker in self.workers:
            del worker

    def run(self):
        """
        Eternal loop that runs the inference
        """
        try:
            print("Model Run Server started!")
            self.add_worker(self.max_workers)
            while not self.shutdown:
                for _ in range(100):
                    # Update model every 100 forward passes
                    self.update_model()
                    for _ in range(1000):
                        total_start_time = time.time()
                        return_ids, batch = self.get_batch()
                        start_inf_time = time.time()
                        with torch.no_grad():
                            batch = batch.to(self.device)
                            out = self.model.forward(batch)
                        del batch
                        self.avg_batch_infr_time = 0.9 * self.avg_batch_infr_time + 0.1 * (
                            time.time() - start_inf_time)
                        start_return_time = time.time()
                        self.threaded_return_batch(return_ids, out)
                        self.avg_batch_retu_time = 0.9 * self.avg_batch_retu_time + 0.1 * (
                            time.time() - start_return_time)
                        self.avg_total_time = 0.9 * self.avg_total_time + 0.1 * (
                            time.time() - total_start_time)
                print(f"Averages with {len(self.workers)}\n"
                      f"queue size: {self.avg_queue_size}\n"
                      f"batch size: {self.avg_batch_size}\n"
                      f"pull  time: {self.avg_batch_pull_time}\n"
                      f"infer time: {self.avg_batch_infr_time}\n"
                      f"return time: {self.avg_batch_retu_time}\n"
                      f"total time: {self.avg_total_time}")
        except Exception as e:
            traceback.print_exc()
            print(f"{e} in model run server")
        finally:
            for worker in self.workers:
                worker.shutdown = True
def controller_train_proc(ctrl_dir,
                          controller,
                          vae,
                          mdrnn,
                          target_return=950,
                          skip_train=False,
                          display=True):
    step_log('4-2. controller_train_proc START!!')
    # define current best and load parameters
    cur_best = None
    if not os.path.exists(ctrl_dir):
        os.mkdir(ctrl_dir)
    ctrl_file = os.path.join(ctrl_dir, 'best.tar')

    p_queue = Queue()
    r_queue = Queue()
    #e_queue = Queue()   # pipaek : not necessary if not multiprocessing

    print("Attempting to load previous best...")
    if os.path.exists(ctrl_file):
        #state = torch.load(ctrl_file, map_location={'cuda:0': 'cpu'})
        state = torch.load(ctrl_file)
        cur_best = -state['reward']
        controller.load_state_dict(state['state_dict'])
        print("Previous best was {}...".format(-cur_best))

    if skip_train:
        return  # pipaek : 트레이닝을 통한 모델 개선을 skip하고 싶을 때..

    def evaluate(solutions,
                 results,
                 rollouts=100):  # pipaek : rollout 100 -> 10 , originally 100
        """ Give current controller evaluation.

        Evaluation is minus the cumulated reward averaged over rollout runs.

        :args solutions: CMA set of solutions
        :args results: corresponding results
        :args rollouts: number of rollouts

        :returns: minus averaged cumulated reward
        """
        index_min = np.argmin(results)
        best_guess = solutions[index_min]
        restimates = []

        for s_id in range(rollouts):
            print('p_queue.put(), s_id=%d' % s_id)
            p_queue.put((s_id, best_guess))
            print('>>>rollout_routine!!')
            rollout_routine()  # pipaek : 여기서도 p_queue.put 하자마자 바로 처리..

        print(">>>Evaluating...")
        for _ in tqdm(range(rollouts)):
            #while r_queue.empty():
            #    sleep(.1)   # pipaek : multi-process가 아니므로
            if not r_queue.empty(
            ):  # pipaek : 20180718 r_queue.get()에서 stuck되어 있는 것을 방지하기 위해 체크!!
                #print('r_queue.get()')
                #restimates.append(r_queue.get()[1])
                r_s_id, r = r_queue.get()
                print(
                    'in evaluate r_queue.get() r_s_id=%d, r_queue remain=%d' %
                    (r_s_id, r_queue.qsize()))
                restimates.append(r)
            else:
                print('r_queue.empty() -> break!!')
                break

        return best_guess, np.mean(restimates), np.std(restimates)

    def rollout_routine():
        """ Thread routine.

        Threads interact with p_queue, the parameters queue, r_queue, the result
        queue and e_queue the end queue. They pull parameters from p_queue, execute
        the corresponding rollout, then place the result in r_queue.

        Each parameter has its own unique id. Parameters are pulled as tuples
        (s_id, params) and results are pushed as (s_id, result).  The same
        parameter can appear multiple times in p_queue, displaying the same id
        each time.

        As soon as e_queue is non empty, the thread terminate.

        When multiple gpus are involved, the assigned gpu is determined by the
        process index p_index (gpu = p_index % n_gpus).

        :args p_queue: queue containing couples (s_id, parameters) to evaluate
        :args r_queue: where to place results (s_id, results)
        :args e_queue: as soon as not empty, terminate
        :args p_index: the process index
        """
        # init routine
        #gpu = p_index % torch.cuda.device_count()
        #device = torch.device('cuda:{}'.format(gpu) if torch.cuda.is_available() else 'cpu')

        # redirect streams
        #if not os.path.exists(tmp_dir):
        #    os.mkdir(tmp_dir)

        #sys.stdout = open(os.path.join(tmp_dir, 'rollout.out'), 'a')
        #sys.stderr = open(os.path.join(tmp_dir, 'rollout.err'), 'a')

        with torch.no_grad():
            r_gen = RolloutGenerator(vae, mdrnn, controller, device,
                                     rollout_time_limit)

            while not p_queue.empty():
                print('in rollout_routine, p_queue.get()')
                s_id, params = p_queue.get()
                print('r_queue.put() sid=%d' % s_id)
                r_queue.put((s_id, r_gen.rollout(params)))
                print('r_gen.rollout OK, r_queue.put()')
                #r_queue.qsize()

    parameters = controller.parameters()
    es = cma.CMAEvolutionStrategy(flatten_parameters(parameters), 0.1,
                                  {'popsize': C_POP_SIZE})
    print("CMAEvolutionStrategy start OK!!")

    epoch = 0
    log_step = 3
    while not es.stop():
        print("--------------------------------------")
        print("CURRENT EPOCH = %d" % epoch)
        if cur_best is not None and -cur_best > target_return:
            print("Already better than target, breaking...")
            break

        r_list = [0] * C_POP_SIZE  # result list
        solutions = es.ask()
        print("CMAEvolutionStrategy-ask")

        # push parameters to queue
        for s_id, s in enumerate(
                solutions):  # pipaek : 이 for가 C_POP_SIZE 만큼 반복된다.
            #for _ in range(C_POP_SIZE * C_N_SAMPLES):
            for _ in range(C_N_SAMPLES):
                print('in controller_train_proc p_queue.put() s_id : %d' %
                      s_id)
                p_queue.put((s_id, s))
                #print("p_queue.put %d" % s_id)
                rollout_routine(
                )  # pipaek : p_queue.put 하자마자 바로 get해서 rollout하고 나서 r_queue에 결과 입력.
                print("rollout_routine OK, r_queue size=%d" % r_queue.qsize())

        # retrieve results
        if display:
            pbar = tqdm(total=C_POP_SIZE * C_N_SAMPLES)
        #for idx in range(C_POP_SIZE * C_N_SAMPLES):
        while not r_queue.empty(
        ):  # pipaek : 20180718 여기서 r_queue.get을 못해서 영원히 걸려있는 상태를 방지하기 위해 for문을 while문으로 바꾼다.
            #while r_queue.empty():
            #    sleep(.1)
            try:
                r_s_id, r = r_queue.get()
                print(
                    'in controller_train_proc r_queue.get() r_s_id=%d, r_queue remain=%d'
                    % (r_s_id, r_queue.qsize()))
                r_list[r_s_id] += r / C_N_SAMPLES
                if display:
                    pbar.update(1)
            except IndexError as err:
                print('IndexError during r_queue.get()')
                print('cur r_list size:%d, index:%d' % (len(r_list), r_s_id))
        if display:
            pbar.close()

        es.tell(solutions,
                r_list)  # pipaek : solution array에다가 r_list 결과를 업데이트..
        es.disp()

        # evaluation and saving
        if epoch % log_step == log_step - 1:
            print(">>>> TRYING EVALUATION, CURRENT EPOCH = %d" % epoch)
            best_params, best, std_best = evaluate(
                solutions, r_list, rollouts=100
            )  # pipaek : evaluate을 위해서 rollout은 10번만 하자.. originally 100
            print("Current evaluation: {}".format(best))
            if not cur_best or cur_best > best:
                cur_best = best
                print("Saving new best with value {}+-{}...".format(
                    -cur_best, std_best))
                load_parameters(best_params, controller)
                torch.save(
                    {
                        'epoch': epoch,
                        'reward': -cur_best,
                        'state_dict': controller.state_dict()
                    }, os.path.join(ctrl_dir, 'best.tar'))
            if -best > target_return:
                print(
                    "Terminating controller training with value {}...".format(
                        best))
                break

        epoch += 1

    print("es.stop!!")
    es.result_pretty()
Beispiel #9
0
class VariationalAutoEncoder(nn.Module):
    '''
    Implementation of a Variational AutoEncoder in pytorch. Currently two
    decoder/encoder units are supported. The first unit features a two layer
    dense neural network and the second a deep convolutional net.
    The number of laternt units can be specified and is usually around 4-12 units.
    The implmentation supports cuda. If cuda is not used the multiprocessing
    framework is/can be used to send the computation to the background, so the
    jupyter notebook it runs in will not be blocked.
    '''
    def __init__(self, n_latent_units, drop_ratio, convolutional=False):
        '''
        Constructor
        :param n_latent_units:
        :param drop_ratio:
        '''
        super(VariationalAutoEncoder, self).__init__()
        self.encoder = Encoder.Encoder(n_latent_units, drop_ratio) if not convolutional \
            else ConvEncoder.Encoder(n_latent_units, drop_ratio)
        self.decoder = Decoder.Decoder(n_latent_units, drop_ratio) if not convolutional \
            else ConvDecoder.Decoder(n_latent_units, drop_ratio)
        self.proc = None

        self.counter_epoch = Counter()
        self.counter_interation = Counter()
        self.loss_queue = Queue()
        self.stop_signal = Signal()
        self.losses = []

    def forward(self, x):
        '''
        The forward method, calles the encoder and decoder
        :param x:
        :return:
        '''
        z, mu, log_std = self.encoder.forward(x)
        self.mu = mu
        self.log_std = log_std
        return self.decoder.forward(z)

    def loss(self, _in, _out, mu, log_std):
        '''
        The loss function, the loss is calculated as the reconstruction error and
        the error given by the deviation of latent variable from the normal distirbution
        :param _in:
        :param _out:
        :param mu:
        :param log_std:
        :return:
        '''
        # img_loss = self.img_loss_func(_in, _out)
        # img_loss = F.mse_loss(_in, _out)
        img_loss = _in.sub(_out).pow(2).sum()
        mean_sq = mu * mu
        # -0.5 * tf.reduce_sum(1.0 + 2.0 * logsd - tf.square(mn) - tf.exp(2.0 * logsd), 1)
        latent_loss = -0.5 * torch.sum(1.0 + 2.0 * log_std - mean_sq -
                                       torch.exp(2.0 * log_std))
        return img_loss + latent_loss, img_loss, latent_loss

    def start(self, train=None):
        '''
        This runs the training in the background. Currently only works with the cpu version
        (cuda not supported atm)
        :param train:
        :return:
        '''
        if self.proc is not None:
            raise Exception("Process already started.")
        self.share_memory()
        self.losses = []
        if train is None:
            train = VariationalAutoEncoder._get_training_test_method()
        self.proc = mp.Process(target=train,
                               args=(self, self.train_loader, self.test_loader,
                                     self.counter_epoch,
                                     self.counter_interation, self.loss_queue,
                                     self.stop_signal))
        self.proc.start()

    def restart(self, train=None):
        '''
        Running in the background can be stopped. This method should be used if
        the computation should be resumed. As with start(), does currently not work with cuda.
        :param train:
        :return:
        '''
        if self.proc is None:
            raise Exception("Process has not been started before.")
        if self.proc.is_alive():
            raise Exception("Process is still active.")
        self.stop_signal.set_signal(False)
        if train is None:
            train = VariationalAutoEncoder._get_training_test_method()
        self.proc = mp.Process(target=train,
                               args=(self, self.train_loader, self.test_loader,
                                     self.counter_epoch,
                                     self.counter_interation, self.loss_queue,
                                     self.stop_signal))
        self.proc.start()

    def stop(self):
        '''
        This functions sends a stop signal to the background process.
        :return:
        '''
        if self.proc is None:
            raise Exception("Process has been started.")
        if not self.proc.is_alive():
            raise Exception("Process is not alive.")
        self.stop_signal.set_signal(True)
        self.proc.join()
        self.stop_signal.set_signal(False)

    def get_progress(self):
        '''
        Functions gets the progress of the computation running in the background.
        :return:
        '''
        while self.loss_queue.qsize() > 0:
            self.losses.append(self.loss_queue.get())
        return self.losses

    def set_train_loader(self, train_loader, test_loader=None):
        self.train_loader = train_loader
        self.test_loader = test_loader

    def cuda(self):
        super(VariationalAutoEncoder, self).cuda()
        self.decoder.cuda()
        self.encoder.cuda()

    @staticmethod
    def _get_training_test_method():
        def train(model, train_loader, test_loader, counter_epoch,
                  counter_iterations, loss_queue, stop_signal):
            print("started", stop_signal.value)
            train_op = optim.Adam(model.parameters(), lr=0.0005)
            while not stop_signal.value:
                loss_train = []
                loss_test = []
                n_train = []
                n_test = []
                for _, data in enumerate(train_loader):
                    # data = Variable(data.view(-1,784))
                    data = Variable(data)
                    train_op.zero_grad()
                    dec = model(data)
                    loss, loss_1, loss_2 = model.loss(dec, data, model.mu,
                                                      model.log_std)
                    loss_train.append(
                        (loss.data[0], loss_1.data[0], loss_2.data[0]))
                    n_train.append(len(data))
                    loss.backward()
                    train_op.step()
                    counter_iterations.increment()

                for _, data in enumerate(test_loader):
                    # data = Variable(data.view(-1,784))
                    data = Variable(data)
                    dec = model(data)
                    loss, _, _ = model.loss(dec, data, model.mu, model.log_std)
                    loss_test.append(loss.data[0])
                    n_test.append(len(data))

                counter_epoch.increment()

                epoch = counter_epoch.value
                loss_train_mean = numpy.mean(loss_train,
                                             axis=0)  # / numpy.sum(n_train)
                loss_test_mean = numpy.mean(loss_test)  # / numpy.sum(n_test)
                loss_queue.put((epoch, loss_train_mean, loss_test_mean))
                #print("{}: ".format(epoch),  loss_train_mean, loss_test_mean)

        return train

    @staticmethod
    def get_MNIST_train_loader(batch_size=32, keep_classes=False):
        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            './data/datasets/MNIST',
            train=True,
            download=True,
            transform=transforms.Compose([transforms.ToTensor()])),
                                                   batch_size=batch_size)

        test_loader = torch.utils.data.DataLoader(datasets.MNIST(
            './data/datasets/MNIST',
            train=False,
            transform=transforms.Compose([transforms.ToTensor()])),
                                                  batch_size=batch_size)

        if keep_classes:
            return train_loader, test_loader
        return DataIterator(train_loader), DataIterator(test_loader)

    @staticmethod
    def get_FashionMNIST_train_loader(batch_size=32, keep_classes=False):
        train_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(
            './data/datasets/FMNIST',
            train=True,
            download=True,
            transform=transforms.Compose([transforms.ToTensor()])),
                                                   batch_size=batch_size)

        test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(
            './data/datasets/FMNIST',
            train=False,
            transform=transforms.Compose([transforms.ToTensor()])),
                                                  batch_size=batch_size)

        if keep_classes:
            return train_loader, test_loader
        return DataIterator(train_loader), DataIterator(test_loader)