示例#1
0
def sample(i):
    seed = i
    np.random.seed(seed)
    milp = get_sample(DATASET, 'train', i % LENGTH_MAP[DATASET]['train'])
    mip = SCIPMIPInstance.fromMIPInstance(milp.mip)
    all_integer_vars = []
    feasible_ass = milp.feasible_solution
    for vname, var in mip.varname2var.items():
        if var.vtype() in ['INTEGER', 'BINARY']:
            all_integer_vars.append(vname.lstrip('t_'))

    K = min(len(all_integer_vars), np.random.randint(20, 50))
    fixed_ass = {
        all_integer_vars[i]: feasible_ass[all_integer_vars[i]]
        for i in np.random.choice(
            len(all_integer_vars), len(all_integer_vars) - K, replace=False)
    }
    model = mip.fix(fixed_ass)
    model.setBoolParam('randomization/permutevars', True)
    model.setIntParam('randomization/permutationseed', seed)
    model.setIntParam('randomization/randomseedshift', seed)
    model.optimize()
    solving_stats = ConfigDict(model.getSolvingStats())
    results = ConfigDict(
        solving_time=solving_stats.solvingtime,
        determinstic_time=solving_stats.deterministictime,
        nnodes=model.getNNodes(),
    )
    return results
示例#2
0
    def __init__(self,
                 id,
                 seed,
                 graph_seed=-1,
                 graph_start_idx=0,
                 n_graphs=1,
                 dataset='milp-facilities-3',
                 dataset_type='train',
                 max_nodes=-1,
                 max_edges=-1,
                 **env_config):
        """If graph_seed < 0, then use the environment seed.
       max_nodes, max_edges -> Use for padding
    """
        self.config = ConfigDict(env_config)
        self.id = id
        self.seed = seed
        self._max_nodes = max_nodes
        self._max_edges = max_edges
        self.set_seed(seed)
        if graph_seed < 0: graph_seed = seed
        self._setup_graph_random_state(graph_seed)
        self._dataset = dataset
        self._dataset_type = dataset_type
        self._n_graphs = n_graphs
        self._graph_start_idx = graph_start_idx

        self.config.update(NORMALIZATION_CONSTANTS[dataset])
示例#3
0
def get_config():
    config = get_base_config()

    # required fields.
    config.class_path = "liaison.agents.gcn"
    config.class_name = "Agent"

    config.model = ConfigDict()
    config.model.class_path = "liaison.agents.models.gcn_attn_rins"
    config.model.n_prop_layers = 4
    config.model.node_hidden_layer_sizes = [32]
    config.model.edge_hidden_layer_sizes = [32]
    config.model.key_dim = 32
    config.model.value_dim = 32
    config.model.num_heads = 4
    config.model.node_embed_dim = 32
    config.model.edge_embed_dim = 32
    config.query_key_product_hidden_layer_sizes = [16]

    config.clip_rho_threshold = 1.0
    config.clip_pg_rho_threshold = 1.0

    config.loss = ConfigDict()
    config.loss.vf_loss_coeff = 1.0

    return config
示例#4
0
文件: main.py 项目: aravic/liaison
    def setup(self, argv):
        """After reading the component name this function will be called."""

        args = parser.parse_args(args=argv)

        self.args = args
        self.experiment_id = args.experiment_id
        self.work_id = args.work_id
        self.experiment_name = args.experiment_name
        self.batch_size = args.batch_size
        self.traj_length = args.traj_length
        self.seed = args.seed
        self.results_folder = args.results_folder
        self.hyper_params = args.hyper_configs

        self.env_config = ConfigDict(to_nested_dicts(args.env_config))
        self.sess_config = ConfigDict(to_nested_dicts(args.sess_config))
        self.agent_config = ConfigDict(to_nested_dicts(args.agent_config))

        if hasattr(args, 'eval_config'):
            self.eval_config = ConfigDict(to_nested_dicts(args.eval_config))
        else:
            self.eval_config = ConfigDict()

        check_config_compatibility(self.env_config, self.sess_config,
                                   self.agent_config, self.eval_config)
示例#5
0
  def build_update_ops(self, obs, targets):
    """
    This function will only be called once to create a TF graph which
    will be run repeatedly during training at the learner.

    All the arguments are tf placeholders (or nested structures of placeholders).

    Args:
      obs: [B, ...]
      targets: [B, N], N is the max node size.
    """
    self._validate_observations(obs)
    obs = ConfigDict(**obs)
    with tf.variable_scope(self._name):
      # flatten graph_features
      obs.graph_features = flatten_graphs(
          gn.graphs.GraphsTuple(**obs.graph_features))

      with tf.variable_scope('target_logits'):
        preds, logits_logged_vals = self._model.get_logits_and_next_state(obs)

      with tf.variable_scope('loss'):
        loss = tf.reduce_sum(obs.node_mask * ((preds - targets)**2))
        loss /= tf.reduce_sum(obs.node_mask)

      with tf.variable_scope('optimize'):
        opt_vals = self._optimize(loss)

      with tf.variable_scope('logged_vals'):
        self._logged_values = {
            'loss/supervised_loss': tf.reduce_sum(loss),
            **opt_vals,
            **logits_logged_vals,
            **self._extract_logged_values(obs),
        }
示例#6
0
  def __init__(self,
               id,
               seed,
               graph_start_idx=0,
               n_graphs=1,
               dataset='',
               dataset_type='train',
               k=5,
               n_local_moves=10,
               max_nodes=-1,
               max_edges=-1,
               sample_every_n_resets=1,
               **env_config):
    """k -> Max number of variables to unfix at a time.
            Informally, this is a bound on the local search
            neighbourhood size.
       max_nodes, max_edges -> Use for padding
    """
    self.config = ConfigDict(env_config)
    self.id = id
    self.k = self._original_k = k
    if self.config.k_schedule.enable:
      self.max_k = max(self.config.k_schedule.values)
    else:
      self.max_k = k
    self.max_local_moves = n_local_moves
    self.seed = seed
    if max_nodes < 0:
      max_nodes = NORMALIZATION_CONSTANTS[dataset]['max_nodes']
    if max_edges < 0:
      max_edges = NORMALIZATION_CONSTANTS[dataset]['max_edges']
    self._max_nodes = max_nodes
    self._max_edges = max_edges
    self.set_seed(seed)
    self._dataset = dataset
    self._dataset_type = dataset_type
    self._max_graphs = n_graphs
    self._graph_start_idx = graph_start_idx
    self._sample_every_n_resets = sample_every_n_resets

    if dataset:
      self.config.update(NORMALIZATION_CONSTANTS[dataset])
    # call reset so that obs_spec can work without calling reset
    self._ep_return = None
    self._prev_ep_return = np.nan
    self._prev_avg_quality = np.nan
    self._prev_best_quality = np.nan
    self._prev_final_quality = np.nan
    self._prev_mean_work = np.nan
    self._prev_k = np.nan
    self._reset_next_step = True
    if 'SYMPH_PS_SERVING_HOST' in os.environ:
      self._global_step_fetcher = GlobalStepFetcher(min_request_spacing=4)
    else:
      self._global_step_fetcher = None
    # map from sample to length of the mip
    self._sample_lengths = None
    self._n_resets = 0
    self._vars_unfixed_so_far = []
    self.reset()
示例#7
0
def get_config():
    config = get_base_config()

    # required fields.
    config.class_path = "liaison.agents.gcn_multi_actions"
    config.class_name = "Agent"

    config.model = ConfigDict()
    config.model.class_path = 'liaison.agents.models.transformer_auto_regressive'
    config.model.num_blocks = 4
    config.model.d_ff = 32
    config.model.num_heads = 4
    config.model.d_model = 64
    config.model.dropout_rate = 0.
    config.model.use_mlp_value_func = False

    # The following code duplicated in gcn_rins.py as well.
    # Propagate any changes made as needed.
    config.model.model_kwargs = ConfigDict()
    config.model.model_kwargs.class_path = "liaison.agents.models.bipartite_gcn_rins"
    config.model.model_kwargs.n_prop_layers = 4
    config.model.model_kwargs.edge_embed_dim = 32
    config.model.model_kwargs.node_embed_dim = 32
    config.model.model_kwargs.global_embed_dim = 32
    config.model.model_kwargs.policy_torso_hidden_layer_sizes = [16, 16]
    config.model.model_kwargs.value_torso_hidden_layer_sizes = [16, 16]
    config.model.model_kwargs.policy_summarize_hidden_layer_sizes = [16]
    config.model.model_kwargs.value_summarize_hidden_layer_sizes = [16]
    config.model.model_kwargs.supervised_prediction_torso_hidden_layer_sizes = [
        16, 16
    ]
    config.model.model_kwargs.sum_aggregation = False
    config.model.model_kwargs.use_layer_norm = True

    config.clip_rho_threshold = 1.0
    config.clip_pg_rho_threshold = 1.0

    config.loss = ConfigDict()
    config.loss.vf_loss_coeff = 1.0

    config.loss.al_coeff = ConfigDict()
    config.loss.al_coeff.init_val = 0.
    config.loss.al_coeff.min_val = 0.
    config.loss.al_coeff.start_decay_step = int(1e10)
    config.loss.al_coeff.decay_steps = 5000
    # dec_val not used for linear scheme
    config.loss.al_coeff.dec_val = .1
    config.loss.al_coeff.dec_approach = 'linear'

    # applicable for agent 'liaison.agents.gcn_large_batch'
    config.apply_grads_every = 1

    config.log_features_every = -1  # disable
    config.freeze_graphnet_weights_step = int(1e9)

    return config
示例#8
0
 def _get_model_config(self):
     config = ConfigDict()
     if FLAGS.model == 'mlp':
         config.class_path = "liaison.agents.models.mlp"
         config.hidden_layer_sizes = [32, 32]
     elif FLAGS.model == 'gcn':
         config.class_path = "liaison.agents.models.gcn"
     else:
         raise Exception('Unknown model %s' % FLAGS.model)
     return config
示例#9
0
def get_config():
  config = ConfigDict()
  config.host_names = dict(surreal_tmux='127.0.0.1')

  config.host_info = dict(
      surreal_tmux=dict(base_dir='/home/ubuntu/ml4opt/liaison',
                        use_ssh=False,
                        shell_setup_commands=[],
                        spy_port=4007))

  assert sorted(config.host_names.keys()) == sorted(config.host_info.keys())
  return config
示例#10
0
    def debatch_and_stack(self):
        traj_len = self._traj_len

        exps = []
        for i, finished_ts in enumerate(self._finished_timesteps):
            chopping_traj = self._chopping_trajs[i]
            for ts in finished_ts:
                if len(chopping_traj) == 0:
                    # tihs branch is taken only after reset is called on trajectory.
                    chopping_traj.start(
                        next_state=ts['step_output']['next_state'],
                        # remove step_output from ts
                        **ConfigDict(**{
                            k: v
                            for k, v in ts.items() if k != 'step_output'
                        }))
                    assert ts['step_output']['action'] is None
                    assert ts['step_output']['logits'] is None
                    continue

                chopping_traj.add(**ConfigDict(**ts))
                assert ts['step_output']['action'] is not None
                assert len(chopping_traj) <= traj_len + 1

                if len(chopping_traj) == traj_len + 1:
                    # TODO: Add dummy batch dimension and use debatch_and_stack
                    # for uniformity.
                    exps.append(chopping_traj.stack())
                    chopping_traj.reset()
                    chopping_traj.start(
                        next_state=ts['step_output']['next_state'],
                        # remove step_output from ts
                        **ConfigDict(**{
                            k: v
                            for k, v in ts.items() if k != 'step_output'
                        }))

            self._finished_timesteps[i] = []

        def f(path, spec, v):
            if path[0] == 'step_output' and path[1] != 'next_state':
                assert len(v) == traj_len
                return
            assert len(v) == traj_len + 1

        assert all([
            nest.map_structure_with_tuple_paths_up_to(self.spec, f, self.spec,
                                                      exp) for exp in exps
        ])
        return exps
示例#11
0
文件: mlp.py 项目: aravic/liaison
def get_config():
  config = get_base_config()

  # required fields.
  config.class_path = "liaison.agents.mlp"
  config.class_name = "Agent"

  config.model = ConfigDict()
  config.model.class_path = "liaison.agents.models.mlp"
  config.model.hidden_layer_sizes = [32, 32]

  config.loss = ConfigDict()
  config.loss.vf_loss_coeff = 1.0

  return config
示例#12
0
def get_config():
  config = ConfigDict()
  config.agent_config = ConfigDict()
  config.agent_config.network = ConfigDict()

  config.shell_config = ConfigDict()
  config.shell_config.use_gpu = False

  config.session_config = ConfigDict()
  config.session_config.sync_period = 100

  return config
示例#13
0
  def __init__(self, id, seed, discount=1.0, graph_seed=-1, **env_config):
    """if graph_seed < 0, then use the environment seed"""
    self.config = ConfigDict(env_config)
    self.id = id
    self.seed = seed
    self.discount = discount
    self.set_seed(seed)

    if graph_seed < 0:
      graph_seed = seed
    # generate graph with 32 nodes.
    nx_graph, self._path = generate_networkx_graph(graph_seed, [32, 33])
    nx_graph = nx_graph.to_directed()
    # max number of steps in an episode.
    self._max_steps = 3 * len(nx_graph)
    self._nx_graph = nx_graph
    self._src_node = self._path[0]
    self._target_node = self._path[-1]
    self._shortest_path_length = sum([
        nx_graph[u][v][DISTANCE_WEIGHT_NAME] for u, v in pairwise(self._path)
    ])
    self._reset_graph_features = self._networkx_to_graph_features(
        nx_graph, self._src_node, self._target_node)
    self._graph_features = copy.deepcopy(self._reset_graph_features)

    self._curr_node = self._src_node
    self._reset_next_step = True
示例#14
0
文件: base.py 项目: aravic/liaison
    def __init__(self,
                 seed,
                 evict_interval,
                 compress_before_send,
                 load_balanced=True,
                 index=0,
                 **kwargs):
        self.config = ConfigDict(kwargs)
        self.index = index

        if load_balanced:
            collector_port = os.environ['SYMPH_COLLECTOR_BACKEND_PORT']
            sampler_port = os.environ['SYMPH_SAMPLER_BACKEND_PORT']
        else:
            collector_port = os.environ['SYMPH_COLLECTOR_FRONTEND_PORT']
            sampler_port = os.environ['SYMPH_SAMPLER_FRONTEND_PORT']
        self._collector_server = ExperienceCollectorServer(
            host='localhost' if load_balanced else '*',
            port=collector_port,
            exp_handler=self._insert_wrapper,
            load_balanced=load_balanced,
            compress_before_send=compress_before_send)
        self._sampler_server = ZmqServer(
            host='localhost' if load_balanced else '*',
            port=sampler_port,
            bind=not load_balanced,
            serializer=get_serializer(compress_before_send),
            deserializer=get_deserializer(compress_before_send))
        self._sampler_server_thread = None

        self._evict_interval = evict_interval
        self._evict_thread = None

        self._setup_logging()
示例#15
0
    def __init__(self, name, action_spec, seed, model=None, **kwargs):

        self.set_seed(seed)
        self.config = ConfigDict(**kwargs)
        self._name = name
        self._action_spec = action_spec
        self._load_model(name, action_spec=action_spec, **(model or {}))
示例#16
0
 def _pad_graph_features(self, features: dict):
   features = ConfigDict(**features)
   features.update(nodes=pad_first_dim(features.nodes, self._max_nodes),
                   edges=pad_first_dim(features.edges, self._max_edges),
                   senders=pad_first_dim(features.senders, self._max_edges),
                   receivers=pad_first_dim(features.receivers, self._max_edges))
   return dict(**features)
示例#17
0
文件: rins_v2.py 项目: aravic/liaison
    def _init_ds(self):
        # Initialize data structures
        milp = self.milp
        self._ep_return = 0
        self._n_steps = 0
        self._n_local_moves = 0
        self._reset_next_step = False
        self._mip_works = []
        # mip stats for the current step
        self._mip_stats = ConfigDict(mip_work=0,
                                     n_cuts=0,
                                     n_cuts_applied=0,
                                     n_lps=0,
                                     solving_time=0.,
                                     pre_solving_time=0.,
                                     time_elapsed=0.)

        self._varnames2varidx = {
            var_name: i
            for i, var_name in enumerate(self._var_names)
        }
        # optimal solution can be used for supervised auxiliary tasks.
        self._optimal_soln = np.float32(
            [milp.optimal_solution[v] for v in self._var_names])
        self._optimal_lp_soln = np.float32(
            [milp.optimal_lp_sol[v] for v in self._var_names])

        globals_ = np.zeros(Env.N_GLOBAL_FIELDS, dtype=np.float32)
        globals_[Env.GLOBAL_STEP_NUMBER] = self._n_steps / np.sqrt(
            self.k * self.max_local_moves)
        globals_[Env.GLOBAL_UNFIX_LEFT] = self.k
        globals_[Env.GLOBAL_N_LOCAL_MOVES] = self._n_local_moves
        self._globals = globals_
        self._n_steps_in_this_local_move = 0
        self._set_stop_switch_mask()
示例#18
0
  def _scip_solve(self, solver):
    """solves a mip/lp using scip"""
    if solver is None:
      solver = Model()
    solver.hideOutput()
    if self.config.disable_maxcuts:
      for param in [
          'separating/maxcuts', 'separating/maxcutsroot', 'propagating/maxrounds',
          'propagating/maxroundsroot', 'presolving/maxroundsroot'
      ]:
        solver.setIntParam(param, 0)

      solver.setBoolParam('conflict/enable', False)
      solver.setPresolve(SCIP_PARAMSETTING.OFF)

    solver.setBoolParam('randomization/permutevars', True)
    # seed is set to 0 permanently.
    solver.setIntParam('randomization/permutationseed', 0)
    solver.setIntParam('randomization/randomseedshift', 0)

    with U.Timer() as timer:
      solver.optimize()
    assert solver.getStatus() == 'optimal', solver.getStatus()
    obj = float(solver.getObjVal())
    ass = {var.name: solver.getVal(var) for var in solver.getVars()}
    mip_stats = ConfigDict(mip_work=solver.getNNodes(),
                           n_cuts=solver.getNCuts(),
                           n_cuts_applied=solver.getNCutsApplied(),
                           n_lps=solver.getNLPs(),
                           solving_time=solver.getSolvingTime(),
                           pre_solving_time=solver.getPresolvingTime(),
                           time_elapsed=timer.to_seconds())
    return ass, obj, mip_stats
示例#19
0
def get_config():
    config = get_base_config()

    # required fields.
    config.class_path = "liaison.agents.mlp"
    config.class_name = "Agent"

    config.model = ConfigDict()
    config.model.class_path = "liaison.agents.models.transformer_rins"

    config.clip_rho_threshold = 1.0
    config.clip_pg_rho_threshold = 1.0

    config.loss = ConfigDict()
    config.loss.vf_loss_coeff = 1.0

    return config
示例#20
0
def get_config():
    config = get_base_config()

    # required fields.
    config.class_path = "liaison.agents.gcn"
    config.class_name = "Agent"

    # The following code duplicated in gcn_ar.py as well.
    # Propagate any changes made
    config.model = ConfigDict()
    config.model.class_path = "liaison.agents.models.gcn_rins"
    config.model.n_prop_layers = 4
    config.model.edge_embed_dim = 16
    config.model.node_embed_dim = 16
    config.model.global_embed_dim = 16
    config.model.node_hidden_layer_sizes = [16]
    config.model.edge_hidden_layer_sizes = [16]
    config.model.policy_torso_hidden_layer_sizes = [16, 16]
    config.model.value_torso_hidden_layer_sizes = [16, 16]
    config.model.policy_summarize_hidden_layer_sizes = [16]
    config.model.value_summarize_hidden_layer_sizes = [16]
    config.model.supervised_prediction_torso_hidden_layer_sizes = [16, 16]

    config.model.sum_aggregation = False
    config.model.use_layer_norm = True

    config.clip_rho_threshold = 1.0
    config.clip_pg_rho_threshold = 1.0

    config.loss = ConfigDict()
    config.loss.vf_loss_coeff = 1.0

    config.loss.al_coeff = ConfigDict()
    config.loss.al_coeff.init_val = 0.
    config.loss.al_coeff.min_val = 0.
    config.loss.al_coeff.start_decay_step = int(1e10)
    config.loss.al_coeff.decay_steps = 5000
    # dec_val not used for linear scheme
    config.loss.al_coeff.dec_val = .1
    config.loss.al_coeff.dec_approach = 'linear'

    # applicable for agent 'liaison.agents.gcn_large_batch'
    config.apply_grads_every = 1
    config.choose_stop_switch = False

    return config
示例#21
0
文件: gcn.py 项目: aravic/liaison
def get_config():
    config = get_base_config()

    # required fields.
    config.class_path = "liaison.agents.gcn"
    config.class_name = "Agent"

    config.model = ConfigDict()
    config.model.class_path = "liaison.agents.models.gcn"
    config.model.n_prop_layers = 8
    config.model.node_embed_dim = 32

    config.clip_rho_threshold = 1.0
    config.clip_pg_rho_threshold = 1.0

    config.loss = ConfigDict()
    config.loss.vf_loss_coeff = 1.0

    return config
示例#22
0
    def run_evaluator(self, id: str):
        env_config, sess_config, agent_config = (self.env_config,
                                                 self.sess_config,
                                                 self.agent_config)
        eval_config = self.eval_config

        agent_class = U.import_obj(agent_config.class_name,
                                   agent_config.class_path)
        shell_class = U.import_obj(sess_config.shell.class_name,
                                   sess_config.shell.class_path)
        env_class = U.import_obj(env_config.class_name, env_config.class_path)
        agent_config = copy.deepcopy(agent_config)
        agent_config.update(evaluation_mode=True)
        shell_config = dict(agent_class=agent_class,
                            agent_config=agent_config,
                            **self.sess_config.shell)

        env_configs = []

        for i in range(eval_config.batch_size):
            env_config = ConfigDict(**self.env_config)
            env_config.update({
                eval_config.dataset_type_field: id,
                'graph_start_idx': i,
                **eval_config.env_config
            })
            env_configs.append(env_config)

        evaluator_config = dict(
            shell_class=shell_class,
            shell_config=shell_config,
            env_class=env_class,
            env_configs=env_configs,
            loggers=self._setup_evaluator_loggers(id),
            heuristic_loggers=self._setup_evaluator_loggers(f'heuristic-{id}'),
            seed=self.seed,
            **eval_config)
        from liaison.distributed import Evaluator
        evaluator = Evaluator(**evaluator_config)
        t = evaluator.get_heuristic_thread()
        t.start()
        evaluator.run_loop(int(1e9))
        t.join()
示例#23
0
def get_shell_config():
    config = ConfigDict()
    agent_config = get_agent_config()
    # shell class path is default to the distributed folder.
    config.class_path = 'liaison.distributed.shell_for_test'
    config.class_name = 'Shell'
    config.agent_scope = 'shell'
    config.use_gpu = True
    config.agent_class = U.import_obj(agent_config.class_name,
                                      agent_config.class_path)
    config.agent_config = agent_config
    config.agent_config.update(evaluation_mode=True)
    return config
示例#24
0
  def __init__(self, name, action_spec, seed, model=None, choose_stop_switch=False, **kwargs):

    self.set_seed(seed)
    self.config = ConfigDict(**kwargs)
    self._name = name
    self._action_spec = action_spec
    self.choose_stop_switch = choose_stop_switch
    self._load_model(name,
                     action_spec=action_spec,
                     choose_stop_switch=choose_stop_switch,
                     **(model or {}))
示例#25
0
文件: tsp.py 项目: aravic/liaison
def get_config():
    config = ConfigDict()

    # required fields.
    config.class_path = "liaison.env.tsp"  # should be rel to the parent directory.
    config.class_name = "Env"

    # makes observations suitable for the MLP model.
    config.make_obs_for_mlp = False
    """if graph_seed < 0, then use the environment seed"""
    config.graph_seed = 42

    config.dataset = 'tsp-20'
    config.dataset_type = 'train'
    config.graph_idx = 0

    return config
示例#26
0
  def __init__(self, name, action_spec, seed, model=None, **kwargs):

    self.set_seed(seed)
    self.config = ConfigDict(**kwargs)
    self._name = name
    self._action_spec = action_spec
    self._load_model(name, action_spec=action_spec, **(model or {}))
    self._global_step = tf.train.get_or_create_global_step()
    self._total_steps = tf.Variable(0,
                                    trainable=False,
                                    collections=[tf.GraphKeys.LOCAL_VARIABLES],
                                    name='total_steps')
示例#27
0
 def batch_and_preprocess_trajs(self, l):
   traj = Trajectory.batch(l, self._traj_spec)
   # feed and overwrite the trajectory
   traj['step_output'], traj['step_output']['next_state'], traj['step_type'], traj[
       'reward'], traj['observation'], traj['discount'] = self._agent.update_preprocess(
           step_outputs=ConfigDict(traj['step_output']),
           prev_states=traj['step_output']['next_state'],
           step_types=traj['step_type'],
           rewards=traj['reward'],
           observations=traj['observation'],
           discounts=traj['discount'])
   return traj
示例#28
0
文件: worker.py 项目: aravic/liaison
    def __init__(self, serving_host, serving_port, checkpoint_folder,
                 profile_folder, kvstream_folder, **kwargs):
        Thread.__init__(self)
        self.config = ConfigDict(**kwargs)
        self.checkpoint_folder = checkpoint_folder
        self.profile_folder = profile_folder
        self.kvstream_folder = kvstream_folder
        self.serving_host = serving_host
        self.serving_port = serving_port

        # Attributes
        self._server = None
示例#29
0
文件: cluster.py 项目: aravic/liaison
def get_config():
    config = ConfigDict()
    config.host_names = dict(
        surreal_tmux='ec2-52-14-254-34.us-east-2.compute.amazonaws.com',
        surreal_tmux2='ec2-13-58-35-146.us-east-2.compute.amazonaws.com')

    config.host_info = dict(
        surreal_tmux=dict(base_dir='/home/ubuntu/ml4opt/liaison',
                          use_ssh=False,
                          shell_setup_commands=[],
                          spy_port=4007),
        surreal_tmux2=dict(
            base_dir='/home/ubuntu/ml4opt/liaison',
            use_ssh=True,
            ssh_username='******',
            ssh_key_file='/home/ubuntu/.ssh/temp',
            shell_setup_commands=[
                'source ${HOME}/.bashrc && cd ${HOME}/nfs/liaison'
            ],
            spy_port=4007))

    assert sorted(config.host_names.keys()) == sorted(config.host_info.keys())
    return config
示例#30
0
  def __init__(self,
               name,
               action_spec,
               seed,
               mlp_model,
               model,
               evaluation_mode=False,
               **kwargs):

    self.set_seed(seed)
    self.config = ConfigDict(evaluation_mode=evaluation_mode, **kwargs)
    self._name = name
    self._action_spec = action_spec
    mlp_model.update(action_spec=action_spec)
    self._load_models(name, model, mlp_model)