def agent(self):
        nb_actions = self.env.action_space.n
        model = self.build()
        print(model.summary())

        memory = SequentialMemory(limit=50000, window_length=1)
        dqn = DQNAgent(model=model,
                       nb_actions=nb_actions,
                       memory=memory,
                       nb_steps_warmup=32,
                       enable_dueling_network=True,
                       target_model_update=1e-2,
                       policy=InformedBoltzmannGumbelQPolicy(self.env),
                       test_policy=InformedGreedyQPolicy(self.env),
                       batch_size=32,
                       train_interval=32)
        dqn.compile(Adam(lr=1e-3), metrics=['mae'])

        if self.initial_weights_file is not None:
            try:
                dqn.load_weights(self.initial_weights_file)
            except:
                # just skip loading
                pass

        return dqn
    def agent(self):
        nb_actions = self.env.action_space.n
        obs_dim = self.env.observation_space.shape
        model = Sequential()
        model.add(Flatten(input_shape=(1, obs_dim)))
        model.add(Dense(nb_actions, activation='linear'))
        print(model.summary())

        memory = SequentialMemory(limit=50000, window_length=1)
        dqn = DQNAgent(model=model,
                       nb_actions=nb_actions,
                       memory=memory,
                       nb_steps_warmup=256,
                       enable_dueling_network=True,
                       target_model_update=1e-2,
                       policy=InformedBoltzmannGumbelQPolicy(self.env),
                       test_policy=InformedGreedyQPolicy(self.env),
                       batch_size=128,
                       train_interval=128)
        dqn.compile(Adam(lr=1e-3), metrics=['mae'])

        if self.initial_weights_file is not None:
            dqn.load_weights(self.initial_weights_file)
            self.train_episodes = 0

        return dqn
示例#3
0
class Player:
    """Mandatory class with the player methods"""
    def __init__(self, name='DQN', load_model=None, env=None):
        """Initiaization of an agent"""
        self.equity_alive = 0
        self.actions = []
        self.last_action_in_stage = ''
        self.temp_stack = []
        self.name = name
        self.autoplay = True

        self.dqn = None
        self.model = None
        self.env = env

        # if load_model:
        #     self.model = self.load_model(load_model)

    def initiate_agent(self,
                       env,
                       model_name=None,
                       load_memory=None,
                       load_model=None,
                       load_optimizer=None,
                       load_dqn=None,
                       batch_size=500,
                       learn_rate=1e-3):
        """initiate a deep Q agent"""
        # tf.compat.v1.disable_eager_execution()

        self.env = env

        nb_actions = self.env.action_space.n

        if load_model:
            pass
        #     self.model, trainable_model, target_model = self.load_model(load_model)
        #     print(self.model.history)

        else:
            pass

        self.model = Sequential()
        self.model.add(
            Dense(512, activation='relu', input_shape=env.observation_space))
        self.model.add(Dropout(0.2))
        self.model.add(Dense(512, activation='relu'))
        self.model.add(Dropout(0.2))
        self.model.add(Dense(512, activation='relu'))
        self.model.add(Dropout(0.2))
        self.model.add(Dense(nb_actions, activation='linear'))

        # Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
        # even the metrics!
        if load_memory:
            # print(load_memory)
            # exit()
            try:
                memory = self.load_memory(load_memory)

            except:
                pass

        else:
            memory = SequentialMemory(limit=memory_limit,
                                      window_length=window_length)

        self.batch_size = batch_size
        self.policy = CustomEpsGreedyQPolicy()
        self.policy.env = self.env

        self.test_policy = CustomEpsGreedyQPolicy()
        self.test_policy.eps = 0.05
        self.test_policy.env = self.env

        self.reduce_lr = ReduceLROnPlateau(monitor='loss',
                                           factor=0.2,
                                           patience=5,
                                           min_lr=1e-4)

        nb_actions = env.action_space.n
        self.dqn = DQNAgent(model=self.model,
                            nb_actions=nb_actions,
                            memory=memory,
                            nb_steps_warmup=nb_steps_warmup,
                            target_model_update=1e-2,
                            policy=self.policy,
                            test_policy=self.test_policy,
                            processor=CustomProcessor(),
                            batch_size=self.batch_size,
                            train_interval=train_interval,
                            enable_double_dqn=enable_double_dqn)

        # timestr = time.strftime("%Y%m%d-%H%M%S") + "_" + str(model_name)
        # self.tensorboard = MyTensorBoard(log_dir='./Graph/{}'.format(timestr), player=self)
        self.dqn.compile(Adam(lr=learn_rate), metrics=['mae'])

        if load_model:
            self.load_model(load_model)
            # self.dqn.trainable_model = trainable_model
            # self.dqn.target_model = target_model

        # self.reduce_lr = ReduceLROnPlateau

        if load_optimizer:
            self.load_optimizer_weights(load_optimizer)

    def start_step_policy(self, observation):
        """Custom policy for random decisions for warm up."""
        log.info("Random action")
        _ = observation
        legal_moves_limit = [
            move.value for move in self.env.info['legal_moves']
        ]
        action = np.random.choice(legal_moves_limit)

        return action

    def train(self, env_name, batch_size=500, policy_epsilon=0.2):
        """Train a model"""
        # initiate training loop

        train_vars = {
            'batch_size': batch_size,
            'policy_epsilon': policy_epsilon
        }

        timestr = time.strftime("%Y%m%d-%H%M%S") + "_" + str(env_name)
        tensorboard = TensorBoard(log_dir='./Graph/{}'.format(timestr),
                                  histogram_freq=0,
                                  write_graph=True,
                                  write_images=False)
        self.dqn.fit(self.env,
                     nb_max_start_steps=nb_max_start_steps,
                     nb_steps=nb_steps,
                     visualize=False,
                     verbose=2,
                     start_step_policy=self.start_step_policy,
                     callbacks=[tensorboard])

        self.policy.eps = policy_epsilon

        self.dqn.save_weights("dqn_{}_model.h5".format(env_name),
                              overwrite=True)

        # Save memory
        pickle.dump(self.dqn.memory,
                    open("train_memory_{}.p".format(env_name), "wb"))

        # Save optimizer weights
        symbolic_weights = getattr(self.dqn.trainable_model.optimizer,
                                   'weights')
        optim_weight_values = K.batch_get_value(symbolic_weights)
        pickle.dump(optim_weight_values,
                    open('optimizer_weights_{}.p'.format(env_name), "wb"))

        # # Dump dqn
        # pickle.dump(self.dqn, open( "dqn_{}.p".format(env_name), "wb" ))

        # Finally, evaluate our algorithm for 5 episodes.
        self.dqn.test(self.env, nb_episodes=5, visualize=False)

    def load_model(self, env_name):
        """Load a model"""

        # Load the architecture
        # with open('dqn_{}_json.json'.format(env_name), 'r') as architecture_json:
        #     dqn_json = json.load(architecture_json)

        self.dqn.load_weights("dqn_{}_model.h5".format(env_name))
        # model = keras.models.load_model("dqn_{}_model.h5".format(env_name))
        # trainable_model = keras.models.load_model("dqn_{}_trainable_model.h5".format(env_name))
        # target_model = keras.models.load_model("dqn_{}_target_model.h5".format(env_name), overwrite=True)

        # return model, trainable_model, target_model

    def load_memory(self, model_name):
        memory = pickle.load(open('train_memory_{}.p'.format(model_name),
                                  "rb"))
        return memory

    def load_optimizer_weights(self, env_name):
        optim_weights = pickle.load(
            open('optimizer_weights_{}.p'.format(env_name), "rb"))
        self.dqn.trainable_model.optimizer.set_weights(optim_weights)

    def play(self, nb_episodes=5, render=False):
        """Let the agent play"""
        memory = SequentialMemory(limit=memory_limit,
                                  window_length=window_length)
        policy = CustomEpsGreedyQPolicy()

        class CustomProcessor(Processor):  # pylint: disable=redefined-outer-name
            """The agent and the environment"""
            def process_state_batch(self, batch):
                """
                Given a state batch, I want to remove the second dimension, because it's
                useless and prevents me from feeding the tensor into my CNN
                """
                return np.squeeze(batch, axis=1)

            def process_info(self, info):
                processed_info = info['player_data']
                if 'stack' in processed_info:
                    processed_info = {'x': 1}
                return processed_info

        nb_actions = self.env.action_space.n

        self.dqn = DQNAgent(model=self.model,
                            nb_actions=nb_actions,
                            memory=memory,
                            nb_steps_warmup=nb_steps_warmup,
                            target_model_update=1e-2,
                            policy=policy,
                            processor=CustomProcessor(),
                            batch_size=batch_size,
                            train_interval=train_interval,
                            enable_double_dqn=enable_double_dqn)
        self.dqn.compile(Adam(lr=1e-3), metrics=['mae'])  # pylint: disable=no-member

        self.dqn.test(self.env, nb_episodes=nb_episodes, visualize=render)

    def action(self, action_space, observation, info):  # pylint: disable=no-self-use
        """Mandatory method that calculates the move based on the observation array and the action space."""
        _ = observation  # not using the observation for random decision
        _ = info

        this_player_action_space = {
            Action.FOLD, Action.CHECK, Action.CALL, Action.RAISE_POT,
            Action.RAISE_HALF_POT, Action.RAISE_2POT
        }
        _ = this_player_action_space.intersection(set(action_space))

        action = None
        return action
示例#4
0
## Set up the agent for training ##
memory = SequentialMemory(limit=params.REPLAY_BUFFER_SIZE, window_length=1)
agent = DQNAgent(model=model,
                 policy=BoltzmannQPolicy(),
                 memory=memory,
                 nb_actions=action_size)

agent.compile(Adam(lr=params.LR_MODEL), metrics=[params.METRICS])

## Train ##
if args.train:
    check_overwrite('DQN', params.ENV, args.model)
    history = agent.fit(env,
                        nb_steps=params.N_STEPS_TRAIN,
                        visualize=args.visualize,
                        verbose=1,
                        nb_max_episode_steps=env._max_episode_steps,
                        log_interval=params.LOG_INTERVAL)
    agent.save_weights(WEIGHTS_FILES, overwrite=True)
    save_plot_reward('DQN', params.ENV, history, args.model, params.PARAMS)

## Test ##
if not args.train:
    agent.load_weights(WEIGHTS_FILES)
    history = agent.test(env,
                         nb_episodes=params.N_EPISODE_TEST,
                         visualize=args.visualize,
                         nb_max_episode_steps=env._max_episode_steps)
    save_result('DQN', params.ENV, history, args.model, params.PARAMS)
示例#5
0
class Player:
    """Mandatory class with the player methods"""

    def __init__(self, name='DQN', load_model=None):
        """Initiaization of an agent"""
        self.equity_alive = 0
        self.actions = []
        self.last_action_in_stage = ''
        self.temp_stack = []
        self.name = name
        self.autoplay = True

        self.dqn = None
        self.env = None

        if load_model:
            self.load(load_model)

    def initiate_agent(self, env):
        """initiate a deep Q agent"""
        from keras import Sequential
        from keras.optimizers import Adam
        from keras.layers import Dense, Dropout
        from rl.memory import SequentialMemory
        from rl.agents import DQNAgent

        self.env = env

        nb_actions = self.env.action_space.n

        model = Sequential()
        model.add(Dense(512, activation='relu', input_shape=env.observation_space))
        model.add(Dropout(0.2))
        model.add(Dense(512, activation='relu'))
        model.add(Dropout(0.2))
        model.add(Dense(512, activation='relu'))
        model.add(Dropout(0.2))
        model.add(Dense(nb_actions, activation='linear'))

        # Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
        # even the metrics!
        memory = SequentialMemory(limit=memory_limit, window_length=window_length)
        policy = TrumpPolicy()
        from rl.core import Processor

        class CustomProcessor(Processor):
            """he agent and the environment"""

            def process_state_batch(self, batch):
                """
                Given a state batch, I want to remove the second dimension, because it's
                useless and prevents me from feeding the tensor into my CNN
                """
                return np.squeeze(batch, axis=1)

            def process_info(self, info):
                processed_info = info['player_data']
                if 'stack' in processed_info:
                    processed_info = {'x': 1}
                return processed_info

        nb_actions = env.action_space.n

        self.dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=nb_steps_warmup,
                            target_model_update=1e-2, policy=policy,
                            processor=CustomProcessor(),
                            batch_size=batch_size, train_interval=train_interval, enable_double_dqn=enable_double_dqn)
        self.dqn.compile(Adam(lr=1e-3), metrics=['mae'])

    def start_step_policy(self, observation):
        """Custom policy for random decisions for warm up."""
        log.info("Random action")
        _ = observation
        action = self.env.action_space.sample()
        return action

    def train(self, env_name):
        """Train a model"""
        # initiate training loop
        timestr = time.strftime("%Y%m%d-%H%M%S") + "_" + str(env_name)
        tensorboard = TensorBoard(log_dir='./Graph/{}'.format(timestr), histogram_freq=0, write_graph=True,
                                  write_images=False)

        self.dqn.fit(self.env, nb_max_start_steps=nb_max_start_steps, nb_steps=nb_steps, visualize=False, verbose=2,
                     start_step_policy=self.start_step_policy, callbacks=[tensorboard])

        # After training is done, we save the final weights.
        self.dqn.save_weights('dqn_{}_weights.h5'.format(env_name), overwrite=True)

        # Finally, evaluate our algorithm for 5 episodes.
        self.dqn.test(self.env, nb_episodes=5, visualize=False)

    def load(self, env_name):
        """Load a model"""
        self.dqn.load_weights('dqn_{}_weights.h5'.format(env_name))

    def action(self, action_space, observation, info):  # pylint: disable=no-self-use
        """Mandatory method that calculates the move based on the observation array and the action space."""
        _ = observation  # not using the observation for random decision
        _ = info

        this_player_action_space = {Action.FOLD, Action.CHECK, Action.CALL, Action.RAISE_POT, Action.RAISE_HALF_POT,
                                    Action.RAISE_2POT}
        _ = this_player_action_space.intersection(set(action_space))

        action = None
        return action
x = Conv2D(64, 4, strides=2, activation='relu')(x)
x = Conv2D(64, 2, strides=3, activation='relu')(x)
x = Flatten()(x)
x = Dense(512, activation='relu')(x)
action = Dense(nb_actions, activation='linear')(x)
model = K.Model(inputs=inputs, outputs=action)
model.summary()

memory = SequentialMemory(limit=1000000, window_length=4)
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(),
                              attr='eps',
                              value_max=1.,
                              value_min=.1,
                              value_test=0.05,
                              nb_steps=1000000)
dqn = DQNAgent(model=model,
               nb_actions=nb_actions,
               memory=memory,
               nb_steps_warmup=50000,
               target_model_update=10000,
               policy=policy,
               processor=AtariProcessor(),
               gamma=.99,
               train_interval=4,
               delta_clip=1.)
dqn.compile(K.optimizers.Adam(lr=.00025), metrics=['mae'])

dqn.load_weights('policy.h5')

dqn.test(env, nb_episodes=10, visualize=True)
    """ build the keras model for deep learning """
    # inputs = layers.Input(shape=(84, 84, 4,))
    inputs = layers.Input(shape=(4, ) + state_size)
    layer1 = layers.Conv2D(32, 8, strides=4, activation="relu")(inputs)
    layer2 = layers.Conv2D(64, 4, strides=2, activation="relu")(layer1)
    layer3 = layers.Conv2D(64, 3, strides=1, activation="relu")(layer2)
    layer4 = layers.Flatten()(layer3)
    layer5 = layers.Dense(512, activation="relu")(layer4)
    action = layers.Dense(num_actions, activation="linear")(layer5)
    return k.Model(inputs=inputs, outputs=action)


model = build_model(state_size, num_actions)
model.summary()
"""
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1.,
                              value_min=.1, value_test=0.1, nb_steps=1000000)
"""
memory = SequentialMemory(limit=1000000, window_length=4)
agent = DQNAgent(model=model,
                 policy=GreedyQPolicy(),
                 nb_actions=num_actions,
                 memory=memory,
                 nb_steps_warmup=50000)
agent.compile(k.optimizers.Adam(learning_rate=.00025), metrics=['mae'])
"""
agent.fit(env, nb_steps=10000, log_interval=1000, visualize=False, verbose=2)
"""
agent.load_weights('policy.h5')
agent.test(env, nb_episodes=10, visualize=False)
示例#8
0
def run():
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices=['train', 'test'], default='train')
    parser.add_argument('--env-name', type=str, default='iemocap-rl-v3.1')
    parser.add_argument('--weights', type=str, default=None)
    parser.add_argument('--policy', type=str, default='EpsGreedyQPolicy')
    parser.add_argument('--data-version', choices=[DataVersions.IEMOCAP, DataVersions.SAVEE, DataVersions.IMPROV],
                        type=str2dataset, default=DataVersions.IEMOCAP)
    parser.add_argument('--disable-wandb', type=str2bool, default=False)
    parser.add_argument('--zeta-nb-steps', type=int, default=1000000)
    parser.add_argument('--nb-steps', type=int, default=500000)
    parser.add_argument('--max-train-steps', type=int, default=440000)
    parser.add_argument('--eps', type=float, default=0.1)
    parser.add_argument('--pre-train', type=str2bool, default=False)
    parser.add_argument('--pre-train-dataset',
                        choices=[DataVersions.IEMOCAP, DataVersions.IMPROV, DataVersions.SAVEE], type=str2dataset,
                        default=DataVersions.IEMOCAP)
    parser.add_argument('--warmup-steps', type=int, default=50000)
    parser.add_argument('--pretrain-epochs', type=int, default=64)
    parser.add_argument('--gpu', type=int, default=1)
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1)
    config = tf.ConfigProto(gpu_options=gpu_options)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    tf.compat.v1.keras.backend.set_session(sess)

    policy = parse_policy(args)
    data_version = args.data_version

    env: gym.Env = None

    if data_version == DataVersions.IEMOCAP:
        env = IemocapEnv(data_version)

    if data_version == DataVersions.SAVEE:
        env = SaveeEnv(data_version)

    if data_version == DataVersions.IMPROV:
        env = ImprovEnv(data_version)

    for k in args.__dict__.keys():
        print("\t{} :\t{}".format(k, args.__dict__[k]))
        env.__setattr__("_" + k, args.__dict__[k])

    experiment_name = "P-{}-S-{}-e-{}-pt-{}".format(args.policy, args.zeta_nb_steps, args.eps, args.pre_train)
    if args.pre_train:
        experiment_name = "P-{}-S-{}-e-{}-pt-{}-pt-w-{}".format(args.policy, args.zeta_nb_steps, args.eps,
                                                                args.pre_train,
                                                                args.pre_train_dataset.name)
    env.__setattr__("_experiment", experiment_name)

    nb_actions = env.action_space.n

    input_layer = Input(shape=(1, NUM_MFCC, NO_features))

    model = models.get_model_9_rl(input_layer, model_name_prefix='mfcc')

    memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH)

    dqn = DQNAgent(model=model, nb_actions=nb_actions, policy=policy, memory=memory,
                   nb_steps_warmup=args.warmup_steps, gamma=.99, target_model_update=10000,
                   train_interval=4, delta_clip=1., train_max_steps=args.max_train_steps)
    dqn.compile(Adam(lr=.00025), metrics=['mae'])

    if args.pre_train:
        from feature_type import FeatureType

        datastore: Datastore = None

        if args.pre_train_dataset == DataVersions.IMPROV:
            from datastore_iemocap import IemocapDatastore
            datastore = IemocapDatastore(FeatureType.MFCC)

        if args.pre_train_dataset == DataVersions.Vimprov:
            from datastore_improv import ImprovDatastore
            datastore = ImprovDatastore(22)

        if args.pre_train_dataset == DataVersions.Vsavee:
            from datastore_savee import SaveeDatastore
            datastore = SaveeDatastore(FeatureType.MFCC)

        assert datastore is not None

        x_train, y_train, y_gen_train = datastore.get_pre_train_data()

        dqn.pre_train(x=x_train.reshape((len(x_train), 1, NUM_MFCC, NO_features)), y=y_train,
                      EPOCHS=args.pretrain_epochs, batch_size=128)

    if args.mode == 'train':
        # Okay, now it's time to learn something! We capture the interrupt exception so that training
        # can be prematurely aborted. Notice that now you can use the built-in Keras callbacks!
        weights_filename = 'rl-files/models/dqn_{}_weights.h5f'.format(args.env_name)
        checkpoint_weights_filename = 'rl-files/models/dqn_' + args.env_name + '_weights_{step}.h5f'
        log_filename = 'rl-files/logs/dqn_{}_log.json'.format(args.env_name)
        callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=250000)]
        callbacks += [FileLogger(log_filename, interval=100)]

        if not args.disable_wandb:
            wandb_project_name = 'zeta-policy'
            callbacks += [WandbLogger(project=wandb_project_name, name=args.env_name)]

        dqn.fit(env, callbacks=callbacks, nb_steps=args.nb_steps, log_interval=10000)

        # After training is done, we save the final weights one more time.
        dqn.save_weights(weights_filename, overwrite=True)

        # Finally, evaluate our algorithm for 10 episodes.
        dqn.test(env, nb_episodes=10, visualize=False)

    elif args.mode == 'test':
        weights_filename = 'rl-files/models/dqn_{}_weights.h5f'.format(args.env_name)
        if args.weights:
            weights_filename = args.weights
        dqn.load_weights(weights_filename)
        dqn.test(env, nb_episodes=10, visualize=True)
示例#9
0
        with open('_experiments/history_' + filename + '.pickle',
                  'wb') as handle:
            pickle.dump(hist.history, handle, protocol=pickle.HIGHEST_PROTOCOL)

        # After training is done, we save the final weights.
        agent.save_weights('h5f_files/dqn_{}_weights.h5f'.format(filename),
                           overwrite=True)

        # Finally, evaluate our algorithm for 5 episodes.
        agent.test(env,
                   nb_episodes=5,
                   visualize=True,
                   nb_max_episode_steps=500)

    if mode == 'test':
        agent.load_weights('h5f_files/dqn_{}_weights.h5f'.format(filename))
        agent.test(env,
                   nb_episodes=10,
                   visualize=True,
                   nb_max_episode_steps=400)  # 40 seconds episodes

    if mode == 'real':

        # set the heading target
        env.target = 0.

        agent.load_weights('h5f_files/dqn_{}_weights.h5f'.format(filename))
        agent.test(env,
                   nb_episodes=10,
                   visualize=True,
                   nb_max_episode_steps=400)  # 40 seconds episodes
示例#10
0
def run():
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices=['train', 'test'], default='train')
    parser.add_argument('--env-name', type=str, default='iemocap-rl-v3.1')
    parser.add_argument('--weights', type=str, default=None)
    parser.add_argument('--policy', type=str, default='EpsGreedyQPolicy')
    parser.add_argument(
        '--data-version',
        nargs='+',
        choices=[
            DataVersions.IEMOCAP, DataVersions.SAVEE, DataVersions.IMPROV,
            DataVersions.ESD, DataVersions.EMODB, DataVersions.KITCHEN_EMODB,
            DataVersions.KITCHEN_ESD, DataVersions.KITCHEN_ESD_DB0,
            DataVersions.KITCHEN_ESD_DBn5, DataVersions.KITCHEN_ESD_DBn10,
            DataVersions.KITCHEN_ESD_DBp5, DataVersions.KITCHEN_ESD_DBp10
        ],
        type=str2dataset,
        default=DataVersions.IEMOCAP)
    parser.add_argument('--data-split', nargs='+', type=float, default=None)
    parser.add_argument('--zeta-nb-steps', type=int, default=100000)
    parser.add_argument('--nb-steps', type=int, default=500000)
    parser.add_argument('--eps', type=float, default=0.1)
    parser.add_argument('--pre-train', type=str2bool, default=False)
    parser.add_argument('--pre-train-dataset',
                        choices=[
                            DataVersions.IEMOCAP, DataVersions.IMPROV,
                            DataVersions.SAVEE, DataVersions.ESD,
                            DataVersions.EMODB
                        ],
                        type=str2dataset,
                        default=DataVersions.IEMOCAP)
    parser.add_argument('--pre-train-data-split', type=float, default=None)
    parser.add_argument('--warmup-steps', type=int, default=50000)
    parser.add_argument('--pretrain-epochs', type=int, default=64)
    parser.add_argument(
        '--testing-dataset',
        type=str2dataset,
        default=None,
        choices=[
            DataVersions.IEMOCAP, DataVersions.IMPROV, DataVersions.SAVEE,
            DataVersions.ESD, DataVersions.COMBINED, DataVersions.EMODB,
            DataVersions.KITCHEN_EMODB, DataVersions.KITCHEN_ESD,
            DataVersions.KITCHEN_ESD_DB0, DataVersions.KITCHEN_ESD_DBn5,
            DataVersions.KITCHEN_ESD_DBn10, DataVersions.KITCHEN_ESD_DBp5,
            DataVersions.KITCHEN_ESD_DBp10
        ])
    parser.add_argument('--gpu', type=int, default=1)
    parser.add_argument('--wandb-disable',
                        type=str2bool,
                        default=False,
                        choices=[True, False])
    parser.add_argument('--wandb-mode',
                        type=str,
                        default='online',
                        choices=['online', 'offline'])
    parser.add_argument('--double-dqn',
                        type=str2bool,
                        default=False,
                        choices=[True, False])
    parser.add_argument('--dueling-network',
                        type=str2bool,
                        default=False,
                        choices=[True, False])
    parser.add_argument('--dueling-type',
                        type=str,
                        default='avg',
                        choices=['avg', 'max', 'naive'])
    parser.add_argument('--schedule-csv', type=str, default=None)
    parser.add_argument('--schedule-idx', type=int, default=None)

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    print("Tensorflow version:", tf.__version__)

    if os.path.exists(f'{RESULTS_ROOT}/{time_str}'):
        raise RuntimeError(
            f'Results directory {RESULTS_ROOT}/{time_str} is already exists')

    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus),
                  "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)
    tf.compat.v1.experimental.output_all_intermediates(True)
    policy = parse_policy(args)

    data_version_map = {}
    custom_data_split = []
    if args.data_split is not None:
        if len(args.data_split) == 1 and len(args.data_version) > 1:
            for i in range(len(args.data_version)):
                custom_data_split.append(args.data_split[0])
        elif 1 < len(args.data_split) != len(args.data_version) > 1:
            raise RuntimeError(
                "--data-split either should have one value or similar to --data-version"
            )
        else:
            custom_data_split = args.data_split
    else:
        for i in range(len(args.data_version)):
            custom_data_split.append(None)

    if len(args.data_version) == 1:
        target_datastore = get_datastore(
            data_version=args.data_version[0],
            custom_split=None
            if args.data_split is None else args.data_split[0])
        data_version_map[args.data_version[0]] = target_datastore
        env = get_environment(data_version=args.data_version[0],
                              datastore=target_datastore,
                              custom_split=None if args.data_split is None else
                              args.data_split[0])
    else:
        ds = []
        for i in range(len(args.data_version)):
            d = get_datastore(data_version=args.data_version[i],
                              custom_split=custom_data_split[i])
            data_version_map[args.data_version[i]] = d
            ds.append(d)
        target_datastore = combine_datastores(ds)
        env = get_environment(data_version=DataVersions.COMBINED,
                              datastore=target_datastore,
                              custom_split=None)

    for k in args.__dict__.keys():
        print("\t{} :\t{}".format(k, args.__dict__[k]))
        env.__setattr__("_" + k, args.__dict__[k])

    experiment_name = "P-{}-S-{}-e-{}-pt-{}".format(args.policy,
                                                    args.zeta_nb_steps,
                                                    args.eps, args.pre_train)
    if args.pre_train:
        experiment_name = "P-{}-S-{}-e-{}-pt-{}-pt-w-{}".format(
            args.policy, args.zeta_nb_steps, args.eps, args.pre_train,
            args.pre_train_dataset.name)
    env.__setattr__("_experiment", experiment_name)

    nb_actions = env.action_space.n

    input_layer = Input(shape=(1, NUM_MFCC, NO_features))

    model = models.get_model_9_rl(input_layer, model_name_prefix='mfcc')

    memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH)

    dqn = DQNAgent(model=model,
                   nb_actions=nb_actions,
                   policy=policy,
                   memory=memory,
                   nb_steps_warmup=args.warmup_steps,
                   gamma=.99,
                   target_model_update=10000,
                   train_interval=4,
                   delta_clip=1.,
                   enable_double_dqn=args.double_dqn,
                   enable_dueling_network=args.dueling_network,
                   dueling_type=args.dueling_type)
    # dqn.compile(Adam(learning_rate=.00025), metrics=['mae', 'accuracy'])
    dqn.compile('adam', metrics=['mae', 'accuracy'])

    pre_train_datastore: Datastore = None
    if args.pre_train:

        if args.pre_train_dataset == args.data_version:
            raise RuntimeError(
                "Pre-Train and Target datasets cannot be the same")
        else:
            pre_train_datastore = get_datastore(
                data_version=args.pre_train_dataset,
                custom_split=args.pre_train_data_split)

        assert pre_train_datastore is not None

        (x_train, y_train, y_gen_train), _ = pre_train_datastore.get_data()

        pre_train_log_dir = f'{RESULTS_ROOT}/{time_str}/logs/pre_train'
        if not os.path.exists(pre_train_log_dir):
            os.makedirs(pre_train_log_dir)

        dqn.pre_train(x=x_train.reshape(
            (len(x_train), 1, NUM_MFCC, NO_features)),
                      y=y_train,
                      epochs=args.pretrain_epochs,
                      batch_size=128,
                      log_base_dir=pre_train_log_dir)

    if args.mode == 'train':

        models_dir = f'{RESULTS_ROOT}/{time_str}/models'
        log_dir = f'{RESULTS_ROOT}/{time_str}/logs'

        if not os.path.exists(models_dir):
            os.makedirs(models_dir)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        print(f"Models: {models_dir}")
        # Okay, now it's time to learn something! We capture the interrupt exception so that training
        # can be prematurely aborted. Notice that now you can use the built-in Keras callbacks!
        weights_filename = f'{models_dir}/dqn_{args.env_name}_weights.h5f'
        checkpoint_weights_filename = models_dir + '/dqn_' + args.env_name + '_weights_{step}.h5f'
        log_filename = log_dir + '/dqn_{}_log.json'.format(args.env_name)
        callbacks = [
            ModelIntervalCheckpoint(checkpoint_weights_filename,
                                    interval=250000)
        ]
        callbacks += [FileLogger(log_filename, interval=10)]

        if not args.wandb_disable:
            wandb_project_name = 'zeta-policy'
            wandb_dir = f'{RESULTS_ROOT}/{time_str}/wandb'
            if not os.path.exists(wandb_dir):
                os.makedirs(wandb_dir)
            callbacks += [
                WandbLogger(project=wandb_project_name,
                            name=args.env_name,
                            mode=args.wandb_mode,
                            dir=wandb_dir)
            ]

        dqn.fit(env,
                callbacks=callbacks,
                nb_steps=args.nb_steps,
                log_interval=10000)
        model = dqn.model

        # After training is done, we save the final weights one more time.
        dqn.save_weights(weights_filename, overwrite=True)

        # Testing with Labelled Data
        testing_dataset = args.testing_dataset
        if testing_dataset is not None:
            if testing_dataset == DataVersions.COMBINED:
                if pre_train_datastore is not None:
                    testing_datastore = combine_datastores(
                        [target_datastore, pre_train_datastore])
                else:
                    testing_datastore = target_datastore
            else:
                testing_datastore = data_version_map[testing_dataset]
        else:
            # testing dataset is not defined
            if pre_train_datastore is not None:
                testing_datastore = combine_datastores(
                    [target_datastore, pre_train_datastore])
            else:
                testing_datastore = target_datastore

        x_test, y_test, _ = testing_datastore.get_testing_data()
        test_loss, test_mae, test_acc, test_mean_q = model.evaluate(
            x_test.reshape((len(x_test), 1, NUM_MFCC, NO_features)),
            y_test,
            verbose=1)

        print(f"Test\n\t Accuracy: {test_acc}")

        store_results(f"{log_dir}/results.txt",
                      args=args,
                      experiment=experiment_name,
                      time_str=time_str,
                      test_loss=test_loss,
                      test_acc=test_acc)

        # # Finally, evaluate our algorithm for 10 episodes.
        # dqn.test(env, nb_episodes=10, visualize=False)

    elif args.mode == 'test':
        weights_filename = f'rl-files/models/dqn_{args.env_name}_weights.h5f'
        if args.weights:
            weights_filename = args.weights
        dqn.load_weights(weights_filename)
        dqn.test(env, nb_episodes=10, visualize=True)

    if args.schedule_csv is not None:
        from scheduler_callback import callback
        callback(args.schedule_csv, args.schedule_idx)