Exemple #1
0
def train(env):
    value_net = Critic(1290, 128, 256, params['critic_weight_init']).to(device)
    policy_net = Actor(1290, 128, 256, params['actor_weight_init']).to(device)
    target_value_net = Critic(1290, 128, 256).to(device)
    target_policy_net = Actor(1290, 128, 256).to(device)

    #Switiching off dropout layers
    target_value_net.eval()
    target_policy_net.eval()

    softUpdate(value_net, target_value_net, soft_tau=1.0)
    softUpdate(policy_net, target_policy_net, soft_tau=1.0)

    value_optimizer = optimizer.Ranger(value_net.parameters(),
                                       lr=params['value_lr'],
                                       weight_decay=1e-2)
    policy_optimizer = optimizer.Ranger(policy_net.parameters(),
                                        lr=params['policy_lr'],
                                        weight_decay=1e-5)
    value_criterion = nn.MSELoss()
    loss = {
        'test': {
            'value': [],
            'policy': [],
            'step': []
        },
        'train': {
            'value': [],
            'policy': [],
            'step': []
        }
    }

    plotter = Plotter(
        loss,
        [['value', 'policy']],
    )

    step = 0
    plot_every = 10
    for epoch in range(100):
        print("Epoch: {}".format(epoch + 1))
        for batch in (env.train_dataloader):
            loss, value_net, policy_net, target_value_net, target_policy_net, value_optimizer, policy_optimizer\
             = ddpg(value_net,policy_net,target_value_net,target_policy_net,\
              value_optimizer, policy_optimizer, batch, params, step=step)
            # print(loss)
            plotter.log_losses(loss)
            step += 1
            if step % plot_every == 0:
                print('step', step)
                test_loss = run_tests(env,step,value_net,policy_net,target_value_net,target_policy_net,\
                 value_optimizer, policy_optimizer,plotter)
                plotter.log_losses(test_loss, test=True)
                plotter.plot_loss()
            if step > 1500:
                assert False
Exemple #2
0
	perturbator_optimizer = optimizer.Ranger(perturbator_net.parameters(), lr=params['value_lr'], weight_decay=1e-3,k=10)
	generator_optimizer = optimizer.Ranger(generator_net.parameters(), lr=params['generator_lr'], k=10)
	
	loss = {
		'train': {'value': [], 'perturbator': [], 'generator': [], 'step': []},
		'test': {'value': [], 'perturbator': [], 'generator': [], 'step': []},
		}
	
	plotter = Plotter(loss, [['generator'], ['value', 'perturbator']])


	for epoch in range(n_epochs):
		print("Epoch: {}".format(epoch+1))
		for batch in env.train_dataloader:
			loss = bcq_update(batch, params, writer, debug, step=step)
			plotter.log_losses(loss)
			step += 1
			print("Loss:{}".format(loss))
			if step % plot_every == 0:
				print('step', step)
				test_loss = run_tests(env,params,writer,debug)
				print(test_loss)
				plotter.log_losses(test_loss, test=True)
				plotter.plot_loss()
			if step > 1500:
				break
	gen_actions = debug['perturbed_actions']
	true_actions = env.embeddings.numpy()


	ad = AnomalyDetector().to(device)
            'policy': [],
            'step': []
        }
    }
    plotter = Plotter(
        loss,
        [['value', 'policy']],
    )
    step = 0
    for epoch in range(n_epochs):
        for batch in env.train_dataloader:
            loss = reinforce_update(batch,
                                    params,
                                    nets,
                                    optimizer,
                                    writer=writer,
                                    device=device,
                                    debug=debug,
                                    learn=True,
                                    step=step)
            if loss:
                plotter.log_losses(loss)
            step += 1
            if step % plot_every == 0:
                print('step', step)
                # plotter.log_losses(test_loss, test=False)
                plotter.plot_loss()
            if step > 1000:
                pass
                assert False