コード例 #1
0
ファイル: experiment.py プロジェクト: yolenan/garage
    def __call__(self, *args, **kwargs):
        """Wrap a function to turn it into an ExperimentTemplate.

        Note that this docstring will be overriden to match the function's
        docstring on the ExperimentTemplate once a function is passed in.

        Args:
            args (list): If no function has been set yet, must be a list
                containing a single callable. If the function has been set, may
                be a single value, a dictionary containing overrides for the
                original arguments to `wrap_experiment`.
            kwargs (dict): Arguments passed onto the wrapped function.

        Returns:
            object: The returned value of the wrapped function.

        Raises:
            ValueError: If not passed a single callable argument.

        """
        if self.function is None:
            if len(args) != 1 or len(kwargs) != 0 or not callable(args[0]):
                raise ValueError('Please apply the result of '
                                 'wrap_experiment() to a single function')
            # Apply ourselves as a decorator
            self.function = args[0]
            self._update_wrap_params()
            return self
        else:
            ctxt = self._make_context(self._get_options(*args), **kwargs)
            result = self.function(ctxt, **kwargs)
            logger.remove_all()
            logger.pop_prefix()
            gc.collect()  # See dowel issue #44
            return result
コード例 #2
0
ファイル: experiment_wrapper.py プロジェクト: qin1921/garage
def run_experiment(argv):
    """Run experiment.

    Args:
        argv (list[str]): Command line arguments.

    Raises:
        BaseException: Propagate any exception in the experiment.

    """
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=('Number of parallel workers to perform rollouts. '
              "0 => don't start any workers"))
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='last',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument(
        '--resume_from_dir',
        type=str,
        default=None,
        help='Directory of the pickle file to resume experiment from.')
    parser.add_argument('--resume_from_epoch',
                        type=str,
                        default=None,
                        help='Index of iteration to restore from. '
                        'Can be "first", "last" or a number. '
                        'Not applicable when snapshot_mode="last"')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--tensorboard_step_key',
                        type=str,
                        default=None,
                        help='Name of the step key in tensorboard_summary.')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help='Print only the tabular log information (in a horizontal format)')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        deterministic.set_seed(args.seed)

    if args.n_parallel > 0:
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if not args.plot:
        garage.plotter.Plotter.disable()
        garage.tf.plotter.Plotter.disable()

    if args.log_dir is None:
        log_dir = os.path.join(os.path.join(os.getcwd(), 'data'),
                               args.exp_name)
    else:
        log_dir = args.log_dir

    tabular_log_file = os.path.join(log_dir, args.tabular_log_file)
    text_log_file = os.path.join(log_dir, args.text_log_file)
    params_log_file = os.path.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = os.path.join(log_dir, args.variant_log_file)
        dump_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    log_parameters(params_log_file, args)

    logger.add_output(dowel.TextOutput(text_log_file))
    logger.add_output(dowel.CsvOutput(tabular_log_file))
    logger.add_output(dowel.TensorBoardOutput(log_dir, x_axis='TotalEnvSteps'))
    logger.add_output(dowel.StdOutput())

    logger.push_prefix('[%s] ' % args.exp_name)

    snapshot_config = SnapshotConfig(snapshot_dir=log_dir,
                                     snapshot_mode=args.snapshot_mode,
                                     snapshot_gap=args.snapshot_gap)

    method_call = cloudpickle.loads(base64.b64decode(args.args_data))
    try:
        method_call(snapshot_config, variant_data, args.resume_from_dir,
                    args.resume_from_epoch)
    except BaseException:
        children = garage.plotter.Plotter.get_plotters()
        children += garage.tf.plotter.Plotter.get_plotters()
        if args.n_parallel > 0:
            children += [parallel_sampler]
        child_proc_shutdown(children)
        raise

    logger.remove_all()
    logger.pop_prefix()
コード例 #3
0
ファイル: log_progress.py プロジェクト: ziyiwu9494/dowel
This example demonstrates how to log a simple progress metric using dowel.

The metric is simultaneously sent to the screen, a CSV files, a text log file
and TensorBoard.
"""
import time

import dowel
from dowel import logger, tabular

logger.add_output(dowel.StdOutput())
logger.add_output(dowel.CsvOutput('progress.csv'))
logger.add_output(dowel.TextOutput('progress.txt'))
logger.add_output(dowel.TensorBoardOutput('tensorboard_logdir'))

logger.log('Starting up...')
for i in range(1000):
    logger.push_prefix('itr {}: '.format(i))
    logger.log('Running training step')

    time.sleep(0.01)  # Tensorboard doesn't like output to be too fast.

    tabular.record('itr', i)
    tabular.record('loss', 100.0 / (2 + i))
    logger.log(tabular)

    logger.pop_prefix()
    logger.dump_all()

logger.remove_all()
コード例 #4
0
def run_experiment(argv):
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=('Number of parallel workers to perform rollouts. '
              "0 => don't start any workers"))
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='last',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument(
        '--resume_from_dir',
        type=str,
        default=None,
        help='Directory of the pickle file to resume experiment from.')
    parser.add_argument('--resume_from_epoch',
                        type=str,
                        default=None,
                        help='Index of iteration to restore from. '
                        'Can be "first", "last" or a number. '
                        'Not applicable when snapshot_mode="last"')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--tensorboard_step_key',
                        type=str,
                        default=None,
                        help='Name of the step key in tensorboard_summary.')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help='Print only the tabular log information (in a horizontal format)')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        deterministic.set_seed(args.seed)

    # SIGINT is blocked for all processes created in parallel_sampler to avoid
    # the creation of sleeping and zombie processes.
    #
    # If the user interrupts run_experiment, there's a chance some processes
    # won't die due to a dead lock condition where one of the children in the
    # parallel sampler exits without releasing a lock once after it catches
    # SIGINT.
    #
    # Later the parent tries to acquire the same lock to proceed with his
    # cleanup, but it remains sleeping waiting for the lock to be released.
    # In the meantime, all the process in parallel sampler remain in the zombie
    # state since the parent cannot proceed with their clean up.
    with mask_signals([signal.SIGINT]):
        if args.n_parallel > 0:
            parallel_sampler.initialize(n_parallel=args.n_parallel)
            if args.seed is not None:
                parallel_sampler.set_seed(args.seed)

    if not args.plot:
        garage.plotter.Plotter.disable()
        garage.tf.plotter.Plotter.disable()

    if args.log_dir is None:
        log_dir = os.path.join(os.path.join(os.getcwd(), 'data'),
                               args.exp_name)
    else:
        log_dir = args.log_dir

    tabular_log_file = os.path.join(log_dir, args.tabular_log_file)
    text_log_file = os.path.join(log_dir, args.text_log_file)
    params_log_file = os.path.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = os.path.join(log_dir, args.variant_log_file)
        dump_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        log_parameters(params_log_file, args)

    logger.add_output(dowel.TextOutput(text_log_file))
    logger.add_output(dowel.CsvOutput(tabular_log_file))
    logger.add_output(dowel.TensorBoardOutput(log_dir))
    logger.add_output(dowel.StdOutput())

    logger.push_prefix('[%s] ' % args.exp_name)

    snapshot_config = SnapshotConfig(snapshot_dir=log_dir,
                                     snapshot_mode=args.snapshot_mode,
                                     snapshot_gap=args.snapshot_gap)

    # read from stdin
    if args.use_cloudpickle:
        import cloudpickle
        method_call = cloudpickle.loads(base64.b64decode(args.args_data))
        try:
            method_call(snapshot_config, variant_data, args.resume_from_dir,
                        args.resume_from_epoch)
        except BaseException:
            children = garage.plotter.Plotter.get_plotters()
            children += garage.tf.plotter.Plotter.get_plotters()
            if args.n_parallel > 0:
                children += [parallel_sampler]
            child_proc_shutdown(children)
            raise
    else:
        data = pickle.loads(base64.b64decode(args.args_data))
        maybe_iter = concretize(data)
        if is_iterable(maybe_iter):
            for _ in maybe_iter:
                pass

    logger.remove_all()
    logger.pop_prefix()
コード例 #5
0
 def __exit__(self, type, value, traceback):
     logger.remove_all()
     logger.pop_prefix()