def train_fetch(experiment: sacred.Experiment, agent: Any, eval_env: FetchEnv, progressive_noise: bool, small_goal: bool):
    reporting.register_field("eval_success_rate")
    reporting.register_field("action_norm")
    reporting.finalize_fields()
    if progressive_noise:
        trange = tqdm.trange(2000000)
    elif small_goal:
        trange = tqdm.trange(2000000)
    else:
        trange = tqdm.trange(2000000)
    for iteration in trange:
        if iteration % 10000 == 0:
            action_norms = []
            success_rate = 0
            for i in range(50):
                state = eval_env.reset()
                while not eval_env.needs_reset:
                    action = agent.eval_action(state)
                    action_norms.append(np.linalg.norm(action))
                    state, reward, is_terminal, info = eval_env.step(action)
                    if reward > -1.:
                        success_rate += 1
                        break
            reporting.iter_record("eval_success_rate", success_rate)
            reporting.iter_record("action_norm", np.mean(action_norms).item())

        if iteration % 20000 == 0:
            policy_path = f"/tmp/policy_{iteration}"
            with open(policy_path, 'wb') as f:
                torch.save(agent.freeze_policy(torch.device('cpu')), f)
            experiment.add_artifact(policy_path)

        agent.update()
        reporting.iterate()
        trange.set_description(f"{iteration} -- " + reporting.get_description(["return", "td_loss", "env_steps"]))
示例#2
0
def save_lang_to_idx(lang_to_idx: dict, ex: Experiment):
    """Saves the lang_to_idx dict as an artifact

    Arguments:
        lang_to_idx {dict} -- The dict to save in a file
    """
    tmpf = tempfile.NamedTemporaryFile(dir="", delete=False, suffix=".pkl")
    pickle.dump(lang_to_idx, tmpf)
    tmpf.flush()
    ex.add_artifact(tmpf.name, "lang_to_idx.pkl")
    tmpf.close()
    os.unlink(tmpf.name)
示例#3
0
def save_probs(pred_prob, ex: Experiment, file_ending=""):
    """Saves probabilities as a .npy file and adds it as artifact

    Arguments:
        pred_prob  -- list or numpy array to save as .npy file
    """
    tmpf = tempfile.NamedTemporaryFile(delete=False, suffix=".npy")
    np.save(tmpf.name, pred_prob)
    fname = "prediction_probabilities" + file_ending + ".npy"
    ex.add_artifact(tmpf.name, fname)
    tmpf.close()
    os.unlink(tmpf.name)
示例#4
0
class SacredExperiment(object):
    def __init__(
        self,
        experiment_name,
        experiment_dir,
        observer_type="file_storage",
        mongo_url=None,
        db_name=None,
    ):
        """__init__

        :param experiment_name: The name of the experiments.
        :param experiment_dir:  The directory to store all the results of the experiments(This is for file_storage).
        :param observer_type:   The observer to record the results: the `file_storage` or `mongo`
        :param mongo_url:       The mongo url(for mongo observer)
        :param db_name:         The mongo url(for mongo observer)
        """
        self.experiment_name = experiment_name
        self.experiment = Experiment(self.experiment_name)
        self.experiment_dir = experiment_dir
        self.experiment.logger = get_module_logger("Sacred")

        self.observer_type = observer_type
        self.mongo_db_url = mongo_url
        self.mongo_db_name = db_name

        self._setup_experiment()

    def _setup_experiment(self):
        if self.observer_type == "file_storage":
            file_storage_observer = FileStorageObserver.create(
                basedir=self.experiment_dir)
            self.experiment.observers.append(file_storage_observer)
        elif self.observer_type == "mongo":
            mongo_observer = MongoObserver.create(url=self.mongo_db_url,
                                                  db_name=self.mongo_db_name)
            self.experiment.observers.append(mongo_observer)
        else:
            raise NotImplementedError("Unsupported observer type: {}".format(
                self.observer_type))

    def add_artifact(self, filename):
        self.experiment.add_artifact(filename)

    def add_info(self, key, value):
        self.experiment.info[key] = value

    def main_wrapper(self, func):
        return self.experiment.main(func)

    def config_wrapper(self, func):
        return self.experiment.config(func)
示例#5
0
    return data_path, full_path


@ex.capture
def save_and_add_artifact(path, arr):
    np.save(path, arr)
    ex.add_artifact(path)


@ex.capture
def save_table(case, WV_path):
    case.table()
    latex_table = open(WV_path + '_latex_table.txt', 'w')
    latex_table.write(case.table_latex)
    latex_table.close()
    ex.add_artifact(WV_path + '_latex_table.txt')


@ex.capture
def log_variables(results, Q, RT):
    Q = np.append(Q, results.Q, axis=0)
    RT = np.append(RT, results.rec_time)
    return Q, RT


#%%
@ex.automain
def main(nTD, nVD, train, bpath, stop_crit, specifics):
    # Specific phantom
    # %%
    t1 = time.time()
示例#6
0
    total_time += t1-t0
    _run.info['time'][''] = t1 - t0
    _run.info['matrix_shape'] = (mat.width, mat.height)

    t0 = time.time()
    mat.get_USV(k)
    mat.compute_amat()
    mat.compute_candidates(show=False, percent=percent, width=width)
    t1 = time.time()
    total_time += t1-t0
    _run.info['time']['SVD'] = t1 - t0
    _run.info['candidates_num'] = len(mat.candidates)

    plot.dots(mat, mat.candidates)
    plt.savefig('')
    ex.add_artifact()

    return
    t0 = time.time()
    mat.pagerank(alpha=alpha, trustrank=True, inverse=True, show=False)
    inv_rank = mat.rank
    inv_rank_norm = mat.rank_norm
    mat.pagerank(alpha=alpha, trustrank=True, inverse=False, show=False)
    rank = mat.rank
    rank_norm = mat.rank_norm
    rank_max = {}
    for key in rank_norm.keys():
        rank_max[key] = max(rank_norm[key], inv_rank_norm[key])
    t1 = time.time()
    total_time += t1-t0
    _run.info['time']['pagerank'] = t1-t0