예제 #1
0
def create_logger(args):
    from torch.utils.tensorboard import SummaryWriter
    """Use hyperparms to set a directory to output diagnostic files."""

    arg_dict = args.__dict__
    assert "seed" in arg_dict, \
    "You must provide a 'seed' key in your command line arguments"
    assert "logdir" in arg_dict, \
    "You must provide a 'logdir' key in your command line arguments."
    assert "env_name" in arg_dict, \
    "You must provide a 'env_name' key in your command line arguments."

    # sort the keys so the same hyperparameters will always have the same hash
    arg_dict = OrderedDict(sorted(arg_dict.items(), key=lambda t: t[0]))

    # remove seed so it doesn't get hashed, store value for filename
    # same for logging directory
    run_name = arg_dict.pop('run_name')
    seed = str(arg_dict.pop("seed"))
    logdir = str(arg_dict.pop('logdir'))
    env_name = str(arg_dict['env_name'])

    # see if this run has a unique name, if so then that is going to be the name of the folder, even if it overrirdes
    if run_name is not None:
        logdir = os.path.join(logdir, env_name)
        output_dir = os.path.join(logdir, run_name)
    else:
        # see if we are resuming a previous run, if we are mark as continued
        if args.previous is not None:
            if args.exchange_reward is not None:
                output_dir = args.previous[0:-1] + "_NEW-" + args.reward
            else:
                print(args.previous[0:-1])
                output_dir = args.previous[0:-1] + '-cont'
        else:
            # get a unique hash for the hyperparameter settings, truncated at 10 chars
            arg_hash   = hashlib.md5(str(arg_dict).encode('ascii')).hexdigest()[0:6] + '-seed' + seed
            logdir     = os.path.join(logdir, env_name)
            output_dir = os.path.join(logdir, arg_hash)

    # create a directory with the hyperparm hash as its name, if it doesn't
    # already exist.
    os.makedirs(output_dir, exist_ok=True)

    # Create a file with all the hyperparam settings in human-readable plaintext,
    # also pickle file for resuming training easily
    info_path = os.path.join(output_dir, "experiment.info")
    pkl_path = os.path.join(output_dir, "experiment.pkl")
    with open(pkl_path, 'wb') as file:
        pickle.dump(args, file)
    with open(info_path, 'w') as file:
        for key, val in arg_dict.items():
            file.write("%s: %s" % (key, val))
            file.write('\n')

    logger = SummaryWriter(output_dir, flush_secs=0.1) # flush_secs=0.1 actually slows down quite a bit, even on parallelized set ups
    print("Logging to " + color.BOLD + color.ORANGE + str(output_dir) + color.END)

    logger.dir = output_dir
    return logger
예제 #2
0
def create_logger(args):
    from torch.utils.tensorboard import SummaryWriter
    """Use hyperparms to set a directory to output diagnostic files."""

    arg_dict = args.__dict__
    #assert "seed" in arg_dict, \
    #  "You must provide a 'seed' key in your command line arguments"
    assert "logdir" in arg_dict, \
      "You must provide a 'logdir' key in your command line arguments."
    #assert "env" in arg_dict, \
    #  "You must provide a 'env' key in your command line arguments."

    # sort the keys so the same hyperparameters will always have the same hash
    arg_dict = OrderedDict(sorted(arg_dict.items(), key=lambda t: t[0]))

    # remove seed so it doesn't get hashed, store value for filename
    # same for logging directory
    if 'seed' in arg_dict:
        seed = str(arg_dict.pop("seed"))
    else:
        seed = None

    if 'env' in arg_dict:
        task_name = str(arg_dict.pop('env'))
    elif 'policy' in arg_dict:
        task_name = os.path.normpath(arg_dict.pop('policy')).split(os.path.sep)
        task_name = '-'.join([x for x in task_name[-4:-1]])

    logdir = str(arg_dict.pop('logdir'))

    # get a unique hash for the hyperparameter settings, truncated at 10 chars
    if seed is None:
        arg_hash = hashlib.md5(str(arg_dict).encode('ascii')).hexdigest()[0:6]
    else:
        arg_hash = hashlib.md5(
            str(arg_dict).encode('ascii')).hexdigest()[0:6] + '-seed' + seed

    logdir = os.path.join(logdir, task_name)
    output_dir = os.path.join(logdir, arg_hash)

    # create a directory with the hyperparm hash as its name, if it doesn't
    # already exist.
    os.makedirs(output_dir, exist_ok=True)

    # Create a file with all the hyperparam settings in plaintext
    info_path = os.path.join(output_dir, "experiment.info")
    file = open(info_path, 'w')
    for key, val in arg_dict.items():
        file.write("%s: %s" % (key, val))
        file.write('\n')

    logger = SummaryWriter(output_dir, flush_secs=0.1)
    print("Logging to " + color.BOLD + color.ORANGE + str(output_dir) +
          color.END)

    logger.taskname = task_name
    logger.dir = output_dir
    logger.arg_hash = arg_hash
    return logger