Exemplo n.º 1
0
parser = argparse.ArgumentParser("add-neat-maui-geodss")
parser.add_argument("path")
parser.add_argument("-r", action="store_true", help="recursive search")
parser.add_argument("--config", help="CATCH configuration file")
# parser.add_argument('-u', action='store_true', help='update')

args = parser.parse_args()


def product_id_to_int_id(pid):
    s = pid.split("_")[-1]
    s = s[:-1] + str(ord(s[-1]) - 65)
    return int(s)


with Catch.with_config(Config.from_file(args.config)) as catch:
    for path, dirnames, filenames in os.walk(args.path):
        catch.logger.info("inspecting " + path)
        observations = []
        labels = [f for f in filenames if f.endswith(".lbl")]
        for labelfn in labels:
            try:
                label = pds3.PDS3Label(os.path.join(path, labelfn))
            except:
                catch.logger.error("unable to read " + labelfn)
                continue

            if label["PRODUCT_NAME"] != "NEAT GEODSS IMAGE":
                catch.logger.warning("not a GEODSS image label: " + labelfn)
                continue
Exemplo n.º 2
0
from sbsearch.util import FieldOfView, RADec

parser = argparse.ArgumentParser('add-stsci-dss')
parser.add_argument('headers', nargs='*')

args = parser.parse_args()


def pltlabel_to_int_id(pltlabel):
    s = ''
    for c in pltlabel:
        s += c if c.isdigit() else ord(c)
    return int(s)


with Catch(Config.from_file(), save_log=True) as catch:
    obs = []
    for fn in args.headers:
        h = fits.getheader(fn)
        shape = np.array((h['XPIXELS'], h['YPIXELS'])

        wcs = WCS(h)

        v = wcs.all_pix2world([[0, 0], [0, shape[1]], [shape[0], shape[1]],
                               [shape[0], 0]], 0)
        fov = str(FieldOfView(RADec(v, unit='deg')))

        obs.append(STScIDSS(
            id=product_id_to_int_id(label['PLTLABEL']),
            label=label['PLTLABEL'],
Exemplo n.º 3
0
logger = setup_logger(args.log)

if not os.path.exists(f"{args.base_path}/gbo.ast.spacewatch.survey"):
    raise ValueError(
        f"gbo.ast.spacewatch.survey not found in {args.base_path}")

if args.t:
    logger.info("Testing for existence of all files.")

if args.dry_run or args.t:
    logger.info("Dry run, databases will not be updated.")

if args.v:
    logger.setLevel(logging.DEBUG)

with Catch.with_config(args.config) as catch:
    observations = []
    failed = 0

    tri = ProgressTriangle(1, logger=logger, base=2)
    for fn, label in inventory(args.base_path):
        tri.update()

        if args.t:
            if not os.path.exists(fn):
                logger.error("Missing %s", fn)
            continue

        try:
            observations.append(process(fn, label))
            msg = "added"
Exemplo n.º 4
0
 def next_state(self):
     return Catch(self.next_input())
Exemplo n.º 5
0
def main():
    STEP_LOG_RATE = 1000
    TENSORBOARD_ROOT_PATH = "tensorboard"
    CHECKPOINT_ROOT_PATH = "checkpoints_test"
    CHECKPOINTS_STEPS = 100000
    EXPERIENCE_MEMORY_CAPACITY = 6400000
    MINIBATCH_SIZE = 32
    GAMMA = 0.99
    EPSILON_START = 1.0
    EPSILON_END = 0.1
    EPSILON_DECAY_STEPS = 1000000
    LEARNING_RATE = 0.0001
    FIELD_WIDTH = 5
    FIELD_HEIGHT = 5
    USE_TARGET_NETWORK = True
    TARGET_NETWORK_UPDATE_STEPS = 10000
    STATE_AS_COORDINATES = True
    STATE_NORMALISATION = True

    descriptiveString = buildDescriptiveString(EXPERIENCE_MEMORY_CAPACITY, \
        MINIBATCH_SIZE, GAMMA, EPSILON_START, EPSILON_END, EPSILON_DECAY_STEPS, \
        LEARNING_RATE, STATE_AS_COORDINATES, STATE_NORMALISATION, \
        FIELD_WIDTH, FIELD_HEIGHT, USE_TARGET_NETWORK, TARGET_NETWORK_UPDATE_STEPS)

    tensorboardDirectory = os.path.join(TENSORBOARD_ROOT_PATH,
                                        descriptiveString)
    checkpointDirectory = os.path.join(CHECKPOINT_ROOT_PATH, descriptiveString)

    # create catch environment
    catch = Catch(FIELD_WIDTH, FIELD_HEIGHT, STATE_AS_COORDINATES,
                  STATE_NORMALISATION)
    numberOfActions = catch.getNumberOfActions()
    stateSize = catch.getStateSize()
    # create experience memory
    experienceMemory = ExperienceMemory(EXPERIENCE_MEMORY_CAPACITY, stateSize)

    ########################################################################################################################################################
    input, output, outputLabel, onlineSummary = createModel(stateSize, \
        numberOfActions, isTargetNetwork=False)

    if USE_TARGET_NETWORK:
        targetInput, targetOutput, _, targetSummary = createModel(stateSize, \
            numberOfActions, isTargetNetwork=True)

    with tf.name_scope("train"):
        optimizer = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE)
        loss = tf.losses.huber_loss(labels=outputLabel, predictions=output)
        train = optimizer.minimize(loss)

        tf.summary.scalar("loss", loss)

    episodicStepsSummary = tf.Summary()
    episodicRewardSummary = tf.Summary()
    explorationSummary = tf.Summary()
    experienceMemorySizeSummary = tf.Summary()

    episodicStepsSummary.value.add(tag="episodic_steps", simple_value=None)
    episodicRewardSummary.value.add(tag="episodic_reward", simple_value=None)
    explorationSummary.value.add(tag="exploration", simple_value=None)
    experienceMemorySizeSummary.value.add(tag="experience_memory_size",
                                          simple_value=None)

    trainSummary = tf.summary.merge_all(scope="train")

    sess = tf.Session()
    init = tf.global_variables_initializer()

    sess.run(init)

    writer = tf.summary.FileWriter(tensorboardDirectory, sess.graph)

    updateTargetNetwork(sess, writer, targetSummary, 0)
    ########################################################################################################################################################

    step = 0
    episode = 0
    epsilon = EPSILON_START

    while step < EPSILON_DECAY_STEPS:
        episode += 1

        catch.reset()
        state = catch.getState()
        done = False
        episodeReward = 0
        episodeSteps = 0

        while not done and step < EPSILON_DECAY_STEPS:
            step += 1
            # select next action
            if np.random.random() <= epsilon:
                actionNumber = np.random.randint(numberOfActions)
            else:
                prediction = sess.run(
                    output,
                    feed_dict={input: np.reshape(state, (-1, stateSize))})
                actionNumber = np.argmax(prediction[0])
            # convert action number to action
            action = list(Actions)[actionNumber]
            # execute selected action
            reward, nextState, done = catch.move(action)
            # store experience to memory
            experienceMemory.store(state, actionNumber, reward, nextState,
                                   done)
            # replace current state by next state
            state = nextState
            # replay experiences
            if experienceMemory.size() > MINIBATCH_SIZE:
                # sample from experience memory
                ids, states, actions, rewards, nextStates, nextStateTerminals = experienceMemory.sample(
                    MINIBATCH_SIZE)

                if USE_TARGET_NETWORK:
                    statePredictions = sess.run(output,
                                                feed_dict={input: states})
                    nextStatePredictions = sess.run(
                        targetOutput, feed_dict={targetInput: nextStates})
                else:
                    predictions = sess.run(output,
                                           feed_dict={
                                               input:
                                               np.concatenate(
                                                   (states, nextStates))
                                           })
                    statePredictions = predictions[:MINIBATCH_SIZE]
                    nextStatePredictions = predictions[MINIBATCH_SIZE:]

                statePredictions[np.arange(MINIBATCH_SIZE), actions] = \
                                rewards + np.invert(nextStateTerminals) * GAMMA * \
                                nextStatePredictions.max(axis=1)

                # update online network
                _, onlineSummaryResult, trainSummaryResult = sess.run(
                    [train, onlineSummary, trainSummary],
                    feed_dict={
                        input: states,
                        outputLabel: statePredictions
                    })
                # write summary
                if step % STEP_LOG_RATE == 0:
                    writer.add_summary(onlineSummaryResult, step)
                    writer.add_summary(trainSummaryResult, step)

            episodeReward += reward
            episodeSteps += 1
            # update target network
            if USE_TARGET_NETWORK and step % TARGET_NETWORK_UPDATE_STEPS == 0:
                updateTargetNetwork(sess, writer, targetSummary, step)
            # write exploration summary
            if step % STEP_LOG_RATE == 0:
                explorationSummary.value[0].simple_value = epsilon
                experienceMemorySizeSummary.value[
                    0].simple_value = experienceMemory.size()
                writer.add_summary(explorationSummary, step)
                writer.add_summary(experienceMemorySizeSummary, step)
            # save checkpoint
            if step % CHECKPOINTS_STEPS == 0:
                saveModel(checkpointDirectory, step, sess)
            # calculate epsilon for next step
            epsilon = EPSILON_START - (EPSILON_START - EPSILON_END) / (
                EPSILON_DECAY_STEPS / step)

        # write episodic summary
        episodicStepsSummary.value[0].simple_value = episodeSteps
        episodicRewardSummary.value[0].simple_value = episodeReward
        writer.add_summary(episodicStepsSummary, step)
        writer.add_summary(episodicRewardSummary, step)
Exemplo n.º 6
0
from catch import Catch, Config
from catch.model import SkyMapper

# Find 65P in SkyMapper DR2
#
# Catch v0 result:
#
# * JD: 2457971.9152
# * Product ID: 20170806095706-22
# * https://api.skymapper.nci.org.au/public/siap/dr2/get_image?IMAGE=20170806095706-22&SIZE=0.08333333333333333&POS=237.22441,-23.40757&FORMAT=fits
#
# For CATCH with min_edge_length = 3e-4 rad, spatial index terms are:
# $9e8c1,9e8c1,9e8c4,9e8d,9e8c,9e9,9ec,$9e8c7,9e8c7,$9e8ea04,9e8ea04,9e8ea1,9e8ea4,9e8eb,9e8ec,9e8f,$9e8ea0c,9e8ea0c,$9e8ea74,9e8ea74,9e8ea7

config = Config.from_file('../catch.config', debug=True)
with Catch.with_config(config) as catch:
    catch.db.engine.echo = False  # set to true to see SQL statements

    expected = (catch.db.session.query(SkyMapper)
                .filter(SkyMapper.product_id == '20170806095706-22')
                .all())[0]

    # benchmark queries
    t = []

    # full survey search
    t.append(Time.now())
    job_id = uuid.uuid4()
    count = catch.query('65P', job_id, sources=['skymapper'], cached=False,
                        debug=True)
    full = catch.caught(job_id)
Exemplo n.º 7
0
    optimizer.param_groups[1]['lr'] = lr


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_batches = 350
im_sz = 24
glimps_width = 6
scale = 3
batch_size =  64
lr = 0.0001
std = 0.25

model = MODEL(im_sz=im_sz, channel=1, glimps_width=glimps_width, scale=scale, std = std).to(device)
loss_fn = LOSS(gamma=0).to(device)
optimizer = optim.Adam([{'params': model.parameters(), 'lr': lr}, {'params':loss_fn.parameters(), 'lr':lr}])
env = Catch(batch_size = batch_size, device = device)


for epoch in range(1, 701):
    adjust_learning_rate(optimizer, epoch, lr)
    model.train()
    train_aloss, train_lloss, train_bloss, train_reward = 0, 0, 0, 0
    for batch_idx in range(n_batches):
        optimizer.zero_grad()
        model.initialize(batch_size)
        loss_fn.initialize(batch_size)
        Done = 0
        while(not Done):
            data = env.getframe()                      # get frames of
            action, logpi_a, logpi_l = model(data)     # pass frames from the model to generate actions
            Done, reward = env.step(action)            # make actions and receive rewards
Exemplo n.º 8
0
def catch_manager(save_log: bool = True) -> Iterator[Catch]:
    """Catch library session manager."""
    with Catch(catch_config, session=db_session(), save_log=save_log) as catch:
        yield catch
Exemplo n.º 9
0
def catch_cli(*args):
    """CATCH command-line script."""
    import sys
    import argparse
    import uuid
    from astropy.time import Time
    from astropy.table import Table
    from catch import Catch, Config
    from catch.config import _config_example

    parser = argparse.ArgumentParser(
        "catch",
        epilog=f"Configuration files are JSON-formatted:\n{_config_example}")
    parser.add_argument("--config", help="CATCH configuration file")
    parser.add_argument("--database", help="use this database URI")
    parser.add_argument("--log", help="save log messages to this file")
    parser.add_argument("--arc-limit",
                        type=float,
                        help="maximal arc length to search, radians")
    parser.add_argument("--time-limit",
                        type=float,
                        help="maximal time length to search, days")
    parser.add_argument("--debug", action="store_true", help="debug mode")
    subparsers = parser.add_subparsers(help="sub-command help")

    verify = subparsers.add_parser(
        "verify", help="connect to database and verify and create tables")
    verify.set_defaults(command="verify")

    list_sources = subparsers.add_parser("sources",
                                         help="show available data sources")
    list_sources.set_defaults(command="sources")

    search = subparsers.add_parser("search", help="search for an object")
    search.set_defaults(command="search")
    search.add_argument("desg", help="object designation")
    search.add_argument(
        "--source",
        dest="sources",
        action="append",
        help="search this observation source (may be used multiple times)",
    )
    search.add_argument("--force",
                        dest="cached",
                        action="store_false",
                        help="do not use cached results")
    search.add_argument("-o", help="write table to this file")

    args = parser.parse_args()

    try:
        getattr(args, "command")
    except AttributeError:
        parser.print_help()
        sys.exit()

    if args.command == "verify":
        print("Verify databases and create as needed.\n")

    rows = []
    config = Config.from_args(args)
    with Catch.with_config(config) as catch:
        if args.command == "verify":
            pass
        elif args.command == "sources":
            print("Available sources:\n  *",
                  "\n  * ".join(catch.sources.keys()))
        elif args.command == "search":
            job_id = uuid.uuid4()
            catch.query(args.desg,
                        job_id,
                        sources=args.sources,
                        cached=args.cached)
            columns = set()
            # catch.caught returns a list of rows.
            for row in catch.caught(job_id):
                r = {}
                # Each row consists of a Found and an Observation object.  The
                # Observation object will be a subclass, e.g.,
                # NeatPalomarTricam, or SkyMapper.
                for data_object in row:
                    # Aggregate fields and values from each data object
                    for k, v in _serialize_object(data_object):
                        r[k] = v

                columns = columns.union(set(r.keys()))

                r["cutout_url"] = row.Observation.cutout_url(
                    row.Found.ra, row.Found.dec)

                r["date"] = Time(row.Found.mjd, format="mjd").iso

                rows.append(r)

    if args.command == "search":
        if rows == []:
            print("# none found")
        else:
            # make sure all rows have all columns
            for i in range(len(rows)):
                for col in columns:
                    rows[i][col] = rows[i].get(col)
            tab = Table(rows=rows)
            if args.o:
                tab.write(args.o,
                          format="ascii.fixed_width_two_line",
                          overwrite=True)
            else:
                tab.pprint(-1, -1)
Exemplo n.º 10
0
def main():
    args: argparse.Namespace = _parse_args()

    logger = logging.getLogger("add-css")
    for handler in list(logger.handlers):
        handler.close()
        logger.removeHandler(handler)
    logger.addHandler(logging.StreamHandler())
    logger.addHandler(logging.FileHandler(args.log))
    formatter = logging.Formatter(
        "%(levelname)s %(asctime)s (%(name)s): %(message)s")
    for handler in logger.handlers:
        handler.setFormatter(formatter)
    logger.setLevel(logging.INFO)
    logger.info("Initialized.")
    logger.debug(f"astropy {astropy_version}")
    logger.debug(f"catch {catch_version}")
    logger.debug(f"pds4_tools {pds4_tools_version}")
    logger.debug(f"requests {requests_version}")
    logger.debug(f"sbpy {sbpy_version}")
    logger.debug(f"sbsearch {sbsearch_version}")

    if args.dry_run:
        logger.info("Dry run, databases will not be updated.")

    if args.v:
        logger.setLevel(logging.DEBUG)

    if args.f is None:
        listfile = sync_list()
    else:
        listfile = args.f
        logger.info("Checking user-specified file list.")

    with harvester_db(args.db) as db:
        with Catch.with_config(args.config) as catch:
            observations = []
            failed = 0

            tri = ProgressTriangle(1, logger=logger, base=2)
            for path in new_labels(db, listfile):
                try:
                    observations.append(process(path))
                    msg = "added"
                except ValueError as e:
                    failed += 1
                    msg = str(e)
                except:
                    logger.error("A fatal error occurred processing %s",
                                 path,
                                 exc_info=True)
                    raise

                logger.debug("%s: %s", path, msg)
                tri.update()

                if args.dry_run:
                    continue

                db.execute("INSERT INTO labels VALUES (?,?,?)",
                           (path, Time.now().iso, msg))

                if len(observations) >= 10000:
                    catch.add_observations(observations)
                    db.commit()
                    observations = []

            # add any remaining files
            if not args.dry_run and (len(observations) > 0):
                catch.add_observations(observations)
                db.commit()

            if failed > 0:
                logger.warning("Failed processing %d files", failed)

            logger.info("Updating survey statistics.")
            for source in ("catalina_bigelow", "catalina_lemmon",
                           "catalina_kittpeak"):
                catch.update_statistics(source=source)
Exemplo n.º 11
0
from getkey import getkey, keys
from catch import Catch, Actions

env = Catch(5, 5, True, False)

state = env.getState()
done = False

print(state)
print()

while not done:
    key = getkey()

    if key == keys.LEFT:
        action = Actions.LEFT
    elif key == keys.RIGHT:
        action = Action.RIGHT
    elif key == keys.UP:
        action = Actions.UP
    elif key == keys.DOWN:
        action = Actions.DOWN

    reward, state, done = env.move(action)
    print(state)
    print("reward: {}".format(reward))
    print("done: {}".format(done))
    print()