config['device'] = 'cpu'

# Get the environment and extract the number of actions.
# env = CartPoleEnvCustom()
trainEnv = gym.make("CartPole-v0")
testEnv = gym.make("CartPole-v0")
N_S = trainEnv.observation_space.shape[0]
N_A = trainEnv.action_space.n

netParameter = dict()
netParameter['n_feature'] = N_S
netParameter['n_hidden'] = [40, 40]
netParameter['n_output'] = N_A

actorNet = MultiLayerNetLogSoftmax(netParameter['n_feature'],
                                netParameter['n_hidden'],
                                N_A)

criticNet = MultiLayerNetRegression(netParameter['n_feature'],
                                    netParameter['n_hidden'],
                                    1)

optimizer1 = optim.Adam(actorNet.parameters(), lr=config['learningRate'])
optimizer2 = optim.Adam(criticNet.parameters(), lr=config['learningRate'])

agent = ActorCriticTwoNet(actorNet, criticNet, [trainEnv, testEnv], [optimizer1, optimizer2], torch.nn.MSELoss(), N_A, config)


agent.train()

agent.test(100)
N_A = env.nbActions

netParameter = dict()
netParameter['n_feature'] = N_S
netParameter['n_hidden'] = [100]
netParameter['n_output'] = N_A

policyNet = MultiLayerNetRegression(netParameter['n_feature'],
                                    netParameter['n_hidden'],
                                    netParameter['n_output'])

print(policyNet.state_dict())

targetNet = deepcopy(policyNet)

optimizer = optim.Adam(policyNet.parameters(), lr=config['learningRate'])

agent = DQNAgent(policyNet,
                 targetNet,
                 env,
                 optimizer,
                 torch.nn.MSELoss(),
                 N_S,
                 N_A,
                 config=config)

policy = deepcopy(env.map)
for i in range(policy.shape[0]):
    for j in range(policy.shape[1]):
        if env.map[i, j] == 0:
            policy[i, j] = -1
config['trainBatchSize'] = 32
config['gamma'] = 0.9
config['learningRate'] = 0.001
config['netGradClip'] = 1
config['logFlag'] = False
config['logFrequency'] = 100
config['priorityMemoryOption'] = False
config['netUpdateOption'] = 'doubleQ'
config['netUpdateFrequency'] = 1
config['priorityMemory_absErrUpper'] = 5
config['numWorkers'] = 4

env = StablizerOneD()
N_S = env.stateDim
N_A = env.nbActions

netParameter = dict()
netParameter['n_feature'] = N_S
netParameter['n_hidden'] = [4]
netParameter['n_output'] = N_A

policyNet = MultiLayerNetRegression(netParameter['n_feature'],
                                    netParameter['n_hidden'],
                                    netParameter['n_output'])

optimizer = SharedAdam(policyNet.parameters(), lr=1.0)

agent = DQNA3CMaster(config, policyNet, env, optimizer,
                     torch.nn.MSELoss(reduction='none'), N_A)

agent.test_multiProcess()