Exemplo n.º 1
0
    def __init__(self,
                 B,
                 T,
                 g_E,
                 g_H,
                 d_E,
                 d_H,
                 d_dropout,
                 path_pos,
                 path_neg,
                 g_lr=1e-3,
                 d_lr=1e-3,
                 n_sample=10000,
                 generate_samples=5,
                 init_eps=0.1):

        self.B, self.T = B, T
        self.g_E, self.g_H = g_E, g_H
        self.d_E, self.d_H = d_E, d_H
        self.d_dropout = d_dropout
        self.generate_samples = generate_samples
        self.g_lr, self.d_lr = g_lr, d_lr
        self.eps = init_eps
        self.init_eps = init_eps
        self.top = os.getcwd()
        self.path_pos = path_pos
        self.path_neg = path_neg

        self.g_data = GeneratorPretrainingGenerator(
            self.path_pos, B=B, T=T, min_count=1
        )  # next方法产生x, y_true数据; 都是同一个数据,比如[BOS, 8, 10, 6, 3, EOS]预测[8, 10, 6, 3, EOS]
        self.d_data = DiscriminatorGenerator(
            path_pos=self.path_pos,
            path_neg=self.path_neg,
            B=self.B,
            shuffle=True)  # next方法产生 pos数据和neg数据

        self.V = self.g_data.V

        self.agent = Agent(B, self.V, g_E, g_H, g_lr)
        self.g_beta = Agent(B, self.V, g_E, g_H, g_lr)

        self.discriminator = Discriminator(self.V, d_E, d_H, d_dropout)

        self.env = Environment(self.discriminator,
                               self.g_data,
                               self.g_beta,
                               n_sample=n_sample)

        self.generator_pre = GeneratorPretraining(self.V, g_E, g_H)
        print("para is ", self.V, g_E, g_H)
Exemplo n.º 2
0
 def __init__(self,
              B,
              T,
              g_E,
              g_H,
              d_E,
              d_H,
              d_dropout,
              g_lr=1e-3,
              d_lr=1e-3,
              n_sample=16,
              generate_samples=20,
              init_eps=0.1,
              real_packets_file="real_packet_sizes.txt"):
     self.top = os.getcwd()
     self.B, self.T = B, T
     self.g_E, self.g_H = g_E, g_H
     self.d_E, self.d_H = d_E, d_H
     self.d_dropout = d_dropout
     self.generate_samples = generate_samples
     self.g_lr, self.d_lr = g_lr, d_lr
     self.eps = init_eps
     self.init_eps = init_eps
     self.real_packets = real_packets_file
     self.pos_sequences, self.V = extract_all(real_packets_file)
     self.all_signatures = signatureExtractionAll(self.pos_sequences, 2, 6,
                                                  5, 4)
     self.pos_sequences = split_sequences(self.pos_sequences, T - 2)
     self.neg_sequences = []
     self.agent = Agent(B, self.V, g_E, g_H, g_lr)
     self.g_beta = Agent(B, self.V, g_E, g_H, g_lr)
     self.discriminator = Discriminator(self.V, d_E, d_H, d_dropout)
     self.signature_discriminator = SignatureDiscriminator(signatureCount(
         self.all_signatures),
                                                           H=200)
     self.env = Environment(self.discriminator,
                            self.signature_discriminator,
                            self.all_signatures, self.B, self.V,
                            self.g_beta, n_sample)
     self.generator_pre = GeneratorPretraining(self.V, g_E, g_H)
     self.g_pre_path, self.d_pre_path = None, None
Exemplo n.º 3
0
    def __init__(self, run_name, state_file, actions, input_shape, n_variables,
                 start_epsilon, complexities, learning_steps_per_epoch,
                 testing_episodes, max_complexity, ticks, still_action,
                 still_ticks, log_file, load_net_file, save_net_file,
                 save_net_file_format):
        self.actions = actions
        self.state_file = state_file
        self.run_name = run_name
        self.log_file = log_file or run_name + ".log"
        self.load_net_file = load_net_file or run_name + ".npy"
        self.save_net_file = save_net_file
        self.save_net_file_format = save_net_file or save_net_file_format

        self.agent = Agent(len(self.actions),
                           input_shape,
                           n_variables=n_variables,
                           start_epsilon=start_epsilon)

        self.tester = Tester(self.actions,
                             input_shape,
                             n_variables,
                             ticks=ticks,
                             still_action=still_action,
                             still_ticks=still_ticks,
                             episodes=testing_episodes)
        self.tutor = Tutor(self.agent,
                           self.actions,
                           input_shape,
                           n_variables,
                           ticks=ticks,
                           still_action=still_action,
                           still_ticks=still_ticks)

        self.epochs_done = 0
        self.complexities = complexities
        self.max_complexity = max_complexity
        self.complexity_done = -1
        self.learning_steps_per_epoch = learning_steps_per_epoch
Exemplo n.º 4
0
def train_mario_model(epochs=500,
                      gamma=0.9,
                      learning_rate=0.004,
                      memory_size=100000,
                      sample_size=32,
                      level=LEVELS[0],
                      actions_history_size=16,
                      frame_history_size=4):
    working_dir = create_working_dir(epochs, gamma, learning_rate, memory_size,
                                     sample_size, actions_history_size,
                                     frame_history_size)

    env = gym.make(level)

    env = CustomEnv(env,
                    frame_width=84,
                    frame_height=84,
                    history_width=32,
                    history_height=32,
                    actions_history_size=actions_history_size,
                    frame_history_size=frame_history_size)

    model = build_model(actions_history_size=actions_history_size,
                        frame_history_size=frame_history_size,
                        learning_rate=learning_rate)

    agent = Agent(model, gamma=gamma)
    memory = ExperienceReplay(max_size=memory_size,
                              sample_size=sample_size,
                              database_file='memory.db',
                              should_pop_oldest=True,
                              reuse_db=False,
                              verbose=True)
    policy = RandomPolicy(action_mapper,
                          epsilon=1.,
                          epsilon_decay_step=0.00001,
                          epsilon_min=0.05,
                          dropout=0.01)

    train(agent,
          env,
          policy,
          memory,
          epochs=epochs,
          test_interval=5,
          working_dir=working_dir)
Exemplo n.º 5
0
def main():
    model_path = '../rl/reinforce/models/model_045500_537.pth'
    agent = Agent(LARGEST_CARD, HAND_SIZE, model_path)
    game = Game(N_PLAYERS, LARGEST_CARD, HAND_SIZE, N_ROUNDS, agent)

    scores_per_game = []
    for i in trange(N_GAMES):
        scores = game.play_game()
        scores_per_game.append(scores)

    # initially [n_games, n_players, n_rounds]
    weights = [GAMMA**i for i in range(N_ROUNDS)]
    scores_per_player_per_game = np.asarray(scores_per_game).transpose(1, 0, 2)
    scores_per_player = np.average(scores_per_player_per_game, axis=2, weights=weights) * sum(weights)
    cum_scores = np.cumsum(scores_per_player, axis=1)
    
    plot_cum_scores(cum_scores, title='Strategy: best vs. random')
    plot_reward_freq(scores_per_player_per_game,
                     player_index=0, round_index=-1)
    ]

    env = FrewEnv(
            wall_depth_bounds=[5, 45, 1], 
            pile_d=pile_diameters, 
            max_deflection=1000
    )

    n_episodes = int(1.5e3)
    n_actions=len(env.action_space)

    agent = Agent(
        alpha=0.01, 
        gamma=0.9,  
        eps=1, 
        batch_size=256, 
        n_actions=n_actions,
        input_dim=1,
        eps_dec=0.998,
        nn_arch=[256, 256, 256, 256]
    )

    scores = []
    normalized_rewards = []
    eps_history = []
    steps = 0
    total_steps = 0

    fig = plt.figure(figsize=(7,7))
    plt.get_current_fig_manager().window.wm_geometry("+10+10")
    steps_ax = fig.add_subplot(2, 1, 1)
    episode_ax = fig.add_subplot(2, 1, 2)
Exemplo n.º 7
0
class Trainer(object):
    def __init__(self,
                 B,
                 T,
                 g_E,
                 g_H,
                 d_E,
                 d_H,
                 d_dropout,
                 g_lr=1e-3,
                 d_lr=1e-3,
                 n_sample=16,
                 generate_samples=20,
                 init_eps=0.1,
                 real_packets_file="real_packet_sizes.txt"):
        self.top = os.getcwd()
        self.B, self.T = B, T
        self.g_E, self.g_H = g_E, g_H
        self.d_E, self.d_H = d_E, d_H
        self.d_dropout = d_dropout
        self.generate_samples = generate_samples
        self.g_lr, self.d_lr = g_lr, d_lr
        self.eps = init_eps
        self.init_eps = init_eps
        self.real_packets = real_packets_file
        self.pos_sequences, self.V = extract_all(real_packets_file)
        self.all_signatures = signatureExtractionAll(self.pos_sequences, 2, 6,
                                                     5, 4)
        self.pos_sequences = split_sequences(self.pos_sequences, T - 2)
        self.neg_sequences = []
        self.agent = Agent(B, self.V, g_E, g_H, g_lr)
        self.g_beta = Agent(B, self.V, g_E, g_H, g_lr)
        self.discriminator = Discriminator(self.V, d_E, d_H, d_dropout)
        self.signature_discriminator = SignatureDiscriminator(signatureCount(
            self.all_signatures),
                                                              H=200)
        self.env = Environment(self.discriminator,
                               self.signature_discriminator,
                               self.all_signatures, self.B, self.V,
                               self.g_beta, n_sample)
        self.generator_pre = GeneratorPretraining(self.V, g_E, g_H)
        self.g_pre_path, self.d_pre_path = None, None

    def pre_train(self,
                  g_epochs=3,
                  d_epochs=1,
                  g_pre_path=None,
                  d_pre_path=None,
                  g_lr=1e-3,
                  d_lr=1e-3):
        self.pre_train_generator(g_epochs=g_epochs,
                                 g_pre_path=g_pre_path,
                                 lr=g_lr)
        self.pre_train_discriminator(d_epochs=d_epochs,
                                     d_pre_path=d_pre_path,
                                     lr=d_lr)

    def pre_train_generator(self, g_epochs=3, lr=1e-3, g_pre_path=None):
        if g_pre_path is None:
            self.g_pre_path = os.path.join(self.top, 'data', 'save',
                                           'generator_pre.hdf5')
        else:
            self.g_pre_path = g_pre_path

        g_adam = Adam(lr)
        self.generator_pre.compile(g_adam, 'categorical_crossentropy')
        print('Generator pre-training')
        self.generator_pre.summary()
        X, Y, self.T = build_generator_pretraining_datasets(
            self.pos_sequences, self.V)
        self.generator_pre.fit(x=X, y=Y, batch_size=self.B, epochs=g_epochs)
        self.generator_pre.save_weights(self.g_pre_path)
        self.reflect_pre_train()

    def pre_train_discriminator(self, d_epochs=1, lr=1e-3, d_pre_path=None):
        if d_pre_path is None:
            self.d_pre_path = os.path.join(self.top, 'data', 'save',
                                           'discriminator_pre.hdf5')
        else:
            self.d_pre_path = d_pre_path

        neg_sequences = self.agent.generator.generate_samples(
            self.T, self.generate_samples)
        X, Y, _ = build_discriminator_datasets(self.pos_sequences,
                                               neg_sequences)
        d_adam = Adam(lr)
        self.discriminator.compile(d_adam, 'binary_crossentropy')
        self.discriminator.summary()
        print('Discriminator pre-training')
        self.discriminator.fit(x=X, y=Y, batch_size=self.B, epochs=d_epochs)
        self.discriminator.save(self.d_pre_path)

    def reflect_pre_train(self):
        for layer in self.generator_pre.layers:
            print(len(layer.get_weights()))
        for layer in self.agent.generator.packet_size_policy.layers:
            print(len(layer.get_weights()))
        w1 = self.generator_pre.layers[1].get_weights()
        w2 = self.generator_pre.layers[2].get_weights()
        w3 = self.generator_pre.layers[3].get_weights()
        self.agent.generator.packet_size_policy.layers[1].set_weights(w1)
        self.g_beta.generator.packet_size_policy.layers[1].set_weights(w1)
        self.agent.generator.packet_size_policy.layers[4].set_weights(w2)
        self.g_beta.generator.packet_size_policy.layers[4].set_weights(w2)
        self.agent.generator.packet_size_policy.layers[5].set_weights(w3)
        self.g_beta.generator.packet_size_policy.layers[5].set_weights(w3)

    def load_pre_train(self, g_pre_path, d_pre_path):
        self.load_pre_train_g(g_pre_path)
        self.load_pre_train_d(d_pre_path)

    def load_pre_train_g(self, g_pre_path):
        self.generator_pre.load_weights(g_pre_path)
        self.reflect_pre_train()

    def load_pre_train_d(self, d_pre_path):
        self.discriminator.load_weights(d_pre_path)

    def save(self, g_path, d_path):
        self.agent.save(g_path)
        self.discriminator.save(d_path)

    def load(self, g_path, d_path):
        self.agent.load(g_path)
        self.g_beta.load(g_path)
        self.discriminator.load_weights(d_path)

    def train(self,
              steps=10,
              g_steps=1,
              d_steps=1,
              d_epochs=3,
              g_weights_path='data/save/generator.pkl',
              d_weights_path='data/save/discriminator.hdf5',
              use_sig=False):
        d_adam = Adam(self.d_lr)
        self.discriminator.compile(d_adam,
                                   'binary_crossentropy',
                                   metrics=['accuracy'])
        self.eps = self.init_eps
        for step in range(steps):
            print("Adverserial Training - Generator")
            for _ in range(g_steps):
                rewards = np.zeros([self.B, self.T])
                self.agent.reset()
                self.env.reset()
                for t in range(self.T):
                    state = self.env.get_state()
                    action = self.agent.act(state, epsilon=0.0, stateful=False)
                    next_state, reward, is_episode_end, info = self.env.step(
                        action, use_sig)
                    self.agent.generator.update(state, action, reward)
                    rewards[:, t] = reward.reshape([
                        self.B,
                    ])
            print("Adverserial Training - Discriminator")
            for _ in range(d_steps):
                if use_sig:
                    neg_sequences = self.agent.generator.generate_samples(
                        self.T, self.generate_samples)
                    print("generated sequences")
                    print(neg_sequences)
                    X_pos = featureExtractionAll(self.pos_sequences,
                                                 self.all_signatures)
                    X_neg = featureExtractionAll(neg_sequences,
                                                 self.all_signatures)
                    X = np.array(X_pos + X_neg)
                    y = np.array([1] * len(self.pos_sequences) +
                                 [0] * len(neg_sequences))
                    X_train, X_test, y_train, y_test = train_test_split(
                        X, y, test_size=0.33, random_state=25)
                    self.signature_discriminator.fit(x=X_train,
                                                     y=y_train,
                                                     batch_size=self.B,
                                                     epochs=d_epochs,
                                                     validation_data=(X_test,
                                                                      y_test))
                else:
                    neg_sequences = self.agent.generator.generate_samples(
                        self.T, self.generate_samples)
                    print("generated sequences")
                    print(neg_sequences)
                    X, Y, _ = build_discriminator_datasets(
                        self.pos_sequences, neg_sequences)
                    X_train, X_test, y_train, y_test = train_test_split(
                        X, Y, test_size=0.33, random_state=25)
                    self.discriminator.fit(x=X_train,
                                           y=y_train,
                                           batch_size=self.B,
                                           epochs=d_epochs,
                                           validation_data=(X_test, y_test))

            # Update env.g_beta to agent
            self.agent.save(g_weights_path)
            self.g_beta.load(g_weights_path)

            self.discriminator.save(d_weights_path)
            self.eps = max(self.eps * (1 - float(step) / steps * 4), 1e-4)
Exemplo n.º 8
0
class Runner:
    def __init__(self, run_name, state_file, actions, input_shape, n_variables,
                 start_epsilon, complexities, learning_steps_per_epoch,
                 testing_episodes, max_complexity, ticks, still_action,
                 still_ticks, log_file, load_net_file, save_net_file,
                 save_net_file_format):
        self.actions = actions
        self.state_file = state_file
        self.run_name = run_name
        self.log_file = log_file or run_name + ".log"
        self.load_net_file = load_net_file or run_name + ".npy"
        self.save_net_file = save_net_file
        self.save_net_file_format = save_net_file or save_net_file_format

        self.agent = Agent(len(self.actions),
                           input_shape,
                           n_variables=n_variables,
                           start_epsilon=start_epsilon)

        self.tester = Tester(self.actions,
                             input_shape,
                             n_variables,
                             ticks=ticks,
                             still_action=still_action,
                             still_ticks=still_ticks,
                             episodes=testing_episodes)
        self.tutor = Tutor(self.agent,
                           self.actions,
                           input_shape,
                           n_variables,
                           ticks=ticks,
                           still_action=still_action,
                           still_ticks=still_ticks)

        self.epochs_done = 0
        self.complexities = complexities
        self.max_complexity = max_complexity
        self.complexity_done = -1
        self.learning_steps_per_epoch = learning_steps_per_epoch

    #@retry(stop_max_attempt_number=10, wait_random_min=1000, wait_random_max=10000)
    def run(self, game, esp_factor=2):
        logging.basicConfig(filename=self.log_file,
                            level=logging.INFO,
                            filemode='a')
        if os.path.exists(self.load_net_file):
            logging.info("Loading net file: %s", self.load_net_file)
            self.agent.load_params(self.load_net_file)
        for complexity, epochs in [(x, y) for x, y in self.complexities
                                   if x > self.complexity_done]:
            logging.info('Complexity: %d', complexity)
            if self.epochs_done == 0:
                self.agent.reset_epsilon(epsilon_change_steps=epochs *
                                         self.learning_steps_per_epoch /
                                         esp_factor)
            if self.save_net_file_format:
                self.save_net_file = self.save_net_file_format.format(
                    self.run_name, complexity)
            for i in trange(self.epochs_done + 1, epochs + 1, desc='epochs'):
                logging.info("epoch: %d", i)
                self.tutor.epoch(game, self.learning_steps_per_epoch,
                                 complexity)
                self.tester.test(game, self.agent, complexity)
                if self.max_complexity is not None and complexity != self.max_complexity:
                    self.tester.test(game, self.agent, self.max_complexity)
                self.epochs_done = i
                self._persist()
            self.epochs_done = 0
            self.complexity_done = complexity

    def _persist(self):
        self.load_net_file = self.save_net_file
        tmp = tempfile.mkdtemp(dir='.') + '/'
        self.agent.save_params(tmp + self.save_net_file)
        pickle.dump(self, open(tmp + self.state_file, 'wb'))
        os.rename(tmp + self.save_net_file, self.save_net_file + '.saved')
        os.rename(tmp + self.state_file, self.state_file + '.saved')
        shutil.rmtree(tmp)
        os.rename(self.save_net_file + '.saved', self.save_net_file)
        os.rename(self.state_file + '.saved', self.state_file)

    @staticmethod
    def prepare_runner(run_name,
                       state_file,
                       actions,
                       input_shape,
                       complexities,
                       n_variables=0,
                       start_epsilon=1.0,
                       learning_steps_per_epoch=50000,
                       testing_episodes=300,
                       max_complexity=None,
                       ticks=4,
                       still_action=None,
                       still_ticks=0,
                       log_file=None,
                       load_net_file=None,
                       save_net_file=None,
                       save_net_file_format="{}_{}.npy"):
        state_file_saved = state_file + '.saved'
        if os.path.exists(state_file_saved):
            os.rename(state_file_saved, state_file)
            runner = pickle.load(open(state_file, 'rb'))
            if os.path.exists(runner.load_net_file + '.saved'):
                os.rename(runner.load_net_file + '.saved',
                          runner.load_net_file)
        elif os.path.exists(state_file):
            runner = pickle.load(open(state_file, 'rb'))
        else:
            runner = Runner(run_name, state_file, actions, input_shape,
                            n_variables, start_epsilon, complexities,
                            learning_steps_per_epoch, testing_episodes,
                            max_complexity, ticks, still_action, still_ticks,
                            log_file, load_net_file, save_net_file,
                            save_net_file_format)
        return runner
Exemplo n.º 9
0
class Trainer(object):
    def __init__(self,
                 B,
                 T,
                 g_E,
                 g_H,
                 d_E,
                 d_H,
                 d_dropout,
                 path_pos,
                 path_neg,
                 g_lr=1e-3,
                 d_lr=1e-3,
                 n_sample=10000,
                 generate_samples=5,
                 init_eps=0.1):

        self.B, self.T = B, T
        self.g_E, self.g_H = g_E, g_H
        self.d_E, self.d_H = d_E, d_H
        self.d_dropout = d_dropout
        self.generate_samples = generate_samples
        self.g_lr, self.d_lr = g_lr, d_lr
        self.eps = init_eps
        self.init_eps = init_eps
        self.top = os.getcwd()
        self.path_pos = path_pos
        self.path_neg = path_neg

        self.g_data = GeneratorPretrainingGenerator(
            self.path_pos, B=B, T=T, min_count=1
        )  # next方法产生x, y_true数据; 都是同一个数据,比如[BOS, 8, 10, 6, 3, EOS]预测[8, 10, 6, 3, EOS]
        self.d_data = DiscriminatorGenerator(
            path_pos=self.path_pos,
            path_neg=self.path_neg,
            B=self.B,
            shuffle=True)  # next方法产生 pos数据和neg数据

        self.V = self.g_data.V

        self.agent = Agent(B, self.V, g_E, g_H, g_lr)
        self.g_beta = Agent(B, self.V, g_E, g_H, g_lr)

        self.discriminator = Discriminator(self.V, d_E, d_H, d_dropout)

        self.env = Environment(self.discriminator,
                               self.g_data,
                               self.g_beta,
                               n_sample=n_sample)

        self.generator_pre = GeneratorPretraining(self.V, g_E, g_H)
        print("para is ", self.V, g_E, g_H)

    def pre_train(self,
                  g_epochs=3,
                  d_epochs=1,
                  g_pre_path=None,
                  d_pre_path=None,
                  g_lr=1e-3,
                  d_lr=1e-3):
        self.pre_train_generator(g_epochs=g_epochs,
                                 g_pre_path=g_pre_path,
                                 lr=g_lr)

        self.pre_train_discriminator(d_epochs=d_epochs,
                                     d_pre_path=d_pre_path,
                                     lr=d_lr)
        print("end pretrain")

    def pre_train_generator(self, g_epochs=3, g_pre_path=None, lr=1e-3):
        if g_pre_path is None:
            self.g_pre_path = os.path.join(self.top, 'data', 'save',
                                           'generator_pre.hdf5')
        else:
            self.g_pre_path = g_pre_path

        g_adam = keras.optimizers.Adam(lr)
        self.generator_pre.compile(g_adam, 'categorical_crossentropy')
        print('Generator pre-training')
        self.generator_pre.summary()

        self.generator_pre.fit_generator(self.g_data,
                                         steps_per_epoch=None,
                                         epochs=g_epochs)
        self.generator_pre.save_weights(self.g_pre_path)
        self.reflect_pre_train()

    def pre_train_discriminator(self, d_epochs=1, d_pre_path=None, lr=1e-3):
        if d_pre_path is None:
            self.d_pre_path = os.path.join(self.top, 'data', 'save',
                                           'discriminator_pre.hdf5')
        else:
            self.d_pre_path = d_pre_path

        print('Start Generating sentences')
        #fix
        # self.agent.generator.generate_samples(self.T, self.g_data,
        #                                       self.generate_samples, self.path_neg)
        self.agent.generator.generate_samples(self.T, self.g_data,
                                              self.generate_samples,
                                              self.path_neg)
        print("generating sentences")
        self.d_data = DiscriminatorGenerator(path_pos=self.path_pos,
                                             path_neg=self.path_neg,
                                             B=self.B,
                                             shuffle=True)

        d_adam = keras.optimizers.Adam(lr)
        self.discriminator.compile(d_adam, 'binary_crossentropy')
        self.discriminator.summary()
        print('Discriminator pre-training')

        self.discriminator.fit_generator(self.d_data,
                                         steps_per_epoch=None,
                                         epochs=d_epochs)

        self.discriminator.save(self.d_pre_path)
        print("end dis pre_training")

    def load_pre_train(self, g_pre_path, d_pre_path):
        self.generator_pre.load_weights(g_pre_path)
        self.reflect_pre_train()
        self.discriminator.load_weights(d_pre_path)
        print("end load pre train")

    def load_pre_train_g(self, g_pre_path):
        self.generator_pre.load_weights(g_pre_path)
        self.reflect_pre_train()

    def load_pre_train_d(self, d_pre_path):
        self.discriminator.load_weights(d_pre_path)

    def reflect_pre_train(self):
        i = 0
        print("relfecting")

        st = self.env.get_state()
        h, c = self.agent.generator.get_rnn_state()
        h, c = self.g_beta.generator.get_rnn_state()

        self.agent.generator(h, c, st)

        self.g_beta.generator(h, c, st)

        l_embeding = self.generator_pre.layers[1]
        l_mask = self.generator_pre.layers[2]
        l_lstm = self.generator_pre.layers[3]
        l_td = self.generator_pre.layers[4]
        #time distribute  -->dense   Necessary!
        if len(l_embeding.get_weights()) != 0:
            #have pretrained
            self.agent.generator.embedding_layer.set_weights(
                l_embeding.get_weights())
            self.agent.generator.mask_layer.set_weights(l_mask.get_weights())
            self.agent.generator.lstm_layer.set_weights(l_lstm.get_weights())
            self.agent.generator.dense_layer.set_weights(l_td.get_weights())

            self.g_beta.generator.embedding_layer.set_weights(
                l_embeding.get_weights())
            self.g_beta.generator.mask_layer.set_weights(l_mask.get_weights())
            self.g_beta.generator.lstm_layer.set_weights(l_lstm.get_weights())
            self.g_beta.generator.dense_layer.set_weights(l_td.get_weights())

        # for layer in self.generator_pre.layers:
        #     if len(layer.get_weights()) != 0:
        #         w = layer.get_weights()
        #         # print(w[0].shape)
        #         # just build a graph.
        #         main_layers=self.agent.generator.layers
        #         pre_layers=self.generator_pre.layers
        #
        #         self.agent.generator.layers[i].set_weights(w)
        #         self.g_beta.generator.layers[i].set_weights(w)
        #         i += 1
        self.agent.reset()
        self.g_beta.reset()
        self.env.reset()
        print("end reflect")
        # return

    def train(self,
              steps=30,
              g_steps=1,
              d_steps=1,
              d_epochs=1,
              g_weights_path='data/save/generator.pkl',
              d_weights_path='data/save/discriminator.hdf5',
              verbose=True,
              head=1):
        print("start adv train")
        d_adam = keras.optimizers.Adam(self.d_lr)
        # print("start adv train1")
        self.discriminator.compile(d_adam, 'binary_crossentropy')
        self.eps = self.init_eps
        # print("start adv train2")
        debug_flag = 0

        log = open("data/log.txt", 'w')

        for step in range(steps):

            # Generator training
            for _ in range(g_steps):
                rewards = np.zeros([self.B, self.T])
                self.agent.reset()
                self.env.reset()
                # print("start adv train4")
                global tflag
                avg_loss = 0
                for t in range(self.T):
                    state = self.env.get_state()
                    # print("start adv train5")
                    # debug_flag = 1 + debug_flag
                    # if debug_flag==2:
                    #     asdfsdfa=23
                    action = self.agent.act(state, epsilon=0.0)
                    # print("start adv train6")

                    _next_state, reward, is_episode_end, _info = self.env.step(
                        action)
                    # print("start adv train7")
                    # print(step,_,"before update")

                    cur_loss = self.agent.generator.update(
                        state, action, reward)
                    avg_loss += tf.reduce_mean(cur_loss)

                    log.write("epoch" + str(step) + "g step in cur epoch " +
                              str(t) + "loss" + str(tf.reduce_mean(cur_loss)) +
                              '\n')

                    print("epoch", step, "g step in cur epoch ", t, "loss",
                          tf.reduce_mean(cur_loss))

                    # print("start adv train8")
                    rewards[:, t] = reward.reshape([
                        self.B,
                    ])
                    if is_episode_end:
                        if verbose:
                            print('Reward: {:.3f}, Episode end'.format(
                                np.average(rewards)))
                            self.env.render(head=head)
                        break
                log.write("avg loss=" + str(avg_loss / self.T) + '\n')
                print("avg loss=", (avg_loss / self.T))
            print("train d")
            # Discriminator training
            for _ in range(d_steps):
                self.agent.generator.generate_samples(self.T, self.g_data,
                                                      self.generate_samples,
                                                      self.path_neg)
                self.d_data = DiscriminatorGenerator(path_pos=self.path_pos,
                                                     path_neg=self.path_neg,
                                                     B=self.B,
                                                     shuffle=True)
                self.discriminator.fit_generator(self.d_data,
                                                 steps_per_epoch=None,
                                                 epochs=d_epochs)

            # Update env.g_beta to agent
            self.agent.save(g_weights_path)
            self.g_beta.load(g_weights_path)

            self.discriminator.save(d_weights_path)
            self.eps = max(self.eps * (1 - float(step) / steps * 4), 1e-4)

    def save(self, g_path, d_path):
        self.agent.save(g_path)
        self.discriminator.save(d_path)

    def load(self, g_path, d_path):
        self.agent.load(g_path)
        self.g_beta.load(g_path)
        self.discriminator.load_weights(d_path)

    def test(self):
        x, y = self.d_data.next()
        pred = self.discriminator.predict(x)

        for i in range(self.B):
            txt = [self.g_data.id2word[id] for id in x[i].tolist()]

            label = y[i]
            print('{}, {:.3f}: {}'.format(label, pred[i, 0], ''.join(txt)))

    def generate_txt(self, file_name, generate_samples):
        path_neg = os.path.join(self.top, 'data', 'save', file_name)

        self.agent.generator.generate_samples(self.T, self.g_data,
                                              generate_samples, path_neg)