コード例 #1
0
	def one_step_sample(self, one_step_token, init_cell_state, domain=None):
		action_probs, cell_state = self.sess.run([self.update_ops[domain_key("one_step_action_probs", domain)],
												  self.update_ops[domain_key("one_step_cell_state", domain)]],
												 feed_dict={
													 self.inputs[domain_key("one_step_token", domain)]: one_step_token,
													 self.inputs[domain_key("init_cell_state", domain)]: init_cell_state
												 })
		return action_probs, np.asarray(cell_state)
コード例 #2
0
	def nas_reinforce(self, net_seq, net_seq_len, reward, probs_mask, action, action_start, domain=None):
		self.sess.run(self.update_ops[domain_key("reinforce", domain)], feed_dict={
			self.inputs["input_seq"]: net_seq,
			self.inputs["seq_len"]: net_seq_len,
			self.inputs[domain_key("reward", domain)]: reward,
			self.inputs[domain_key("probs_mask", domain)]: probs_mask,
			self.inputs[domain_key("action", domain)]: action,
			self.inputs[domain_key("action_start", domain)]: action_start,
		})
コード例 #3
0
ファイル: rnn_seq2v.py プロジェクト: yenchih/RL4AS_NetTrans
	def _seq2v_train_loop_(self, train_data_gen, valid_data_maker, statistics_loop, domain=None, save_config=None):
		RUNNING_FLAG = False
		self.sess.run(self.initializers[domain_key("train", domain)])
		for i, (input_seq, seq_len, labels) in zip(range(statistics_loop), train_data_gen):
			RUNNING_FLAG = True
			self.sess.run(self.update_ops[domain_key("train", domain)], feed_dict={
				self.inputs["input_seq"]: input_seq,
				self.inputs["is_training"]: True,
				self.inputs["seq_len"]: seq_len,
				self.inputs[domain_key("labels", domain)]: labels
			})
		if RUNNING_FLAG:
			global_step = self.sess.run(self.overheads["global_step"])
			self.summary_writer.add_summary(self.sess.run(self.summaries[domain_key("train", domain)]), global_step)
			
			self.sess.run(self.initializers[domain_key("validate", domain)])
			for input_seq, seq_len, labels in valid_data_maker():
				self.sess.run(self.update_ops[domain_key("validate", domain)], feed_dict={
					self.inputs["input_seq"]: input_seq,
					self.inputs["is_training"]: False,
					self.inputs["seq_len"]: seq_len,
					self.inputs[domain_key("labels", domain)]: labels
				})
			print("validation metric: {}.".format(self.sess.run(self.overheads[domain_key("validate_metric", domain)])))
			self.summary_writer.add_summary(self.sess.run(self.summaries[domain_key("validate", domain)]), global_step)
			if save_config:
				save_model_path, save_step_size = save_config
				if global_step % save_step_size == 0:
					self.save_model(save_model_path)
		return RUNNING_FLAG
コード例 #4
0
ファイル: rnn_seq2v.py プロジェクト: yenchih/RL4AS_NetTrans
	def seq2v_test(self, test_data_generator, domain=None):
		self.sess.run(self.initializers[domain_key("test", domain)])
		for input_seq, seq_len, labels in test_data_generator:
			self.sess.run(self.update_ops[domain_key("test", domain)], feed_dict={
				self.inputs["input_seq"]: input_seq,
				self.inputs["is_training"]: False,
				self.inputs["seq_len"]: seq_len,
				self.inputs[domain_key("labels", domain)]: labels
			})
		print("test metric: {}.".format(self.sess.run(self.overheads[domain_key("test_metric", domain)])))
		self.summary_writer.add_summary(self.sess.run(self.summaries[domain_key("test", domain)]))
コード例 #5
0
ファイル: rnn_seq2v.py プロジェクト: yenchih/RL4AS_NetTrans
	def seq2v_query(self, input_seq, seq_len, domain=None):
		pVals = self.sess.run(self.overheads[domain_key("predictions", domain)], feed_dict={
			self.inputs["input_seq"]: input_seq,
			self.inputs["is_training"]: False,
			self.inputs["seq_len"]: seq_len
		})
		return pVals
コード例 #6
0
	def given_input_encode(self, net_seq, net_seq_len, domain=None):
		action_probs, cell_state = self.sess.run([self.update_ops[domain_key("given_input_action_probs", domain)],
												  self.update_ops["cell_state"]], feed_dict={
			self.inputs["input_seq"]: net_seq,
			self.inputs["seq_len"]: net_seq_len
		})
		return action_probs, np.asarray(cell_state)
コード例 #7
0
    def build_graph(self):
        encoder_outputs, _, _ = self.build_encoder()

        states = tf.stack(encoder_outputs,
                          axis=0)  # (num_steps, states_num, units)
        states = tf.transpose(states,
                              [1, 0, 2])  # (states_num, num_steps, units)

        self.inputs["state_seg"] = tf.placeholder(tf.int32,
                                                  shape=(),
                                                  name="state_seg")
        n2w_states = states[:self.inputs[
            "state_seg"]]  # (n2w_states_num, num_steps, units)
        n2d_states = states[
            self.inputs["state_seg"]:]  # (n2d_states_num, num_steps, units)

        if self.config.get("mode") is None:
            n2w_build_train, n2d_build_train, n2n_build_train = False, False, False
        elif self.config["mode"] == "net2wider":
            n2w_build_train, n2d_build_train, n2n_build_train = True, False, False
        elif self.config["mode"] == "net2deeper":
            n2w_build_train, n2d_build_train, n2n_build_train = False, True, False
        else:
            n2w_build_train, n2d_build_train, n2n_build_train = False, False, True

        n2w_obj = self.net2wider_decoder(n2w_states,
                                         "net2wider",
                                         build_train=n2w_build_train)
        n2d_obj = self.net2deeper_decoder(n2d_states,
                                          "net2deeper",
                                          build_train=n2d_build_train)
        if n2n_build_train:
            optimizer = TFUtils.build_optimizer(self.config["optimizer"])
            domain = "net2net"
            with tf.variable_scope("Net2Net"):
                n2n_obj = n2w_obj + n2d_obj
                self.inputs[domain_key("episode_num",
                                       domain)] = tf.placeholder(
                                           tf.float32, (), name="episode_num")
                self.overheads[domain_key(
                    "loss", domain)] = -n2n_obj / self.inputs[domain_key(
                        "episode_num", domain)]
                self.update_ops[domain_key("reinforce", domain)] = \
                 optimizer.minimize(self.overheads[domain_key("loss", domain)], self.overheads["global_step"])
コード例 #8
0
 def net2deeper_sample_action(self,
                              net_seq,
                              net_seq_len,
                              domain=None,
                              _random=False):
     place_probs, param_probs = self.sess.run(
         [
             self.update_ops[domain_key("place_probs", domain)],
             self.update_ops[domain_key("param_probs", domain)]
         ],
         feed_dict={
             self.inputs["input_seq"]: net_seq,
             self.inputs["seq_len"]: net_seq_len,
             self.inputs["state_seg"]: 0,
         })
     if _random:
         place_probs = np.ones_like(place_probs)
         param_probs = np.ones_like(param_probs)
     return place_probs, param_probs
コード例 #9
0
 def net2wider_sample_action(self,
                             net_seq,
                             net_seq_len,
                             valid_action,
                             domain=None,
                             _random=False):
     action_probs = self.sess.run(
         self.update_ops[domain_key("action_probs", domain)],
         feed_dict={
             self.inputs["input_seq"]: net_seq,
             self.inputs["seq_len"]: net_seq_len,
             self.inputs["state_seg"]: len(net_seq),
             self.inputs[domain_key("valid_action", domain)]: valid_action
         })
     if _random:
         action = np.random.randint(0, 2, action_probs.shape)
     else:
         action = np.random.random_sample(
             action_probs.shape) <= action_probs
     return action.astype(np.int32)
コード例 #10
0
 def net2wider_reinforce(self,
                         net_seq,
                         net_seq_len,
                         action,
                         action_mask,
                         valid_action,
                         reward,
                         episode_num,
                         domain=None):
     self.sess.run(self.update_ops[domain_key("reinforce", domain)],
                   feed_dict={
                       self.inputs["input_seq"]:
                       net_seq,
                       self.inputs["seq_len"]:
                       net_seq_len,
                       self.inputs["state_seg"]:
                       len(net_seq),
                       self.inputs[domain_key("valid_action", domain)]:
                       valid_action,
                       self.inputs[domain_key("reward", domain)]:
                       reward,
                       self.inputs[domain_key("action", domain)]:
                       action,
                       self.inputs[domain_key("action_mask", domain)]:
                       action_mask,
                       self.inputs[domain_key("episode_num", domain)]:
                       episode_num
                   })
コード例 #11
0
 def net2deeper_reinforce(self,
                          net_seq,
                          net_seq_len,
                          place_action,
                          place_probs_mask,
                          param_action,
                          param_probs_mask,
                          reward,
                          episode_num,
                          place_loss_mask,
                          param_loss_mask,
                          domain=None):
     self.sess.run(self.update_ops[domain_key("reinforce", domain)],
                   feed_dict={
                       self.inputs["input_seq"]:
                       net_seq,
                       self.inputs["seq_len"]:
                       net_seq_len,
                       self.inputs["state_seg"]:
                       0,
                       self.inputs[domain_key("reward", domain)]:
                       reward,
                       self.inputs[domain_key("place_action", domain)]:
                       place_action,
                       self.inputs[domain_key("place_probs_mask", domain)]:
                       place_probs_mask,
                       self.inputs[domain_key("param_action", domain)]:
                       param_action,
                       self.inputs[domain_key("param_probs_mask", domain)]:
                       param_probs_mask,
                       self.inputs[domain_key("episode_num", domain)]:
                       episode_num,
                       self.inputs[domain_key("place_loss_mask", domain)]:
                       place_loss_mask,
                       self.inputs[domain_key("param_loss_mask", domain)]:
                       param_loss_mask,
                   })
コード例 #12
0
 def net2net_reinforce(self,
                       n2w_net_seq,
                       n2w_net_seq_len,
                       n2w_action,
                       n2w_action_mask,
                       n2w_valid_action,
                       n2w_reward,
                       n2d_net_seq,
                       n2d_net_seq_len,
                       n2d_place_action,
                       n2d_place_probs_mask,
                       n2d_param_action,
                       n2d_param_probs_mask,
                       n2d_reward,
                       n2d_place_loss_mask,
                       n2d_param_loss_mask,
                       episode_num,
                       n2w_domain=None,
                       n2d_domain=None,
                       n2n_domain=None):
     state_seg = len(n2w_net_seq)
     net_seq = np.concatenate([n2w_net_seq, n2d_net_seq], axis=0)
     net_seq_len = np.concatenate([n2w_net_seq_len, n2d_net_seq_len],
                                  axis=0)
     self.sess.run(
         self.update_ops[domain_key("reinforce", n2n_domain)],
         feed_dict={
             self.inputs["input_seq"]: net_seq,
             self.inputs["seq_len"]: net_seq_len,
             self.inputs["state_seg"]: state_seg,
             self.inputs[domain_key("valid_action", n2w_domain)]:
             n2w_valid_action,
             self.inputs[domain_key("reward", n2w_domain)]: n2w_reward,
             self.inputs[domain_key("action", n2w_domain)]: n2w_action,
             self.inputs[domain_key("action_mask", n2w_domain)]:
             n2w_action_mask,
             self.inputs[domain_key("reward", n2d_domain)]: n2d_reward,
             self.inputs[domain_key("place_action", n2d_domain)]:
             n2d_place_action,
             self.inputs[domain_key("place_probs_mask", n2d_domain)]:
             n2d_place_probs_mask,
             self.inputs[domain_key("param_action", n2d_domain)]:
             n2d_param_action,
             self.inputs[domain_key("param_probs_mask", n2d_domain)]:
             n2d_param_probs_mask,
             self.inputs[domain_key("place_loss_mask", n2d_domain)]:
             n2d_place_loss_mask,
             self.inputs[domain_key("param_loss_mask", n2d_domain)]:
             n2d_param_loss_mask,
             self.inputs[domain_key("episode_num", n2n_domain)]:
             episode_num,
         })