Exemplo n.º 1
0
    def _build_pserver_programs(self, compiled_config):
        _main = fluid.Program()
        _startup = fluid.Program()

        if not compiled_config.is_geo_mode():
            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_optimizer_pass(_main, compiled_config)
            _main = server.large_scale_sparse_pass(_main, _main,
                                                   compiled_config, False)
            _startup = server.build_pserver_startup_program_pass(
                _startup, _main, compiled_config)
            _startup = server.large_scale_sparse_pass(_startup, _main,
                                                      compiled_config, True)

            if not compiled_config.is_sync_mode():
                _main = server.delete_unused_in_main_pass(
                    _main, compiled_config)

            _startup = server.delete_unused_in_startup_pass(
                _startup, _main, compiled_config)
        else:
            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_geo_optimizer_pass(_main, compiled_config)
            _main = server.large_scale_sparse_pass(_main, _main,
                                                   compiled_config, False)
            _startup = server.build_pserver_startup_program_pass(
                _startup, _main, compiled_config)
            _startup = server.large_scale_sparse_pass(_startup, _main,
                                                      compiled_config, True)
            _startup = server.delete_unused_in_startup_pass(
                _startup, _main, compiled_config)

        return _main, _startup
Exemplo n.º 2
0
    def _build_pserver_programs(self, compiled_config):
        _main = fluid.Program()
        _startup = fluid.Program()

        from paddle.fluid.incubate.fleet.parameter_server.ir import pserver_pass as server

        if not compiled_config.is_geo_mode():

            from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_optimize_ops
            is_sgd_adam = False

            main_program = compiled_config.get_origin_main_program()
            ops = _get_optimize_ops(main_program)

            if len(ops) == 0:
                return _main, _startup

            from paddle.fluid.incubate.fleet.parameter_server.ir.public import _add_lr_decay_table_pass
            lr_decay_steps = self.user_defined_strategy.a_sync_configs[
                "lr_decay_steps"]
            _add_lr_decay_table_pass(main_program, compiled_config,
                                     lr_decay_steps)

            for op in ops:
                if op.type in ["sgd", "adam"]:
                    is_sgd_adam = True
                    break

            if is_sgd_adam:
                return _main, _startup

            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_optimizer_pass(_main, compiled_config)
            _main = server.large_scale_sparse_pass(_main, _main,
                                                   compiled_config, False)
            _startup = server.build_pserver_startup_program_pass(
                _startup, _main, compiled_config)
            _startup = server.large_scale_sparse_pass(_startup, _main,
                                                      compiled_config, True)

            if not compiled_config.is_sync_mode():
                _main = server.delete_unused_in_main_pass(
                    _main, compiled_config)

            _startup = server.delete_unused_in_startup_pass(
                _startup, _main, compiled_config)
        else:
            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_geo_optimizer_pass(_main, compiled_config)
            _startup = server.build_pserver_startup_program_pass(
                _startup, _main, compiled_config)
            _startup = server.delete_unused_in_startup_pass(
                _startup, _main, compiled_config)

        return _main, _startup