コード例 #1
0
def sacred_experiment_with_config(config, name, main_fcn, db_name, base_dir, checkpoint_dir, sources=[], tune_config={}):
    """Launch a sacred experiment."""
    # creating a sacred experiment
    # https://github.com/IDSIA/sacred/issues/492
    from sacred import Experiment, SETTINGS
    SETTINGS.CONFIG.READ_ONLY_CONFIG = False

    ex = Experiment(name, base_dir=base_dir)
    ex.observers.append(MongoObserver(db_name=db_name))

    for f in sources:
        if isinstance(f, str):
            f_py = f + '.py'
            shutil.copy(f, f_py)
            ex.add_source_file(f_py)

    export_config = dict(config)
    export_config.update(tune_config)
    ex.add_config(config=tune_config, **tune_config)

    @ex.main
    def run_train():
        return main_fcn(config=config, checkpoint_dir=checkpoint_dir, ex=ex)

    return ex.run()
コード例 #2
0
ファイル: train.py プロジェクト: dstnluong/tf2multiagentrl
def main():
    file_observer = FileStorageObserver.create(os.path.join('results', 'sacred'))
    # use this code to attach a mongo database for logging
    mongo_observer = MongoObserver(url='localhost:27017', db_name='sacred')
    train_ex.observers.append(mongo_observer)
    train_ex.observers.append(file_observer)
    train_ex.run_commandline()
コード例 #3
0
def setup_experiment(experiment_name):
    mongo_uri = 'mongodb://*****:*****@localhost:27017/sacred?authSource=admin'
    ex = Experiment(experiment_name, save_git_info=False)
    ex.observers.append(MongoObserver(url=mongo_uri, db_name='sacred'))
    slack_obs = SlackObserver.from_config(os.environ['SLACK_CONFIG'])
    ex.observers.append(slack_obs)
    return ex
コード例 #4
0
def run(config_updates, mongodb_url="localhost:27017"):
    """Run a single experiment with the given configuration

    Args:
        config_updates (dict): Configuration updates
        mongodb_url (str): MongoDB URL, or None if no Mongo observer should be used for
            this run
    """

    # Dynamically bind experiment config and main function
    ex = Experiment()
    ex.config(base_config)
    ex.main(element_world_v4)

    # Attach MongoDB observer if necessary
    if mongodb_url is not None and not ex.observers:
        ex.observers.append(MongoObserver(url=mongodb_url))

    # Suppress warnings about padded MPDs
    with warnings.catch_warnings():
        warnings.filterwarnings(action="ignore", category=PaddedMDPWarning)

        # Run the experiment
        run = ex.run(config_updates=config_updates
                     )  # , options={"--loglevel": "ERROR"})

    # Return the result
    return run.result
コード例 #5
0
def experiment(name):
    ex = Experiment(name)
    mongo_observer = MongoObserver(url="mongodb://*****:*****@localhost:27017")
    file_observer = FileStorageObserver(
        basedir=os.path.join("sacred/experiments", name))
    ex.observers.append(mongo_observer)
    ex.observers.append(file_observer)
    return ex
コード例 #6
0
ファイル: config.py プロジェクト: Ryanrenqian/PPNet-PyTorch
def add_observer(config, command_name, logger):
    """A hook fucntion to add observer"""
    exp_name = f'{ex.path}_{config["exp_str"]}'
    observer = MongoObserver()
    # observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name))
    ex.observers.append(observer)
    config['workspace'] = os.path.join(config['path']['log_dir'], exp_name)
    return config
コード例 #7
0
ファイル: sacred_config.py プロジェクト: sts-sadr/GTN-1
    def local(cls, docker: bool = True) -> SacredConfig:
        def observer():
            return MongoObserver(
                url=
                f'mongodb://*****:*****@localhost:27017/?authMechanism=SCRAM-SHA-1',
                db_name='db')

        return SacredConfigImpl(
            observer if docker else lambda: MongoObserver())
コード例 #8
0
def ready_mongo_observer(ex, db_name='sacred', url='localhost:27017'):
    """Readies a mongo observer for use with sacred.

    Args:
        ex (sacred.Experiment): Sacred experiment to track.
        db_name (str): Name of the mongo database.
        url (str): Host location.
    """
    ex.observers.append(MongoObserver(url=url, db_name=db_name))
コード例 #9
0
def observer_from_env(resume_key: str = None):
    user = os.environ['SACRED_USER']
    password = os.environ['SACRED_PASSWORD']
    database = os.environ['SACRED_DATABASE']
    host = os.environ['SACRED_HOST']
    url = (
        f'mongodb+srv://{user}:{password}@{host}/{database}?retryWrites=true&'
        'w=majority')
    if resume_key is None:
        return MongoObserver(url, db_name=database)
    return ResumableMongoObserver(resume_key, url, db_name=database)
コード例 #10
0
def get_sacred_experiment(name, observer='mongo', capture_output=True):
    ex = sacred.Experiment(name)
    if observer == 'mongo':
        ex.observers.append(
            MongoObserver(url='mongodb://{{cookiecutter.mongo_user}}:'
                          '{{cookiecutter.mongo_password}}@127.0.0.1:27017',
                          db_name='sacred'))
    else:
        ex.observers.append(FileStorageObserver('data/sacred/'))

    if not capture_output:
        SETTINGS.CAPTURE_MODE = 'no'
    return ex
コード例 #11
0
 def default_observers(self):
     observers = []
     if socket.gethostname() in self.mongo_hostnames:
         observers.append(
             MongoObserver(
                 url=
                 f"mongodb://*****:*****@localhost:27017/?authMechanism=SCRAM-SHA-1",
                 db_name="db",
             ))
     observers.append(
         FileStorageObserver(
             self.exp_config.get("storage_dir", "./sacred_storage")))
     return observers
コード例 #12
0
def config():
    algorithm = "mcts"  # {"mcts", "lfd", "lfd-mcts", "acer", "greedy", "random", "truth", "mle", "beamsearch"}
    policy = "nn"  # {"nn", "random", "likelihood"}
    teacher = "truth"  # {"truth", "mle"}

    env_type = "1d"  # for now, only {"1d"}
    seed = 24927  # 1971248 was used for first round

    name = f"{algorithm}_{policy}" if algorithm == "mcts" else algorithm
    run_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}_{seed}"

    database = True
    ex.observers.append(FileStorageObserver(f"./data/runs/{run_name}"))
    if database:
        ex.observers.append(MongoObserver())
コード例 #13
0
def create_experiment(name='exp', database=None):
    """ Create Scared experiment object for experiment logging """
    ex = Experiment(name)

    atlas_user = os.environ.get('MONGO_DB_USER')
    atlas_password = os.environ.get('MONGO_DB_PASS')
    atlas_host = os.environ.get('MONGO_DB_HOST')

    # Add remote MongoDB observer, only if environment variables are set
    if atlas_user and atlas_password and atlas_host:
        ex.observers.append(
            MongoObserver(
                url=f"mongodb+srv://{atlas_user}:{atlas_password}@{atlas_host}",
                db_name=database))
    return ex
コード例 #14
0
    def to_mongo(self, base_dir, remove_sources=False,
                 overwrite=None, *args, **kwargs):
        """
        Exports the file log into a mongo database.
        Requires sacred to be installed.
        Args:
            base_dir: root path to sources
            remove_sources: if sources are too complicated to match
            overwrite: whether to overwrite an experiment
            *args: args of the MongoObserver
            **kwargs: keyword args of the MongoObserver
        """
        from sacred.observers import MongoObserver

        observer = MongoObserver(*args, overwrite=overwrite, **kwargs)
        self.export(observer, base_dir, remove_sources, overwrite)
コード例 #15
0
def run(config, mongodb_url="localhost:27017"):
    """Run a single experiment with the given configuration"""

    # Dynamically bind experiment config and main function
    ex = Experiment()
    ex.config(base_config)
    ex.main(canonical_puddle_world)

    # Attach MongoDB observer if necessary
    if not ex.observers:
        ex.observers.append(MongoObserver(url=mongodb_url))

    # Suppress warnings about padded MPDs
    with warnings.catch_warnings():
        warnings.filterwarnings(action="ignore", category=PaddedMDPWarning)

        # Run the experiment
        run = ex.run(config_updates=config, options={"--loglevel": "ERROR"})

    # Return the result
    return run.result
コード例 #16
0
def evaluate(args, args_eval, model_file):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    ex = None
    if args_eval.sacred:
        from sacred import Experiment
        from sacred.observers import MongoObserver
        ex = Experiment(args_eval.sacred_name)
        ex.observers.append(
            MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME))
        ex.add_config({
            "batch_size": args.batch_size,
            "epochs": args.epochs,
            "learning_rate": args.learning_rate,
            "encoder": args.encoder,
            "num_objects": args.num_objects,
            "custom_neg": args.custom_neg,
            "in_ep_prob": args.in_ep_prob,
            "seed": args.seed,
            "dataset": args.dataset,
            "save_folder": args.save_folder,
            "eval_dataset": args_eval.dataset,
            "num_steps": args_eval.num_steps,
            "use_action_attention": args.use_action_attention
        })

    device = torch.device('cuda' if args.cuda else 'cpu')

    dataset = utils.PathDatasetStateIds(hdf5_file=args.dataset,
                                        path_length=args_eval.num_steps)
    eval_loader = data.DataLoader(dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=4)

    # Get data sample
    obs = eval_loader.__iter__().next()[0]
    input_shape = obs[0][0].size()

    model = modules.ContrastiveSWM(
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        action_dim=args.action_dim,
        input_dims=input_shape,
        num_objects=args.num_objects,
        sigma=args.sigma,
        hinge=args.hinge,
        ignore_action=args.ignore_action,
        copy_action=args.copy_action,
        split_mlp=args.split_mlp,
        same_ep_neg=args.same_ep_neg,
        only_same_ep_neg=args.only_same_ep_neg,
        immovable_bit=args.immovable_bit,
        split_gnn=args.split_gnn,
        no_loss_first_two=args.no_loss_first_two,
        bisim_model=make_pairwise_encoder() if args.bisim_model_path else None,
        encoder=args.encoder,
        use_action_attention=args.use_action_attention).to(device)

    model.load_state_dict(torch.load(model_file))
    model.eval()

    # topk = [1, 5, 10]
    topk = [1]
    hits_at = defaultdict(int)
    num_samples = 0
    rr_sum = 0

    pred_states = []
    next_states = []
    next_ids = []

    with torch.no_grad():

        for batch_idx, data_batch in enumerate(eval_loader):
            data_batch = [[t.to(device) for t in tensor]
                          for tensor in data_batch]
            observations, actions, state_ids = data_batch

            if observations[0].size(0) != args.batch_size:
                continue

            obs = observations[0]
            next_obs = observations[-1]
            next_id = state_ids[-1]

            state = model.obj_encoder(model.obj_extractor(obs))
            next_state = model.obj_encoder(model.obj_extractor(next_obs))

            pred_state = state
            for i in range(args_eval.num_steps):
                pred_state = model.forward_transition(pred_state, actions[i])

            pred_states.append(pred_state.cpu())
            next_states.append(next_state.cpu())
            next_ids.append(next_id.cpu().numpy())

        pred_state_cat = torch.cat(pred_states, dim=0)
        next_state_cat = torch.cat(next_states, dim=0)
        next_ids_cat = np.concatenate(next_ids, axis=0)

        full_size = pred_state_cat.size(0)

        # Flatten object/feature dimensions
        next_state_flat = next_state_cat.view(full_size, -1)
        pred_state_flat = pred_state_cat.view(full_size, -1)

        dist_matrix = utils.pairwise_distance_matrix(next_state_flat,
                                                     pred_state_flat)

        #num_digits = 1
        #dist_matrix = (dist_matrix * 10 ** num_digits).round() / (10 ** num_digits)
        #dist_matrix = dist_matrix.float()

        dist_matrix_diag = torch.diag(dist_matrix).unsqueeze(-1)
        dist_matrix_augmented = torch.cat([dist_matrix_diag, dist_matrix],
                                          dim=1)

        # Workaround to get a stable sort in numpy.
        dist_np = dist_matrix_augmented.numpy()
        indices = []
        for row in dist_np:
            keys = (np.arange(len(row)), row)
            indices.append(np.lexsort(keys))
        indices = np.stack(indices, axis=0)

        if args_eval.dedup:
            mask_mistakes = indices[:, 0] != 0
            closest_next_ids = next_ids_cat[indices[:, 0] - 1]

            if len(next_ids_cat.shape) == 2:
                equal_mask = np.all(closest_next_ids == next_ids_cat, axis=1)
            else:
                equal_mask = closest_next_ids == next_ids_cat

            indices[:, 0][np.logical_and(equal_mask, mask_mistakes)] = 0

        indices = torch.from_numpy(indices).long()

        #print('Processed {} batches of size {}'.format(
        #    batch_idx + 1, args.batch_size))

        labels = torch.zeros(indices.size(0),
                             device=indices.device,
                             dtype=torch.int64).unsqueeze(-1)

        num_samples += full_size
        #print('Size of current topk evaluation batch: {}'.format(
        #    full_size))

        for k in topk:
            match = indices[:, :k] == labels
            num_matches = match.sum()
            hits_at[k] += num_matches.item()

        match = indices == labels
        _, ranks = match.max(1)

        reciprocal_ranks = torch.reciprocal(ranks.double() + 1)
        rr_sum += reciprocal_ranks.sum().item()

        pred_states = []
        next_states = []
        next_ids = []

    hits = hits_at[topk[0]] / float(num_samples)
    mrr = rr_sum / float(num_samples)

    if ex is not None:
        # ugly hack
        @ex.main
        def sacred_main():
            ex.log_scalar("hits", hits)
            ex.log_scalar("mrr", mrr)

        ex.run()

    print('Hits @ {}: {}'.format(topk[0], hits))
    print('MRR: {}'.format(mrr))

    return hits, mrr
コード例 #17
0
ファイル: experiment.py プロジェクト: Kotzly/BCI_MsC
from sacred.observers import MongoObserver, FileStorageObserver
from sacred import Experiment

import json

from sklearn.preprocessing import StandardScaler
from sklearn.base import BaseEstimator

from mne.time_frequency import psd_multitaper

ex = Experiment("experiment")
ex.observers.append(FileStorageObserver('my_runs'))
ex.observers.append(
    MongoObserver(
        url='mongodb://*****:*****@localhost:27017',
        db_name='sacred',
    ))

train_filepaths = list(
    Path("/home/paulo/Documents/datasets/BCI_Comp_IV_2a/gdf").glob(
        "*0[1,3,4,7,8]T.gdf"))
test_filepaths = list(
    Path("/home/paulo/Documents/datasets/BCI_Comp_IV_2a/gdf").glob(
        "*0[2,9,6,5]T.gdf"))

ICA_N_COMPONENTS = None
CSP_N_COMPONENTS = 12
PICKS = ["EEG-C3", "EEG-C4", "EEG-Cz"]
PSD_PICKS = [
    "ICA{}".format(str(i).rjust(3, "0")) for i in range(
        len(PICKS) if ICA_N_COMPONENTS is None else ICA_N_COMPONENTS)
コード例 #18
0
from sacred import Experiment
from sacred.observers import MongoObserver
from ....run.online.fruits_seq.RunDQN import RunDQN
from ....constants import Constants
from ....utils.logger import Logger
from ....utils import sacred_utils
from .... import constants
from .... import paths

ex = Experiment("fruits_seq_DQN")
if constants.MONGO_URI is not None and constants.DB_NAME is not None:
    ex.observers.append(
        MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME))
else:
    print(
        "WARNING: results are not being saved. See 'Setup MongoDB' in README.")
ex.add_config(paths.CFG_ONLINE_FRUITS_DQN)


@ex.config
def config():

    num_expert_steps = 0
    num_random_steps = 0
    num_pretraining_steps = 0


@ex.automain
def main(dueling, double_learning, prioritized_replay, learning_rate,
         weight_decay, discount, goal, batch_size, max_steps, max_episodes,
         exploration_steps, prioritized_replay_max_steps, buffer_size,
コード例 #19
0
# encoding: utf-8
from sacred import Experiment
from sacred.observers import MongoObserver
import numpy as np
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error

ex = Experiment("linear 1")
obv = MongoObserver(url="localhost", port=8888, db_name="ml")
ex.observers.append(obv)


@ex.config
def config():
    alpha = 1.0


@ex.automain
def run():
    x = np.random.rand(200, 8)
    y = np.random.rand(200)
    model = Ridge(alpha=1.0)
    model.fit(x, y)
    mse = mean_squared_error(model.predict(x), y)
    ex.log_scalar("mse", mse)
    return float(mse)
コード例 #20
0
ファイル: eval_ids_b_inep.py プロジェクト: ondrejba/c-swm
def evaluate(args, args_eval, model_file):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    ex = None
    if args_eval.sacred:
        from sacred import Experiment
        from sacred.observers import MongoObserver
        ex = Experiment(args_eval.sacred_name)
        ex.observers.append(MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME))
        ex.add_config({
            "batch_size": args.batch_size,
            "epochs": args.epochs,
            "learning_rate": args.learning_rate,
            "encoder": args.encoder,
            "num_objects": args.num_objects,
            "custom_neg": args.custom_neg,
            "in_ep_prob": args.in_ep_prob,
            "seed": args.seed,
            "dataset": args.dataset,
            "save_folder": args.save_folder,
            "eval_dataset": args_eval.dataset,
            "num_steps": args_eval.num_steps,
            "use_action_attention": args.use_action_attention
        })

    device = torch.device('cuda' if args.cuda else 'cpu')

    dataset = utils.PathDatasetStateIds(
        hdf5_file=args.dataset, path_length=10)
    eval_loader = data.DataLoader(
        dataset, batch_size=100, shuffle=False, num_workers=4)

    # Get data sample
    obs = eval_loader.__iter__().next()[0]
    input_shape = obs[0][0].size()

    model = modules.ContrastiveSWM(
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        action_dim=args.action_dim,
        input_dims=input_shape,
        num_objects=args.num_objects,
        sigma=args.sigma,
        hinge=args.hinge,
        ignore_action=args.ignore_action,
        copy_action=args.copy_action,
        split_mlp=args.split_mlp,
        same_ep_neg=args.same_ep_neg,
        only_same_ep_neg=args.only_same_ep_neg,
        immovable_bit=args.immovable_bit,
        split_gnn=args.split_gnn,
        no_loss_first_two=args.no_loss_first_two,
        bisim_model=make_pairwise_encoder() if args.bisim_model_path else None,
        encoder=args.encoder,
        use_action_attention=args.use_action_attention
    ).to(device)

    model.load_state_dict(torch.load(model_file))
    model.eval()

    hits_list = []

    with torch.no_grad():

        for batch_idx, data_batch in enumerate(eval_loader):

            data_batch = [[t.to(
                device) for t in tensor] for tensor in data_batch]

            observations, actions, state_ids = data_batch

            if observations[0].size(0) != args.batch_size:
                continue

            states = []
            for obs in observations:
                states.append(model.obj_encoder(model.obj_extractor(obs)))
            states = torch.stack(states, dim=0)
            state_ids = torch.stack(state_ids, dim=0)

            pred_state = states[0]
            if not args_eval.no_transition:
                for i in range(args_eval.num_steps):
                    pred_state = model.forward_transition(pred_state, actions[i])

            # pred_state: [100, |O|, D]
            # states: [10, 100, |O|, D]
            # pred_state_flat: [100, X]
            # states_flat: [10, 100, X]
            pred_state_flat = pred_state.reshape((pred_state.size(0), pred_state.size(1) * pred_state.size(2)))
            states_flat = states.reshape((states.size(0), states.size(1), states.size(2) * states.size(3)))

            # dist_matrix: [10, 100]
            dist_matrix = (states_flat - pred_state_flat[None]).pow(2).sum(2)
            indices = torch.argmin(dist_matrix, dim=0)
            correct = indices == args_eval.num_steps

            # print(indices[0], args_eval.num_steps)
            # observations = torch.stack(observations, dim=0)
            # correct_obs = observations[args_eval.num_steps, 0]
            # pred_obs = observations[indices[0], 0]
            # import matplotlib
            # matplotlib.use("TkAgg")
            # import matplotlib.pyplot as plt
            # plt.subplot(1, 2, 1)
            # plt.imshow(correct_obs.cpu().numpy()[3:].transpose((1, 2, 0)))
            # plt.subplot(1, 2, 2)
            # plt.imshow(pred_obs.cpu().numpy()[3:].transpose((1, 2, 0)))
            # plt.show()

            # check for duplicates
            if args_eval.dedup:
                equal_mask = torch.all(state_ids[indices, list(range(100))] == state_ids[args_eval.num_steps], dim=1)
                correct = correct + equal_mask

            # hits
            hits_list.append(correct.float().mean().item())

    hits = np.mean(hits_list)

    if ex is not None:
        # ugly hack
        @ex.main
        def sacred_main():
            ex.log_scalar("hits", hits)
            ex.log_scalar("mrr", 0.)

        ex.run()

    print('Hits @ 1: {}'.format(hits))

    return hits, 0.
コード例 #21
0
import frcnn_utils
import init_frcnn_utils
from experiments.exp_utils import get_config_var, LoggerForSacred, Args
from init_frcnn_utils import init_dataloaders_1s_1t, init_val_dataloaders_mt, init_val_dataloaders_1t, \
    init_htcn_model_optimizer

from sacred import Experiment
ex = Experiment()
from sacred.observers import MongoObserver
enable_mongo_observer = False
if enable_mongo_observer:
    vars = get_config_var()
    ex.observers.append(
        MongoObserver(
            url='mongodb://{}:{}@{}/admin?authMechanism=SCRAM-SHA-1'.format(
                vars["SACRED_USER"], vars["SACRED_PWD"], vars["SACRED_URL"]),
            db_name=vars["SACRED_DB"]))
    ex.captured_out_filter = lambda text: 'Output capturing turned off.'

from dataclasses import dataclass

import numpy as np

import torch
import torch.nn as nn

from model.utils.config import cfg, cfg_from_file, cfg_from_list
from model.utils.net_utils import adjust_learning_rate, save_checkpoint, FocalLoss, EFocalLoss

from model.utils.parser_func import set_dataset_args
コード例 #22
0
import os, site, socket, sys
from sacred import Experiment
from sacred.observers import MongoObserver
from sacred.utils import apply_backspaces_and_linefeeds

from derive_conceptualspace.settings import get_setting
from fb_classifier.preprocess_data import preprocess_data, create_traintest
from fb_classifier.dataset import load_data
from fb_classifier.train import TrainPipeline
# from src.fb_classifier.util.misc import get_all_debug_confs, clear_checkpoints_summary
from fb_classifier.util.misc import get_all_configs
from fb_classifier.settings import CLASSIFIER_CHECKPOINT_PATH, SUMMARY_PATH, MONGO_URI, DATA_BASE, DEBUG
import fb_classifier

ex = Experiment("Fachbereich_Classifier")
ex.observers.append(MongoObserver(url=MONGO_URI, db_name=os.environ["MONGO_DATABASE"]))
ex.captured_out_filter = apply_backspaces_and_linefeeds
ex.add_config(get_all_configs(as_string=False))
ex.add_config(DEBUG=DEBUG)
for pyfile in [join(path, name) for path, subdirs, files in os.walk(dirname(fb_classifier.__file__)) for name in files if splitext(name)[1] == ".py"]:
    ex.add_source_file(pyfile)

@ex.main
def run_experiment(_run):
    args = parse_command_line_args()
    setup_logging(args.loglevel, args.logfile)

    # if args.restart: #delete checkpoint- and summary-dir
    #     clear_checkpoints_summary()

    if args.no_continue or not os.listdir(CLASSIFIER_CHECKPOINT_PATH):
コード例 #23
0
from sacred import Experiment
from sacred.observers import MongoObserver

ex = Experiment('hello_config')
# jupyter execute
# @ex.automain (x), @ex.main (o) 
# ex = Experiment('hello_config', interactive=True)

# MongoObserver
ex.observers.append(MongoObserver(url='localhost:27017', 
                                db_name='MY_DB'))

@ex.config
def my_config():
    recipient = "world"
    message = "Hello %s!" % recipient

@ex.automain
def my_main(message):
    print(message)
コード例 #24
0
# -*- coding: utf-8 -*-
from network import Network
from PIL import Image
from scale import size, load_batches
import numpy as np
import os
import sys
from sacred import Experiment
from sacred.observers import MongoObserver
from sacred.utils import apply_backspaces_and_linefeeds
from gpu_helpers import init_all_gpu
init_all_gpu()

ex = Experiment('Superresolution', ingredients=[])
ex.observers.append(MongoObserver())
ex.captured_out_filter = apply_backspaces_and_linefeeds


@ex.config
def my_config():
    image_size = (320, 240)
    batch_size = 5
    no_epochs = 500
    lr = 0.0001
    lr_stair_width = 10
    lr_decay = 0.95


@ex.capture
def log_training_performance(_run, loss, lr):
    _run.log_scalar("loss", float(loss))
コード例 #25
0
ファイル: classify.py プロジェクト: Flyfoxs/lung_classify
    conf_name_base = backbone_name
    oof_file = f'./output/stacking/{version}_{host_name[:5]}_s{best_score:6.5f}_{conf_name_base}_{conf.model_type}_f{valid_fold}_val{val_len}_trn{train_len}.h5'

    print(f'Stacking file save to:{oof_file}')
    save_stack_feature(oof_val, oof_test, oof_file)

###### sacred begin
from sacred import Experiment
from easydict import EasyDict as edict
from sacred.observers import MongoObserver
from sacred import SETTINGS

ex = Experiment('lung')
db_url = 'mongodb://*****:*****@10.10.20.103:27017/db?authSource=admin'
ex.observers.append(MongoObserver(url=db_url, db_name='db'))
#SETTINGS.CAPTURE_MODE = 'sys'

@ex.config
def my_config():
    conf_name = None
    fold = -1
    image_size = 200
    model_type = 'raw'


@ex.command()
def main(_config):
    config = edict(_config)
    print('=====', config)
    train(config)
コード例 #26
0
import tensorflow as tf
import tqdm
import ubelt as ub
from sacred import Experiment
from sacred.observers import MongoObserver

import utils
from EvalNet import EvalNet
from GGNNPolyModel import GGNNPolygonModel
from PolygonModel import PolygonModel
from utils import get_all_files, save_to_json
from vis_predictions import vis_single

ex = Experiment()

ex.observers.append(MongoObserver(url="localhost:27017", db_name="sacred"))
#
tf.logging.set_verbosity(tf.logging.INFO)
# --
flags = tf.flags
FLAGS = flags.FLAGS
# ---
flags.DEFINE_string(
    "PolyRNN_metagraph",
    "models/poly/polygonplusplus.ckpt.meta",
    "PolygonRNN++ MetaGraph ",
)
flags.DEFINE_string("PolyRNN_checkpoint", "models/poly/polygonplusplus.ckpt",
                    "PolygonRNN++ checkpoint ")
flags.DEFINE_string("EvalNet_checkpoint", "models/evalnet/evalnet.ckpt",
                    "Evaluator checkpoint ")
コード例 #27
0
ファイル: train.py プロジェクト: MikeyQiu/DELFT-Extend
import joblib
from data import CATEGORY_IDS
from data import GraphDataset, TextGraphDataset, GloVeTokenizer
import models
import utils

OUT_PATH = 'output/'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

ex = Experiment()
ex.logger = utils.get_logger()
# Set up database logs
uri = os.environ.get('DB_URI')
database = os.environ.get('DB_NAME')
if all([uri, database]):
    ex.observers.append(MongoObserver(uri, database))


@ex.config
def config():
    dataset = 'entities'
    inductive = True
    dim = 128
    model = 'blp'
    rel_model = 'complex'
    loss_fn = 'margin'
    encoder_name = 'bert-base-cased'
    regularizer = 1e-2
    max_len = 32
    num_negatives = 64
    lr = 2e-5
コード例 #28
0
ファイル: train.py プロジェクト: talesa/amci
from sacred import Experiment
from sacred.observers import MongoObserver

import amci.utils

ex = Experiment()

if __name__ == '__main__':
    with open(
            os.path.join(os.path.dirname(os.path.abspath(__file__)),
                         "config_sacred_mongodb.yaml"), 'r') as stream:
        config_sacred_mongodb = yaml.safe_load(stream)
        if config_sacred_mongodb['sacred_mongo_url']:
            ex.observers.append(
                MongoObserver(
                    url=config_sacred_mongodb['sacred_mongo_url'],
                    db_name=config_sacred_mongodb['sacred_mongo_db_name']))
            print(
                f"Adding Sacred MongoDB observer at {config_sacred_mongodb['sacred_mongo_url']}."
            )


def train(args, _run, _writer):
    model = amci.utils.get_model(args)

    q = model.Training.get_proposal_model(args)

    optimizer = torch.optim.Adam(q.parameters(), lr=args.learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        patience=args.scheduler_patience,
コード例 #29
0
mongo_enabled = os.environ.get('MONGO_SACRED_ENABLED')
mongo_user = os.environ.get('MONGO_SACRED_USER')
mongo_pass = os.environ.get('MONGO_SACRED_PASS')
mongo_host = os.environ.get('MONGO_SACRED_HOST')
mongo_port = os.environ.get('MONGO_SACRED_PORT', '27017')

if mongo_enabled == 'true':
    assert mongo_user, 'Setting $MONGO_USER is required'
    assert mongo_pass, 'Setting $MONGO_PASS is required'
    assert mongo_host, 'Setting $MONGO_HOST is required'

    mongo_url = 'mongodb://{0}:{1}@{2}:{3}/' \
                'sacred?authMechanism=SCRAM-SHA-1'.format(mongo_user, mongo_pass, mongo_host, mongo_port)

    ex.observers.append(MongoObserver(url=mongo_url, db_name='sacred'))


@ex.config
def cfg():
    n_epochs = 20
    lr = 2e-5
    batch_size = 16
    base_model = "distilbert-base-uncased"
    clustering_loss_weight = 1.0
    embedding_extractor = concat_cls_n_hidden_states
    annealing_alphas = np.arange(1, n_epochs + 1)
    dataset = "../datasets/ags_news/ag_news.csv"
    train_idx_file = "../datasets/ag_news/splits/train"
    result_dir = f"../results/ag_news-distilbert/{strftime('%Y-%m-%d_%H:%M:%S', gmtime())}"
    early_stopping = True
コード例 #30
0
from utils import calc_loss,load_ckp

register(
    id='SimpleHighway-v1',
    entry_point='env.simple_highway.simple_highway_env:SimpleHighway',
)



experiment_name = "driving_behavior"
sacred_ex = Experiment(experiment_name)
now = datetime.datetime.now()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    sacred_ex.observers.append(MongoObserver(url='localhost:27017',
                                  db_name='sacred'))
except ConnectionError :
    print("MongoDB instance should be running")
             

@sacred_ex.config
def dqn_cfg():
    seed = 123523
    num_lane = 3

    level_k = 2

    agent_level_k = level_k -1
    TRAIN = True
    LOAD_SAVED_MODEL = False
    MODEL_PATH_FINAL = "best_"+str(agent_level_k)