예제 #1
0
 def test_concretize(self):
     assert concretize([5]) == [5]
     assert concretize((5, )) == (5, )
     fake_globals = dict(TestClass=TestClass)
     modified = fake_globals['TestClass']
     assert concretize((5, )) == (5, )
     assert isinstance(concretize(modified()), TestClass)
예제 #2
0
 def test_concretize():
     it.assertEqual(concretize([5]), [5])
     it.assertEqual(concretize((5, )), (5, ))
     fake_globals = dict(TestClass=TestClass)
     modified = fake_globals["TestClass"]
     it.assertEqual(concretize((5, )), (5, ))
     it.assertIsInstance(concretize(modified()), TestClass)
예제 #3
0
 def test_chained_call(self):
     fake_globals = dict(TestClass=TestClass)
     modified = fake_globals['TestClass']
     assert concretize(modified().arr[0]) == 1
예제 #4
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=("Number of parallel workers to perform rollouts. "
              "0 => don't start any workers"))
    parser.add_argument(
        '--exp_name',
        type=str,
        default=default_exp_name,
        help='Name of the experiment.')
    parser.add_argument(
        '--log_dir',
        type=str,
        default=None,
        help='Path to save the log and iteration snapshot.')
    parser.add_argument(
        '--snapshot_mode',
        type=str,
        default='all',
        help='Mode to save the snapshot. Can be either "all" '
        '(all iterations will be saved), "last" (only '
        'the last iteration will be saved), "gap" (every'
        '`snapshot_gap` iterations are saved), or "none" '
        '(do not save snapshots)')
    parser.add_argument(
        '--snapshot_gap',
        type=int,
        default=1,
        help='Gap between snapshot iterations.')
    parser.add_argument(
        '--tabular_log_file',
        type=str,
        default='progress.csv',
        help='Name of the tabular log file (in csv).')
    parser.add_argument(
        '--text_log_file',
        type=str,
        default='debug.log',
        help='Name of the text log file (in pure text).')
    parser.add_argument(
        '--tensorboard_step_key',
        type=str,
        default=None,
        help=("Name of the step key in tensorboard_summary."))
    parser.add_argument(
        '--params_log_file',
        type=str,
        default='params.json',
        help='Name of the parameter log file (in json).')
    parser.add_argument(
        '--variant_log_file',
        type=str,
        default='variant.json',
        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument(
        '--plot',
        type=ast.literal_eval,
        default=False,
        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help='Print only the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument(
        '--args_data', type=str, help='Pickled data for objects')
    parser.add_argument(
        '--variant_data',
        type=str,
        help='Pickled data for variant configuration')
    parser.add_argument(
        '--use_cloudpickle', type=ast.literal_eval, default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    # SIGINT is blocked for all processes created in parallel_sampler to avoid
    # the creation of sleeping and zombie processes.
    #
    # If the user interrupts run_experiment, there's a chance some processes
    # won't die due to a dead lock condition where one of the children in the
    # parallel sampler exits without releasing a lock once after it catches
    # SIGINT.
    #
    # Later the parent tries to acquire the same lock to proceed with his
    # cleanup, but it remains sleeping waiting for the lock to be released.
    # In the meantime, all the process in parallel sampler remain in the zombie
    # state since the parent cannot proceed with their clean up.
    with mask_signals([signal.SIGINT]):
        if args.n_parallel > 0:
            parallel_sampler.initialize(n_parallel=args.n_parallel)
            if args.seed is not None:
                parallel_sampler.set_seed(args.seed)

    if not args.plot:
        garage.plotter.Plotter.disable()
        garage.tf.plotter.Plotter.disable()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    logger.set_tensorboard_dir(log_dir)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.set_tensorboard_step_key(args.tensorboard_step_key)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            try:
                method_call(variant_data)
            except BaseException:
                children = garage.plotter.Plotter.get_plotters()
                children += garage.tf.plotter.Plotter.get_plotters()
                if args.n_parallel > 0:
                    children += [parallel_sampler]
                child_proc_shutdown(children)
                raise
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
예제 #5
0
 def test_chained_call():
     fake_globals = dict(TestClass=TestClass)
     modified = fake_globals["TestClass"]
     it.assertEqual(concretize(modified().arr[0]), 1)