def sacred_experiment_with_config(config, name, main_fcn, db_name, base_dir, checkpoint_dir, sources=[], tune_config={}): """Launch a sacred experiment.""" # creating a sacred experiment # https://github.com/IDSIA/sacred/issues/492 from sacred import Experiment, SETTINGS SETTINGS.CONFIG.READ_ONLY_CONFIG = False ex = Experiment(name, base_dir=base_dir) ex.observers.append(MongoObserver(db_name=db_name)) for f in sources: if isinstance(f, str): f_py = f + '.py' shutil.copy(f, f_py) ex.add_source_file(f_py) export_config = dict(config) export_config.update(tune_config) ex.add_config(config=tune_config, **tune_config) @ex.main def run_train(): return main_fcn(config=config, checkpoint_dir=checkpoint_dir, ex=ex) return ex.run()
def main(): file_observer = FileStorageObserver.create(os.path.join('results', 'sacred')) # use this code to attach a mongo database for logging mongo_observer = MongoObserver(url='localhost:27017', db_name='sacred') train_ex.observers.append(mongo_observer) train_ex.observers.append(file_observer) train_ex.run_commandline()
def setup_experiment(experiment_name): mongo_uri = 'mongodb://*****:*****@localhost:27017/sacred?authSource=admin' ex = Experiment(experiment_name, save_git_info=False) ex.observers.append(MongoObserver(url=mongo_uri, db_name='sacred')) slack_obs = SlackObserver.from_config(os.environ['SLACK_CONFIG']) ex.observers.append(slack_obs) return ex
def run(config_updates, mongodb_url="localhost:27017"): """Run a single experiment with the given configuration Args: config_updates (dict): Configuration updates mongodb_url (str): MongoDB URL, or None if no Mongo observer should be used for this run """ # Dynamically bind experiment config and main function ex = Experiment() ex.config(base_config) ex.main(element_world_v4) # Attach MongoDB observer if necessary if mongodb_url is not None and not ex.observers: ex.observers.append(MongoObserver(url=mongodb_url)) # Suppress warnings about padded MPDs with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", category=PaddedMDPWarning) # Run the experiment run = ex.run(config_updates=config_updates ) # , options={"--loglevel": "ERROR"}) # Return the result return run.result
def experiment(name): ex = Experiment(name) mongo_observer = MongoObserver(url="mongodb://*****:*****@localhost:27017") file_observer = FileStorageObserver( basedir=os.path.join("sacred/experiments", name)) ex.observers.append(mongo_observer) ex.observers.append(file_observer) return ex
def add_observer(config, command_name, logger): """A hook fucntion to add observer""" exp_name = f'{ex.path}_{config["exp_str"]}' observer = MongoObserver() # observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name)) ex.observers.append(observer) config['workspace'] = os.path.join(config['path']['log_dir'], exp_name) return config
def local(cls, docker: bool = True) -> SacredConfig: def observer(): return MongoObserver( url= f'mongodb://*****:*****@localhost:27017/?authMechanism=SCRAM-SHA-1', db_name='db') return SacredConfigImpl( observer if docker else lambda: MongoObserver())
def ready_mongo_observer(ex, db_name='sacred', url='localhost:27017'): """Readies a mongo observer for use with sacred. Args: ex (sacred.Experiment): Sacred experiment to track. db_name (str): Name of the mongo database. url (str): Host location. """ ex.observers.append(MongoObserver(url=url, db_name=db_name))
def observer_from_env(resume_key: str = None): user = os.environ['SACRED_USER'] password = os.environ['SACRED_PASSWORD'] database = os.environ['SACRED_DATABASE'] host = os.environ['SACRED_HOST'] url = ( f'mongodb+srv://{user}:{password}@{host}/{database}?retryWrites=true&' 'w=majority') if resume_key is None: return MongoObserver(url, db_name=database) return ResumableMongoObserver(resume_key, url, db_name=database)
def get_sacred_experiment(name, observer='mongo', capture_output=True): ex = sacred.Experiment(name) if observer == 'mongo': ex.observers.append( MongoObserver(url='mongodb://{{cookiecutter.mongo_user}}:' '{{cookiecutter.mongo_password}}@127.0.0.1:27017', db_name='sacred')) else: ex.observers.append(FileStorageObserver('data/sacred/')) if not capture_output: SETTINGS.CAPTURE_MODE = 'no' return ex
def default_observers(self): observers = [] if socket.gethostname() in self.mongo_hostnames: observers.append( MongoObserver( url= f"mongodb://*****:*****@localhost:27017/?authMechanism=SCRAM-SHA-1", db_name="db", )) observers.append( FileStorageObserver( self.exp_config.get("storage_dir", "./sacred_storage"))) return observers
def config(): algorithm = "mcts" # {"mcts", "lfd", "lfd-mcts", "acer", "greedy", "random", "truth", "mle", "beamsearch"} policy = "nn" # {"nn", "random", "likelihood"} teacher = "truth" # {"truth", "mle"} env_type = "1d" # for now, only {"1d"} seed = 24927 # 1971248 was used for first round name = f"{algorithm}_{policy}" if algorithm == "mcts" else algorithm run_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}_{seed}" database = True ex.observers.append(FileStorageObserver(f"./data/runs/{run_name}")) if database: ex.observers.append(MongoObserver())
def create_experiment(name='exp', database=None): """ Create Scared experiment object for experiment logging """ ex = Experiment(name) atlas_user = os.environ.get('MONGO_DB_USER') atlas_password = os.environ.get('MONGO_DB_PASS') atlas_host = os.environ.get('MONGO_DB_HOST') # Add remote MongoDB observer, only if environment variables are set if atlas_user and atlas_password and atlas_host: ex.observers.append( MongoObserver( url=f"mongodb+srv://{atlas_user}:{atlas_password}@{atlas_host}", db_name=database)) return ex
def to_mongo(self, base_dir, remove_sources=False, overwrite=None, *args, **kwargs): """ Exports the file log into a mongo database. Requires sacred to be installed. Args: base_dir: root path to sources remove_sources: if sources are too complicated to match overwrite: whether to overwrite an experiment *args: args of the MongoObserver **kwargs: keyword args of the MongoObserver """ from sacred.observers import MongoObserver observer = MongoObserver(*args, overwrite=overwrite, **kwargs) self.export(observer, base_dir, remove_sources, overwrite)
def run(config, mongodb_url="localhost:27017"): """Run a single experiment with the given configuration""" # Dynamically bind experiment config and main function ex = Experiment() ex.config(base_config) ex.main(canonical_puddle_world) # Attach MongoDB observer if necessary if not ex.observers: ex.observers.append(MongoObserver(url=mongodb_url)) # Suppress warnings about padded MPDs with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", category=PaddedMDPWarning) # Run the experiment run = ex.run(config_updates=config, options={"--loglevel": "ERROR"}) # Return the result return run.result
def evaluate(args, args_eval, model_file): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) ex = None if args_eval.sacred: from sacred import Experiment from sacred.observers import MongoObserver ex = Experiment(args_eval.sacred_name) ex.observers.append( MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME)) ex.add_config({ "batch_size": args.batch_size, "epochs": args.epochs, "learning_rate": args.learning_rate, "encoder": args.encoder, "num_objects": args.num_objects, "custom_neg": args.custom_neg, "in_ep_prob": args.in_ep_prob, "seed": args.seed, "dataset": args.dataset, "save_folder": args.save_folder, "eval_dataset": args_eval.dataset, "num_steps": args_eval.num_steps, "use_action_attention": args.use_action_attention }) device = torch.device('cuda' if args.cuda else 'cpu') dataset = utils.PathDatasetStateIds(hdf5_file=args.dataset, path_length=args_eval.num_steps) eval_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) # Get data sample obs = eval_loader.__iter__().next()[0] input_shape = obs[0][0].size() model = modules.ContrastiveSWM( embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, split_mlp=args.split_mlp, same_ep_neg=args.same_ep_neg, only_same_ep_neg=args.only_same_ep_neg, immovable_bit=args.immovable_bit, split_gnn=args.split_gnn, no_loss_first_two=args.no_loss_first_two, bisim_model=make_pairwise_encoder() if args.bisim_model_path else None, encoder=args.encoder, use_action_attention=args.use_action_attention).to(device) model.load_state_dict(torch.load(model_file)) model.eval() # topk = [1, 5, 10] topk = [1] hits_at = defaultdict(int) num_samples = 0 rr_sum = 0 pred_states = [] next_states = [] next_ids = [] with torch.no_grad(): for batch_idx, data_batch in enumerate(eval_loader): data_batch = [[t.to(device) for t in tensor] for tensor in data_batch] observations, actions, state_ids = data_batch if observations[0].size(0) != args.batch_size: continue obs = observations[0] next_obs = observations[-1] next_id = state_ids[-1] state = model.obj_encoder(model.obj_extractor(obs)) next_state = model.obj_encoder(model.obj_extractor(next_obs)) pred_state = state for i in range(args_eval.num_steps): pred_state = model.forward_transition(pred_state, actions[i]) pred_states.append(pred_state.cpu()) next_states.append(next_state.cpu()) next_ids.append(next_id.cpu().numpy()) pred_state_cat = torch.cat(pred_states, dim=0) next_state_cat = torch.cat(next_states, dim=0) next_ids_cat = np.concatenate(next_ids, axis=0) full_size = pred_state_cat.size(0) # Flatten object/feature dimensions next_state_flat = next_state_cat.view(full_size, -1) pred_state_flat = pred_state_cat.view(full_size, -1) dist_matrix = utils.pairwise_distance_matrix(next_state_flat, pred_state_flat) #num_digits = 1 #dist_matrix = (dist_matrix * 10 ** num_digits).round() / (10 ** num_digits) #dist_matrix = dist_matrix.float() dist_matrix_diag = torch.diag(dist_matrix).unsqueeze(-1) dist_matrix_augmented = torch.cat([dist_matrix_diag, dist_matrix], dim=1) # Workaround to get a stable sort in numpy. dist_np = dist_matrix_augmented.numpy() indices = [] for row in dist_np: keys = (np.arange(len(row)), row) indices.append(np.lexsort(keys)) indices = np.stack(indices, axis=0) if args_eval.dedup: mask_mistakes = indices[:, 0] != 0 closest_next_ids = next_ids_cat[indices[:, 0] - 1] if len(next_ids_cat.shape) == 2: equal_mask = np.all(closest_next_ids == next_ids_cat, axis=1) else: equal_mask = closest_next_ids == next_ids_cat indices[:, 0][np.logical_and(equal_mask, mask_mistakes)] = 0 indices = torch.from_numpy(indices).long() #print('Processed {} batches of size {}'.format( # batch_idx + 1, args.batch_size)) labels = torch.zeros(indices.size(0), device=indices.device, dtype=torch.int64).unsqueeze(-1) num_samples += full_size #print('Size of current topk evaluation batch: {}'.format( # full_size)) for k in topk: match = indices[:, :k] == labels num_matches = match.sum() hits_at[k] += num_matches.item() match = indices == labels _, ranks = match.max(1) reciprocal_ranks = torch.reciprocal(ranks.double() + 1) rr_sum += reciprocal_ranks.sum().item() pred_states = [] next_states = [] next_ids = [] hits = hits_at[topk[0]] / float(num_samples) mrr = rr_sum / float(num_samples) if ex is not None: # ugly hack @ex.main def sacred_main(): ex.log_scalar("hits", hits) ex.log_scalar("mrr", mrr) ex.run() print('Hits @ {}: {}'.format(topk[0], hits)) print('MRR: {}'.format(mrr)) return hits, mrr
from sacred.observers import MongoObserver, FileStorageObserver from sacred import Experiment import json from sklearn.preprocessing import StandardScaler from sklearn.base import BaseEstimator from mne.time_frequency import psd_multitaper ex = Experiment("experiment") ex.observers.append(FileStorageObserver('my_runs')) ex.observers.append( MongoObserver( url='mongodb://*****:*****@localhost:27017', db_name='sacred', )) train_filepaths = list( Path("/home/paulo/Documents/datasets/BCI_Comp_IV_2a/gdf").glob( "*0[1,3,4,7,8]T.gdf")) test_filepaths = list( Path("/home/paulo/Documents/datasets/BCI_Comp_IV_2a/gdf").glob( "*0[2,9,6,5]T.gdf")) ICA_N_COMPONENTS = None CSP_N_COMPONENTS = 12 PICKS = ["EEG-C3", "EEG-C4", "EEG-Cz"] PSD_PICKS = [ "ICA{}".format(str(i).rjust(3, "0")) for i in range( len(PICKS) if ICA_N_COMPONENTS is None else ICA_N_COMPONENTS)
from sacred import Experiment from sacred.observers import MongoObserver from ....run.online.fruits_seq.RunDQN import RunDQN from ....constants import Constants from ....utils.logger import Logger from ....utils import sacred_utils from .... import constants from .... import paths ex = Experiment("fruits_seq_DQN") if constants.MONGO_URI is not None and constants.DB_NAME is not None: ex.observers.append( MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME)) else: print( "WARNING: results are not being saved. See 'Setup MongoDB' in README.") ex.add_config(paths.CFG_ONLINE_FRUITS_DQN) @ex.config def config(): num_expert_steps = 0 num_random_steps = 0 num_pretraining_steps = 0 @ex.automain def main(dueling, double_learning, prioritized_replay, learning_rate, weight_decay, discount, goal, batch_size, max_steps, max_episodes, exploration_steps, prioritized_replay_max_steps, buffer_size,
# encoding: utf-8 from sacred import Experiment from sacred.observers import MongoObserver import numpy as np from sklearn.linear_model import Ridge from sklearn.metrics import mean_squared_error ex = Experiment("linear 1") obv = MongoObserver(url="localhost", port=8888, db_name="ml") ex.observers.append(obv) @ex.config def config(): alpha = 1.0 @ex.automain def run(): x = np.random.rand(200, 8) y = np.random.rand(200) model = Ridge(alpha=1.0) model.fit(x, y) mse = mean_squared_error(model.predict(x), y) ex.log_scalar("mse", mse) return float(mse)
def evaluate(args, args_eval, model_file): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) ex = None if args_eval.sacred: from sacred import Experiment from sacred.observers import MongoObserver ex = Experiment(args_eval.sacred_name) ex.observers.append(MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME)) ex.add_config({ "batch_size": args.batch_size, "epochs": args.epochs, "learning_rate": args.learning_rate, "encoder": args.encoder, "num_objects": args.num_objects, "custom_neg": args.custom_neg, "in_ep_prob": args.in_ep_prob, "seed": args.seed, "dataset": args.dataset, "save_folder": args.save_folder, "eval_dataset": args_eval.dataset, "num_steps": args_eval.num_steps, "use_action_attention": args.use_action_attention }) device = torch.device('cuda' if args.cuda else 'cpu') dataset = utils.PathDatasetStateIds( hdf5_file=args.dataset, path_length=10) eval_loader = data.DataLoader( dataset, batch_size=100, shuffle=False, num_workers=4) # Get data sample obs = eval_loader.__iter__().next()[0] input_shape = obs[0][0].size() model = modules.ContrastiveSWM( embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, split_mlp=args.split_mlp, same_ep_neg=args.same_ep_neg, only_same_ep_neg=args.only_same_ep_neg, immovable_bit=args.immovable_bit, split_gnn=args.split_gnn, no_loss_first_two=args.no_loss_first_two, bisim_model=make_pairwise_encoder() if args.bisim_model_path else None, encoder=args.encoder, use_action_attention=args.use_action_attention ).to(device) model.load_state_dict(torch.load(model_file)) model.eval() hits_list = [] with torch.no_grad(): for batch_idx, data_batch in enumerate(eval_loader): data_batch = [[t.to( device) for t in tensor] for tensor in data_batch] observations, actions, state_ids = data_batch if observations[0].size(0) != args.batch_size: continue states = [] for obs in observations: states.append(model.obj_encoder(model.obj_extractor(obs))) states = torch.stack(states, dim=0) state_ids = torch.stack(state_ids, dim=0) pred_state = states[0] if not args_eval.no_transition: for i in range(args_eval.num_steps): pred_state = model.forward_transition(pred_state, actions[i]) # pred_state: [100, |O|, D] # states: [10, 100, |O|, D] # pred_state_flat: [100, X] # states_flat: [10, 100, X] pred_state_flat = pred_state.reshape((pred_state.size(0), pred_state.size(1) * pred_state.size(2))) states_flat = states.reshape((states.size(0), states.size(1), states.size(2) * states.size(3))) # dist_matrix: [10, 100] dist_matrix = (states_flat - pred_state_flat[None]).pow(2).sum(2) indices = torch.argmin(dist_matrix, dim=0) correct = indices == args_eval.num_steps # print(indices[0], args_eval.num_steps) # observations = torch.stack(observations, dim=0) # correct_obs = observations[args_eval.num_steps, 0] # pred_obs = observations[indices[0], 0] # import matplotlib # matplotlib.use("TkAgg") # import matplotlib.pyplot as plt # plt.subplot(1, 2, 1) # plt.imshow(correct_obs.cpu().numpy()[3:].transpose((1, 2, 0))) # plt.subplot(1, 2, 2) # plt.imshow(pred_obs.cpu().numpy()[3:].transpose((1, 2, 0))) # plt.show() # check for duplicates if args_eval.dedup: equal_mask = torch.all(state_ids[indices, list(range(100))] == state_ids[args_eval.num_steps], dim=1) correct = correct + equal_mask # hits hits_list.append(correct.float().mean().item()) hits = np.mean(hits_list) if ex is not None: # ugly hack @ex.main def sacred_main(): ex.log_scalar("hits", hits) ex.log_scalar("mrr", 0.) ex.run() print('Hits @ 1: {}'.format(hits)) return hits, 0.
import frcnn_utils import init_frcnn_utils from experiments.exp_utils import get_config_var, LoggerForSacred, Args from init_frcnn_utils import init_dataloaders_1s_1t, init_val_dataloaders_mt, init_val_dataloaders_1t, \ init_htcn_model_optimizer from sacred import Experiment ex = Experiment() from sacred.observers import MongoObserver enable_mongo_observer = False if enable_mongo_observer: vars = get_config_var() ex.observers.append( MongoObserver( url='mongodb://{}:{}@{}/admin?authMechanism=SCRAM-SHA-1'.format( vars["SACRED_USER"], vars["SACRED_PWD"], vars["SACRED_URL"]), db_name=vars["SACRED_DB"])) ex.captured_out_filter = lambda text: 'Output capturing turned off.' from dataclasses import dataclass import numpy as np import torch import torch.nn as nn from model.utils.config import cfg, cfg_from_file, cfg_from_list from model.utils.net_utils import adjust_learning_rate, save_checkpoint, FocalLoss, EFocalLoss from model.utils.parser_func import set_dataset_args
import os, site, socket, sys from sacred import Experiment from sacred.observers import MongoObserver from sacred.utils import apply_backspaces_and_linefeeds from derive_conceptualspace.settings import get_setting from fb_classifier.preprocess_data import preprocess_data, create_traintest from fb_classifier.dataset import load_data from fb_classifier.train import TrainPipeline # from src.fb_classifier.util.misc import get_all_debug_confs, clear_checkpoints_summary from fb_classifier.util.misc import get_all_configs from fb_classifier.settings import CLASSIFIER_CHECKPOINT_PATH, SUMMARY_PATH, MONGO_URI, DATA_BASE, DEBUG import fb_classifier ex = Experiment("Fachbereich_Classifier") ex.observers.append(MongoObserver(url=MONGO_URI, db_name=os.environ["MONGO_DATABASE"])) ex.captured_out_filter = apply_backspaces_and_linefeeds ex.add_config(get_all_configs(as_string=False)) ex.add_config(DEBUG=DEBUG) for pyfile in [join(path, name) for path, subdirs, files in os.walk(dirname(fb_classifier.__file__)) for name in files if splitext(name)[1] == ".py"]: ex.add_source_file(pyfile) @ex.main def run_experiment(_run): args = parse_command_line_args() setup_logging(args.loglevel, args.logfile) # if args.restart: #delete checkpoint- and summary-dir # clear_checkpoints_summary() if args.no_continue or not os.listdir(CLASSIFIER_CHECKPOINT_PATH):
from sacred import Experiment from sacred.observers import MongoObserver ex = Experiment('hello_config') # jupyter execute # @ex.automain (x), @ex.main (o) # ex = Experiment('hello_config', interactive=True) # MongoObserver ex.observers.append(MongoObserver(url='localhost:27017', db_name='MY_DB')) @ex.config def my_config(): recipient = "world" message = "Hello %s!" % recipient @ex.automain def my_main(message): print(message)
# -*- coding: utf-8 -*- from network import Network from PIL import Image from scale import size, load_batches import numpy as np import os import sys from sacred import Experiment from sacred.observers import MongoObserver from sacred.utils import apply_backspaces_and_linefeeds from gpu_helpers import init_all_gpu init_all_gpu() ex = Experiment('Superresolution', ingredients=[]) ex.observers.append(MongoObserver()) ex.captured_out_filter = apply_backspaces_and_linefeeds @ex.config def my_config(): image_size = (320, 240) batch_size = 5 no_epochs = 500 lr = 0.0001 lr_stair_width = 10 lr_decay = 0.95 @ex.capture def log_training_performance(_run, loss, lr): _run.log_scalar("loss", float(loss))
conf_name_base = backbone_name oof_file = f'./output/stacking/{version}_{host_name[:5]}_s{best_score:6.5f}_{conf_name_base}_{conf.model_type}_f{valid_fold}_val{val_len}_trn{train_len}.h5' print(f'Stacking file save to:{oof_file}') save_stack_feature(oof_val, oof_test, oof_file) ###### sacred begin from sacred import Experiment from easydict import EasyDict as edict from sacred.observers import MongoObserver from sacred import SETTINGS ex = Experiment('lung') db_url = 'mongodb://*****:*****@10.10.20.103:27017/db?authSource=admin' ex.observers.append(MongoObserver(url=db_url, db_name='db')) #SETTINGS.CAPTURE_MODE = 'sys' @ex.config def my_config(): conf_name = None fold = -1 image_size = 200 model_type = 'raw' @ex.command() def main(_config): config = edict(_config) print('=====', config) train(config)
import tensorflow as tf import tqdm import ubelt as ub from sacred import Experiment from sacred.observers import MongoObserver import utils from EvalNet import EvalNet from GGNNPolyModel import GGNNPolygonModel from PolygonModel import PolygonModel from utils import get_all_files, save_to_json from vis_predictions import vis_single ex = Experiment() ex.observers.append(MongoObserver(url="localhost:27017", db_name="sacred")) # tf.logging.set_verbosity(tf.logging.INFO) # -- flags = tf.flags FLAGS = flags.FLAGS # --- flags.DEFINE_string( "PolyRNN_metagraph", "models/poly/polygonplusplus.ckpt.meta", "PolygonRNN++ MetaGraph ", ) flags.DEFINE_string("PolyRNN_checkpoint", "models/poly/polygonplusplus.ckpt", "PolygonRNN++ checkpoint ") flags.DEFINE_string("EvalNet_checkpoint", "models/evalnet/evalnet.ckpt", "Evaluator checkpoint ")
import joblib from data import CATEGORY_IDS from data import GraphDataset, TextGraphDataset, GloVeTokenizer import models import utils OUT_PATH = 'output/' device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') ex = Experiment() ex.logger = utils.get_logger() # Set up database logs uri = os.environ.get('DB_URI') database = os.environ.get('DB_NAME') if all([uri, database]): ex.observers.append(MongoObserver(uri, database)) @ex.config def config(): dataset = 'entities' inductive = True dim = 128 model = 'blp' rel_model = 'complex' loss_fn = 'margin' encoder_name = 'bert-base-cased' regularizer = 1e-2 max_len = 32 num_negatives = 64 lr = 2e-5
from sacred import Experiment from sacred.observers import MongoObserver import amci.utils ex = Experiment() if __name__ == '__main__': with open( os.path.join(os.path.dirname(os.path.abspath(__file__)), "config_sacred_mongodb.yaml"), 'r') as stream: config_sacred_mongodb = yaml.safe_load(stream) if config_sacred_mongodb['sacred_mongo_url']: ex.observers.append( MongoObserver( url=config_sacred_mongodb['sacred_mongo_url'], db_name=config_sacred_mongodb['sacred_mongo_db_name'])) print( f"Adding Sacred MongoDB observer at {config_sacred_mongodb['sacred_mongo_url']}." ) def train(args, _run, _writer): model = amci.utils.get_model(args) q = model.Training.get_proposal_model(args) optimizer = torch.optim.Adam(q.parameters(), lr=args.learning_rate) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=args.scheduler_patience,
mongo_enabled = os.environ.get('MONGO_SACRED_ENABLED') mongo_user = os.environ.get('MONGO_SACRED_USER') mongo_pass = os.environ.get('MONGO_SACRED_PASS') mongo_host = os.environ.get('MONGO_SACRED_HOST') mongo_port = os.environ.get('MONGO_SACRED_PORT', '27017') if mongo_enabled == 'true': assert mongo_user, 'Setting $MONGO_USER is required' assert mongo_pass, 'Setting $MONGO_PASS is required' assert mongo_host, 'Setting $MONGO_HOST is required' mongo_url = 'mongodb://{0}:{1}@{2}:{3}/' \ 'sacred?authMechanism=SCRAM-SHA-1'.format(mongo_user, mongo_pass, mongo_host, mongo_port) ex.observers.append(MongoObserver(url=mongo_url, db_name='sacred')) @ex.config def cfg(): n_epochs = 20 lr = 2e-5 batch_size = 16 base_model = "distilbert-base-uncased" clustering_loss_weight = 1.0 embedding_extractor = concat_cls_n_hidden_states annealing_alphas = np.arange(1, n_epochs + 1) dataset = "../datasets/ags_news/ag_news.csv" train_idx_file = "../datasets/ag_news/splits/train" result_dir = f"../results/ag_news-distilbert/{strftime('%Y-%m-%d_%H:%M:%S', gmtime())}" early_stopping = True
from utils import calc_loss,load_ckp register( id='SimpleHighway-v1', entry_point='env.simple_highway.simple_highway_env:SimpleHighway', ) experiment_name = "driving_behavior" sacred_ex = Experiment(experiment_name) now = datetime.datetime.now() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: sacred_ex.observers.append(MongoObserver(url='localhost:27017', db_name='sacred')) except ConnectionError : print("MongoDB instance should be running") @sacred_ex.config def dqn_cfg(): seed = 123523 num_lane = 3 level_k = 2 agent_level_k = level_k -1 TRAIN = True LOAD_SAVED_MODEL = False MODEL_PATH_FINAL = "best_"+str(agent_level_k)