コード例 #1
0
ファイル: base.py プロジェクト: anonyfizcu1/UP
def flag_defaults(FLAGS, load_log_flags=False):
    if load_log_flags:
        if FLAGS.load_log_path and os.path.exists(log_path(FLAGS, load=True)):
            log_flags = parse_flags(log_path(FLAGS, load=True))
            for k in log_flags.keys():
                setattr(FLAGS, k, log_flags[k])

            # Optionally override flags from log file.
            FLAGS(sys.argv)

    if not FLAGS.experiment_name:
        timestamp = str(int(time.time()))
        FLAGS.experiment_name = "{}-{}-{}".format(
            FLAGS.data_type,
            FLAGS.model_type,
            timestamp,
        )

    if not FLAGS.git_branch_name:
        FLAGS.git_branch_name = os.popen(
            'git rev-parse --abbrev-ref HEAD').read().strip()

    if not FLAGS.git_sha:
        FLAGS.git_sha = os.popen('git rev-parse HEAD').read().strip()

    if not FLAGS.slurm_job_id:
        FLAGS.slurm_job_id = os.popen('echo $SLURM_JOB_ID').read().strip()

    if not FLAGS.load_log_path:
        FLAGS.load_log_path = FLAGS.log_path

    if not FLAGS.load_experiment_name:
        FLAGS.load_experiment_name = FLAGS.experiment_name

    if not FLAGS.ckpt_path:
        FLAGS.ckpt_path = FLAGS.load_log_path

    if not FLAGS.sample_interval_steps:
        FLAGS.sample_interval_steps = FLAGS.statistics_interval_steps

    if not FLAGS.metrics_path:
        FLAGS.metrics_path = FLAGS.log_path

    if FLAGS.model_type == "CBOW" or FLAGS.model_type == "RNN" or FLAGS.model_type == "Pyramid" or FLAGS.model_type == "ChoiPyramid":
        FLAGS.num_samples = 0

    if not torch.cuda.is_available():
        FLAGS.gpu = -1

    if FLAGS.full_trees:
        print('Using deprecated flag full_trees. Use transition_mode instead.')
        assert FLAGS.transition_mode == 'default', 'If full_trees is set, then do not use transition_mode.'
        FLAGS.transition_mode = 'full'
コード例 #2
0
ファイル: base.py プロジェクト: sayingandparsing/spinn
def flag_defaults(FLAGS, load_log_flags=False):
    if load_log_flags:
        if FLAGS.load_log_path and os.path.exists(log_path(FLAGS, load=True)):
            log_flags = parse_flags(log_path(FLAGS, load=True))
            for k in list(log_flags.keys()):
                setattr(FLAGS, k, log_flags[k])

            # Optionally override flags from log file.
            FLAGS(sys.argv)

    if not FLAGS.experiment_name:
        timestamp = str(int(time.time()))
        FLAGS.experiment_name = "{}-{}-{}".format(
            FLAGS.data_type,
            FLAGS.model_type,
            timestamp,
        )

    if not FLAGS.git_branch_name:
        FLAGS.git_branch_name = os.popen(
            'git rev-parse --abbrev-ref HEAD').read().strip()

    if not FLAGS.git_sha:
        FLAGS.git_sha = os.popen('git rev-parse HEAD').read().strip()

    if not FLAGS.slurm_job_id:
        FLAGS.slurm_job_id = os.popen('echo $SLURM_JOB_ID').read().strip()

    if not FLAGS.load_log_path:
        FLAGS.load_log_path = FLAGS.log_path

    if not FLAGS.load_experiment_name:
        FLAGS.load_experiment_name = FLAGS.experiment_name

    if not FLAGS.ckpt_path:
        FLAGS.ckpt_path = FLAGS.load_log_path

    if not FLAGS.sample_interval_steps:
        FLAGS.sample_interval_steps = FLAGS.statistics_interval_steps

    if FLAGS.model_type in [
            "CBOW", "RNN", "ChoiPyramid", "LMS", "Maillard", "CatalanPyramid"
    ]:
        FLAGS.num_samples = 0

    if FLAGS.model_type == "LMS":
        FLAGS.reduce = "lms"

    if not torch.cuda.is_available():
        FLAGS.gpu = -1
コード例 #3
0
def flag_defaults(FLAGS, load_log_flags=False):
    if load_log_flags:
        if FLAGS.load_log_path and os.path.exists(log_path(FLAGS, load=True)):
            log_flags = parse_flags(log_path(FLAGS, load=True))
            for k in log_flags.keys():
                setattr(FLAGS, k, log_flags[k])

            # Optionally override flags from log file.
            FLAGS(sys.argv)

    if not FLAGS.experiment_name:
        timestamp = str(int(time.time()))
        FLAGS.experiment_name = "{}-{}-{}".format(
            FLAGS.data_type,
            FLAGS.model_type,
            timestamp,
        )

    if not FLAGS.branch_name:
        FLAGS.branch_name = os.popen(
            'git rev-parse --abbrev-ref HEAD').read().strip()

    if not FLAGS.sha:
        FLAGS.sha = os.popen('git rev-parse HEAD').read().strip()

    if not FLAGS.load_log_path:
        FLAGS.load_log_path = FLAGS.log_path

    if not FLAGS.load_experiment_name:
        FLAGS.load_experiment_name = FLAGS.experiment_name

    if not FLAGS.ckpt_path:
        FLAGS.ckpt_path = FLAGS.load_log_path

    if not FLAGS.sample_interval_steps:
        FLAGS.sample_interval_steps = FLAGS.statistics_interval_steps

    if not FLAGS.metrics_path:
        FLAGS.metrics_path = FLAGS.log_path

    if FLAGS.model_type == "CBOW" or FLAGS.model_type == "RNN":
        FLAGS.num_samples = 0

    if not torch.cuda.is_available():
        FLAGS.gpu = -1