def benchmark_adversarial_policy(args=get_args()):
    env = make_atari_env_watch(args)
    if args.save_video:
        log_path = os.path.join(args.logdir, args.task, args.policy, "critical_point_attack_eps-" + str(args.eps) +\
                                "_n-" + str(args.n) + "_m-" + str(args.m) + "_" + args.target_policy)
        env = gym.wrappers.Monitor(env, log_path, force=True)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.env.action_space.shape or env.env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape: ", args.state_shape)
    print("Actions shape: ", args.action_shape)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # make policy
    policy = make_policy(args, args.policy, args.resume_path)
    # make target policy
    if args.target_policy is not None:
        victim_policy = make_policy(args, args.target_policy,
                                    args.target_policy_path)
        adv_net = make_victim_network(args, victim_policy)
    else:
        adv_net = make_victim_network(args, policy)
    # define observations adversarial attack
    obs_adv_atk, atk_type = make_img_adv_attack(args, adv_net, targeted=True)
    print("Attack type:", atk_type)

    # define adversarial collector
    acts_mask = None
    dam = None
    if "Pong" in args.task:
        acts_mask = [3, 4]
        dam = dam_pong
        delta = 100
    if "Breakout" in args.task:
        acts_mask = [1, 2, 3]
        dam = dam_breakout
        delta = 100
    collector = critical_point_attack_collector(
        policy,
        env,
        obs_adv_atk,
        perfect_attack=args.perfect_attack,
        acts_mask=acts_mask,
        device=args.device,
        full_search=args.full_search,
        repeat_adv_act=args.repeat_act,
        dam=dam,
        delta=delta)
    collector.n = int(args.n * args.repeat_act)
    collector.m = int(args.m * args.repeat_act)
    start_time = time.time()
    test_adversarial_policy = collector.collect(n_episode=args.test_num)
    print("Attack finished in %s seconds" % (time.time() - start_time))
    atk_freq_ = test_adversarial_policy['atk_rate(%)']
    reward = test_adversarial_policy['rew']
    n_attacks = test_adversarial_policy['n_atks']
    print("attack frequency =", atk_freq_, "| n_attacks =", n_attacks,
          "| n_succ_atks (%)", test_adversarial_policy['succ_atks(%)'],
          "| reward: ", reward)
def benchmark_adversarial_policy(args=get_args()):
    env = make_atari_env_watch(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.env.action_space.shape or env.env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape: ", args.state_shape)
    print("Actions shape: ", args.action_shape)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # make policy
    policy = make_policy(args, args.policy, args.resume_path)
    # make target policy
    transferability_type = ""
    # THIS PART MAY BE REMOVED
    if "def" in args.logdir and args.target_policy is None:
        warnings.warn(
            "You are generating adversarial observation on the defended model, you may want to craft them on"
            "the undefended version instead")
    if args.target_policy is not None:
        victim_policy = make_policy(args, args.target_policy,
                                    args.target_policy_path)
        transferability_type = "_transf_" + str(args.target_policy)
        adv_net = make_victim_network(args, victim_policy)
    else:
        adv_net = make_victim_network(args, policy)
    # define observations adversarial attack
    obs_adv_atk, atk_type = make_img_adv_attack(args, adv_net, targeted=False)
    print("Attack type:", atk_type)

    # define adversarial collector
    collector = uniform_attack_collector(policy,
                                         env,
                                         obs_adv_atk,
                                         perfect_attack=args.perfect_attack,
                                         device=args.device)
    atk_freq = np.linspace(args.min, args.max, args.steps, endpoint=True)
    n_attacks = []
    rewards = []
    for f in atk_freq:
        collector.atk_frequency = f
        test_adversarial_policy = collector.collect(n_episode=args.test_num)
        atk_freq_ = test_adversarial_policy['atk_rate(%)']
        rewards.append(test_adversarial_policy['rew'])
        n_attacks.append(test_adversarial_policy['n_atks'])
        print("attack frequency =", atk_freq_, "| n_attacks =", n_attacks[-1],
              "| n_succ_atks (%)", test_adversarial_policy['succ_atks(%)'],
              "| reward: ", rewards[-1])
        # pprint.pprint(test_adversarial_policy)
    log_path = os.path.join(
        args.logdir, args.task, args.policy,
        "uniform_attack_" + atk_type + transferability_type + ".npy")

    # save results
    with open(log_path, 'wb') as f:
        np.save(f, atk_freq)
        np.save(f, n_attacks)
        np.save(f, rewards)
    print("Results saved to", log_path)
def benchmark_adversarial_policy(args=get_args()):
    env = make_atari_env_watch(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.env.action_space.shape or env.env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape: ", args.state_shape)
    print("Actions shape: ", args.action_shape)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # make policy
    policy = make_policy(args, args.policy, args.resume_path)
    # make target policy
    transferability_type = ""
    if args.target_policy is not None:
        victim_policy = make_policy(args, args.target_policy,
                                    args.target_policy_path)
        transferability_type = "_transf_" + str(args.target_policy)
        adv_net = make_victim_network(args, victim_policy)
    else:
        adv_net = make_victim_network(args, policy)
    # define observations adversarial attack
    obs_adv_atk, atk_type = make_img_adv_attack(args, adv_net, targeted=True)
    print("Attack type:", atk_type)

    # define adversarial collector
    collector = strategically_timed_attack_collector(
        policy,
        env,
        obs_adv_atk,
        perfect_attack=args.perfect_attack,
        softmax=False if args.no_softmax else True,
        device=args.device)
    beta = np.linspace(args.min, args.max, args.steps, endpoint=True)
    atk_freq = []
    n_attacks = []
    rewards = []
    for b in beta:
        collector.beta = b
        test_adversarial_policy = collector.collect(n_episode=args.test_num)
        rewards.append(test_adversarial_policy['rew'])
        atk_freq.append(test_adversarial_policy['atk_rate(%)'])
        n_attacks.append(test_adversarial_policy['n_atks'])
        print("attack frequency =", atk_freq[-1], "| n_attacks =",
              n_attacks[-1], "| n_succ_atks (%)",
              test_adversarial_policy['succ_atks(%)'], "| reward: ",
              rewards[-1])
        # pprint.pprint(test_adversarial_policy)
    log_path = os.path.join(
        args.logdir, args.task, args.policy, "strategically_timed_attack_" +
        atk_type + transferability_type + ".npy")

    with open(log_path, 'wb') as f:
        np.save(f, atk_freq)
        np.save(f, n_attacks)
        np.save(f, rewards)
    print("Results saved to", log_path)
def benchmark_adversarial_policy(args=get_args()):
    env = make_atari_env_watch(args)
    if args.save_video:
        log_path = os.path.join(args.logdir, args.task, args.policy, "adversarial_policy_attack_eps-" + str(args.eps) +\
                                "_beta-" + str(args.beta) + "_" + args.target_policy)
        env = gym.wrappers.Monitor(env, log_path, force=True)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.env.action_space.shape or env.env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape: ", args.state_shape)
    print("Actions shape: ", args.action_shape)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # make policy
    policy = make_policy(args, args.policy, args.resume_path)
    # make target policy
    if args.target_policy is not None:
        victim_policy = make_policy(args, args.target_policy,
                                    args.target_policy_path)
        adv_net = make_victim_network(args, victim_policy)
    else:
        adv_net = make_victim_network(args, policy)
    # define observations adversarial attack
    obs_adv_atk, atk_type = make_img_adv_attack(args, adv_net, targeted=True)
    print("Attack type:", atk_type)

    # define adversarial policy
    adv_policy = None
    if args.adv_policy is not None:
        adv_policy = make_policy(args, args.adv_policy, args.adv_policy_path)
    # define adversarial collector
    collector = adversarial_policy_attack_collector(
        policy,
        env,
        obs_adv_atk,
        perfect_attack=args.perfect_attack,
        softmax=False if args.no_softmax else True,
        device=args.device,
        adv_policy=adv_policy)
    collector.beta = args.beta
    start_time = time.time()
    test_adversarial_policy = collector.collect(n_episode=args.test_num)
    print("Attack finished in %s seconds" % (time.time() - start_time))
    atk_freq_ = test_adversarial_policy['atk_rate(%)']
    reward = test_adversarial_policy['rew']
    n_attacks = test_adversarial_policy['n_atks']
    print("attack frequency =", atk_freq_, "| n_attacks =", n_attacks,
          "| n_succ_atks (%)", test_adversarial_policy['succ_atks(%)'],
          "| reward: ", reward)
 def watch():
     print("Testing agent ...")
     actor_critic.eval()
     args.task, args.frames_stack = args.env_name, 4
     env = make_atari_env_watch(args)
     obs = env.reset()
     n_ep, tot_rew = 0, 0
     while True:
         inputs = Batch(obs=np.expand_dims(obs, axis=0))
         with torch.no_grad():
             result = actor_critic(inputs)
         action = result.act
         # Observe reward and next obs
         obs, reward, done, _ = env.step(action)
         tot_rew += reward
         if done:
             n_ep += 1
             obs = env.reset()
             if n_ep == args.test_num:
                 break
     print("Evaluation using {} episodes: mean reward {:.5f}\n".format(
         n_ep, tot_rew / n_ep))
示例#6
0
def benchmark_adversarial_policy(args=get_args()):
    env = make_atari_env_watch(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.env.action_space.shape or env.env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape: ", args.state_shape)
    print("Actions shape: ", args.action_shape)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # make policy
    policy = make_policy(args, args.policy, args.resume_path)
    # make target policy
    transferability_type = ""
    if args.target_policy is not None:
        victim_policy = make_policy(args, args.target_policy,
                                    args.target_policy_path)
        transferability_type = "_transf_" + str(args.target_policy)
        adv_net = make_victim_network(args, victim_policy)
    else:
        adv_net = make_victim_network(args, policy)
    # define observations adversarial attack
    obs_adv_atk, atk_type = make_img_adv_attack(args, adv_net, targeted=True)
    print("Attack type:", atk_type)

    # define adversarial collector
    acts_mask = None
    if "Pong" in args.task:
        acts_mask = [3, 4]
        delta = 0
    if "Breakout" in args.task:
        acts_mask = [1, 2, 3]
        delta = 0
    collector = critical_strategy_attack_collector(
        policy,
        env,
        obs_adv_atk,
        perfect_attack=args.perfect_attack,
        acts_mask=acts_mask,
        device=args.device,
        full_search=args.full_search,
        repeat_adv_act=args.repeat_act,
        delta=delta)
    n_range = list(np.arange(args.min, args.max)) + [args.max]
    m_range = [0., 0.25, 0.5, 0.75, 1.]
    atk_freq = []
    n_attacks = []
    rewards = []
    for n in n_range:
        for m in m_range:
            collector.n = int(n * args.repeat_act)
            collector.m = int(n * args.repeat_act + n * args.repeat_act * m)
            test_adversarial_policy = collector.collect(
                n_episode=args.test_num)
            rewards.append(test_adversarial_policy['rew'])
            atk_freq.append(test_adversarial_policy['atk_rate(%)'])
            n_attacks.append(test_adversarial_policy['n_atks'])
            print("n =", str(int(n * args.repeat_act)), "m =",
                  str(int(n * args.repeat_act + n * args.repeat_act * m)),
                  "| attack frequency =", atk_freq[-1], "| n_attacks =",
                  n_attacks[-1], "| n_succ_atks (%)",
                  test_adversarial_policy['succ_atks(%)'], "| reward: ",
                  rewards[-1])
            # pprint.pprint(test_adversarial_policy)
    log_path = os.path.join(
        args.logdir, args.task, args.policy,
        "critical_strategy_attack_" + atk_type + transferability_type + ".npy")

    with open(log_path, 'wb') as f:
        np.save(f, atk_freq)
        np.save(f, n_attacks)
        np.save(f, rewards)
    print("Results saved to", log_path)
示例#7
0
def test_dqn(args=get_args()):
    env = make_atari_env(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.env.action_space.shape or env.env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape: ", args.state_shape)
    print("Actions shape: ", args.action_shape)
    # make environments
    train_envs = SubprocVectorEnv(
        [lambda: make_atari_env(args) for _ in range(args.training_num)])
    test_envs = SubprocVectorEnv(
        [lambda: make_atari_env_watch(args) for _ in range(1)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    test_envs.seed(args.seed)
    # define model
    net = DQN(*args.state_shape, args.action_shape,
              args.device).to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    # define policy
    policy = DQNPolicy(net,
                       optim,
                       args.gamma,
                       args.n_step,
                       target_update_freq=args.target_update_freq)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(torch.load(args.resume_path))
        print("Loaded agent from: ", args.resume_path)

    if args.target_model_path:
        victim_policy = copy.deepcopy(policy)
        victim_policy.load_state_dict(torch.load(args.target_model_path))
        print("Loaded victim agent from: ", args.target_model_path)
    else:
        victim_policy = policy

    args.target_policy, args.policy = "dqn", "dqn"
    args.perfect_attack = False
    adv_net = make_victim_network(args, victim_policy)
    adv_atk, _ = make_img_adv_attack(args, adv_net, targeted=False)

    buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True)
    # collector
    train_collector = adversarial_training_collector(
        policy,
        train_envs,
        adv_atk,
        buffer,
        atk_frequency=args.atk_freq,
        device=args.device)
    test_collector = adversarial_training_collector(
        policy,
        test_envs,
        adv_atk,
        buffer,
        atk_frequency=args.atk_freq,
        test=True,
        device=args.device)
    # log
    log_path = os.path.join(args.logdir, args.task, 'dqn')
    writer = SummaryWriter(log_path)

    def save_fn(policy, policy_name='policy.pth'):
        torch.save(policy.state_dict(), os.path.join(log_path, policy_name))

    def stop_fn(x):
        return 0

    def train_fn(epoch, env_step):
        # nature DQN setting, linear decay in the first 1M steps
        if env_step <= 1e6:
            eps = args.eps_train - env_step / 1e6 * \
                  (args.eps_train - args.eps_train_final)
        else:
            eps = args.eps_train_final
        policy.set_eps(eps)
        writer.add_scalar('train/eps', eps, global_step=env_step)
        print("set eps =", policy.eps)

    def test_fn(epoch, env_step):
        policy.set_eps(args.eps_test)

    # watch agent's performance
    def watch():
        assert args.target_model_path is not None
        print("Testing agent ...")
        policy.eval()
        policy.set_eps(args.eps_test)
        test_envs.seed(args.seed)
        test_collector.reset()
        result = test_collector.collect(n_episode=[args.test_num],
                                        render=args.render)
        pprint.pprint(result)

    if args.watch:
        watch()
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(n_step=args.batch_size * 4)
    # trainer
    result = offpolicy_trainer(policy,
                               train_collector,
                               test_collector,
                               args.epoch,
                               args.step_per_epoch,
                               args.collect_per_step,
                               args.test_num,
                               args.batch_size,
                               train_fn=train_fn,
                               test_fn=test_fn,
                               stop_fn=stop_fn,
                               save_fn=save_fn,
                               writer=writer,
                               test_in_train=False)

    pprint.pprint(result)
    watch()
示例#8
0
                        type=str,
                        default='log_perturbation_benchmark')
    parser.add_argument('--attack_freq', type=float, default=0.5)
    parser.add_argument('--sample_points', type=int, default=10)
    parser.add_argument('--targeted', default=False, action='store_true')
    args = parser.parse_known_args()[0]
    return args


if __name__ == '__main__':
    args = get_args()
    args.resume_path = os.path.join("log", args.task, args.policy,
                                    "policy.pth")
    args.perfect_attack = False
    args.target_policy = args.policy
    env = make_atari_env_watch(args)
    # comment the attacks you don't need
    img_attacks = [  #"No Attack",
        "GradientSignAttack",  # ok
        #"LinfPGDAttack",  # ok
        "MomentumIterativeAttack",  # ok
        #"DeepfoolLinfAttack"
    ]
    # you can change attack parameters in the utils.py file

    attack_labels = {
        "No Attack": "No Attack",
        "GradientAttack": "FGM",
        "GradientSignAttack": "FGSM-Linf",
        "LinfPGDAttack": "PGD-Linf",
        "L2PGDAttack": "PGD-L2",