Пример #1
0
    def find_gate(self):
        self._logger.log("finding the gate")
        config = self.config['search']
        MAX_TIME_SEC = config['max_time_sec']
        MODE = config['mode']

        stopwatch = Stopwatch()
        stopwatch.start()
        self._logger.log("started find gate loop")

        while stopwatch.time() < config['max_time_sec']:

            if MODE == "mvg_avg":
                for i in range(config['number_of_samples']):
                    if self.is_this_gate():
                        return True
            elif MODE == "simple":
                bbox = False
                bbox = self.darknet_client.predict()[0].normalize(480, 480)
                if not bbox:
                    self._logger.log("gate not found")
                    return False
                self._logger.log("gate found")
                return True
            #self.movements.rotate_angle(0, 0, config['rotation_angle'])

        self._logger.log("gate not found")
        return False
Пример #2
0
 def __init__(self):
     pygame.init()
     self.agent = PlayerAgent()
     self.env = None
     self.screen = None
     self.fps_clock = None
     self.timestep_watch = Stopwatch()
Пример #3
0
    def find_gate(self):
        self._logger.log("finding the gate")
        config = self.config['search']
        MAX_ANG_SPEED = config['max_ang_speed']
        MOVING_AVERAGE_DISCOUNT = config['moving_avg_discount']
        CONFIDENCE_THRESHOLD = config['confidence_threshold']
        MAX_TIME_SEC = config['max_time_sec']

        self._control.set_ang_velocity(0, 0, MAX_ANG_SPEED)

        stopwatch = Stopwatch()
        stopwatch.start()
        self._logger.log("started to find gate loop")

        result = self.darknet_client.predict()

        if not result:
            self._logger.log("Empty bounding box")
        else:
            self._logger.log("Num of bb: "+ str(len(result)))
            self._logger.log(str(result[0]))
        while stopwatch.time() < MAX_TIME_SEC:
            if self.is_this_gate(img):
            # TODO: zlokalizuj bramkę
                gate = {"x", "y", "angle"}
            # TODO: zlokalizuj przeszkodę
                obstacle = "{"x", "y"}
            # more info in create_path comment
            #    self.create_path(gate, obstacle)    #W trajektorii musi byc uwzgledniona przeszkoda funkcja wyszukujaca przeszkode is_this_obstacle(bounding_box, img):
            return True
Пример #4
0
 def center_above_bucket(self):
     '''
     centering above bucket
     '''
     self.darknet_client.change_camera("bottom")
     stopwatch = Stopwatch()
     stopwatch.start()
     bbox = self.darknet_client.predict()[0].normalize(480, 480)
     while bbox is None and stopwatch.time() < self.MAX_TIME_SEC:
         bbox = self.darknet_client.predict()[0].normalize(480, 480)
         sleep(0.3)
     if bbox is None:
         self._logger.log("Could not locate bucket")
         return 0
     position_x = bbox.x
     position_y = bbox.y
     Kp = 0.001
     i = 0
     while position_x > self.POSITION_THRESHOLD and position_y > self.POSITION_THRESHOLD:
         self._control.set_lin_velocity(front=position_y * Kp,
                                        right=position_x * Kp)
         bbox = self.darknet_client.predict()[0].normalize(480, 480)
         if bbox is not None:
             position_x = bbox.x
             position_y = bbox.y
         i += 1
         if i == 1000:
             self._logger.log("Could not center above bucket")
             return 0
     return 1
Пример #5
0
    def center_on_flare(self):
        """
        rotates in vertical axis so flare is in the middle of an image
        TODO: obsługa dwóch flar
        """
        config = self.config['centering']
        flare_size = get_config(
            "objects_size")["localization"]["flare"]["height"]

        MAX_CENTER_ANGLE_DEG = config['max_center_angle_deg']
        MAX_TIME_SEC = config['max_time_sec']

        stopwatch = Stopwatch()
        stopwatch.start()

        while stopwatch <= MAX_TIME_SEC:
            bbox = self.darknet_client.predict()[0].normalize(480, 480)
            self.flare_position = location_calculator(bbox, flare_size,
                                                      "height")
            angle = -m.degrees(
                m.atan2(self.flare_position['x'],
                        self.flare_position['distance']))
            if abs(angle) <= MAX_CENTER_ANGLE_DEG:
                self._logger.log("centered on flare successfully")
                return True
            self._control.rotate_angle(0, 0, angle)
        self._logger.log("couldn't center on flare")
        return False
Пример #6
0
 def __init__(self):
     Context.logger = self
     self._episodes = Context.config['exp.episodes']
     self._work_path = os.path.join(Context.work_path, WORK_DIR)
     self._sw = Stopwatch()
     self._saved_time = 0.
     self._train_history = []
     self._eval_history = []
Пример #7
0
    def run(self):
        self._logger.log("Bucket grab task exector started")

        self.darknet_client.load_model("bucket")
        '''
        TO DO: uncomment loading model when its done
        '''
        #self.darknet_client.load_model('buckets_task')
        stopwatch = Stopwatch()
        stopwatch.start()

        ## THIS LOOP SHOULD FIND AT LEAST FIRST BBOX WITH BUCKET
        while not self.find_buckets():
            self._logger.log("Finding buckets in progress")
            if stopwatch.time() >= self.MAX_TIME_SEC:
                self._logger.log("Finding buckets time expired")
                return 0

        #self.darknet_client.load_model("bucket")
        i = 0
        # THIS LOOP SHOULD FIND BUCKET WITH PINGER
        if self.bucket == 'pinger':
            while i < self.PINGER_LOOP_COUNTER:
                if self.find_pinger_bucket():
                    self._logger.log("Found pinger bucket")
                    i = self.PINGER_LOOP_COUNTER
                    self.grab_marker()
                    self._logger.log("Marker grabbed")
                    return 1
                i += 1
            self._logger.log("Marker could not be grabbed")
            return 0

        elif self.bucket == 'blue':
            k = 0
            while k < self.BLUE_LOOP_COUNTER:
                if self.find_blue_bucket():
                    self._logger.log("Found blue bucket")
                    k = self.BLUE_LOOP_COUNTER
                    self.grab_marker()
                    self._logger.log("Marker grabbed")
                    return 1
                k += 1
            self._logger.log("Marker could not be grabbed")
            return 0

        elif self.bucket == 'red':
            l = 0
            while l < self.ANY_BUCKET_COUNTER:
                if self.find_random_bucket():
                    self._logger.log("Found random bucket")
                    l = self.ANY_BUCKET_COUNTER
                    self.grab_marker()
                    self._logger.log("Marker grabbed")
                    return 1
                l += 1
            self._logger.log("Grabbing marker failed")
            return 0
Пример #8
0
 def worker(proc_args, out_q, i_proc):
     """ 
     The worker function: this function is assigned to different processes
     """
     print "Process %i has been assigned %i subtask(s)" % (os.getpid(), len(proc_args))
     for arg_idx, arg in enumerate(proc_args):
         t = Stopwatch()
         out_q.put(multiply_matrices(arg))
         elapsed = t.finish(milli=False)
         print "Process %i needed %d seconds for task %i/%i" % (os.getpid(), elapsed, arg_idx+1, len(proc_args))
Пример #9
0
    def center_on_gate(self):
        config = self.config['centering']
        MAX_TIME_SEC = config['max_time_sec']
        MAX_CENTER_DISTANCE = config['max_center_distance']

        stopwatch = Stopwatch()
        stopwatch.start()

        while stopwatch.time() <= MAX_TIME_SEC:

            bbox = self.darknet_client.predict()[0].normalize(480, 480)
            if bbox.x <= MAX_CENTER_DISTANCE & bbox.y <= MAX_CENTER_DISTANCE:
                self._logger.log("centered on gate successfully")
                return True
            center_rov(move=self._control,
                       Bbox=bbox,
                       depth_sensor=self.depth_sensor)
        self._logger.log("couldn't center on gate")
        return False
Пример #10
0
def index():
    """Renders and returns the index-page HTML template."""
    sw = Stopwatch()
    try:
        return render_template("index.html")
    except Exception as ex:
        error_handler.log_error(ex)
        return str(ex)
    finally:
        di.service_stats.log_command(name="index", elapsed_ms=sw.elapsed_ms)
Пример #11
0
    def run(self):
        MAX_TIME_SEC = self.config['search']['max_time_sec'] #w configu trzeba zmienic bo na razie jest do samego gate
        self._logger.log("Gate task executor started")
        self._control.pid_turn_on()

        stopwatch = Stopwatch()
        stopwatch.start()

        if stopwatch.time() >= MAX_TIME_SEC:
            self._logger.log("TIME EXPIRED GATE NOT FOUND")
            self._control.set_ang_velocity(0, 0, 0)

        if not self.dive():
            self._logger.log("GateTE: diving in progress")

        if not self.find_gate():
            self._logger.log("GateTE: finding the gate in progress")

        if not self.go_trough_gate():
            self._logger.log("GateTE: going through the gate")
Пример #12
0
def ping() -> Response:
    """Returns the plain text "Up", "LoadingData", or "Down" depending on service state."""
    sw = Stopwatch()
    try:
        response = di.service_state.state.name
    except Exception as ex:
        error_handler.log_error(ex)
        response = str(ex)
    finally:
        di.service_stats.log_command(name="ping", elapsed_ms=sw.elapsed_ms)
    return Response(response=response, mimetype="text/plain")
Пример #13
0
    def find_flare(self):
        # TODO: obsługa wykrycia dwóch flar

        self._logger.log("finding the flare")
        config = self.config['search']

        stopwatch = Stopwatch()
        stopwatch.start()
        self._logger.log("started find flare loop")

        while stopwatch < config['max_time_sec']:
            # sprawdza kilka razy, dla pewności
            #   (no i żeby confidence się zgadzało, bo bez tego to nawet jak raz wykryje, to nie przejdzie -
            #       - przy moving_avg_discount=0.9 musi wykryć 10 razy z rzędu
            for i in range(config['number_of_samples']):
                if self.is_this_flare():
                    return True
            self._control.rotate_angle(0, 0, config['rotation_angle'])

        self._logger.log("flare not found")
        return False
Пример #14
0
def get_stats() -> Response:
    """Generates service statistics and returns JSON response."""
    errors = []
    sw = Stopwatch()
    try:
        response = command_writer.get_stats()
    except Exception as ex:
        errors.append(ex)
        error_handler.log_error(ex)
        response = command_writer.error(errors, "getstats")
    finally:
        di.service_stats.log_command(name="getstats", elapsed_ms=sw.elapsed_ms)
    return Response(response, mimetype="application/json")
Пример #15
0
    def go_to_flare(self):
        """
        moves distance to flare + a little more to knock it
        :return: True - if managed to move distance in time
                 False - if didn't manage to move distance in time
        """
        config = self.config['go']
        MAX_TIME_SEC = config['max_time_sec']

        stopwatch = Stopwatch()
        stopwatch.start()

        self._control.move_distance(
            self.flare_position['distance'] +
            self.config['go']['distance_to_add_m'], 0, 0)

        if stopwatch <= MAX_TIME_SEC:
            self._logger.log("go_to_flare - traveled whole distance")
            return True
        else:
            self._logger.log("go_to_flare - didn't travel whole distance")
            return False
Пример #16
0
    def center_on_pinger(self):
        """
        rotates in vertical axis so pinger signal is in front
        """
        config = self.config['centering']
        MAX_TIME_SEC = config['max_time_sec']
        MAX_CENTER_ANGLE_DEG = config['max_center_angle_deg']

        self._logger.log("centering on pinger")
        stopwatch = Stopwatch()
        stopwatch.start()

        while stopwatch < MAX_TIME_SEC:
            angle = self._hydrophones.get_angle()
            if angle is None:
                self._logger.log(
                    "no signal from hydrophones - locating pinger failed")
                return False
            if abs(angle) < MAX_CENTER_ANGLE_DEG:
                self._logger.log("centered on pinger successfully")
                return True
            self._control.rotate_angle(0, 0, angle)
        self._logger.log("couldn't ceneter on pinger")
        return False
Пример #17
0
def main():
    puzzle = Puzzle(hard_table)
    stopwatch = Stopwatch()
    stopwatch.tic()
    print("Searching A*")
    result = AStar(puzzle, h1).search()
    stopwatch.toc()
    for path in result.path:
        sleep(0.25)
        print_puzzle_table(path)

    print("Depth = " + str(result.depth))
    print("Cost = " + str(result.cost))
    print("Time = " + str(stopwatch))
    return 0
Пример #18
0
def main():
    containers = 6
    capacity = 6
    wastes = ['t{}'.format(i) for i in range(4)]
    explosive = ['e{}'.format(i) for i in range(4)]
    frozen = ['fz{}'.format(i) for i in range(5)]
    fresh = ['fs{}'.format(i) for i in range(6)]
    edibles = ['f{}'.format(i) for i in range(6)]

    stopwatch = Stopwatch()
    stopwatch.tic()
    container = ContainerAssignment(containers, capacity, wastes, edibles,
                                    explosive, frozen, fresh)
    result = container.backtrack_search()
    stopwatch.toc()
    print("N = {};\ttime={}".format(4, stopwatch))
    print(container)
    return 0
Пример #19
0
def main():
    table = hardest_table

    sudoku = Sudoku(table)
    stopwatch = Stopwatch()
    stopwatch.tic()
    res = sudoku.search()
    stopwatch.toc()
    print("Solved={}\ttime={}".format(res.success, stopwatch))

    if res.success:
        for coordinates, number in res.assignment.items():
            if table[coordinates[0]][coordinates[1]] == 0:
                table[coordinates[0]][coordinates[1]] = number
            else:
                if table[coordinates[0]][coordinates[1]] != number:
                    raise Exception("Algoritmo di ricerca sminchiato")

        sudoku_pretty_print(table)
    return 0
Пример #20
0
    def is_assignment_consistent(self, var, value):
        return self.assignment.count(value) == 0 and self.__check_diagonals(
            (var, value))

    def add_assignment(self, var, value):
        self.assignment[var] = value

    def remove_assignment(self, var):
        self.assignment[var] = -1

    def __check_diagonals(self, house):
        for i in range(self.size // 2 + 1):
            for old_assignment in enumerate(self.assignment):
                if old_assignment[1] == -1:
                    continue
                if house[0] in [old_assignment[0] + i, old_assignment[0] - i] and \
                        house[1] in [old_assignment[1] + i, old_assignment[1] - i]:
                    return False
        return True


stopwatch = Stopwatch()
stopwatch.tic()
queen = QueenCSP(60)
result = queen.backtrack_search()
stopwatch.toc()
print("N = {};\ttime={}".format(8, stopwatch))
if result.success:
    print("result: {}".format(result.assignment))
    print(make_board(result.assignment))
Пример #21
0
class Logger:
    def __init__(self):
        Context.logger = self
        self._episodes = Context.config['exp.episodes']
        self._work_path = os.path.join(Context.work_path, WORK_DIR)
        self._sw = Stopwatch()
        self._saved_time = 0.
        self._train_history = []
        self._eval_history = []

    def __str__(self):
        return "%s:\n\t%s\n\t%s\n\t%s" % (
            self.__class__.__name__,
            "saved_time: %s" % hms(self._saved_time),
            "train_history: %d" % len(self._train_history),
            "eval_history: %d" % len(self._eval_history),
        )

    def on_start(self):
        self._sw.start()

    def on_train_episode(self, e, n, r, q):
        # episode, noise_rate, reward, q_max
        t = self.time_spent
        self._train_history.append([e, t, n, r, q])
        self._log_training_episode(e, n, r, q)
        self._update_training_title(self._episodes, e, n, r, q)

        if (e + 1) % Context.config['report.write_every_episodes'] == 0:
            self._write_train_report()

        if (e + 1) % Context.config['report.summary_every_episodes'] == 0:
            self._log_summary(e, self._episodes, self.time_spent)

        if (e + 1) % Context.config['exp.save_every_episodes'] == 0:
            self._save_state()

    @staticmethod
    def on_evaluiation_start():
        Context.window_title['episode'] = "| Evaluating..."

    def on_evaluiation_end(self, e, r):
        self._log_evaluiation_end(r)
        self._eval_history.append([e, self.time_spent, r])  # e,t,r
        Context.window_title['episode'] = "| Trainig..."

    @property
    def episode(self):
        return len(self._train_history)

    @property
    def time_left(self):
        return int((self.time_spent * (1 - self.progress) /
                    self.progress) if self.progress != 0 else 0)

    @property
    def progress(self):
        return self.episode / float(self._episodes)

    @property
    def time_spent(self):
        return self._saved_time + self._sw.time_elapsed

    @property
    def _state_path(self):
        return os.path.join(self._work_path, "state.pickle")

    def _save_state(self):
        with open(make_dir_if_not_exists(self._state_path), 'w') as f:
            pickle.dump(
                [self.time_spent, self._train_history, self._eval_history],
                f,
                protocol=pickle.HIGHEST_PROTOCOL)

    def restore(self):
        with open(self._state_path, 'r') as f:
            [self._saved_time, self._train_history,
             self._eval_history] = pickle.load(f)

    def try_restore(self):
        try:
            self.restore()
        except (ValueError, IOError):
            return False
        return True

    def _write_train_report(self):
        info = {
            'ep': len(self._train_history),
            'eps': self._episodes,
            'spent': self.time_spent,
            'left': self.time_left,
            'progress': self.progress,
            'train_history': self._train_history,
            'eval_history': self._eval_history,
        }
        report = TrainReport(Context.config, info)
        report.make()
        report.save()

    def _log_evaluiation_end(self, r):
        pass

    def _log_summary(self, ep, eps, spent):
        line = '========================================================================='
        self.log(line + ('\n' * 3))
        self._log_summary_time(ep, eps, spent)
        self.log(('\n' * 3) + line)

    @staticmethod
    def log(text, end='\n'):
        print(text, end=end)

    # noinspection PyStringFormat
    def _log_summary_time(self, ep, eps, spent):
        progress = ep / float(eps)
        left = (spent * (1 - progress) / progress) if progress != 0 else 0
        self.log("Time spent:  %s | " % hms(spent), end='')
        self.log("%.0f%% " % (progress * 100, ), end='')
        self.log("+ %s = %s, " % (hms(left), hms(spent + left)), end='')
        self.log("%.1f per sec" % (ep / float(spent), ))

    @staticmethod
    def _update_training_title(eps, e, n, r, q):
        Context.window_title['episode'] = "| %d%% N=%.2f R=%+.0f Q=%+.0f" % (
            e / float(eps) * 100, n, r, q)

    def _log_training_episode(self, e, n, r, q):
        self.log("Ep: %3d  |  NR: %.2f  |  Reward: %+7.0f  |  Qmax: %+8.1f" %
                 (e, n, r, q))
Пример #22
0
async def _eval(ctx, *, code: str):
    if not dev_check(ctx.author.id):
        return await ctx.send(
            f"Sorry, but you can't run this command because you ain't a developer! {bot.get_emoji(555121740465045516)}"
        )
    env = {
        "bot": bot,
        "ctx": ctx,
        "channel": ctx.channel,
        "author": ctx.author,
        "guild": ctx.guild,
        "message": ctx.message,
        "msg": ctx.message,
        "_": bot._last_result,
        "source": inspect.getsource,
        "src": inspect.getsource,
        "session": bot.session,
        "docs": lambda x: print(x.__doc__)
    }

    env.update(globals())
    body = cleanup_code(code)
    stdout = io.StringIO()
    err = out = None
    to_compile = f"async def func():\n{textwrap.indent(body, '  ')}"
    stopwatch = Stopwatch().start()

    try:
        exec(to_compile, env)
    except Exception as e:
        stopwatch.stop()
        err = await ctx.send(
            f"**Error**```py\n{e.__class__.__name__}: {e}\n```\n**Type**```ts\n{Type(e)}```\n⏱ {stopwatch}"
        )
        return await ctx.message.add_reaction(bot.get_emoji(522530579627900938)
                                              )

    func = env["func"]
    stopwatch.restart()
    try:
        with redirect_stdout(stdout):
            ret = await func()
            stopwatch.stop()
    except Exception as e:
        stopwatch.stop()
        value = stdout.getvalue()
        err = await ctx.send(
            f"**Error**```py\n{value}{traceback.format_exc()}\n```\n**Type**```ts\n{Type(err)}```\n⏱ {stopwatch}"
        )
    else:
        value = stdout.getvalue()
        if ret is None:
            if value:
                try:
                    out = await ctx.send(
                        f"**Output**```py\n{value}```\n⏱ {stopwatch}")
                except:
                    paginated_text = paginate(value)
                    for page in paginated_text:
                        if page == paginated_text[-1]:
                            out = await ctx.send(f"```py\n{page}\n```",
                                                 edit=False)
                            break
                        await ctx.send(f"```py\n{page}\n```", edit=False)
                    await ctx.send(f"⏱ {stopwatch}", edit=False)
        else:
            bot._last_result = ret
            try:
                out = await ctx.send(
                    f"**Output**```py\n{value}{ret}```\n**Type**```ts\n{Type(ret)}```\n⏱ {stopwatch}"
                )
            except:
                paginated_text = paginate(f"{value}{ret}")
                for page in paginated_text:
                    if page == paginated_text[-1]:
                        out = await ctx.send(f"```py\n{page}```", edit=False)
                        break
                    await ctx.send(f"```py\n{page}```", edit=False)
                await ctx.send(f"**Type**```ts\n{Type(ret)}```\n⏱ {stopwatch}",
                               edit=False)
        if out:
            await ctx.message.add_reaction(bot.get_emoji(522530578860605442))
        elif err:
            await ctx.message.add_reaction(bot.get_emoji(522530579627900938))
        else:
            await ctx.message.add_reaction("\u2705")
Пример #23
0
def word_split() -> Response:
    """Performs word split operation, returns JSON response with metadata OR plain text."""
    errors: List[Exception] = []
    output = "json"
    sw = Stopwatch()
    try:
        # parse params
        inputs: List[str] = (request.args.get("input")
                             or "").replace("|", ",").split(",")
        pass_display = int(request.args.get("passdisplay") or "5")
        exhaustive = (request.args.get("exhaustive") or "0") == "1"
        verbosity = VerbosityLevel(int(request.args.get("verbosity") or "0"))
        output = (request.args.get("output") or "json").lower()
        cache = (request.args.get("cache") or "1") == "1"

        # parse restrictions
        if not exhaustive:
            max_input_chars = config.default_max_input_chars
            max_terms = config.default_max_terms
            max_passes = config.default_max_passes
        else:
            max_input_chars = config.exhaustive_max_input_chars
            max_terms = config.exhaustive_max_terms
            max_passes = config.exhaustive_max_passes

        # limit input
        for i in range(len(inputs)):
            if len(inputs[i]) > max_input_chars:
                inputs[i] = remove(inputs[i], max_input_chars)
        if len(inputs) > 1000:
            del inputs[1000:]

        # perform splits
        results = []
        for s in inputs:
            if (verbosity < VerbosityLevel.High) and (not exhaustive):
                result = di.word_splitter.simple_split(s, cache, max_terms,
                                                       max_passes, errors)
            else:
                result = di.word_splitter.full_split(s, cache, pass_display,
                                                     max_terms, max_passes,
                                                     errors)
            results.append(result)

        # write response
        response = ""
        if output == "json":
            response = command_writer.word_split(verbosity, inputs,
                                                 pass_display, exhaustive,
                                                 results, sw.elapsed_ms,
                                                 errors)
        elif output == "text":
            for r in results:
                response += r.output + "\n"

    except Exception as ex:
        errors.append(ex)
        error_handler.log_error(ex)
        response = command_writer.error(errors, "wordsplit")

    finally:
        di.service_stats.log_command(name="wordsplit",
                                     elapsed_ms=sw.elapsed_ms)

    return Response(
        response,
        mimetype="application/json" if output == "json" else "text/plain")
Пример #24
0
def gl_FGD_primal(x0: np.ndarray, A: np.ndarray, b: np.ndarray, mu_0,
                  opts: dict):
    default_opts = {
        "maxit": 1500,  # 最大迭代次数
        "thres": 1e-3,  # 判断小量是否被认为 0 的阈值
        "step_type": "line_search",  # 步长衰减的类型(见辅助函数)
        "alpha0": 1e-3,  # 步长的初始值
        "ftol": 1e-6,  # 停机准则,当目标函数历史最优值的变化小于该值时认为满足
        "stable_len_threshold": 70,
        "line_search_attenuation_coeffi": 0.98,
        "maxit_line_search_iter": 5,
        "delta": 1e-6  # 光滑化参数
    }
    # The second dictionary's values overwrite those from the first.
    opts = {**default_opts, **opts}
    sparsity_func = lambda x: np.sum(np.abs(x) > 1e-6 * np.max(np.abs(x))
                                     ) / x.size

    def real_obj_func(x: np.ndarray):
        fro_term = 0.5 * np.sum((A @ x - b)**2)
        regular_term = np.sum(LA.norm(x, axis=1).reshape(-1, 1))
        return fro_term + mu_0 * regular_term

    out = {
        "fvec": None,  # 每一步迭代的 LASSO 问题目标函数值
        "grad_hist": None,  # 可微部分梯度范数的历史值
        "f_hist": None,  # 目标函数的历史值
        "f_hist_best": None,  # 目标函数每一步迭代对应的历史最优值
        "tt": None,  # 运行时间
        "flag": None  # 标记是否收敛
    }

    maxit, ftol, alpha0 = opts["maxit"], opts["ftol"], opts["alpha0"]
    stable_len_threshold = opts["stable_len_threshold"]
    thres = opts["thres"]
    step_type = opts['step_type']
    aten_coeffi = opts['line_search_attenuation_coeffi']
    max_line_search_iter = opts['maxit_line_search_iter']
    delta = opts["delta"]

    logger.debug("alpha0= {:10E}".format(alpha0))
    f_hist, f_hist_best, sparsity_hist = [], [], []
    v_hist, t_hist = [], []
    f_best = np.inf

    x_k = np.copy(x0)

    stopwatch = Stopwatch()
    stopwatch.start()
    k = 0
    for mu in [100 * mu_0, 10 * mu_0, mu_0]:
        logger.debug("new mu= {:10E}".format(mu))

        # min f(x) = g(x) + h(x)
        # g(x) = 0.5 * |Ax-b|_F^2 + mu * smoothed |x|_{1,2}
        # h(x) = 0

        def g_func(x: np.ndarray):
            fro_term = 0.5 * np.sum((A @ x - b)**2)
            regular_term = np.sum(
                np.sqrt(np.sum(x**2, axis=1).reshape(-1, 1) + delta * delta) -
                delta)
            return fro_term + mu * regular_term

        def grad_g_func(x: np.ndarray):
            fro_term_grad = A.T @ (A @ x - b)
            regular_term_grad = x / np.sqrt(
                np.sum(x**2, axis=1).reshape(-1, 1) + delta * delta)
            return fro_term_grad + mu * regular_term_grad

        v_k = np.copy(x_k)
        t_k = alpha0

        def prox_th(x: np.ndarray, t):
            """ Proximal operator of t * mu * h(x).
            """
            return x

        inner_iter = 0

        def set_step(step_type: str):
            iter_hat = max(inner_iter, 1000) - 999
            if step_type == 'fixed':
                return alpha0
            elif step_type == 'diminishing':
                return alpha0 / np.sqrt(iter_hat)
            elif step_type == 'diminishing2':
                return alpha0 / iter_hat
            elif step_type == 'line_search':

                t = t_k
                g_y = g_func(y)

                def stop_condition(t):
                    x = prox_th(y - t * grad_g_y, t)
                    g_x = g_func(x)
                    return g_x <= g_y + np.sum(grad_g_y * (x - y)) + np.sum(
                        (x - y)**2) / (2 * t)

                for i in range(max_line_search_iter):
                    if stop_condition(t):
                        break
                    t *= aten_coeffi
                return t

            else:
                logger.error("Unsupported type.")

        stable_len = 0

        while inner_iter < maxit:
            # Record current objective value
            f_now = real_obj_func(x_k)
            f_hist.append(f_now)

            f_best = min(f_best, f_now)
            f_hist_best.append(f_best)

            sparsity_hist.append(sparsity_func(x_k))

            v_hist.append(v_k)

            t_hist.append(t_k)

            k += 1
            inner_iter += 1

            if (k > 1
                    and abs(f_hist[k - 1] - f_hist[k - 2]) / abs(f_hist[k - 2])
                    < ftol
                    and abs(sparsity_hist[k - 1] - sparsity_hist[k - 2]) /
                    abs(sparsity_hist[k - 2]) < ftol):
                stable_len += 1
            else:
                stable_len = 0
            if stable_len > stable_len_threshold:
                break

            x_k[np.abs(x_k) < thres] = 0

            theta = 2 / (inner_iter + 1)
            y = (1 - theta) * x_k + theta * v_k
            grad_g_y = grad_g_func(y)

            t = set_step(step_type)
            x = prox_th(y - t * grad_g_y, t)
            v = x_k + (x - x_k) / theta

            x_k, v_k, t_k = x, v, t

            if k % 100 == 0:
                logger.debug(
                    'iter= {:5}, objective= {:10E}, sparsity= {:3f}'.format(
                        k, f_now.item(), sparsity_func(x)))

    elapsed_time = stopwatch.elapsed(
        time_format=Stopwatch.TimeFormat.kMicroSecond) / 1e6
    out = {
        "tt": elapsed_time,
        "fval": real_obj_func(x),
        "f_hist": f_hist,
        "f_hist_best": f_hist_best
    }

    return x, k, out
Пример #25
0
def load_gene_expression_profile(gene_list_file_name,
                                 gene_expression_file_name,
                                 gene_filter_file_name=None,
                                 gene_list_path=None,
                                 gene_expression_path=None,
                                 gene_filter_path=None,
                                 by_gene=False,
                                 list_mode="FROM_DISK"):
    stopwatch = Stopwatch()
    stopwatch.start()
    if list_mode == "ON_THE_FLY":
        gene_list = gene_list_file_name
    else:
        gene_list = load_gene_list(gene_list_file_name=gene_list_file_name,
                                   gene_list_path=gene_list_path)
    gene_list = [l.split(".")[0] for i, l in enumerate(gene_list)]
    print stopwatch.stop("done loading gene list")
    # random.shuffle(gene_list)
    # gene_list = gene_list[:400]
    if gene_filter_file_name:
        stopwatch.start()
        filter_gene_list = load_gene_list(
            gene_list_file_name=gene_filter_file_name,
            gene_list_path=gene_filter_path)
        gene_list = [cur for cur in gene_list if cur in filter_gene_list]
        print stopwatch.stop("done filter gene list")

    if gene_expression_path == None:
        gene_expression_path = os.path.join(constants.TCGA_DATA_DIR,
                                            gene_expression_file_name)
        stopwatch.start()
    f = open(gene_expression_path, 'r')
    expression_profiles_filtered = [
        l.strip().split() for i, l in enumerate(f)
        if i == 0 or l[:l.strip().find('\t')].split(".")[0] in gene_list
    ]
    # or l.strip()[0:l.strip().find('\t')] in gene_list or l.strip()[0:l.strip().find('\t')].split(".")[0] in gene_list
    f.close()
    print stopwatch.stop("done filter gene expression")
    if not by_gene:
        stopwatch.start()
        expression_profiles_filtered = np.flip(
            np.rot90(expression_profiles_filtered, k=1, axes=(1, 0)), 1)
        print stopwatch.stop("done rotate gene expression")

    return expression_profiles_filtered
Пример #26
0
def gl_Alm_dual(x0: np.ndarray, A: np.ndarray, b: np.ndarray, mu, opts: dict):
    default_opts = {
        "maxit": 100,  # 最大迭代次数
        "thres": 1e-3,  # 判断小量是否被认为 0 的阈值
        "tau": (1 + math.sqrt(5)) * 0.5,
        "rho": 1e2,
        "converge_len": 20,
    }

    # The second dictionary's values overwrite those from the first.
    opts = {**default_opts, **opts}

    def sparsity_func(x: np.ndarray):
        return np.sum(np.abs(x) > 1e-6 * np.max(np.abs(x))) / x.size

    def real_obj_func(x: np.ndarray):
        fro_term = 0.5 * np.sum((A @ x - b)**2)
        regular_term = np.sum(LA.norm(x, axis=1).reshape(-1, 1))
        return fro_term + mu * regular_term

    out = {
        "fvec": None,  # 每一步迭代的 LASSO 问题目标函数值
        "f_hist": None,  # 目标函数的历史值
        "f_hist_best": None,  # 目标函数每一步迭代对应的历史最优值
        "tt": None,  # 运行时间
    }

    def projection_functor(x: np.array):
        row_norms = LA.norm(x, axis=1, ord=2).reshape(-1, 1)
        return mu * x / np.clip(row_norms, a_min=mu, a_max=None)

    maxit, thres = opts["maxit"], opts["thres"]
    rho, tau = opts['rho'], opts['tau']
    converge_len = opts['converge_len']

    f_hist, f_hist_best, sparsity_hist = [], [], []
    f_best = np.inf

    x_k = np.copy(x0)
    z_k = np.zeros_like(b)
    u_k = np.zeros_like(x_k)

    stopwatch = Stopwatch()
    stopwatch.start()

    L = LA.cholesky(np.identity(A.shape[0]) + rho * A @ A.T)

    k = 0
    length = 0

    while k < maxit:
        k += 1

        u = solve_sub_problem(b, rho, A, x_k, L, mu)
        z = LA.solve(L.T, LA.solve(L, A @ (x_k - rho * u) - b))
        x = x_k - tau * rho * (u + A.T @ z)

        r_k = u + A.T @ z  # 原始可行性
        s_k = A @ (u_k - u)  # 对偶可行性

        z_k, u_k, x_k = z, u, x
        f_now = real_obj_func(x_k)
        f_hist.append(f_now)

        f_best = min(f_best, f_now)
        f_hist_best.append(f_best)

        sparsity_now = sparsity_func(x_k)

        sparsity_hist.append(sparsity_now)

        if k % 1 == 0:
            logger.debug(
                'iter= {:5}, objective= {:10E}, sparsity= {:3f}'.format(
                    k, f_now.item(), sparsity_now.item()))

        r_k_norm = LA.norm(r_k, ord=2)
        s_k_norm = LA.norm(s_k, ord=2)
        if r_k_norm < thres and s_k_norm < thres:
            length += 1
        else:
            length = 0

        if length >= converge_len:
            break

    elapsed_time = stopwatch.elapsed(
        time_format=Stopwatch.TimeFormat.kMicroSecond) / 1e6
    out = {
        "tt": elapsed_time,
        "fval": real_obj_func(x_k),
        "f_hist": f_hist,
        "f_hist_best": f_hist_best
    }

    return x_k, k, out
Пример #27
0
def gl_Admm_primal(x0: np.ndarray, A: np.ndarray, b: np.ndarray, mu,
                   opts: dict):
    default_opts = {
        "maxit": 100,  # 最大迭代次数
        "thres": 1e-3,  # 判断小量是否被认为 0 的阈值
        "tau": (1 + math.sqrt(5)) * 0.5,
        "rho": 1e-2,
        "eta_0": 100,
        "converge_len": 10,
        "converge_thres": 1e-5,
        "step_type": "fixed",
    }

    # The second dictionary's values overwrite those from the first.
    opts = {**default_opts, **opts}
    sparsity_func = lambda x: np.sum(np.abs(x) > 1e-6 * np.max(np.abs(x))
                                     ) / x.size

    def real_obj_func(x: np.ndarray):
        fro_term = 0.5 * np.sum((A @ x - b)**2)
        regular_term = np.sum(LA.norm(x, axis=1).reshape(-1, 1))
        return fro_term + mu * regular_term

    out = {
        "fvec": None,  # 每一步迭代的 LASSO 问题目标函数值
        "f_hist": None,  # 目标函数的历史值
        "f_hist_best": None,  # 目标函数每一步迭代对应的历史最优值
        "tt": None,  # 运行时间
    }

    maxit, thres = opts["maxit"], opts["thres"]
    rho, tau, eta_0 = opts['rho'], opts['tau'], opts['eta_0']
    converge_len = opts['converge_len']
    converge_thres = opts['converge_thres']
    step_type = opts['step_type']

    f_hist, f_hist_best, sparsity_hist = [], [], []
    f_best = np.inf

    def prox_tf(x: np.array, t):
        t_mu = t * mu
        row_norms = LA.norm(x, axis=1).reshape(-1, 1)
        rv = x * np.clip(row_norms - t_mu, a_min=0, a_max=None) / (
            (row_norms < thres) + row_norms)
        return rv

    x_k = np.copy(x0)
    y_k = x_k
    z_k = x_k

    stopwatch = Stopwatch()
    stopwatch.start()

    k = 0

    L = LA.cholesky(rho * np.identity(A.shape[1]) + A.T @ A)
    AT_b = A.T @ b

    length = 0

    def set_step(step_type: str):
        if step_type == 'fixed':
            return eta_0
        elif step_type == 'diminishing':
            return eta_0 / np.sqrt(k)
        elif step_type == 'diminishing2':
            return eta_0 / k

    while k < maxit:
        k += 1
        eta = set_step(step_type)
        y = LA.solve(L.T, LA.solve(L, AT_b - z_k + rho * x_k))
        # x = prox_tf(y + z_k / rho, 1/rho)
        x = prox_tf(x_k - eta * rho * (x_k - y - z_k / rho), eta)
        z = z_k - tau * rho * (x - y)

        r_k = x - y  # 原始可行性
        s_k = y - y_k  # 对偶可行性

        x_k, y_k, z_k = x, y, z

        f_now = real_obj_func(x_k)
        f_hist.append(f_now)

        f_best = min(f_best, f_now)
        f_hist_best.append(f_best)

        sparsity_hist.append(sparsity_func(x_k))

        if k % 1 == 0:
            logger.debug(
                'iter= {:5}, objective= {:10E}, sparsity= {:3f}'.format(
                    k, f_now.item(), sparsity_func(x_k)))

        r_k_norm = LA.norm(r_k, ord=2)
        s_k_norm = LA.norm(s_k, ord=2)
        if r_k_norm < thres and s_k_norm < thres:
            length += 1
        else:
            length = 0

        if length >= converge_len:
            break

    elapsed_time = stopwatch.elapsed(
        time_format=Stopwatch.TimeFormat.kMicroSecond) / 1e6
    out = {
        "tt": elapsed_time,
        "fval": real_obj_func(x_k),
        "f_hist": f_hist,
        "f_hist_best": f_hist_best
    }

    return x_k, k, out
Пример #28
0
def check_group_enrichment_goatools(tested_gene_file_name,
                                    total_gene_file_name,
                                    th=1):
    if len(tested_gene_file_name) == 0 or len(total_gene_file_name) == 0:
        return []

    if type(total_gene_file_name) == str:
        total_gene_list = load_gene_list(total_gene_file_name)
    else:
        total_gene_list = total_gene_file_name

    if type(tested_gene_file_name) == str:
        tested_gene_list = load_gene_list(tested_gene_file_name)
    else:
        tested_gene_list = tested_gene_file_name

    if not os.path.exists(
            os.path.join(constants.GO_DIR, constants.GO_FILE_NAME)):
        download(constants.GO_OBO_URL, constants.GO_DIR)

    obo_dag = GODag(os.path.join(constants.GO_DIR, constants.GO_FILE_NAME))

    if not os.path.exists(
            os.path.join(constants.GO_DIR,
                         constants.GO_ASSOCIATION_FILE_NAME)):
        download(constants.GO_ASSOCIATION_GENE2GEO_URL, constants.GO_DIR)
        with gzip.open(
                os.path.join(
                    constants.GO_DIR,
                    os.path.basename(constants.GO_ASSOCIATION_GENE2GEO_URL)),
                'rb') as f_in:
            with open(
                    os.path.join(constants.GO_DIR,
                                 constants.GO_ASSOCIATION_FILE_NAME),
                    'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)

    assoc = read_ncbi_gene2go(os.path.join(constants.GO_DIR,
                                           constants.GO_ASSOCIATION_FILE_NAME),
                              no_top=True)

    sw = Stopwatch()
    sw.start()
    g = GOEnrichmentStudy(
        [int(cur) for cur in ensembl2entrez_convertor(total_gene_list)],
        assoc,
        obo_dag,
        methods=[],
        log=None)  # "bonferroni", "fdr_bh"
    g_res = g.run_study(
        [int(cur) for cur in ensembl2entrez_convertor(tested_gene_list)])
    print sw.stop("done GO analysis in ")
    # GO_results = [(cur.NS, cur.GO, cur.goterm.name, cur.pop_count, cur.p_uncorrected, cur.p_fdr_bh) for cur in g_res if
    #               cur.p_fdr_bh <= 0.05]
    GO_results = [(cur.NS, cur.GO, cur.goterm.name, cur.pop_count,
                   cur.p_uncorrected) for cur in g_res
                  if cur.p_uncorrected <= th]

    hg_report = [{
        HG_GO_ROOT: cur[0],
        HG_GO_ID: cur[1],
        HG_GO_NAME: cur[2],
        HG_VALUE: cur[3],
        HG_PVAL: cur[4],
        HG_QVAL: 1
    } for cur in GO_results]  # , HG_QVAL : cur[5]
    # hg_report.sort(key=lambda x: x[HG_QVAL])
    hg_report.sort(key=lambda x: x[HG_PVAL])

    if len(GO_results) > 0:
        go_ns, go_terms, go_names, go_hg_value, uncorrectd_pvals = zip(
            *GO_results)  # , FDRs
    else:
        go_terms = []
        uncorrectd_pvals = []
        FDRs = []
        go_names = []
        go_ns = []
    # output_rows = [("\r\n".join(e2g_convertor(tested_gene_list)),  "\r\n".join(go_ns),
    #                     "\r\n".join(go_terms), "\r\n".join(go_names), "\r\n".join(map(str, uncorrectd_pvals)),
    #                     "\r\n".join(map(str, FDRs)))]
    # print_to_excel(output_rows, str(tested_gene_file_name)[:10], str(total_gene_file_name)[:10])
    return hg_report
Пример #29
0
class GameGUI:

    FPS_LIMIT = 60
    AI_TIMESTEP_DELAY = 100
    HUMAN_TIMESTEP_DELAY = 1000
    CELL_SIZE = 20

    SNAKE_CONTROL_KEYS = [
        pygame.K_UP, pygame.K_LEFT, pygame.K_DOWN, pygame.K_RIGHT
    ]

    def __init__(self):
        pygame.init()
        self.agent = PlayerAgent()
        self.env = None
        self.screen = None
        self.fps_clock = None
        self.timestep_watch = Stopwatch()

    def load_environment(self, environment):
        """
        Load the RL environment into the GUI.
        """

        self.env = environment
        screen_size = (self.env.field.size * self.CELL_SIZE,
                       self.env.field.size * self.CELL_SIZE)

        self.screen = pygame.display.set_mode(screen_size)
        self.screen.fill(Colors.SCREEN_BACKGROUND)
        pygame.display.set_caption('Snake')

    def load_agent(self, agent):
        """ Load the RL agent into the GUI. """
        self.agent = agent

    def render_cell(self, x, y):
        """
        Draw the cell specified by the field coordinates.
        """

        cell_coordinates = pygame.Rect(
            x * self.CELL_SIZE,
            y * self.CELL_SIZE,
            self.CELL_SIZE,
            self.CELL_SIZE,
        )

        if self.env.field[x, y] == CellType.EMPTY:
            pygame.draw.rect(self.screen, Colors.SCREEN_BACKGROUND,
                             cell_coordinates)
        else:
            color = Colors.CELL_TYPE[self.env.field[x, y]]
            pygame.draw.rect(self.screen, color, cell_coordinates, 1)

            internal_padding = self.CELL_SIZE // 6 * 2
            internal_padding = (-internal_padding, -internal_padding)
            internal_square_coords = cell_coordinates.inflate(
                *internal_padding)
            pygame.draw.rect(self.screen, color, internal_square_coords)

    def render(self):
        """ Draw the entire game frame. """
        for x in range(self.env.field.size):
            for y in range(self.env.field.size):
                self.render_cell(x, y)

    def map_key_to_snake_action(self, key):
        """ Convert a keystroke to an environment action. """
        actions = [
            SnakeActions.MAINTAIN_DIRECTION,
            SnakeActions.TURN_LEFT,
            SnakeActions.MAINTAIN_DIRECTION,
            SnakeActions.TURN_RIGHT,
        ]

        ALL_SNAKE_DIRECTIONS = [
            SnakeDirections.NORTH,
            SnakeDirections.EAST,
            SnakeDirections.SOUTH,
            SnakeDirections.WEST,
        ]

        key_idx = self.SNAKE_CONTROL_KEYS.index(key)
        direction_idx = ALL_SNAKE_DIRECTIONS.index(self.env.snake.direction)
        return np.roll(actions, -key_idx)[direction_idx]

    def run(self, num_episodes=1):
        """ Run the GUI player for the specified number of episodes. """
        pygame.display.update()
        self.fps_clock = pygame.time.Clock()

        for episode in range(num_episodes):
            try:
                self.run_episode()
            except QuitRequestedError:
                break

            pygame.time.wait(1500)

    def run_episode(self):
        """ Run the GUI player for a single episode. """

        timestep_result = self.initialize_env()
        is_human_agent = isinstance(self.agent, PlayerAgent)
        timestep_delay = self.get_timestep_delay(is_human_agent)

        default_action = SnakeActions.MAINTAIN_DIRECTION

        while True:
            action = default_action

            for event in pygame.event.get():
                print('Event ', event.type)

                if event.type == KEYDOWN:
                    if is_human_agent and event.key in self.SNAKE_CONTROL_KEYS:
                        action = self.map_key_to_snake_action(event.key)
                    if event.key == pygame.K_ESCAPE:
                        raise QuitRequestedError

                if event.type == QUIT:
                    raise QuitRequestedError

            # Update game state.
            timestep_timed_out = self.timestep_watch.time() >= timestep_delay
            human_made_move = is_human_agent and action != default_action

            if timestep_timed_out or human_made_move:
                self.timestep_watch.reset()

                if not is_human_agent:
                    action = self.agent.next_action(
                        timestep_result.observation, timestep_result.reward)

                self.env.choose_action(action)
                timestep_result = self.env.timestep()

                if timestep_result.is_episode_end:
                    self.agent.end_episode()
                    break

            self.render_scene()

    def initialize_env(self):
        """ Initialize environment for new session. """
        self.timestep_watch.reset()
        timestep_result = self.env.new_episode()
        self.agent.reset_state()

        return timestep_result

    def get_timestep_delay(self, is_human_agent):
        if is_human_agent:
            return self.HUMAN_TIMESTEP_DELAY
        else:
            return self.AI_TIMESTEP_DELAY

    def render_scene(self):
        """ Render current scene. """
        self.render()
        score = self.env.snake.length - self.env.initial_snake_length
        pygame.display.set_caption(f'Snake  [Score: {score:02d}]')
        pygame.display.update()
        self.fps_clock.tick(self.FPS_LIMIT)
Пример #30
0
def gl_SGD_primal(x0: np.ndarray, A: np.ndarray, b: np.ndarray, mu_0, opts: dict):
    default_opts = {
        "maxit": 2100,  # 内循环最大迭代次数
        "thres": 1e-3,  # 判断小量是否被认为 0 的阈值
        "step_type": "diminishing",  # 步长衰减的类型(见辅助函数)
        "alpha0": 1e-3,  # 步长的初始值
        "ftol": 1e-5,  # 停机准则,当目标函数历史最优值的变化小于该值时认为满足
        "stable_len_threshold": 100,
        "continuous_subgradient_flag": False,
    }
    # The second dictionary's values overwrite those from the first.
    opts = {**default_opts, **opts}
    sparsity_func = lambda x: np.sum(np.abs(x) > 1e-6 * np.max(np.abs(x))) / x.size
    out = {
        "fvec": None,  # 每一步迭代的 LASSO 问题目标函数值
        "grad_hist": None,  # 可微部分梯度范数的历史值
        "f_hist": None,  # 目标函数的历史值
        "f_hist_best": None,  # 目标函数每一步迭代对应的历史最优值
        "tt": None,  # 运行时间
        "flag": None  # 标记是否收敛
    }

    maxit, ftol, alpha0 = opts["maxit"], opts["ftol"], opts["alpha0"]
    stable_len_threshold = opts["stable_len_threshold"]
    thres = opts["thres"]

    if opts["continuous_subgradient_flag"]:
        L = np.max(LA.eigvals(A.T @ A))
        alpha0 = 1. / L.real

    logger.debug("alpha0= {:10E}".format(alpha0))
    f_hist, f_hist_best = [], []
    f_best = np.inf

    x = np.copy(x0)
    stopwatch = Stopwatch()
    stopwatch.start()
    k = 0
    stable_len = 0
    for mu in [100 * mu_0, 10 * mu_0, mu_0]:
        logger.debug("new mu= {:10E}".format(mu))

        def obj_func(x: np.ndarray):
            fro_term = 0.5 * np.sum((A @ x - b) ** 2)
            regular_term = np.sum(LA.norm(x, axis=1).reshape(-1, 1))
            return fro_term + mu * regular_term

        def subgrad(x: np.ndarray):
            fro_term_grad = A.T @ (A @ x - b)
            regular_term_norm = LA.norm(x, axis=1).reshape(-1, 1)
            regular_term_grad = x / ((regular_term_norm < thres) + regular_term_norm)
            grad = fro_term_grad + mu * regular_term_grad
            return grad

        inn_iter = 0

        def set_step(step_type):
            iter_hat = max(inn_iter, 1000) - 999
            if step_type == 'fixed' or mu > mu_0:
                return alpha0
            elif step_type == 'diminishing':
                return alpha0 / np.sqrt(iter_hat)
            elif step_type == 'diminishing2':
                return alpha0 / iter_hat
            else:
                logger.error("Unsupported type.")

        while inn_iter < maxit:
            # Record current objective value
            f_now = obj_func(x)
            f_hist.append(f_now)

            f_best = min(f_best, f_now)
            f_hist_best.append(f_best)
            k += 1
            inn_iter += 1

            if k > 1 and abs(f_hist[k - 1] - f_hist[k - 2]) / abs(f_hist[k - 2]) < ftol:
                stable_len += 1
            else:
                stable_len = 0
            # if stable_len > stable_len_threshold:
            #     break

            x[np.abs(x) < thres] = 0
            sub_g = subgrad(x)
            alpha = set_step(opts["step_type"])
            x = x - alpha * sub_g

            if k % 100 == 0:
                logger.debug('iter= {:5}, objective= {:10E}, sparsity= {:3f}'.format(k, f_now.item(), sparsity_func(x)))

    elapsed_time = stopwatch.elapsed(time_format=Stopwatch.TimeFormat.kMicroSecond) / 1e6
    out = {
        "tt": elapsed_time,
        "fval": obj_func(x),
        "f_hist": f_hist,
        "f_hist_best": f_hist_best
    }

    return x, k, out
def gl_ProxGD_primal(x0: np.ndarray, A: np.ndarray, b: np.ndarray, mu_0, opts: dict):
    default_opts = {
        "maxit": 2500,  # 最大迭代次数
        "thres": 1e-3,  # 判断小量是否被认为 0 的阈值
        "step_type": "line_search",  # 步长衰减的类型(见辅助函数)
        "alpha0": 2e-3,  # 步长的初始值
        "ftol": 1e-6,  # 停机准则,当目标函数历史最优值的变化小于该值时认为满足
        "stable_len_threshold": 70,
        "line_search_attenuation_coeffi": 0.9,
        "maxit_line_search_iter": 5,
    }
    # The second dictionary's values overwrite those from the first.
    opts = {**default_opts, **opts}
    sparsity_func = lambda x: np.sum(np.abs(x) > 1e-6 * np.max(np.abs(x))) / x.size

    def real_obj_func(x: np.ndarray):
        fro_term = 0.5 * np.sum((A @ x - b) ** 2)
        regular_term = np.sum(LA.norm(x, axis=1).reshape(-1, 1))
        return fro_term + mu_0 * regular_term

    out = {
        "fvec": None,  # 每一步迭代的 LASSO 问题目标函数值
        "grad_hist": None,  # 可微部分梯度范数的历史值
        "f_hist": None,  # 目标函数的历史值
        "f_hist_best": None,  # 目标函数每一步迭代对应的历史最优值
        "tt": None,  # 运行时间
        "flag": None  # 标记是否收敛
    }

    maxit, ftol, alpha0 = opts["maxit"], opts["ftol"], opts["alpha0"]
    stable_len_threshold = opts["stable_len_threshold"]
    thres = opts["thres"]
    step_type = opts['step_type']
    aten_coeffi = opts['line_search_attenuation_coeffi']
    max_line_search_iter = opts['maxit_line_search_iter']

    logger.debug("alpha0= {:10E}".format(alpha0))
    f_hist, f_hist_best, sparsity_hist = [], [], []
    f_best = np.inf

    x = np.copy(x0)
    stopwatch = Stopwatch()
    stopwatch.start()
    k = 0
    for mu in [100 * mu_0, 10 * mu_0, mu_0]:
        logger.debug("new mu= {:10E}".format(mu))

        # min f(x) = g(x) + h(x)
        # g(x) = 0.5 * |Ax-b|_F^2
        # h(x) = mu * |x|_{1,2}

        def g(x: np.ndarray):
            return 0.5 * np.sum((A @ x - b) ** 2)

        grad_g = None

        def prox_th(x: np.ndarray, t):
            """ Proximal operator of t * mu * h(x).
            """
            t_mu = t * mu
            row_norms = LA.norm(x, axis=1).reshape(-1, 1)
            rv = x * np.clip(row_norms - t_mu, a_min=0, a_max=None) / ((row_norms < thres) + row_norms)
            return rv

        def Gt(x: np.ndarray, t):
            return (x - prox_th(x - t * grad_g, t)) / t

        inner_iter = 0

        def set_step(step_type: str):
            iter_hat = max(inner_iter, 1000) - 999
            if step_type == 'fixed':
                return alpha0
            elif step_type == 'diminishing':
                return alpha0 / np.sqrt(iter_hat)
            elif step_type == 'diminishing2':
                return alpha0 / iter_hat
            elif step_type == 'line_search':
                g_x = g(x)

                def stop_condition(x, t):
                    gt_x = Gt(x, t)
                    return (g(x - t * gt_x)
                            <= g_x - t * np.sum(grad_g * gt_x) + 0.5 * t * np.sum(gt_x ** 2))

                alpha = alpha0
                for i in range(max_line_search_iter):
                    if stop_condition(x, alpha):
                        break
                    alpha *= aten_coeffi
                return alpha
            else:
                logger.error("Unsupported type.")

        stable_len = 0

        while inner_iter < maxit:
            # Record current objective value
            f_now = real_obj_func(x)
            f_hist.append(f_now)

            f_best = min(f_best, f_now)
            f_hist_best.append(f_best)

            sparsity_hist.append(sparsity_func(x))

            k += 1
            inner_iter += 1

            if (k > 1
                    and abs(f_hist[k - 1] - f_hist[k - 2]) / abs(f_hist[k - 2]) < ftol
                    and abs(sparsity_hist[k - 1] - sparsity_hist[k - 2]) / abs(sparsity_hist[k - 2]) < ftol):
                stable_len += 1
            else:
                stable_len = 0
            if stable_len > stable_len_threshold:
                break

            x[np.abs(x) < thres] = 0

            grad_g = A.T @ (A @ x - b)
            alpha_k = set_step(step_type)
            # logger.debug("alpha_k: {}".format(alpha_k))
            x = prox_th(x - alpha_k * grad_g, alpha_k)

            if k % 100 == 0:
                logger.debug(
                    'iter= {:5}, objective= {:10E}, sparsity= {:3f}'.format(k, f_now.item(), sparsity_func(x)))

    elapsed_time = stopwatch.elapsed(time_format=Stopwatch.TimeFormat.kMicroSecond) / 1e6
    out = {
        "tt": elapsed_time,
        "fval": real_obj_func(x),
        "f_hist": f_hist,
        "f_hist_best": f_hist_best
    }

    return x, k, out
Пример #32
0
def gridsearch(nhidden, lrate, wcost, momentum, n_in_train_minibatch, \
                     sampling_steps, n_temperatures, max_time_secs, save_plots):
    params = ['nhidden', 'lrate', 'wcost', 'momentum', 'n_in_train_minibatch', 'n_sampling_steps', 'n_temperatures']
    attempts = []
    for nh in nhidden: # [200, 400, 800]:
        for lr in lrate: # [0.005, 0.02]:
            for wc in wcost: #, 0.002]:
                for mo in momentum: # [0, 0.4, 0.9]:
                    for nitm in n_in_train_minibatch: # [10, 250]:
                        for st in sampling_steps:
                            for ntemp in n_temperatures:
                                attempts.append({'nhidden': nh,
                                                 'lrate': lr,
                                                 'wcost': wc,
                                                 'momentum': mo,
                                                 'n_in_train_minibatch': nitm,
                                                 'n_sampling_steps': st,
                                                 'n_temperatures': ntemp,
                                                 'should_plot': False,
                                                 })
    
    nattempts = len(attempts)
    pid = os.getpid()
    total_time_secs = max_time_secs * nattempts
    print 'Beginning %i attempts (PID=%s, DT=%s), each for %i secs, ETA = %s' % \
        (nattempts, pid, dt_str(), max_time_secs, eta_str(total_time_secs))
    for attempt in attempts: print attempt

    # Generate data set
    n_trainpatterns = 5000
    n_validpatterns = 1000
    train_pset, valid_pset, test_pset = create_mnist_patternsets(n_trainpatterns=n_trainpatterns, n_validpatterns=n_validpatterns)
    n_trainpatterns, n_validpatterns, n_testpatterns = train_pset.n, valid_pset.n, test_pset.n
    valid_patterns = np.array(valid_pset.patterns).reshape((n_validpatterns,-1))
    test_patterns = np.array(test_pset.patterns).reshape((n_testpatterns, -1))

    all_t = Stopwatch()
    for attempt_idx, attempt in enumerate(attempts):
        np.random.seed(1)
        t = Stopwatch()
        net = RbmNetwork(np.prod(train_pset.shape),
                         attempt['nhidden'],
                         attempt['lrate'],
                         attempt['wcost'],
                         attempt['momentum'],
                         attempt['n_temperatures'],
                         attempt['n_sampling_steps'],
                         v_shape=train_pset.shape,
                         plot=save_plots)
        train_errors = []
        valid_errors = []
        test_errors = []
        epochnum = 0
        while True: # train as long as time limit is not exceeded
            minibatch_pset = Minibatch(train_pset, nitm)
            net.learn_trial(minibatch_pset.patterns)

            # calculate errors
            train_error = np.mean(net.test_trial(minibatch_pset.patterns)[0])
            train_errors.append(train_error)
            valid_error = np.mean(net.test_trial(valid_patterns)[0])
            valid_errors.append(valid_error)
            test_error = np.mean(net.test_trial(test_patterns)[0])
            test_errors.append(test_error)

            if t.finish(milli=False) > max_time_secs: break
            epochnum += 1

        train_elapsed = t.finish(milli=False)

        # add ATTEMPT keys that will vary each time below
        # this line. the HASH should be the same for this
        # set of parameters every time you run the benchmark
        attempt['hash'] = hash(HashableDict(attempt))
        attempt['attempt_idx'] = attempt_idx
        attempt['pid'] = pid
        attempt['n_train_epochs'] = epochnum
        attempt['train_elapsed'] = train_elapsed
        attempt['dt'] = dt_str()
        attempt['test_error'] = train_errors[-1] # error from last iteration
        attempt['npatterns'] = n_trainpatterns

        if save_plots:
            filename = 'error%i.png' % attempt_idx
            net.save_error_plots(train_errors, valid_errors, test_errors, filename)

        print 'Finished %i of %i: error %.2f in %.1f secs' % (attempt_idx+1, \
                                           nattempts, test_error, train_elapsed)
        print attempt
        print '--------------------'
        print
    
    print_best(attempts, params)