예제 #1
0
def build_config():

    ckpt_config = parallax.CheckPointConfig(
        ckpt_dir=FLAGS.ckpt_dir, save_ckpt_steps=calculate_ckpt_steps())
    ps_config = parallax.PSConfig(replicate_variables=FLAGS.replicate_variables,
                                  protocol=FLAGS.protocol,
                                  local_aggregation=FLAGS.local_aggregation,
                                  boundary_among_servers=FLAGS.boundary_among_servers,
                                  boundary_between_workers_and_servers=\
                                  FLAGS.boundary_between_workers_and_servers)
    mpi_config = parallax.MPIConfig(use_allgatherv=FLAGS.use_allgatherv,
                                    mpirun_options=FLAGS.mpirun_options)
    parallax_config = parallax.Config()
    parallax_config.run_option = FLAGS.run_option
    parallax_config.average_sparse = False
    parallax_config.communication_config = parallax.CommunicationConfig(
        ps_config, mpi_config)
    parallax_config.ckpt_config = ckpt_config

    def get_profile_steps():
        if not FLAGS.profile_steps:
            return []
        FLAGS.profile_steps = FLAGS.profile_steps.strip()
        return [int(step) for step in FLAGS.profile_steps.split(',')]

    profile_config = parallax.ProfileConfig(profile_dir=FLAGS.profile_dir,
                                            profile_steps=get_profile_steps())
    parallax_config.profile_config = profile_config
    parallax_config.redirect_path = FLAGS.redirect_path

    return parallax_config
예제 #2
0
def build_config():

    ckpt_config = parallax.CheckPointConfig(
        ckpt_dir=FLAGS.ckpt_dir, save_ckpt_steps=calculate_ckpt_steps())
    ps_config = parallax.PSConfig(replicate_variables=FLAGS.replicate_variables,
                                  protocol=FLAGS.protocol,
                                  local_aggregation=FLAGS.local_aggregation,
                                  boundary_among_servers=FLAGS.boundary_among_servers,
                                  boundary_between_workers_and_servers=\
                                  FLAGS.boundary_between_workers_and_servers)
    mpi_config = parallax.MPIConfig(mpirun_options=FLAGS.mpirun_options)

    def get_profile_steps():
        if FLAGS.profile_steps:
            FLAGS.profile_steps = FLAGS.profile_steps.strip()
            return [int(step) for step in FLAGS.profile_steps.split(',')]
        return None

    def get_profile_range():
        if FLAGS.profile_range:
            FLAGS.profile_range = FLAGS.profile_range.strip()
            splits = FLAGS.profile_range.split(',')
            return (int(splits[0]), int(splits[1]))
        return None

    profile_config = parallax.ProfileConfig(
        profile_dir=FLAGS.profile_dir,
        profile_steps=get_profile_steps(),
        profile_range=get_profile_range(),
        profile_worker=FLAGS.profile_worker)

    parallax_config = parallax.Config()
    parallax_config.run_option = FLAGS.run_option
    parallax_config.average_sparse = False
    parallax_config.communication_config = parallax.CommunicationConfig(
        ps_config, mpi_config)
    parallax_config.ckpt_config = ckpt_config
    parallax_config.profile_config = profile_config
    parallax_config.redirect_path = FLAGS.redirect_path
    parallax_config.export_graph_path = FLAGS.export_graph_path

    return parallax_config