コード例 #1
0
ファイル: trainer_worker.py プロジェクト: eddyJ/fedlearner
def train(role, args, input_fn, model_fn, serving_input_receiver_fn):
    bridge = Bridge(role, int(args.local_addr.split(':')[1]),
                               args.peer_addr)

    if args.cluster_spec:
        cluster_spec = json.loads(args.cluster_spec)
        assert 'clusterSpec' in cluster_spec, \
            "cluster_spec do not meet legal format"
        assert 'Master' in cluster_spec['clusterSpec'],\
            "cluster_spec must include Master"
        assert isinstance(cluster_spec['clusterSpec']['Master'], list), \
            "Master must be list"
        assert 'Worker' in cluster_spec['clusterSpec'],\
            "cluster_spec must include Worker"
        assert isinstance(cluster_spec['clusterSpec']['Worker'], list), \
            "Worker must be list"
        trainer_master = TrainerMasterClient(
            cluster_spec['clusterSpec']['Master'][0], role, args.worker_rank)
        cluster_spec = tf.train.ClusterSpec({
            'ps': cluster_spec['clusterSpec']['PS'],
            'worker': {args.worker_rank: args.tf_addr}})

    elif args.master_addr:
        assert args.tf_addr is not None, \
            "--tf-addr must be set when master_addr is set."
        trainer_master = TrainerMasterClient(
            args.master_addr, role, args.worker_rank)
        ps_addrs = args.ps_addrs.split(",")
        cluster_spec = tf.train.ClusterSpec({
            'ps': ps_addrs,
            'worker': {args.worker_rank: args.tf_addr}})
    elif args.data_path:
        trainer_master = LocalTrainerMasterClient(role, args.data_path)
        cluster_spec = None
    else:
        raise ValueError("Either --master-addr or --data-path must be set")

    estimator = FLEstimator(
        model_fn, bridge, trainer_master, role, worker_rank=args.worker_rank,
        cluster_spec=cluster_spec)
    if args.checkpoint_path:
        estimator.train(input_fn,
                        checkpoint_path=args.checkpoint_path,
                        save_checkpoint_steps=args.save_checkpoint_steps)
    else:
        estimator.train(input_fn)

    if args.export_path:
        estimator.export_saved_model(args.export_path,
                                    serving_input_receiver_fn,
                                    checkpoint_path=args.checkpoint_path)
コード例 #2
0
def _run_worker(role, args, input_fn, model_fn):
    if not args.local_addr:
        raise ValueError("local-addr is required")
    if not args.peer_addr:
        raise ValueError("peer-addr is required")
    if not args.master_addr:
        raise ValueError("master-addr is required")
    mode = args.mode.lower()

    cluster_spec = _create_cluster_spec(args, require_ps=True)
    cluster_server = ClusterServer(cluster_spec,
                                   "worker",
                                   task_index=args.worker_rank)

    trainer_master = TrainerMasterClient(args.master_addr, args.worker_rank)
    if not trainer_master.worker_register(cluster_spec.as_cluster_def()):
        return

    bridge = Bridge(role, int(args.local_addr.split(':')[1]), args.peer_addr,
                    args.application_id, args.worker_rank)

    estimator_factory = SparseFLEstimator \
        if args.sparse_estimator else FLEstimator
    estimator = estimator_factory(cluster_server,
                                  trainer_master,
                                  bridge,
                                  role,
                                  model_fn,
                                  is_chief=args.worker_rank == 0)

    if mode == 'train':
        estimator.train(input_fn)
    elif mode == 'eval':
        estimator.evaluate(input_fn)

    trainer_master.worker_complete(bridge.terminated_at)
    trainer_master.wait_master_complete()
コード例 #3
0
ファイル: trainer_worker.py プロジェクト: guotie/fedlearner
def train(role, args, input_fn, model_fn, serving_input_receiver_fn):
    logging.basicConfig(
        format="%(asctime)-15s [%(filename)s:%(lineno)d] " \
               "%(levelname)s : %(message)s")
    if args.verbosity == 0:
        logging.getLogger().setLevel(logging.WARNING)
    elif args.verbosity == 1:
        logging.getLogger().setLevel(logging.INFO)
    elif args.verbosity > 1:
        logging.getLogger().setLevel(logging.DEBUG)

    if args.application_id:
        bridge = Bridge(role, int(args.local_addr.split(':')[1]),
                        args.peer_addr, args.application_id, args.worker_rank)
    else:
        bridge = Bridge(role, int(args.local_addr.split(':')[1]),
                        args.peer_addr)

    if args.data_path:
        trainer_master = LocalTrainerMasterClient(role,
                                                  args.data_path,
                                                  epoch_num=args.epoch_num)
        if args.ps_addrs is not None:
            ps_addrs = args.ps_addrs.split(",")
            cluster_spec = tf.train.ClusterSpec({
                'ps': ps_addrs,
                'worker': {
                    args.worker_rank: args.tf_addr
                }
            })
        else:
            cluster_spec = None
    elif args.cluster_spec:
        cluster_spec = json.loads(args.cluster_spec)
        assert 'clusterSpec' in cluster_spec, \
            "cluster_spec do not meet legal format"
        assert 'Master' in cluster_spec['clusterSpec'],\
            "cluster_spec must include Master"
        assert isinstance(cluster_spec['clusterSpec']['Master'], list), \
            "Master must be list"
        assert 'Worker' in cluster_spec['clusterSpec'],\
            "cluster_spec must include Worker"
        assert isinstance(cluster_spec['clusterSpec']['Worker'], list), \
            "Worker must be list"
        trainer_master = TrainerMasterClient(
            cluster_spec['clusterSpec']['Master'][0], role, args.worker_rank)
        cluster_spec = tf.train.ClusterSpec({
            'ps':
            cluster_spec['clusterSpec']['PS'],
            'worker': {
                args.worker_rank: args.tf_addr
            }
        })
    elif args.master_addr:
        assert args.tf_addr is not None, \
            "--tf-addr must be set when master_addr is set."
        trainer_master = TrainerMasterClient(args.master_addr, role,
                                             args.worker_rank)
        ps_addrs = args.ps_addrs.split(",")
        cluster_spec = tf.train.ClusterSpec({
            'ps': ps_addrs,
            'worker': {
                args.worker_rank: args.tf_addr
            }
        })
    elif args.data_source:
        if args.start_time is None or args.end_time is None:
            raise ValueError(
                "data source must be set with start-date and end-date")
        trainer_master = LocalTrainerMasterClient(role,
                                                  args.data_source,
                                                  start_time=args.start_time,
                                                  end_time=args.end_time,
                                                  epoch_num=args.epoch_num)
        cluster_spec = None
    else:
        raise ValueError("Either --master-addr or --data-path must be set")

    if args.summary_path:
        SummaryHook.summary_path = args.summary_path
        SummaryHook.worker_rank = args.worker_rank
        SummaryHook.role = role
    if args.summary_save_steps:
        SummaryHook.save_steps = args.summary_save_steps

    if args.sparse_estimator:
        estimator = SparseFLEstimator(model_fn,
                                      bridge,
                                      trainer_master,
                                      role,
                                      worker_rank=args.worker_rank,
                                      application_id=args.application_id,
                                      cluster_spec=cluster_spec)
    else:
        estimator = FLEstimator(model_fn,
                                bridge,
                                trainer_master,
                                role,
                                worker_rank=args.worker_rank,
                                application_id=args.application_id,
                                cluster_spec=cluster_spec)

    run_mode = args.mode.lower()
    if run_mode == 'train':
        estimator.train(input_fn,
                        checkpoint_path=args.checkpoint_path,
                        save_checkpoint_steps=args.save_checkpoint_steps,
                        save_checkpoint_secs=args.save_checkpoint_secs)
        if args.export_path and args.worker_rank == 0:
            export_path = '%s/%d' % (args.export_path, bridge.terminated_at)
            estimator.export_saved_model(export_path,
                                         serving_input_receiver_fn,
                                         checkpoint_path=args.checkpoint_path)
            fsuccess = tf.io.gfile.GFile('%s/_SUCCESS' % export_path, 'w')
            fsuccess.write('%d' % bridge.terminated_at)
            fsuccess.close()

    elif run_mode == 'eval':
        estimator.evaluate(input_fn, checkpoint_path=args.checkpoint_path)
    else:
        raise ValueError('Allowed values are: --mode=train|eval')
コード例 #4
0
def train(role, args, input_fn, model_fn, serving_input_receiver_fn):
    if args.application_id:
        bridge = Bridge(role, int(args.local_addr.split(':')[1]),
                        args.peer_addr, args.application_id, args.worker_rank)
    else:
        bridge = Bridge(role, int(args.local_addr.split(':')[1]),
                        args.peer_addr)

    if args.data_path:
        trainer_master = LocalTrainerMasterClient(role, args.data_path)
        if args.ps_addrs is not None:
            ps_addrs = args.ps_addrs.split(",")
            cluster_spec = tf.train.ClusterSpec({
                'ps': ps_addrs,
                'worker': {
                    args.worker_rank: args.tf_addr
                }
            })
        else:
            cluster_spec = None
    elif args.cluster_spec:
        cluster_spec = json.loads(args.cluster_spec)
        assert 'clusterSpec' in cluster_spec, \
            "cluster_spec do not meet legal format"
        assert 'Master' in cluster_spec['clusterSpec'],\
            "cluster_spec must include Master"
        assert isinstance(cluster_spec['clusterSpec']['Master'], list), \
            "Master must be list"
        assert 'Worker' in cluster_spec['clusterSpec'],\
            "cluster_spec must include Worker"
        assert isinstance(cluster_spec['clusterSpec']['Worker'], list), \
            "Worker must be list"
        trainer_master = TrainerMasterClient(
            cluster_spec['clusterSpec']['Master'][0], role, args.worker_rank)
        cluster_spec = tf.train.ClusterSpec({
            'ps':
            cluster_spec['clusterSpec']['PS'],
            'worker': {
                args.worker_rank: args.tf_addr
            }
        })
    elif args.master_addr:
        assert args.tf_addr is not None, \
            "--tf-addr must be set when master_addr is set."
        trainer_master = TrainerMasterClient(args.master_addr, role,
                                             args.worker_rank)
        ps_addrs = args.ps_addrs.split(",")
        cluster_spec = tf.train.ClusterSpec({
            'ps': ps_addrs,
            'worker': {
                args.worker_rank: args.tf_addr
            }
        })
    elif args.data_source:
        if args.start_time is None or args.end_time is None:
            raise ValueError(
                "data source must be set with start-date and end-date")
        trainer_master = LocalTrainerMasterClient(role,
                                                  args.data_source,
                                                  start_time=args.start_time,
                                                  end_time=args.end_time)
        cluster_spec = None
    else:
        raise ValueError("Either --master-addr or --data-path must be set")

    if args.summary_path:
        SummaryHook.summary_path = args.summary_path
        SummaryHook.worker_rank = args.worker_rank
        SummaryHook.role = role
    if args.summary_save_steps:
        SummaryHook.save_steps = args.summary_save_steps

    if args.sparse_estimator:
        estimator = SparseFLEstimator(model_fn,
                                      bridge,
                                      trainer_master,
                                      role,
                                      worker_rank=args.worker_rank,
                                      cluster_spec=cluster_spec)
    else:
        estimator = FLEstimator(model_fn,
                                bridge,
                                trainer_master,
                                role,
                                worker_rank=args.worker_rank,
                                cluster_spec=cluster_spec)

    run_mode = args.mode.lower()
    if run_mode == 'train':
        estimator.train(input_fn,
                        checkpoint_path=args.checkpoint_path,
                        save_checkpoint_steps=args.save_checkpoint_steps,
                        save_checkpoint_secs=args.save_checkpoint_secs)
    elif run_mode == 'eval':
        estimator.evaluate(input_fn, checkpoint_path=args.checkpoint_path)
    else:
        raise ValueError('Allowed values are: --mode=train|eval')

    if args.export_path:
        estimator.export_saved_model(args.export_path,
                                     serving_input_receiver_fn,
                                     checkpoint_path=args.checkpoint_path)