def train(env): value_net = Critic(1290, 128, 256, params['critic_weight_init']).to(device) policy_net = Actor(1290, 128, 256, params['actor_weight_init']).to(device) target_value_net = Critic(1290, 128, 256).to(device) target_policy_net = Actor(1290, 128, 256).to(device) #Switiching off dropout layers target_value_net.eval() target_policy_net.eval() softUpdate(value_net, target_value_net, soft_tau=1.0) softUpdate(policy_net, target_policy_net, soft_tau=1.0) value_optimizer = optimizer.Ranger(value_net.parameters(), lr=params['value_lr'], weight_decay=1e-2) policy_optimizer = optimizer.Ranger(policy_net.parameters(), lr=params['policy_lr'], weight_decay=1e-5) value_criterion = nn.MSELoss() loss = { 'test': { 'value': [], 'policy': [], 'step': [] }, 'train': { 'value': [], 'policy': [], 'step': [] } } plotter = Plotter( loss, [['value', 'policy']], ) step = 0 plot_every = 10 for epoch in range(100): print("Epoch: {}".format(epoch + 1)) for batch in (env.train_dataloader): loss, value_net, policy_net, target_value_net, target_policy_net, value_optimizer, policy_optimizer\ = ddpg(value_net,policy_net,target_value_net,target_policy_net,\ value_optimizer, policy_optimizer, batch, params, step=step) # print(loss) plotter.log_losses(loss) step += 1 if step % plot_every == 0: print('step', step) test_loss = run_tests(env,step,value_net,policy_net,target_value_net,target_policy_net,\ value_optimizer, policy_optimizer,plotter) plotter.log_losses(test_loss, test=True) plotter.plot_loss() if step > 1500: assert False
} plotter = Plotter(loss, [['generator'], ['value', 'perturbator']]) for epoch in range(n_epochs): print("Epoch: {}".format(epoch+1)) for batch in env.train_dataloader: loss = bcq_update(batch, params, writer, debug, step=step) plotter.log_losses(loss) step += 1 print("Loss:{}".format(loss)) if step % plot_every == 0: print('step', step) test_loss = run_tests(env,params,writer,debug) print(test_loss) plotter.log_losses(test_loss, test=True) plotter.plot_loss() if step > 1500: break gen_actions = debug['perturbed_actions'] true_actions = env.embeddings.numpy() ad = AnomalyDetector().to(device) ad.load_state_dict(torch.load('trained/anomaly.pt')) ad.eval() plotter.plot_kde_reconstruction_error(ad, gen_actions, true_actions, device)